protect admin

This commit is contained in:
Daniel O'Connell 2025-06-06 18:24:09 +02:00
parent 986d5b9957
commit d17d724631
7 changed files with 81 additions and 59 deletions

View File

@ -1,5 +1,6 @@
import logging
import os
import pathlib
from typing import cast
from mcp.server.auth.handlers.authorize import AuthorizationRequest
@ -53,7 +54,7 @@ validate_metadata(TokenRequest)
# Setup templates
template_dir = os.path.join(os.path.dirname(__file__), "templates")
template_dir = pathlib.Path(__file__).parent.parent / "templates"
templates = Jinja2Templates(directory=template_dir)
@ -86,7 +87,12 @@ async def oauth_protected_resource(request: Request):
def login_form(request: Request, form_data: dict, error: str | None = None):
return templates.TemplateResponse(
"login.html",
{"request": request, "form_data": form_data, "error": error},
{
"request": request,
"form_data": form_data,
"error": error,
"action": "/oauth/login",
},
)

View File

@ -275,9 +275,6 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
if not user_session or user_session.expires_at < now:
return None
logger.info(
f"Loading access token: {token}, state: {user_session.oauth_state}"
)
return AccessToken(
token=token,
client_id=user_session.oauth_state.client_id,

View File

@ -2,8 +2,16 @@
SQLAdmin views for the knowledge base database models.
"""
import uuid
from sqladmin import Admin, ModelView
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import RedirectResponse
import logging
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import (
Chunk,
SourceItem,
@ -21,8 +29,11 @@ from memory.common.db.models import (
AgentObservation,
Note,
User,
UserSession,
OAuthState,
)
logger = logging.getLogger(__name__)
DEFAULT_COLUMNS = (
"modality",
@ -218,7 +229,7 @@ class UserAdmin(ModelView, model=User):
def setup_admin(admin: Admin):
"""Add all admin views to the admin instance."""
"""Add all admin views to the admin instance with OAuth protection."""
admin.add_view(SourceItemAdmin)
admin.add_view(AgentObservationAdmin)
admin.add_view(NoteAdmin)

View File

@ -27,8 +27,12 @@ from memory.common.db.connection import get_engine
from memory.common.db.models import User
from memory.api.admin import setup_admin
from memory.api.search import search, SearchResult
from memory.api.auth import (
get_current_user,
AuthenticationMiddleware,
router as auth_router,
)
from memory.api.MCP.base import mcp
from memory.api.auth import get_current_user
logger = logging.getLogger(__name__)
@ -41,19 +45,6 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
# Add request logging middleware
# @app.middleware("http")
# async def log_requests(request, call_next):
# logger.info(f"Main app: {request.method} {request.url.path}")
# if request.url.path.startswith("/mcp"):
# logger.info(f"Request headers: {dict(request.headers)}")
# response = await call_next(request)
# return response
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify actual origins
@ -62,10 +53,14 @@ app.add_middleware(
allow_headers=["*"],
)
# SQLAdmin setup
# SQLAdmin setup with OAuth protection
engine = get_engine()
admin = Admin(app, engine)
# Setup admin with OAuth protection using existing OAuth provider
setup_admin(admin)
app.include_router(auth_router)
app.add_middleware(AuthenticationMiddleware)
# Add health check to MCP server instead of main app

View File

@ -1,10 +1,10 @@
from datetime import datetime, timedelta, timezone
import textwrap
from typing import cast
import logging
import pathlib
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from starlette.middleware.base import BaseHTTPMiddleware
from memory.common import settings
from sqlalchemy.orm import Session as DBSession, scoped_session
@ -52,28 +52,35 @@ def create_user_session(
return str(session.id)
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
"""Get user from session ID if session is valid"""
def get_user_session(
request: Request, db: DBSession | scoped_session
) -> UserSession | None:
"""Get session ID from request"""
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
return None
session = db.query(UserSession).get(session_id)
if not session:
return None
now = datetime.now(timezone.utc)
if session.expires_at.replace(tzinfo=timezone.utc) > now:
if session.expires_at.replace(tzinfo=timezone.utc) < now:
return None
return session
def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None:
"""Get user from session ID if session is valid"""
if session := get_user_session(request, db):
return session.user
return None
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
"""FastAPI dependency to get current authenticated user"""
# Check for session ID in header or cookie
session_id = request.headers.get(
settings.SESSION_HEADER_NAME
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
raise HTTPException(status_code=401, detail="No session provided")
user = get_session_user(session_id, db)
user = get_session_user(request, db)
if not user:
raise HTTPException(status_code=401, detail="Invalid or expired session")
@ -117,21 +124,18 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
@router.get("/login", response_model=LoginResponse)
def login_page():
def login_page(request: Request):
"""Login page"""
return HTMLResponse(
content=textwrap.dedent("""
<html>
<body>
<h1>Login</h1>
<form method="post" action="/auth/login-form">
<input type="email" name="email" placeholder="Email" />
<input type="password" name="password" placeholder="Password" />
<button type="submit">Login</button>
</form>
</body>
</html>
"""),
template_dir = pathlib.Path(__file__).parent / "templates"
templates = Jinja2Templates(directory=template_dir)
return templates.TemplateResponse(
"login.html",
{
"request": request,
"action": router.url_path_for("login_form"),
"error": None,
"form_data": {},
},
)
@ -170,9 +174,17 @@ def login_form(
return LoginResponse(session_id=session_id, **user.serialize())
@router.post("/logout")
def logout(response: Response, user: User = Depends(get_current_user)):
@router.api_route("/logout", methods=["GET", "POST"])
def logout(
request: Request,
response: Response,
db: DBSession = Depends(get_session),
):
"""Logout and clear session"""
session = get_user_session(request, db)
if session:
db.delete(session)
db.commit()
response.delete_cookie(settings.SESSION_COOKIE_NAME)
return {"message": "Logged out successfully"}
@ -192,6 +204,11 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
"/auth/login",
"/auth/login-form",
"/auth/register",
"/register",
"/token",
"/mcp",
"/oauth/",
"/.well-known/",
}
async def dispatch(self, request: Request, call_next):
@ -205,10 +222,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# Check for session ID in header or cookie
session_id = request.headers.get(
settings.SESSION_HEADER_NAME
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
return Response(
content="Authentication required",
@ -218,7 +232,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
# Validate session and get user
with make_session() as session:
user = get_session_user(session_id, session)
user = get_session_user(request, session)
if not user:
return Response(
content="Invalid or expired session",

View File

@ -136,7 +136,7 @@
</div>
{% endif %}
<form method="post" action="/oauth/login">
<form method="post" action="{{ action }}">
{% for key, value in form_data.items() %}
<input type="hidden" name="{{ key }}" value="{{ value }}">
{% endfor %}

View File

@ -133,7 +133,6 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240
# API settings
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
HTTPS = boolean_env("HTTPS", False)
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))