mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 05:14:43 +02:00
protect admin
This commit is contained in:
parent
986d5b9957
commit
d17d724631
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from typing import cast
|
||||
|
||||
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
||||
@ -53,7 +54,7 @@ validate_metadata(TokenRequest)
|
||||
|
||||
|
||||
# Setup templates
|
||||
template_dir = os.path.join(os.path.dirname(__file__), "templates")
|
||||
template_dir = pathlib.Path(__file__).parent.parent / "templates"
|
||||
templates = Jinja2Templates(directory=template_dir)
|
||||
|
||||
|
||||
@ -86,7 +87,12 @@ async def oauth_protected_resource(request: Request):
|
||||
def login_form(request: Request, form_data: dict, error: str | None = None):
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{"request": request, "form_data": form_data, "error": error},
|
||||
{
|
||||
"request": request,
|
||||
"form_data": form_data,
|
||||
"error": error,
|
||||
"action": "/oauth/login",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
@ -275,9 +275,6 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
if not user_session or user_session.expires_at < now:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Loading access token: {token}, state: {user_session.oauth_state}"
|
||||
)
|
||||
return AccessToken(
|
||||
token=token,
|
||||
client_id=user_session.oauth_state.client_id,
|
||||
|
@ -2,8 +2,16 @@
|
||||
SQLAdmin views for the knowledge base database models.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from sqladmin import Admin, ModelView
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
import logging
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import (
|
||||
Chunk,
|
||||
SourceItem,
|
||||
@ -21,8 +29,11 @@ from memory.common.db.models import (
|
||||
AgentObservation,
|
||||
Note,
|
||||
User,
|
||||
UserSession,
|
||||
OAuthState,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_COLUMNS = (
|
||||
"modality",
|
||||
@ -218,7 +229,7 @@ class UserAdmin(ModelView, model=User):
|
||||
|
||||
|
||||
def setup_admin(admin: Admin):
|
||||
"""Add all admin views to the admin instance."""
|
||||
"""Add all admin views to the admin instance with OAuth protection."""
|
||||
admin.add_view(SourceItemAdmin)
|
||||
admin.add_view(AgentObservationAdmin)
|
||||
admin.add_view(NoteAdmin)
|
||||
|
@ -27,8 +27,12 @@ from memory.common.db.connection import get_engine
|
||||
from memory.common.db.models import User
|
||||
from memory.api.admin import setup_admin
|
||||
from memory.api.search import search, SearchResult
|
||||
from memory.api.auth import (
|
||||
get_current_user,
|
||||
AuthenticationMiddleware,
|
||||
router as auth_router,
|
||||
)
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.api.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -41,19 +45,6 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||
|
||||
|
||||
# Add request logging middleware
|
||||
# @app.middleware("http")
|
||||
# async def log_requests(request, call_next):
|
||||
# logger.info(f"Main app: {request.method} {request.url.path}")
|
||||
# if request.url.path.startswith("/mcp"):
|
||||
# logger.info(f"Request headers: {dict(request.headers)}")
|
||||
# response = await call_next(request)
|
||||
# return response
|
||||
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify actual origins
|
||||
@ -62,10 +53,14 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# SQLAdmin setup
|
||||
# SQLAdmin setup with OAuth protection
|
||||
engine = get_engine()
|
||||
admin = Admin(app, engine)
|
||||
|
||||
# Setup admin with OAuth protection using existing OAuth provider
|
||||
setup_admin(admin)
|
||||
app.include_router(auth_router)
|
||||
app.add_middleware(AuthenticationMiddleware)
|
||||
|
||||
|
||||
# Add health check to MCP server instead of main app
|
||||
|
@ -1,10 +1,10 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import textwrap
|
||||
from typing import cast
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from memory.common import settings
|
||||
from sqlalchemy.orm import Session as DBSession, scoped_session
|
||||
@ -52,28 +52,35 @@ def create_user_session(
|
||||
return str(session.id)
|
||||
|
||||
|
||||
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
|
||||
"""Get user from session ID if session is valid"""
|
||||
def get_user_session(
|
||||
request: Request, db: DBSession | scoped_session
|
||||
) -> UserSession | None:
|
||||
"""Get session ID from request"""
|
||||
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
session = db.query(UserSession).get(session_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if session.expires_at.replace(tzinfo=timezone.utc) > now:
|
||||
if session.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||
return None
|
||||
return session
|
||||
|
||||
|
||||
def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None:
|
||||
"""Get user from session ID if session is valid"""
|
||||
if session := get_user_session(request, db):
|
||||
return session.user
|
||||
return None
|
||||
|
||||
|
||||
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
||||
"""FastAPI dependency to get current authenticated user"""
|
||||
# Check for session ID in header or cookie
|
||||
session_id = request.headers.get(
|
||||
settings.SESSION_HEADER_NAME
|
||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=401, detail="No session provided")
|
||||
|
||||
user = get_session_user(session_id, db)
|
||||
user = get_session_user(request, db)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
||||
|
||||
@ -117,21 +124,18 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
|
||||
|
||||
|
||||
@router.get("/login", response_model=LoginResponse)
|
||||
def login_page():
|
||||
def login_page(request: Request):
|
||||
"""Login page"""
|
||||
return HTMLResponse(
|
||||
content=textwrap.dedent("""
|
||||
<html>
|
||||
<body>
|
||||
<h1>Login</h1>
|
||||
<form method="post" action="/auth/login-form">
|
||||
<input type="email" name="email" placeholder="Email" />
|
||||
<input type="password" name="password" placeholder="Password" />
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
"""),
|
||||
template_dir = pathlib.Path(__file__).parent / "templates"
|
||||
templates = Jinja2Templates(directory=template_dir)
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{
|
||||
"request": request,
|
||||
"action": router.url_path_for("login_form"),
|
||||
"error": None,
|
||||
"form_data": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -170,9 +174,17 @@ def login_form(
|
||||
return LoginResponse(session_id=session_id, **user.serialize())
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(response: Response, user: User = Depends(get_current_user)):
|
||||
@router.api_route("/logout", methods=["GET", "POST"])
|
||||
def logout(
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: DBSession = Depends(get_session),
|
||||
):
|
||||
"""Logout and clear session"""
|
||||
session = get_user_session(request, db)
|
||||
if session:
|
||||
db.delete(session)
|
||||
db.commit()
|
||||
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
||||
return {"message": "Logged out successfully"}
|
||||
|
||||
@ -192,6 +204,11 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"/auth/login",
|
||||
"/auth/login-form",
|
||||
"/auth/register",
|
||||
"/register",
|
||||
"/token",
|
||||
"/mcp",
|
||||
"/oauth/",
|
||||
"/.well-known/",
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
@ -205,10 +222,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
# Check for session ID in header or cookie
|
||||
session_id = request.headers.get(
|
||||
settings.SESSION_HEADER_NAME
|
||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
|
||||
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
if not session_id:
|
||||
return Response(
|
||||
content="Authentication required",
|
||||
@ -218,7 +232,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Validate session and get user
|
||||
with make_session() as session:
|
||||
user = get_session_user(session_id, session)
|
||||
user = get_session_user(request, session)
|
||||
if not user:
|
||||
return Response(
|
||||
content="Invalid or expired session",
|
||||
|
@ -136,7 +136,7 @@
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<form method="post" action="/oauth/login">
|
||||
<form method="post" action="{{ action }}">
|
||||
{% for key, value in form_data.items() %}
|
||||
<input type="hidden" name="{{ key }}" value="{{ value }}">
|
||||
{% endfor %}
|
@ -133,7 +133,6 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240
|
||||
# API settings
|
||||
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
||||
HTTPS = boolean_env("HTTPS", False)
|
||||
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
||||
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
||||
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
||||
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
||||
|
Loading…
x
Reference in New Issue
Block a user