proper oauth flow

This commit is contained in:
Daniel O'Connell 2025-06-06 12:55:48 +02:00
parent 4d057d1ec6
commit 4556ef2c48
9 changed files with 797 additions and 23 deletions

View File

@ -0,0 +1,86 @@
"""oauth codes
Revision ID: 66771d293b27
Revises: 58439dd3088b
Create Date: 2025-06-06 12:36:11.737507
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "66771d293b27"
down_revision: Union[str, None] = "58439dd3088b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"oauth_client",
sa.Column("client_id", sa.String(), nullable=False),
sa.Column("client_secret", sa.String(), nullable=True),
sa.Column("client_id_issued_at", sa.Numeric(), nullable=False),
sa.Column("client_secret_expires_at", sa.Numeric(), nullable=True),
sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=False),
sa.Column("token_endpoint_auth_method", sa.String(), nullable=False),
sa.Column("grant_types", sa.ARRAY(sa.String()), nullable=False),
sa.Column("response_types", sa.ARRAY(sa.String()), nullable=False),
sa.Column("scope", sa.String(), nullable=False),
sa.Column("client_name", sa.String(), nullable=False),
sa.Column("client_uri", sa.String(), nullable=True),
sa.Column("logo_uri", sa.String(), nullable=True),
sa.Column("contacts", sa.ARRAY(sa.String()), nullable=True),
sa.Column("tos_uri", sa.String(), nullable=True),
sa.Column("policy_uri", sa.String(), nullable=True),
sa.Column("jwks_uri", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("client_id"),
)
op.create_table(
"oauth_states",
sa.Column("state", sa.String(), nullable=False),
sa.Column("client_id", sa.String(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("code", sa.String(), nullable=True),
sa.Column("redirect_uri", sa.String(), nullable=False),
sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False),
sa.Column("code_challenge", sa.String(), nullable=True),
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column("expires_at", sa.DateTime(), nullable=False),
sa.Column("stale", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_client.client_id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("state"),
)
op.add_column(
"user_sessions", sa.Column("oauth_state_id", sa.String(), nullable=True)
)
op.create_foreign_key(
"fk_user_sessions_oauth_state_id",
"user_sessions",
"oauth_states",
["oauth_state_id"],
["state"],
)
def downgrade() -> None:
op.drop_constraint(
"fk_user_sessions_oauth_state_id", "user_sessions", type_="foreignkey"
)
op.drop_column("user_sessions", "oauth_state_id")
op.drop_table("oauth_states")
op.drop_table("oauth_client")

124
src/memory/api/MCP/base.py Normal file
View File

@ -0,0 +1,124 @@
import logging
import os
from typing import cast
from mcp.server.auth.handlers.authorize import AuthorizationRequest
from mcp.server.auth.handlers.token import (
AuthorizationCodeRequest,
RefreshTokenRequest,
TokenRequest,
)
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
from mcp.server.fastmcp import FastMCP
from mcp.shared.auth import OAuthClientMetadata
from memory.common.db.models.users import User
from pydantic import AnyHttpUrl
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.templating import Jinja2Templates
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import OAuthState
logger = logging.getLogger(__name__)
def validate_metadata(klass: type):
orig_validate = klass.model_validate
def validate(data: dict):
data = dict(data)
if "redirect_uris" in data:
data["redirect_uris"] = [
str(uri).replace("cursor://", "http://")
for uri in data["redirect_uris"]
]
if "redirect_uri" in data:
data["redirect_uri"] = str(data["redirect_uri"]).replace(
"cursor://", "http://"
)
return orig_validate(data)
klass.model_validate = validate
validate_metadata(OAuthClientMetadata)
validate_metadata(AuthorizationRequest)
validate_metadata(AuthorizationCodeRequest)
validate_metadata(RefreshTokenRequest)
validate_metadata(TokenRequest)
# Setup templates
template_dir = os.path.join(os.path.dirname(__file__), "templates")
templates = Jinja2Templates(directory=template_dir)
oauth_provider = SimpleOAuthProvider()
auth_settings = AuthSettings(
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
client_registration_options=ClientRegistrationOptions(
enabled=True,
valid_scopes=["read", "write"],
default_scopes=["read"],
),
required_scopes=["read", "write"],
)
mcp = FastMCP(
"memory",
stateless_http=True,
auth_server_provider=oauth_provider,
auth=auth_settings,
)
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
async def oauth_protected_resource(request: Request):
"""OAuth 2.0 Protected Resource Metadata."""
logger.info("Protected resource metadata requested")
return JSONResponse(oauth_provider.get_protected_resource_metadata())
def login_form(request: Request, form_data: dict, error: str | None = None):
return templates.TemplateResponse(
"login.html",
{"request": request, "form_data": form_data, "error": error},
)
@mcp.custom_route("/oauth/login", methods=["GET"])
async def login_page(request: Request):
"""Display the login page."""
form_data = dict(request.query_params)
state = form_data.get("state")
with make_session() as session:
oauth_state = session.query(OAuthState).get(state)
if not oauth_state:
logger.error(f"State {state} not found in database")
raise ValueError("Invalid state parameter")
return login_form(request, form_data, None)
@mcp.custom_route("/oauth/login", methods=["POST"])
async def handle_login(request: Request):
"""Handle login form submission."""
form = await request.form()
oauth_params = {
key: value for key, value in form.items() if key not in ["email", "password"]
}
with make_session() as session:
user = session.query(User).filter(User.email == form.get("email")).first()
if not user or not user.is_valid_password(str(form.get("password", ""))):
logger.warning("Login failed - invalid credentials")
return login_form(request, oauth_params, "Invalid email or password")
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
redirect_url = redirect_url.replace("http://", "cursor://")
return RedirectResponse(url=redirect_url, status_code=302)

View File

@ -0,0 +1,246 @@
import secrets
import time
from typing import Optional, Any, cast
from urllib.parse import urlencode
import logging
from datetime import datetime, timezone
from memory.common.db.models.users import (
User,
UserSession,
OAuthClientInformation,
OAuthState,
)
from memory.common.db.connection import make_session
from memory.common import settings
from mcp.server.auth.provider import (
AccessToken,
AuthorizationParams,
AuthorizationCode,
OAuthAuthorizationServerProvider,
RefreshToken,
construct_redirect_uri,
)
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
logger = logging.getLogger(__name__)
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""Get OAuth client information."""
with make_session() as session:
client = session.query(OAuthClientInformation).get(client_id)
return client and OAuthClientInformationFull(**client.serialize())
async def register_client(self, client_info: OAuthClientInformationFull):
"""Register a new OAuth client."""
with make_session() as session:
client = session.query(OAuthClientInformation).get(client_info.client_id)
if not client:
client = OAuthClientInformation(client_id=client_info.client_id)
for key, value in client_info.model_dump().items():
if key == "redirect_uris":
value = [str(uri) for uri in value]
elif value and key in [
"client_uri",
"logo_uri",
"tos_uri",
"policy_uri",
"jwks_uri",
]:
value = str(value)
setattr(client, key, value)
session.add(client)
session.commit()
async def authorize(
self, client: OAuthClientInformationFull, params: AuthorizationParams
) -> str:
"""Redirect to login page for user authentication."""
redirect_uri_str = str(params.redirect_uri)
registered_uris = [str(uri) for uri in getattr(client, "redirect_uris", [])]
if redirect_uri_str not in registered_uris:
logger.error(
f"Redirect URI {redirect_uri_str} not in registered URIs: {registered_uris}"
)
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
# Store the authorization parameters in database
with make_session() as session:
oauth_state = OAuthState(
state=params.state or secrets.token_hex(16),
client_id=client.client_id,
redirect_uri=str(params.redirect_uri),
redirect_uri_provided_explicitly=str(
params.redirect_uri_provided_explicitly
).lower()
== "true",
code_challenge=params.code_challenge or "",
scopes=["read", "write"], # Default scopes
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
)
session.add(oauth_state)
session.commit()
return f"{settings.SERVER_URL}/oauth/login?" + urlencode(
{
"state": oauth_state.state,
"client_id": client.client_id,
"redirect_uri": oauth_state.redirect_uri,
"redirect_uri_provided_explicitly": oauth_state.redirect_uri_provided_explicitly,
"code_challenge": cast(str, oauth_state.code_challenge),
}
)
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
"""Complete authorization after successful login."""
if not (state := oauth_params.get("state")):
logger.error("No state parameter provided")
raise ValueError("Missing state parameter")
with make_session() as session:
# Load OAuth state from database
oauth_state = session.query(OAuthState).get(state)
if not oauth_state:
logger.error(f"State {state} not found in database")
raise ValueError("Invalid state parameter")
# Check if state has expired
now = datetime.fromtimestamp(time.time())
if oauth_state.expires_at < now:
logger.error(f"State {state} has expired")
oauth_state.stale = True
session.commit()
raise ValueError("State has expired")
oauth_state.code = f"code_{secrets.token_hex(16)}"
oauth_state.stale = False
oauth_state.user_id = user.id
session.add(oauth_state)
session.commit()
return construct_redirect_uri(
oauth_state.redirect_uri, code=oauth_state.code, state=state
)
async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str
) -> Optional[AuthorizationCode]:
"""Load an authorization code."""
with make_session() as session:
auth_code = (
session.query(OAuthState)
.filter(OAuthState.code == authorization_code)
.first()
)
if not auth_code:
logger.error(f"Invalid authorization code: {authorization_code}")
raise ValueError("Invalid authorization code")
return AuthorizationCode(**auth_code.serialize(code=True))
async def exchange_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
) -> OAuthToken:
"""Exchange authorization code for tokens."""
with make_session() as session:
auth_code = (
session.query(OAuthState)
.filter(OAuthState.code == authorization_code.code)
.first()
)
if not auth_code:
logger.error(f"Invalid authorization code: {authorization_code.code}")
raise ValueError("Invalid authorization code")
# Get the user associated with this auth code
if not auth_code.user:
logger.error(f"No user found for auth code: {authorization_code.code}")
raise ValueError("Invalid authorization code")
# Create a UserSession to serve as access token
expires_at = datetime.fromtimestamp(time.time() + 3600)
auth_code.session = UserSession(
user_id=auth_code.user_id,
oauth_state_id=auth_code.state,
expires_at=expires_at,
)
auth_code.stale = True # type: ignore
session.commit()
access_token = str(auth_code.session.id)
return OAuthToken(
access_token=access_token,
token_type="bearer",
expires_in=3600,
scope=" ".join(authorization_code.scopes),
)
async def load_access_token(self, token: str) -> Optional[AccessToken]:
"""Load and validate an access token."""
with make_session() as session:
now = datetime.now(timezone.utc).replace(
tzinfo=None
) # Make naive for DB comparison
# Query for active (non-expired) session
user_session = session.query(UserSession).get(token)
if not user_session or user_session.expires_at < now:
return None
return AccessToken(
token=token,
client_id=user_session.oauth_state.client_id,
scopes=user_session.oauth_state.scopes,
expires_at=int(user_session.expires_at.timestamp()),
)
async def load_refresh_token(
self, client: OAuthClientInformationFull, refresh_token: str
) -> Optional[RefreshToken]:
"""Load a refresh token - not supported in this simple implementation."""
return None
async def exchange_refresh_token(
self,
client: OAuthClientInformationFull,
refresh_token: RefreshToken,
scopes: list[str],
) -> OAuthToken:
"""Exchange refresh token - not supported in this simple implementation."""
raise NotImplementedError("Refresh tokens not supported")
async def revoke_token(
self, token: str, token_type_hint: Optional[str] = None
) -> None:
"""Revoke a token."""
with make_session() as session:
user_session = session.query(UserSession).get(token)
if user_session:
session.delete(user_session)
session.commit()
def get_protected_resource_metadata(self) -> dict[str, Any]:
"""Return metadata about the protected resource."""
return {
"resource_server": settings.SERVER_URL,
"scopes_supported": ["read", "write"],
"bearer_methods_supported": ["header"],
"resource_documentation": f"{settings.SERVER_URL}/docs",
"protected_resources": [
{
"resource_uri": f"{settings.SERVER_URL}/mcp",
"scopes": ["read", "write"],
"http_methods": ["POST", "GET"],
},
{
"resource_uri": f"{settings.SERVER_URL}/mcp/",
"scopes": ["read", "write"],
"http_methods": ["POST", "GET"],
},
],
}

View File

@ -0,0 +1,159 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Login - Memory OAuth</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
margin: 0;
padding: 0;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
}
.login-container {
background: white;
padding: 2rem;
border-radius: 12px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
width: 100%;
max-width: 400px;
}
.login-header {
text-align: center;
margin-bottom: 2rem;
}
.login-header h1 {
color: #333;
margin: 0 0 0.5rem 0;
font-size: 1.8rem;
font-weight: 600;
}
.login-header p {
color: #666;
margin: 0;
font-size: 0.9rem;
}
.form-group {
margin-bottom: 1.5rem;
}
.form-group label {
display: block;
margin-bottom: 0.5rem;
color: #333;
font-weight: 500;
font-size: 0.9rem;
}
.form-group input {
width: 100%;
padding: 0.75rem;
border: 2px solid #e1e5e9;
border-radius: 8px;
font-size: 1rem;
transition: border-color 0.2s;
box-sizing: border-box;
}
.form-group input:focus {
outline: none;
border-color: #667eea;
}
.login-button {
width: 100%;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 0.875rem;
border-radius: 8px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s;
}
.login-button:hover {
transform: translateY(-1px);
}
.login-button:active {
transform: translateY(0);
}
.error-message {
background: #fee;
color: #c33;
padding: 0.75rem;
border-radius: 6px;
margin-bottom: 1rem;
font-size: 0.9rem;
border-left: 4px solid #c33;
}
.demo-credentials {
background: #f8f9fa;
padding: 1rem;
border-radius: 8px;
margin-top: 1.5rem;
font-size: 0.85rem;
}
.demo-credentials strong {
color: #333;
}
.demo-credentials code {
background: #e9ecef;
padding: 0.2rem 0.4rem;
border-radius: 4px;
font-family: monospace;
}
</style>
</head>
<body>
<div class="login-container">
<div class="login-header">
<h1>Sign In</h1>
<p>Access your Memory knowledge base</p>
</div>
{% if error %}
<div class="error-message">
{{ error }}
</div>
{% endif %}
<form method="post" action="/oauth/login">
{% for key, value in form_data.items() %}
<input type="hidden" name="{{ key }}" value="{{ value }}">
{% endfor %}
<div class="form-group">
<label for="email">Email</label>
<input type="email" id="email" name="email" required>
</div>
<div class="form-group">
<label for="password">Password</label>
<input type="password" id="password" name="password" required>
</div>
<button type="submit" class="login-button">Sign In</button>
</form>
</div>
</body>
</html>

View File

@ -5,19 +5,21 @@ MCP tools for the epistemic sparring partner system.
import logging
from datetime import datetime, timezone
from mcp.server.fastmcp import FastMCP
from mcp.server.auth.middleware.auth_context import get_access_token
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.common.db.connection import make_session
from memory.common.db.models import AgentObservation, SourceItem
from memory.common.db.models import (
AgentObservation,
SourceItem,
UserSession,
)
from memory.api.MCP.base import mcp
logger = logging.getLogger(__name__)
# Create MCP server instance
mcp = FastMCP("memory", stateless_http=True)
def filter_observation_source_ids(
tags: list[str] | None = None, observation_types: list[str] | None = None
@ -67,4 +69,42 @@ def filter_source_ids(
@mcp.tool()
async def get_current_time() -> dict:
"""Get the current time in UTC."""
logger.info("get_current_time tool called")
return {"current_time": datetime.now(timezone.utc).isoformat()}
@mcp.tool()
async def get_authenticated_user() -> dict:
"""Get information about the authenticated user."""
logger.info("🔧 get_authenticated_user tool called")
access_token = get_access_token()
logger.info(f"🔧 Access token from MCP context: {access_token}")
if not access_token:
logger.warning("❌ No access token found in MCP context!")
return {"error": "Not authenticated"}
logger.info(
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
)
# Look up the actual user from the session token
with make_session() as session:
user_session = (
session.query(UserSession)
.filter(UserSession.id == access_token.token)
.first()
)
if user_session and user_session.user:
user_info = user_session.user.serialize()
else:
user_info = {"error": "User not found"}
return {
"authenticated": True,
"token_type": "Bearer",
"scopes": access_token.scopes,
"client_id": access_token.client_id,
"user": user_info,
}

View File

@ -16,22 +16,19 @@ from fastapi import (
Query,
Form,
Depends,
Request,
)
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from sqladmin import Admin
from memory.common import settings
from memory.common import extract
from memory.common import extract, settings
from memory.common.db.connection import get_engine
from memory.common.db.models import User
from memory.api.admin import setup_admin
from memory.api.search import search, SearchResult
from memory.api.MCP.tools import mcp
from memory.api.auth import (
router as auth_router,
get_current_user,
AuthenticationMiddleware,
)
from memory.api.MCP.base import mcp
from memory.api.auth import get_current_user
logger = logging.getLogger(__name__)
@ -45,23 +42,45 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
# Add request logging middleware
# @app.middleware("http")
# async def log_requests(request, call_next):
# logger.info(f"Main app: {request.method} {request.url.path}")
# if request.url.path.startswith("/mcp"):
# logger.info(f"Request headers: {dict(request.headers)}")
# response = await call_next(request)
# return response
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify actual origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# SQLAdmin setup
engine = get_engine()
admin = Admin(app, engine)
setup_admin(admin)
# Include auth router
app.add_middleware(AuthenticationMiddleware)
app.include_router(auth_router)
# Add health check to MCP server instead of main app
@mcp.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
"""Simple health check endpoint on MCP server"""
from fastapi.responses import JSONResponse
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"})
# Mount MCP server at root - OAuth endpoints need to be at root level
app.mount("/", mcp.streamable_http_app())
@app.get("/health")
def health_check():
"""Simple health check endpoint"""
return {"status": "healthy"}
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
if not item:
return []

View File

@ -35,6 +35,8 @@ from memory.common.db.models.sources import (
from memory.common.db.models.users import (
User,
UserSession,
OAuthClientInformation,
OAuthState,
)
__all__ = [
@ -69,4 +71,6 @@ __all__ = [
# Users
"User",
"UserSession",
"OAuthClientInformation",
"OAuthState",
]

View File

@ -3,7 +3,16 @@ import secrets
from typing import cast
import uuid
from memory.common.db.models.base import Base
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
ForeignKey,
Boolean,
ARRAY,
Numeric,
)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
@ -35,6 +44,9 @@ class User(Base):
sessions = relationship(
"UserSession", back_populates="user", cascade="all, delete-orphan"
)
oauth_states = relationship(
"OAuthState", back_populates="user", cascade="all, delete-orphan"
)
def serialize(self) -> dict:
return {
@ -57,9 +69,92 @@ class UserSession(Base):
__tablename__ = "user_sessions"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
oauth_state_id = Column(String, ForeignKey("oauth_states.state"), nullable=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime, server_default=func.now())
expires_at = Column(DateTime, nullable=False)
# Relationship to user
user = relationship("User", back_populates="sessions")
oauth_state = relationship("OAuthState", back_populates="session")
class OAuthClientInformation(Base):
__tablename__ = "oauth_client"
client_id = Column(String, primary_key=True)
client_secret = Column(String, nullable=True)
client_id_issued_at = Column(Numeric, nullable=False)
client_secret_expires_at = Column(Numeric, nullable=True)
redirect_uris = Column(ARRAY(String), nullable=False)
token_endpoint_auth_method = Column(String, nullable=False)
grant_types = Column(ARRAY(String), nullable=False)
response_types = Column(ARRAY(String), nullable=False)
scope = Column(String, nullable=False)
client_name = Column(String, nullable=False)
client_uri = Column(String, nullable=True)
logo_uri = Column(String, nullable=True)
contacts = Column(ARRAY(String), nullable=True)
tos_uri = Column(String, nullable=True)
policy_uri = Column(String, nullable=True)
jwks_uri = Column(String, nullable=True)
sessions = relationship(
"OAuthState", back_populates="client", cascade="all, delete-orphan"
)
def serialize(self) -> dict:
return {
"client_id": self.client_id,
"client_secret": self.client_secret,
"client_id_issued_at": self.client_id_issued_at,
"client_secret_expires_at": self.client_secret_expires_at,
"redirect_uris": self.redirect_uris,
"token_endpoint_auth_method": self.token_endpoint_auth_method,
"grant_types": self.grant_types,
"response_types": self.response_types,
"scope": self.scope,
"client_name": self.client_name,
"client_uri": self.client_uri,
"logo_uri": self.logo_uri,
"contacts": self.contacts,
"tos_uri": self.tos_uri,
"policy_uri": self.policy_uri,
"jwks_uri": self.jwks_uri,
}
class OAuthState(Base):
__tablename__ = "oauth_states"
state = Column(String, primary_key=True)
client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
code = Column(String, nullable=True)
redirect_uri = Column(String, nullable=False)
redirect_uri_provided_explicitly = Column(Boolean, nullable=False)
code_challenge = Column(String, nullable=True)
scopes = Column(ARRAY(String), nullable=False)
created_at = Column(DateTime, server_default=func.now())
expires_at = Column(DateTime, nullable=False)
stale = Column(Boolean, nullable=False, default=False)
def serialize(self, code: bool = False) -> dict:
data = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly,
"code_challenge": self.code_challenge,
"scopes": self.scopes,
}
if code:
data |= {
"code": self.code,
"expires_at": self.expires_at.timestamp(),
}
return data
client = relationship("OAuthClientInformation", back_populates="sessions")
session = relationship("UserSession", back_populates="oauth_state", uselist=False)
user = relationship("User", back_populates="oauth_states")

View File

@ -131,6 +131,7 @@ else:
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
# API settings
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
HTTPS = boolean_env("HTTPS", False)
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")