protect admin

This commit is contained in:
Daniel O'Connell 2025-06-06 18:24:09 +02:00
parent 986d5b9957
commit d17d724631
7 changed files with 81 additions and 59 deletions

View File

@ -1,5 +1,6 @@
import logging import logging
import os import os
import pathlib
from typing import cast from typing import cast
from mcp.server.auth.handlers.authorize import AuthorizationRequest from mcp.server.auth.handlers.authorize import AuthorizationRequest
@ -53,7 +54,7 @@ validate_metadata(TokenRequest)
# Setup templates # Setup templates
template_dir = os.path.join(os.path.dirname(__file__), "templates") template_dir = pathlib.Path(__file__).parent.parent / "templates"
templates = Jinja2Templates(directory=template_dir) templates = Jinja2Templates(directory=template_dir)
@ -86,7 +87,12 @@ async def oauth_protected_resource(request: Request):
def login_form(request: Request, form_data: dict, error: str | None = None): def login_form(request: Request, form_data: dict, error: str | None = None):
return templates.TemplateResponse( return templates.TemplateResponse(
"login.html", "login.html",
{"request": request, "form_data": form_data, "error": error}, {
"request": request,
"form_data": form_data,
"error": error,
"action": "/oauth/login",
},
) )

View File

@ -275,9 +275,6 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
if not user_session or user_session.expires_at < now: if not user_session or user_session.expires_at < now:
return None return None
logger.info(
f"Loading access token: {token}, state: {user_session.oauth_state}"
)
return AccessToken( return AccessToken(
token=token, token=token,
client_id=user_session.oauth_state.client_id, client_id=user_session.oauth_state.client_id,

View File

@ -2,8 +2,16 @@
SQLAdmin views for the knowledge base database models. SQLAdmin views for the knowledge base database models.
""" """
import uuid
from sqladmin import Admin, ModelView from sqladmin import Admin, ModelView
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import RedirectResponse
import logging
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import ( from memory.common.db.models import (
Chunk, Chunk,
SourceItem, SourceItem,
@ -21,8 +29,11 @@ from memory.common.db.models import (
AgentObservation, AgentObservation,
Note, Note,
User, User,
UserSession,
OAuthState,
) )
logger = logging.getLogger(__name__)
DEFAULT_COLUMNS = ( DEFAULT_COLUMNS = (
"modality", "modality",
@ -218,7 +229,7 @@ class UserAdmin(ModelView, model=User):
def setup_admin(admin: Admin): def setup_admin(admin: Admin):
"""Add all admin views to the admin instance.""" """Add all admin views to the admin instance with OAuth protection."""
admin.add_view(SourceItemAdmin) admin.add_view(SourceItemAdmin)
admin.add_view(AgentObservationAdmin) admin.add_view(AgentObservationAdmin)
admin.add_view(NoteAdmin) admin.add_view(NoteAdmin)

View File

@ -27,8 +27,12 @@ from memory.common.db.connection import get_engine
from memory.common.db.models import User from memory.common.db.models import User
from memory.api.admin import setup_admin from memory.api.admin import setup_admin
from memory.api.search import search, SearchResult from memory.api.search import search, SearchResult
from memory.api.auth import (
get_current_user,
AuthenticationMiddleware,
router as auth_router,
)
from memory.api.MCP.base import mcp from memory.api.MCP.base import mcp
from memory.api.auth import get_current_user
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,19 +45,6 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Knowledge Base API", lifespan=lifespan) app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
# Add request logging middleware
# @app.middleware("http")
# async def log_requests(request, call_next):
# logger.info(f"Main app: {request.method} {request.url.path}")
# if request.url.path.startswith("/mcp"):
# logger.info(f"Request headers: {dict(request.headers)}")
# response = await call_next(request)
# return response
# Add CORS middleware
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], # In production, specify actual origins allow_origins=["*"], # In production, specify actual origins
@ -62,10 +53,14 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# SQLAdmin setup # SQLAdmin setup with OAuth protection
engine = get_engine() engine = get_engine()
admin = Admin(app, engine) admin = Admin(app, engine)
# Setup admin with OAuth protection using existing OAuth provider
setup_admin(admin) setup_admin(admin)
app.include_router(auth_router)
app.add_middleware(AuthenticationMiddleware)
# Add health check to MCP server instead of main app # Add health check to MCP server instead of main app

View File

@ -1,10 +1,10 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import textwrap
from typing import cast from typing import cast
import logging import logging
import pathlib
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from memory.common import settings from memory.common import settings
from sqlalchemy.orm import Session as DBSession, scoped_session from sqlalchemy.orm import Session as DBSession, scoped_session
@ -52,28 +52,35 @@ def create_user_session(
return str(session.id) return str(session.id)
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None: def get_user_session(
"""Get user from session ID if session is valid""" request: Request, db: DBSession | scoped_session
) -> UserSession | None:
"""Get session ID from request"""
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
return None
session = db.query(UserSession).get(session_id) session = db.query(UserSession).get(session_id)
if not session: if not session:
return None return None
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
if session.expires_at.replace(tzinfo=timezone.utc) > now: if session.expires_at.replace(tzinfo=timezone.utc) < now:
return None
return session
def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None:
"""Get user from session ID if session is valid"""
if session := get_user_session(request, db):
return session.user return session.user
return None return None
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User: def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
"""FastAPI dependency to get current authenticated user""" """FastAPI dependency to get current authenticated user"""
# Check for session ID in header or cookie user = get_session_user(request, db)
session_id = request.headers.get(
settings.SESSION_HEADER_NAME
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
raise HTTPException(status_code=401, detail="No session provided")
user = get_session_user(session_id, db)
if not user: if not user:
raise HTTPException(status_code=401, detail="Invalid or expired session") raise HTTPException(status_code=401, detail="Invalid or expired session")
@ -117,21 +124,18 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
@router.get("/login", response_model=LoginResponse) @router.get("/login", response_model=LoginResponse)
def login_page(): def login_page(request: Request):
"""Login page""" """Login page"""
return HTMLResponse( template_dir = pathlib.Path(__file__).parent / "templates"
content=textwrap.dedent(""" templates = Jinja2Templates(directory=template_dir)
<html> return templates.TemplateResponse(
<body> "login.html",
<h1>Login</h1> {
<form method="post" action="/auth/login-form"> "request": request,
<input type="email" name="email" placeholder="Email" /> "action": router.url_path_for("login_form"),
<input type="password" name="password" placeholder="Password" /> "error": None,
<button type="submit">Login</button> "form_data": {},
</form> },
</body>
</html>
"""),
) )
@ -170,9 +174,17 @@ def login_form(
return LoginResponse(session_id=session_id, **user.serialize()) return LoginResponse(session_id=session_id, **user.serialize())
@router.post("/logout") @router.api_route("/logout", methods=["GET", "POST"])
def logout(response: Response, user: User = Depends(get_current_user)): def logout(
request: Request,
response: Response,
db: DBSession = Depends(get_session),
):
"""Logout and clear session""" """Logout and clear session"""
session = get_user_session(request, db)
if session:
db.delete(session)
db.commit()
response.delete_cookie(settings.SESSION_COOKIE_NAME) response.delete_cookie(settings.SESSION_COOKIE_NAME)
return {"message": "Logged out successfully"} return {"message": "Logged out successfully"}
@ -192,6 +204,11 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
"/auth/login", "/auth/login",
"/auth/login-form", "/auth/login-form",
"/auth/register", "/auth/register",
"/register",
"/token",
"/mcp",
"/oauth/",
"/.well-known/",
} }
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
@ -205,10 +222,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
return await call_next(request) return await call_next(request)
# Check for session ID in header or cookie # Check for session ID in header or cookie
session_id = request.headers.get( session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
settings.SESSION_HEADER_NAME
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id: if not session_id:
return Response( return Response(
content="Authentication required", content="Authentication required",
@ -218,7 +232,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
# Validate session and get user # Validate session and get user
with make_session() as session: with make_session() as session:
user = get_session_user(session_id, session) user = get_session_user(request, session)
if not user: if not user:
return Response( return Response(
content="Invalid or expired session", content="Invalid or expired session",

View File

@ -136,7 +136,7 @@
</div> </div>
{% endif %} {% endif %}
<form method="post" action="/oauth/login"> <form method="post" action="{{ action }}">
{% for key, value in form_data.items() %} {% for key, value in form_data.items() %}
<input type="hidden" name="{{ key }}" value="{{ value }}"> <input type="hidden" name="{{ key }}" value="{{ value }}">
{% endfor %} {% endfor %}

View File

@ -133,7 +133,6 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240
# API settings # API settings
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000") SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
HTTPS = boolean_env("HTTPS", False) HTTPS = boolean_env("HTTPS", False)
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id") SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60)) SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30)) SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))