mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 15:14:45 +02:00
oauth refresh + revoke
This commit is contained in:
parent
4556ef2c48
commit
986d5b9957
@ -1,8 +1,8 @@
|
||||
"""oauth codes
|
||||
|
||||
Revision ID: 66771d293b27
|
||||
Revision ID: 1d6bc8015ea9
|
||||
Revises: 58439dd3088b
|
||||
Create Date: 2025-06-06 12:36:11.737507
|
||||
Create Date: 2025-06-06 16:53:33.044558
|
||||
|
||||
"""
|
||||
|
||||
@ -13,7 +13,7 @@ import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "66771d293b27"
|
||||
revision: str = "1d6bc8015ea9"
|
||||
down_revision: Union[str, None] = "58439dd3088b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
@ -43,18 +43,19 @@ def upgrade() -> None:
|
||||
op.create_table(
|
||||
"oauth_states",
|
||||
sa.Column("state", sa.String(), nullable=False),
|
||||
sa.Column("client_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("code", sa.String(), nullable=True),
|
||||
sa.Column("redirect_uri", sa.String(), nullable=False),
|
||||
sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False),
|
||||
sa.Column("code_challenge", sa.String(), nullable=True),
|
||||
sa.Column("stale", sa.Boolean(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("client_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||
),
|
||||
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("stale", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"],
|
||||
["oauth_client.client_id"],
|
||||
@ -63,24 +64,52 @@ def upgrade() -> None:
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("state"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_refresh_tokens",
|
||||
sa.Column("token", sa.String(), nullable=False),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||
sa.Column("access_token_session_id", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("client_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||
),
|
||||
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["access_token_session_id"],
|
||||
["user_sessions.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"],
|
||||
["oauth_client.client_id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.add_column(
|
||||
"user_sessions", sa.Column("oauth_state_id", sa.String(), nullable=True)
|
||||
"user_sessions", sa.Column("oauth_state_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_user_sessions_oauth_state_id",
|
||||
"user_sessions_oauth_state_id_fkey",
|
||||
"user_sessions",
|
||||
"oauth_states",
|
||||
["oauth_state_id"],
|
||||
["state"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"fk_user_sessions_oauth_state_id", "user_sessions", type_="foreignkey"
|
||||
"user_sessions_oauth_state_id_fkey", "user_sessions", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("user_sessions", "oauth_state_id")
|
||||
op.drop_table("oauth_refresh_tokens")
|
||||
op.drop_table("oauth_states")
|
||||
op.drop_table("oauth_client")
|
@ -147,6 +147,7 @@ services:
|
||||
<<: *env
|
||||
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
|
||||
secrets: [postgres_password]
|
||||
volumes:
|
||||
- ./memory_files:/app/memory_files:rw
|
||||
|
@ -97,7 +97,9 @@ async def login_page(request: Request):
|
||||
|
||||
state = form_data.get("state")
|
||||
with make_session() as session:
|
||||
oauth_state = session.query(OAuthState).get(state)
|
||||
oauth_state = (
|
||||
session.query(OAuthState).filter(OAuthState.state == state).first()
|
||||
)
|
||||
if not oauth_state:
|
||||
logger.error(f"State {state} not found in database")
|
||||
raise ValueError("Invalid state parameter")
|
||||
|
@ -10,8 +10,10 @@ from memory.common.db.models.users import (
|
||||
UserSession,
|
||||
OAuthClientInformation,
|
||||
OAuthState,
|
||||
OAuthRefreshToken,
|
||||
OAuthToken as TokenBase,
|
||||
)
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.connection import make_session, scoped_session
|
||||
from memory.common import settings
|
||||
from mcp.server.auth.provider import (
|
||||
AccessToken,
|
||||
@ -26,6 +28,101 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Token configuration constants
|
||||
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
|
||||
REFRESH_TOKEN_LIFETIME = 30 * 24 * 3600 # 30 days
|
||||
|
||||
|
||||
def create_expiration(lifetime_seconds: int) -> datetime:
|
||||
"""Create expiration datetime from lifetime in seconds."""
|
||||
return datetime.fromtimestamp(time.time() + lifetime_seconds)
|
||||
|
||||
|
||||
def generate_refresh_token() -> str:
|
||||
"""Generate a new refresh token."""
|
||||
return f"rt_{secrets.token_hex(32)}"
|
||||
|
||||
|
||||
def create_access_token_session(
|
||||
user_id: int, oauth_state_id: str | None = None
|
||||
) -> UserSession:
|
||||
"""Create a new access token session."""
|
||||
return UserSession(
|
||||
user_id=user_id,
|
||||
oauth_state_id=oauth_state_id,
|
||||
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token_record(
|
||||
client_id: str,
|
||||
user_id: int,
|
||||
scopes: list[str],
|
||||
access_token_session_id: Optional[str] = None,
|
||||
) -> OAuthRefreshToken:
|
||||
"""Create a new refresh token record."""
|
||||
return OAuthRefreshToken(
|
||||
token=generate_refresh_token(),
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
expires_at=create_expiration(REFRESH_TOKEN_LIFETIME),
|
||||
access_token_session_id=access_token_session_id,
|
||||
)
|
||||
|
||||
|
||||
def validate_refresh_token(db_refresh_token: OAuthRefreshToken) -> None:
|
||||
"""Validate a refresh token, raising ValueError if invalid."""
|
||||
now = datetime.now()
|
||||
if db_refresh_token.expires_at < now: # type: ignore
|
||||
logger.error(f"Refresh token expired: {db_refresh_token.token[:20]}...")
|
||||
db_refresh_token.revoked = True # type: ignore
|
||||
raise ValueError("Refresh token expired")
|
||||
|
||||
|
||||
def create_oauth_token(
|
||||
access_token: str, scopes: list[str], refresh_token: Optional[str] = None
|
||||
) -> OAuthToken:
|
||||
"""Create an OAuth token response."""
|
||||
return OAuthToken(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=ACCESS_TOKEN_LIFETIME,
|
||||
refresh_token=refresh_token,
|
||||
scope=" ".join(scopes),
|
||||
)
|
||||
|
||||
|
||||
def make_token(
|
||||
db: scoped_session,
|
||||
auth_state: TokenBase,
|
||||
scopes: list[str],
|
||||
) -> OAuthToken:
|
||||
new_session = UserSession(
|
||||
user_id=auth_state.user_id,
|
||||
oauth_state_id=auth_state.id,
|
||||
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
|
||||
)
|
||||
|
||||
# Create refresh token
|
||||
refresh_token = create_refresh_token_record(
|
||||
cast(str, auth_state.client_id),
|
||||
cast(int, auth_state.user_id),
|
||||
scopes,
|
||||
cast(str, new_session.id),
|
||||
)
|
||||
|
||||
db.add(new_session)
|
||||
db.add(refresh_token)
|
||||
db.commit()
|
||||
|
||||
return create_oauth_token(
|
||||
str(new_session.id),
|
||||
scopes,
|
||||
cast(str, refresh_token.token),
|
||||
)
|
||||
|
||||
|
||||
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
"""Get OAuth client information."""
|
||||
@ -102,7 +199,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
|
||||
with make_session() as session:
|
||||
# Load OAuth state from database
|
||||
oauth_state = session.query(OAuthState).get(state)
|
||||
oauth_state = (
|
||||
session.query(OAuthState).filter(OAuthState.state == state).first()
|
||||
)
|
||||
if not oauth_state:
|
||||
logger.error(f"State {state} not found in database")
|
||||
raise ValueError("Invalid state parameter")
|
||||
@ -111,19 +210,21 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
now = datetime.fromtimestamp(time.time())
|
||||
if oauth_state.expires_at < now:
|
||||
logger.error(f"State {state} has expired")
|
||||
oauth_state.stale = True
|
||||
oauth_state.stale = True # type: ignore
|
||||
session.commit()
|
||||
raise ValueError("State has expired")
|
||||
|
||||
oauth_state.code = f"code_{secrets.token_hex(16)}"
|
||||
oauth_state.stale = False
|
||||
oauth_state.code = f"code_{secrets.token_hex(16)}" # type: ignore
|
||||
oauth_state.stale = False # type: ignore
|
||||
oauth_state.user_id = user.id
|
||||
|
||||
session.add(oauth_state)
|
||||
session.commit()
|
||||
|
||||
return construct_redirect_uri(
|
||||
oauth_state.redirect_uri, code=oauth_state.code, state=state
|
||||
cast(str, oauth_state.redirect_uri),
|
||||
code=cast(str, oauth_state.code),
|
||||
state=state,
|
||||
)
|
||||
|
||||
async def load_authorization_code(
|
||||
@ -156,29 +257,11 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
logger.error(f"Invalid authorization code: {authorization_code.code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
|
||||
# Get the user associated with this auth code
|
||||
if not auth_code.user:
|
||||
logger.error(f"No user found for auth code: {authorization_code.code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
|
||||
# Create a UserSession to serve as access token
|
||||
expires_at = datetime.fromtimestamp(time.time() + 3600)
|
||||
|
||||
auth_code.session = UserSession(
|
||||
user_id=auth_code.user_id,
|
||||
oauth_state_id=auth_code.state,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
auth_code.stale = True # type: ignore
|
||||
session.commit()
|
||||
access_token = str(auth_code.session.id)
|
||||
|
||||
return OAuthToken(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=3600,
|
||||
scope=" ".join(authorization_code.scopes),
|
||||
)
|
||||
return make_token(session, auth_code, authorization_code.scopes)
|
||||
|
||||
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
||||
"""Load and validate an access token."""
|
||||
@ -192,6 +275,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
if not user_session or user_session.expires_at < now:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Loading access token: {token}, state: {user_session.oauth_state}"
|
||||
)
|
||||
return AccessToken(
|
||||
token=token,
|
||||
client_id=user_session.oauth_state.client_id,
|
||||
@ -202,8 +288,34 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
async def load_refresh_token(
|
||||
self, client: OAuthClientInformationFull, refresh_token: str
|
||||
) -> Optional[RefreshToken]:
|
||||
"""Load a refresh token - not supported in this simple implementation."""
|
||||
return None
|
||||
"""Load and validate a refresh token."""
|
||||
with make_session() as session:
|
||||
now = datetime.now()
|
||||
|
||||
# Query for the refresh token
|
||||
db_refresh_token = (
|
||||
session.query(OAuthRefreshToken)
|
||||
.filter(
|
||||
OAuthRefreshToken.token == refresh_token,
|
||||
OAuthRefreshToken.client_id == client.client_id,
|
||||
OAuthRefreshToken.revoked == False, # noqa: E712
|
||||
OAuthRefreshToken.expires_at > now,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_refresh_token:
|
||||
logger.error(
|
||||
f"Invalid or expired refresh token: {refresh_token[:20]}..."
|
||||
)
|
||||
return None
|
||||
|
||||
return RefreshToken(
|
||||
token=refresh_token,
|
||||
client_id=client.client_id,
|
||||
scopes=cast(list[str], db_refresh_token.scopes),
|
||||
expires_at=int(db_refresh_token.expires_at.timestamp()),
|
||||
)
|
||||
|
||||
async def exchange_refresh_token(
|
||||
self,
|
||||
@ -211,18 +323,71 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
refresh_token: RefreshToken,
|
||||
scopes: list[str],
|
||||
) -> OAuthToken:
|
||||
"""Exchange refresh token - not supported in this simple implementation."""
|
||||
raise NotImplementedError("Refresh tokens not supported")
|
||||
"""Exchange refresh token for new access token."""
|
||||
with make_session() as session:
|
||||
# Load the refresh token from database
|
||||
db_refresh_token = (
|
||||
session.query(OAuthRefreshToken)
|
||||
.filter(
|
||||
OAuthRefreshToken.token == refresh_token.token,
|
||||
OAuthRefreshToken.client_id == client.client_id,
|
||||
OAuthRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_refresh_token:
|
||||
logger.error(f"Refresh token not found: {refresh_token.token[:20]}...")
|
||||
raise ValueError("Invalid refresh token")
|
||||
|
||||
# Validate refresh token
|
||||
validate_refresh_token(db_refresh_token)
|
||||
|
||||
# Validate requested scopes are subset of original scopes
|
||||
original_scopes = set(cast(list[str], db_refresh_token.scopes))
|
||||
requested_scopes = set(scopes) if scopes else original_scopes
|
||||
|
||||
if not requested_scopes.issubset(original_scopes):
|
||||
logger.error(
|
||||
f"Requested scopes {requested_scopes} exceed original scopes {original_scopes}"
|
||||
)
|
||||
raise ValueError("Requested scopes exceed original authorization")
|
||||
|
||||
return make_token(session, db_refresh_token, scopes)
|
||||
|
||||
async def revoke_token(
|
||||
self, token: str, token_type_hint: Optional[str] = None
|
||||
) -> None:
|
||||
"""Revoke a token."""
|
||||
"""Revoke a token (access token or refresh token)."""
|
||||
with make_session() as session:
|
||||
user_session = session.query(UserSession).get(token)
|
||||
if user_session:
|
||||
session.delete(user_session)
|
||||
revoked = False
|
||||
|
||||
# Try to revoke as access token (UserSession)
|
||||
if not token_type_hint or token_type_hint == "access_token":
|
||||
user_session = session.query(UserSession).get(token)
|
||||
if user_session:
|
||||
session.delete(user_session)
|
||||
revoked = True
|
||||
logger.info(f"Revoked access token: {token[:20]}...")
|
||||
|
||||
# Try to revoke as refresh token
|
||||
if not revoked and (
|
||||
not token_type_hint or token_type_hint == "refresh_token"
|
||||
):
|
||||
refresh_token = (
|
||||
session.query(OAuthRefreshToken)
|
||||
.filter(OAuthRefreshToken.token == token)
|
||||
.first()
|
||||
)
|
||||
if refresh_token:
|
||||
refresh_token.revoked = True # type: ignore
|
||||
revoked = True
|
||||
logger.info(f"Revoked refresh token: {token[:20]}...")
|
||||
|
||||
if revoked:
|
||||
session.commit()
|
||||
else:
|
||||
logger.warning(f"Token not found for revocation: {token[:20]}...")
|
||||
|
||||
def get_protected_resource_metadata(self) -> dict[str, Any]:
|
||||
"""Return metadata about the protected resource."""
|
||||
@ -231,6 +396,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
"scopes_supported": ["read", "write"],
|
||||
"bearer_methods_supported": ["header"],
|
||||
"resource_documentation": f"{settings.SERVER_URL}/docs",
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"token_endpoint_auth_methods_supported": ["none", "client_secret_basic"],
|
||||
"refresh_token_rotation_enabled": True,
|
||||
"protected_resources": [
|
||||
{
|
||||
"resource_uri": f"{settings.SERVER_URL}/mcp",
|
||||
|
@ -37,6 +37,7 @@ from memory.common.db.models.users import (
|
||||
UserSession,
|
||||
OAuthClientInformation,
|
||||
OAuthState,
|
||||
OAuthRefreshToken,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -73,4 +74,5 @@ __all__ = [
|
||||
"UserSession",
|
||||
"OAuthClientInformation",
|
||||
"OAuthState",
|
||||
"OAuthRefreshToken",
|
||||
]
|
||||
|
@ -15,6 +15,7 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
@ -69,7 +70,7 @@ class UserSession(Base):
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
oauth_state_id = Column(String, ForeignKey("oauth_states.state"), nullable=True)
|
||||
oauth_state_id = Column(Integer, ForeignKey("oauth_states.id"), nullable=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
@ -125,29 +126,38 @@ class OAuthClientInformation(Base):
|
||||
}
|
||||
|
||||
|
||||
class OAuthState(Base):
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
state = Column(String, primary_key=True)
|
||||
class OAuthToken:
|
||||
id = Column(Integer, primary_key=True)
|
||||
client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
scopes = Column(ARRAY(String), nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
"client_id": self.client_id,
|
||||
"scopes": self.scopes,
|
||||
"expires_at": self.expires_at.timestamp(),
|
||||
}
|
||||
|
||||
|
||||
class OAuthState(Base, OAuthToken):
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
state = Column(String, nullable=False)
|
||||
code = Column(String, nullable=True)
|
||||
redirect_uri = Column(String, nullable=False)
|
||||
redirect_uri_provided_explicitly = Column(Boolean, nullable=False)
|
||||
code_challenge = Column(String, nullable=True)
|
||||
scopes = Column(ARRAY(String), nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
stale = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
def serialize(self, code: bool = False) -> dict:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly,
|
||||
"code_challenge": self.code_challenge,
|
||||
"scopes": self.scopes,
|
||||
}
|
||||
} | super().serialize()
|
||||
if code:
|
||||
data |= {
|
||||
"code": self.code,
|
||||
@ -158,3 +168,29 @@ class OAuthState(Base):
|
||||
client = relationship("OAuthClientInformation", back_populates="sessions")
|
||||
session = relationship("UserSession", back_populates="oauth_state", uselist=False)
|
||||
user = relationship("User", back_populates="oauth_states")
|
||||
|
||||
|
||||
class OAuthRefreshToken(Base, OAuthToken):
|
||||
__tablename__ = "oauth_refresh_tokens"
|
||||
|
||||
token = Column(
|
||||
String, nullable=False, default=lambda: f"rt_{secrets.token_hex(32)}"
|
||||
)
|
||||
revoked = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Optional: link to the access token session that was created with this refresh token
|
||||
access_token_session_id = Column(
|
||||
String, ForeignKey("user_sessions.id"), nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClientInformation")
|
||||
user = relationship("User")
|
||||
access_token_session = relationship("UserSession")
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
"token": self.token,
|
||||
"expires_at": self.expires_at.timestamp(),
|
||||
"revoked": self.revoked,
|
||||
} | super().serialize()
|
||||
|
Loading…
x
Reference in New Issue
Block a user