oauth refresh + revoke

This commit is contained in:
Daniel O'Connell 2025-06-06 17:07:25 +02:00
parent 4556ef2c48
commit 986d5b9957
6 changed files with 294 additions and 56 deletions

View File

@ -1,8 +1,8 @@
"""oauth codes """oauth codes
Revision ID: 66771d293b27 Revision ID: 1d6bc8015ea9
Revises: 58439dd3088b Revises: 58439dd3088b
Create Date: 2025-06-06 12:36:11.737507 Create Date: 2025-06-06 16:53:33.044558
""" """
@ -13,7 +13,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "66771d293b27" revision: str = "1d6bc8015ea9"
down_revision: Union[str, None] = "58439dd3088b" down_revision: Union[str, None] = "58439dd3088b"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@ -43,18 +43,19 @@ def upgrade() -> None:
op.create_table( op.create_table(
"oauth_states", "oauth_states",
sa.Column("state", sa.String(), nullable=False), sa.Column("state", sa.String(), nullable=False),
sa.Column("client_id", sa.String(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("code", sa.String(), nullable=True), sa.Column("code", sa.String(), nullable=True),
sa.Column("redirect_uri", sa.String(), nullable=False), sa.Column("redirect_uri", sa.String(), nullable=False),
sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False), sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False),
sa.Column("code_challenge", sa.String(), nullable=True), sa.Column("code_challenge", sa.String(), nullable=True),
sa.Column("stale", sa.Boolean(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("client_id", sa.String(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False), sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
sa.Column( sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True "created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
), ),
sa.Column("expires_at", sa.DateTime(), nullable=False), sa.Column("expires_at", sa.DateTime(), nullable=False),
sa.Column("stale", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
["client_id"], ["client_id"],
["oauth_client.client_id"], ["oauth_client.client_id"],
@ -63,24 +64,52 @@ def upgrade() -> None:
["user_id"], ["user_id"],
["users.id"], ["users.id"],
), ),
sa.PrimaryKeyConstraint("state"), sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"oauth_refresh_tokens",
sa.Column("token", sa.String(), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False),
sa.Column("access_token_session_id", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("client_id", sa.String(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column("expires_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["access_token_session_id"],
["user_sessions.id"],
),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_client.client_id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
) )
op.add_column( op.add_column(
"user_sessions", sa.Column("oauth_state_id", sa.String(), nullable=True) "user_sessions", sa.Column("oauth_state_id", sa.Integer(), nullable=True)
) )
op.create_foreign_key( op.create_foreign_key(
"fk_user_sessions_oauth_state_id", "user_sessions_oauth_state_id_fkey",
"user_sessions", "user_sessions",
"oauth_states", "oauth_states",
["oauth_state_id"], ["oauth_state_id"],
["state"], ["id"],
) )
def downgrade() -> None: def downgrade() -> None:
op.drop_constraint( op.drop_constraint(
"fk_user_sessions_oauth_state_id", "user_sessions", type_="foreignkey" "user_sessions_oauth_state_id_fkey", "user_sessions", type_="foreignkey"
) )
op.drop_column("user_sessions", "oauth_state_id") op.drop_column("user_sessions", "oauth_state_id")
op.drop_table("oauth_refresh_tokens")
op.drop_table("oauth_states") op.drop_table("oauth_states")
op.drop_table("oauth_client") op.drop_table("oauth_client")

View File

@ -147,6 +147,7 @@ services:
<<: *env <<: *env
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
QDRANT_URL: http://qdrant:6333 QDRANT_URL: http://qdrant:6333
SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
secrets: [postgres_password] secrets: [postgres_password]
volumes: volumes:
- ./memory_files:/app/memory_files:rw - ./memory_files:/app/memory_files:rw

View File

@ -97,7 +97,9 @@ async def login_page(request: Request):
state = form_data.get("state") state = form_data.get("state")
with make_session() as session: with make_session() as session:
oauth_state = session.query(OAuthState).get(state) oauth_state = (
session.query(OAuthState).filter(OAuthState.state == state).first()
)
if not oauth_state: if not oauth_state:
logger.error(f"State {state} not found in database") logger.error(f"State {state} not found in database")
raise ValueError("Invalid state parameter") raise ValueError("Invalid state parameter")

View File

@ -10,8 +10,10 @@ from memory.common.db.models.users import (
UserSession, UserSession,
OAuthClientInformation, OAuthClientInformation,
OAuthState, OAuthState,
OAuthRefreshToken,
OAuthToken as TokenBase,
) )
from memory.common.db.connection import make_session from memory.common.db.connection import make_session, scoped_session
from memory.common import settings from memory.common import settings
from mcp.server.auth.provider import ( from mcp.server.auth.provider import (
AccessToken, AccessToken,
@ -26,6 +28,101 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Token configuration constants
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
REFRESH_TOKEN_LIFETIME = 30 * 24 * 3600 # 30 days
def create_expiration(lifetime_seconds: int) -> datetime:
"""Create expiration datetime from lifetime in seconds."""
return datetime.fromtimestamp(time.time() + lifetime_seconds)
def generate_refresh_token() -> str:
"""Generate a new refresh token."""
return f"rt_{secrets.token_hex(32)}"
def create_access_token_session(
user_id: int, oauth_state_id: str | None = None
) -> UserSession:
"""Create a new access token session."""
return UserSession(
user_id=user_id,
oauth_state_id=oauth_state_id,
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
)
def create_refresh_token_record(
client_id: str,
user_id: int,
scopes: list[str],
access_token_session_id: Optional[str] = None,
) -> OAuthRefreshToken:
"""Create a new refresh token record."""
return OAuthRefreshToken(
token=generate_refresh_token(),
client_id=client_id,
user_id=user_id,
scopes=scopes,
expires_at=create_expiration(REFRESH_TOKEN_LIFETIME),
access_token_session_id=access_token_session_id,
)
def validate_refresh_token(db_refresh_token: OAuthRefreshToken) -> None:
"""Validate a refresh token, raising ValueError if invalid."""
now = datetime.now()
if db_refresh_token.expires_at < now: # type: ignore
logger.error(f"Refresh token expired: {db_refresh_token.token[:20]}...")
db_refresh_token.revoked = True # type: ignore
raise ValueError("Refresh token expired")
def create_oauth_token(
access_token: str, scopes: list[str], refresh_token: Optional[str] = None
) -> OAuthToken:
"""Create an OAuth token response."""
return OAuthToken(
access_token=access_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_LIFETIME,
refresh_token=refresh_token,
scope=" ".join(scopes),
)
def make_token(
db: scoped_session,
auth_state: TokenBase,
scopes: list[str],
) -> OAuthToken:
new_session = UserSession(
user_id=auth_state.user_id,
oauth_state_id=auth_state.id,
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
)
# Create refresh token
refresh_token = create_refresh_token_record(
cast(str, auth_state.client_id),
cast(int, auth_state.user_id),
scopes,
cast(str, new_session.id),
)
db.add(new_session)
db.add(refresh_token)
db.commit()
return create_oauth_token(
str(new_session.id),
scopes,
cast(str, refresh_token.token),
)
class SimpleOAuthProvider(OAuthAuthorizationServerProvider): class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""Get OAuth client information.""" """Get OAuth client information."""
@ -102,7 +199,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
with make_session() as session: with make_session() as session:
# Load OAuth state from database # Load OAuth state from database
oauth_state = session.query(OAuthState).get(state) oauth_state = (
session.query(OAuthState).filter(OAuthState.state == state).first()
)
if not oauth_state: if not oauth_state:
logger.error(f"State {state} not found in database") logger.error(f"State {state} not found in database")
raise ValueError("Invalid state parameter") raise ValueError("Invalid state parameter")
@ -111,19 +210,21 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
now = datetime.fromtimestamp(time.time()) now = datetime.fromtimestamp(time.time())
if oauth_state.expires_at < now: if oauth_state.expires_at < now:
logger.error(f"State {state} has expired") logger.error(f"State {state} has expired")
oauth_state.stale = True oauth_state.stale = True # type: ignore
session.commit() session.commit()
raise ValueError("State has expired") raise ValueError("State has expired")
oauth_state.code = f"code_{secrets.token_hex(16)}" oauth_state.code = f"code_{secrets.token_hex(16)}" # type: ignore
oauth_state.stale = False oauth_state.stale = False # type: ignore
oauth_state.user_id = user.id oauth_state.user_id = user.id
session.add(oauth_state) session.add(oauth_state)
session.commit() session.commit()
return construct_redirect_uri( return construct_redirect_uri(
oauth_state.redirect_uri, code=oauth_state.code, state=state cast(str, oauth_state.redirect_uri),
code=cast(str, oauth_state.code),
state=state,
) )
async def load_authorization_code( async def load_authorization_code(
@ -156,29 +257,11 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
logger.error(f"Invalid authorization code: {authorization_code.code}") logger.error(f"Invalid authorization code: {authorization_code.code}")
raise ValueError("Invalid authorization code") raise ValueError("Invalid authorization code")
# Get the user associated with this auth code
if not auth_code.user: if not auth_code.user:
logger.error(f"No user found for auth code: {authorization_code.code}") logger.error(f"No user found for auth code: {authorization_code.code}")
raise ValueError("Invalid authorization code") raise ValueError("Invalid authorization code")
# Create a UserSession to serve as access token return make_token(session, auth_code, authorization_code.scopes)
expires_at = datetime.fromtimestamp(time.time() + 3600)
auth_code.session = UserSession(
user_id=auth_code.user_id,
oauth_state_id=auth_code.state,
expires_at=expires_at,
)
auth_code.stale = True # type: ignore
session.commit()
access_token = str(auth_code.session.id)
return OAuthToken(
access_token=access_token,
token_type="bearer",
expires_in=3600,
scope=" ".join(authorization_code.scopes),
)
async def load_access_token(self, token: str) -> Optional[AccessToken]: async def load_access_token(self, token: str) -> Optional[AccessToken]:
"""Load and validate an access token.""" """Load and validate an access token."""
@ -192,6 +275,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
if not user_session or user_session.expires_at < now: if not user_session or user_session.expires_at < now:
return None return None
logger.info(
f"Loading access token: {token}, state: {user_session.oauth_state}"
)
return AccessToken( return AccessToken(
token=token, token=token,
client_id=user_session.oauth_state.client_id, client_id=user_session.oauth_state.client_id,
@ -202,27 +288,106 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
async def load_refresh_token( async def load_refresh_token(
self, client: OAuthClientInformationFull, refresh_token: str self, client: OAuthClientInformationFull, refresh_token: str
) -> Optional[RefreshToken]: ) -> Optional[RefreshToken]:
"""Load a refresh token - not supported in this simple implementation.""" """Load and validate a refresh token."""
with make_session() as session:
now = datetime.now()
# Query for the refresh token
db_refresh_token = (
session.query(OAuthRefreshToken)
.filter(
OAuthRefreshToken.token == refresh_token,
OAuthRefreshToken.client_id == client.client_id,
OAuthRefreshToken.revoked == False, # noqa: E712
OAuthRefreshToken.expires_at > now,
)
.first()
)
if not db_refresh_token:
logger.error(
f"Invalid or expired refresh token: {refresh_token[:20]}..."
)
return None return None
return RefreshToken(
token=refresh_token,
client_id=client.client_id,
scopes=cast(list[str], db_refresh_token.scopes),
expires_at=int(db_refresh_token.expires_at.timestamp()),
)
async def exchange_refresh_token( async def exchange_refresh_token(
self, self,
client: OAuthClientInformationFull, client: OAuthClientInformationFull,
refresh_token: RefreshToken, refresh_token: RefreshToken,
scopes: list[str], scopes: list[str],
) -> OAuthToken: ) -> OAuthToken:
"""Exchange refresh token - not supported in this simple implementation.""" """Exchange refresh token for new access token."""
raise NotImplementedError("Refresh tokens not supported") with make_session() as session:
# Load the refresh token from database
db_refresh_token = (
session.query(OAuthRefreshToken)
.filter(
OAuthRefreshToken.token == refresh_token.token,
OAuthRefreshToken.client_id == client.client_id,
OAuthRefreshToken.revoked == False, # noqa: E712
)
.first()
)
if not db_refresh_token:
logger.error(f"Refresh token not found: {refresh_token.token[:20]}...")
raise ValueError("Invalid refresh token")
# Validate refresh token
validate_refresh_token(db_refresh_token)
# Validate requested scopes are subset of original scopes
original_scopes = set(cast(list[str], db_refresh_token.scopes))
requested_scopes = set(scopes) if scopes else original_scopes
if not requested_scopes.issubset(original_scopes):
logger.error(
f"Requested scopes {requested_scopes} exceed original scopes {original_scopes}"
)
raise ValueError("Requested scopes exceed original authorization")
return make_token(session, db_refresh_token, scopes)
async def revoke_token( async def revoke_token(
self, token: str, token_type_hint: Optional[str] = None self, token: str, token_type_hint: Optional[str] = None
) -> None: ) -> None:
"""Revoke a token.""" """Revoke a token (access token or refresh token)."""
with make_session() as session: with make_session() as session:
revoked = False
# Try to revoke as access token (UserSession)
if not token_type_hint or token_type_hint == "access_token":
user_session = session.query(UserSession).get(token) user_session = session.query(UserSession).get(token)
if user_session: if user_session:
session.delete(user_session) session.delete(user_session)
revoked = True
logger.info(f"Revoked access token: {token[:20]}...")
# Try to revoke as refresh token
if not revoked and (
not token_type_hint or token_type_hint == "refresh_token"
):
refresh_token = (
session.query(OAuthRefreshToken)
.filter(OAuthRefreshToken.token == token)
.first()
)
if refresh_token:
refresh_token.revoked = True # type: ignore
revoked = True
logger.info(f"Revoked refresh token: {token[:20]}...")
if revoked:
session.commit() session.commit()
else:
logger.warning(f"Token not found for revocation: {token[:20]}...")
def get_protected_resource_metadata(self) -> dict[str, Any]: def get_protected_resource_metadata(self) -> dict[str, Any]:
"""Return metadata about the protected resource.""" """Return metadata about the protected resource."""
@ -231,6 +396,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
"scopes_supported": ["read", "write"], "scopes_supported": ["read", "write"],
"bearer_methods_supported": ["header"], "bearer_methods_supported": ["header"],
"resource_documentation": f"{settings.SERVER_URL}/docs", "resource_documentation": f"{settings.SERVER_URL}/docs",
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["none", "client_secret_basic"],
"refresh_token_rotation_enabled": True,
"protected_resources": [ "protected_resources": [
{ {
"resource_uri": f"{settings.SERVER_URL}/mcp", "resource_uri": f"{settings.SERVER_URL}/mcp",

View File

@ -37,6 +37,7 @@ from memory.common.db.models.users import (
UserSession, UserSession,
OAuthClientInformation, OAuthClientInformation,
OAuthState, OAuthState,
OAuthRefreshToken,
) )
__all__ = [ __all__ = [
@ -73,4 +74,5 @@ __all__ = [
"UserSession", "UserSession",
"OAuthClientInformation", "OAuthClientInformation",
"OAuthState", "OAuthState",
"OAuthRefreshToken",
] ]

View File

@ -15,6 +15,7 @@ from sqlalchemy import (
) )
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from datetime import datetime
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
@ -69,7 +70,7 @@ class UserSession(Base):
__tablename__ = "user_sessions" __tablename__ = "user_sessions"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
oauth_state_id = Column(String, ForeignKey("oauth_states.state"), nullable=True) oauth_state_id = Column(Integer, ForeignKey("oauth_states.id"), nullable=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False) user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime, server_default=func.now()) created_at = Column(DateTime, server_default=func.now())
expires_at = Column(DateTime, nullable=False) expires_at = Column(DateTime, nullable=False)
@ -125,29 +126,38 @@ class OAuthClientInformation(Base):
} }
class OAuthState(Base): class OAuthToken:
__tablename__ = "oauth_states" id = Column(Integer, primary_key=True)
state = Column(String, primary_key=True)
client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False) client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True) user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
scopes = Column(ARRAY(String), nullable=False)
created_at = Column(DateTime, server_default=func.now())
expires_at = Column(DateTime, nullable=False)
def serialize(self) -> dict:
return {
"client_id": self.client_id,
"scopes": self.scopes,
"expires_at": self.expires_at.timestamp(),
}
class OAuthState(Base, OAuthToken):
__tablename__ = "oauth_states"
state = Column(String, nullable=False)
code = Column(String, nullable=True) code = Column(String, nullable=True)
redirect_uri = Column(String, nullable=False) redirect_uri = Column(String, nullable=False)
redirect_uri_provided_explicitly = Column(Boolean, nullable=False) redirect_uri_provided_explicitly = Column(Boolean, nullable=False)
code_challenge = Column(String, nullable=True) code_challenge = Column(String, nullable=True)
scopes = Column(ARRAY(String), nullable=False)
created_at = Column(DateTime, server_default=func.now())
expires_at = Column(DateTime, nullable=False)
stale = Column(Boolean, nullable=False, default=False) stale = Column(Boolean, nullable=False, default=False)
def serialize(self, code: bool = False) -> dict: def serialize(self, code: bool = False) -> dict:
data = { data = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
"redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly, "redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly,
"code_challenge": self.code_challenge, "code_challenge": self.code_challenge,
"scopes": self.scopes, } | super().serialize()
}
if code: if code:
data |= { data |= {
"code": self.code, "code": self.code,
@ -158,3 +168,29 @@ class OAuthState(Base):
client = relationship("OAuthClientInformation", back_populates="sessions") client = relationship("OAuthClientInformation", back_populates="sessions")
session = relationship("UserSession", back_populates="oauth_state", uselist=False) session = relationship("UserSession", back_populates="oauth_state", uselist=False)
user = relationship("User", back_populates="oauth_states") user = relationship("User", back_populates="oauth_states")
class OAuthRefreshToken(Base, OAuthToken):
__tablename__ = "oauth_refresh_tokens"
token = Column(
String, nullable=False, default=lambda: f"rt_{secrets.token_hex(32)}"
)
revoked = Column(Boolean, nullable=False, default=False)
# Optional: link to the access token session that was created with this refresh token
access_token_session_id = Column(
String, ForeignKey("user_sessions.id"), nullable=True
)
# Relationships
client = relationship("OAuthClientInformation")
user = relationship("User")
access_token_session = relationship("UserSession")
def serialize(self) -> dict:
return {
"token": self.token,
"expires_at": self.expires_at.timestamp(),
"revoked": self.revoked,
} | super().serialize()