mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
protect admin
This commit is contained in:
parent
986d5b9957
commit
d17d724631
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
||||||
@ -53,7 +54,7 @@ validate_metadata(TokenRequest)
|
|||||||
|
|
||||||
|
|
||||||
# Setup templates
|
# Setup templates
|
||||||
template_dir = os.path.join(os.path.dirname(__file__), "templates")
|
template_dir = pathlib.Path(__file__).parent.parent / "templates"
|
||||||
templates = Jinja2Templates(directory=template_dir)
|
templates = Jinja2Templates(directory=template_dir)
|
||||||
|
|
||||||
|
|
||||||
@ -86,7 +87,12 @@ async def oauth_protected_resource(request: Request):
|
|||||||
def login_form(request: Request, form_data: dict, error: str | None = None):
|
def login_form(request: Request, form_data: dict, error: str | None = None):
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"login.html",
|
"login.html",
|
||||||
{"request": request, "form_data": form_data, "error": error},
|
{
|
||||||
|
"request": request,
|
||||||
|
"form_data": form_data,
|
||||||
|
"error": error,
|
||||||
|
"action": "/oauth/login",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -275,9 +275,6 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
if not user_session or user_session.expires_at < now:
|
if not user_session or user_session.expires_at < now:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Loading access token: {token}, state: {user_session.oauth_state}"
|
|
||||||
)
|
|
||||||
return AccessToken(
|
return AccessToken(
|
||||||
token=token,
|
token=token,
|
||||||
client_id=user_session.oauth_state.client_id,
|
client_id=user_session.oauth_state.client_id,
|
||||||
|
@ -2,8 +2,16 @@
|
|||||||
SQLAdmin views for the knowledge base database models.
|
SQLAdmin views for the knowledge base database models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
from sqladmin import Admin, ModelView
|
from sqladmin import Admin, ModelView
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import RedirectResponse
|
||||||
|
import logging
|
||||||
|
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||||
|
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
Chunk,
|
Chunk,
|
||||||
SourceItem,
|
SourceItem,
|
||||||
@ -21,8 +29,11 @@ from memory.common.db.models import (
|
|||||||
AgentObservation,
|
AgentObservation,
|
||||||
Note,
|
Note,
|
||||||
User,
|
User,
|
||||||
|
UserSession,
|
||||||
|
OAuthState,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_COLUMNS = (
|
DEFAULT_COLUMNS = (
|
||||||
"modality",
|
"modality",
|
||||||
@ -218,7 +229,7 @@ class UserAdmin(ModelView, model=User):
|
|||||||
|
|
||||||
|
|
||||||
def setup_admin(admin: Admin):
|
def setup_admin(admin: Admin):
|
||||||
"""Add all admin views to the admin instance."""
|
"""Add all admin views to the admin instance with OAuth protection."""
|
||||||
admin.add_view(SourceItemAdmin)
|
admin.add_view(SourceItemAdmin)
|
||||||
admin.add_view(AgentObservationAdmin)
|
admin.add_view(AgentObservationAdmin)
|
||||||
admin.add_view(NoteAdmin)
|
admin.add_view(NoteAdmin)
|
||||||
|
@ -27,8 +27,12 @@ from memory.common.db.connection import get_engine
|
|||||||
from memory.common.db.models import User
|
from memory.common.db.models import User
|
||||||
from memory.api.admin import setup_admin
|
from memory.api.admin import setup_admin
|
||||||
from memory.api.search import search, SearchResult
|
from memory.api.search import search, SearchResult
|
||||||
|
from memory.api.auth import (
|
||||||
|
get_current_user,
|
||||||
|
AuthenticationMiddleware,
|
||||||
|
router as auth_router,
|
||||||
|
)
|
||||||
from memory.api.MCP.base import mcp
|
from memory.api.MCP.base import mcp
|
||||||
from memory.api.auth import get_current_user
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -41,19 +45,6 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
# Add request logging middleware
|
|
||||||
# @app.middleware("http")
|
|
||||||
# async def log_requests(request, call_next):
|
|
||||||
# logger.info(f"Main app: {request.method} {request.url.path}")
|
|
||||||
# if request.url.path.startswith("/mcp"):
|
|
||||||
# logger.info(f"Request headers: {dict(request.headers)}")
|
|
||||||
# response = await call_next(request)
|
|
||||||
# return response
|
|
||||||
|
|
||||||
|
|
||||||
# Add CORS middleware
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"], # In production, specify actual origins
|
allow_origins=["*"], # In production, specify actual origins
|
||||||
@ -62,10 +53,14 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# SQLAdmin setup
|
# SQLAdmin setup with OAuth protection
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
admin = Admin(app, engine)
|
admin = Admin(app, engine)
|
||||||
|
|
||||||
|
# Setup admin with OAuth protection using existing OAuth provider
|
||||||
setup_admin(admin)
|
setup_admin(admin)
|
||||||
|
app.include_router(auth_router)
|
||||||
|
app.add_middleware(AuthenticationMiddleware)
|
||||||
|
|
||||||
|
|
||||||
# Add health check to MCP server instead of main app
|
# Add health check to MCP server instead of main app
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
import textwrap
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
import logging
|
import logging
|
||||||
|
import pathlib
|
||||||
|
|
||||||
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.templating import Jinja2Templates
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from sqlalchemy.orm import Session as DBSession, scoped_session
|
from sqlalchemy.orm import Session as DBSession, scoped_session
|
||||||
@ -52,28 +52,35 @@ def create_user_session(
|
|||||||
return str(session.id)
|
return str(session.id)
|
||||||
|
|
||||||
|
|
||||||
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
|
def get_user_session(
|
||||||
"""Get user from session ID if session is valid"""
|
request: Request, db: DBSession | scoped_session
|
||||||
|
) -> UserSession | None:
|
||||||
|
"""Get session ID from request"""
|
||||||
|
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
return None
|
||||||
|
|
||||||
session = db.query(UserSession).get(session_id)
|
session = db.query(UserSession).get(session_id)
|
||||||
if not session:
|
if not session:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
if session.expires_at.replace(tzinfo=timezone.utc) > now:
|
if session.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
|
return None
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None:
|
||||||
|
"""Get user from session ID if session is valid"""
|
||||||
|
if session := get_user_session(request, db):
|
||||||
return session.user
|
return session.user
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
||||||
"""FastAPI dependency to get current authenticated user"""
|
"""FastAPI dependency to get current authenticated user"""
|
||||||
# Check for session ID in header or cookie
|
user = get_session_user(request, db)
|
||||||
session_id = request.headers.get(
|
|
||||||
settings.SESSION_HEADER_NAME
|
|
||||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
|
||||||
|
|
||||||
if not session_id:
|
|
||||||
raise HTTPException(status_code=401, detail="No session provided")
|
|
||||||
|
|
||||||
user = get_session_user(session_id, db)
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
||||||
|
|
||||||
@ -117,21 +124,18 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/login", response_model=LoginResponse)
|
@router.get("/login", response_model=LoginResponse)
|
||||||
def login_page():
|
def login_page(request: Request):
|
||||||
"""Login page"""
|
"""Login page"""
|
||||||
return HTMLResponse(
|
template_dir = pathlib.Path(__file__).parent / "templates"
|
||||||
content=textwrap.dedent("""
|
templates = Jinja2Templates(directory=template_dir)
|
||||||
<html>
|
return templates.TemplateResponse(
|
||||||
<body>
|
"login.html",
|
||||||
<h1>Login</h1>
|
{
|
||||||
<form method="post" action="/auth/login-form">
|
"request": request,
|
||||||
<input type="email" name="email" placeholder="Email" />
|
"action": router.url_path_for("login_form"),
|
||||||
<input type="password" name="password" placeholder="Password" />
|
"error": None,
|
||||||
<button type="submit">Login</button>
|
"form_data": {},
|
||||||
</form>
|
},
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"""),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -170,9 +174,17 @@ def login_form(
|
|||||||
return LoginResponse(session_id=session_id, **user.serialize())
|
return LoginResponse(session_id=session_id, **user.serialize())
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout")
|
@router.api_route("/logout", methods=["GET", "POST"])
|
||||||
def logout(response: Response, user: User = Depends(get_current_user)):
|
def logout(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
db: DBSession = Depends(get_session),
|
||||||
|
):
|
||||||
"""Logout and clear session"""
|
"""Logout and clear session"""
|
||||||
|
session = get_user_session(request, db)
|
||||||
|
if session:
|
||||||
|
db.delete(session)
|
||||||
|
db.commit()
|
||||||
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
||||||
return {"message": "Logged out successfully"}
|
return {"message": "Logged out successfully"}
|
||||||
|
|
||||||
@ -192,6 +204,11 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
"/auth/login",
|
"/auth/login",
|
||||||
"/auth/login-form",
|
"/auth/login-form",
|
||||||
"/auth/register",
|
"/auth/register",
|
||||||
|
"/register",
|
||||||
|
"/token",
|
||||||
|
"/mcp",
|
||||||
|
"/oauth/",
|
||||||
|
"/.well-known/",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
@ -205,10 +222,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# Check for session ID in header or cookie
|
# Check for session ID in header or cookie
|
||||||
session_id = request.headers.get(
|
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||||
settings.SESSION_HEADER_NAME
|
|
||||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
return Response(
|
return Response(
|
||||||
content="Authentication required",
|
content="Authentication required",
|
||||||
@ -218,7 +232,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# Validate session and get user
|
# Validate session and get user
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
user = get_session_user(session_id, session)
|
user = get_session_user(request, session)
|
||||||
if not user:
|
if not user:
|
||||||
return Response(
|
return Response(
|
||||||
content="Invalid or expired session",
|
content="Invalid or expired session",
|
||||||
|
@ -136,7 +136,7 @@
|
|||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
<form method="post" action="/oauth/login">
|
<form method="post" action="{{ action }}">
|
||||||
{% for key, value in form_data.items() %}
|
{% for key, value in form_data.items() %}
|
||||||
<input type="hidden" name="{{ key }}" value="{{ value }}">
|
<input type="hidden" name="{{ key }}" value="{{ value }}">
|
||||||
{% endfor %}
|
{% endfor %}
|
@ -133,7 +133,6 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240
|
|||||||
# API settings
|
# API settings
|
||||||
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
||||||
HTTPS = boolean_env("HTTPS", False)
|
HTTPS = boolean_env("HTTPS", False)
|
||||||
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
|
||||||
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
||||||
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
||||||
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user