diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py index c3b7369..440d65f 100644 --- a/src/memory/api/MCP/base.py +++ b/src/memory/api/MCP/base.py @@ -1,5 +1,6 @@ import logging import os +import pathlib from typing import cast from mcp.server.auth.handlers.authorize import AuthorizationRequest @@ -53,7 +54,7 @@ validate_metadata(TokenRequest) # Setup templates -template_dir = os.path.join(os.path.dirname(__file__), "templates") +template_dir = pathlib.Path(__file__).parent.parent / "templates" templates = Jinja2Templates(directory=template_dir) @@ -86,7 +87,12 @@ async def oauth_protected_resource(request: Request): def login_form(request: Request, form_data: dict, error: str | None = None): return templates.TemplateResponse( "login.html", - {"request": request, "form_data": form_data, "error": error}, + { + "request": request, + "form_data": form_data, + "error": error, + "action": "/oauth/login", + }, ) diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py index 2405890..ad96000 100644 --- a/src/memory/api/MCP/oauth_provider.py +++ b/src/memory/api/MCP/oauth_provider.py @@ -275,9 +275,6 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): if not user_session or user_session.expires_at < now: return None - logger.info( - f"Loading access token: {token}, state: {user_session.oauth_state}" - ) return AccessToken( token=token, client_id=user_session.oauth_state.client_id, diff --git a/src/memory/api/admin.py b/src/memory/api/admin.py index 540a599..9d9bc08 100644 --- a/src/memory/api/admin.py +++ b/src/memory/api/admin.py @@ -2,8 +2,16 @@ SQLAdmin views for the knowledge base database models. """ +import uuid from sqladmin import Admin, ModelView - +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import RedirectResponse +import logging +from mcp.server.auth.provider import OAuthAuthorizationServerProvider +from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME +from memory.common import settings +from memory.common.db.connection import make_session from memory.common.db.models import ( Chunk, SourceItem, @@ -21,8 +29,11 @@ from memory.common.db.models import ( AgentObservation, Note, User, + UserSession, + OAuthState, ) +logger = logging.getLogger(__name__) DEFAULT_COLUMNS = ( "modality", @@ -218,7 +229,7 @@ class UserAdmin(ModelView, model=User): def setup_admin(admin: Admin): - """Add all admin views to the admin instance.""" + """Add all admin views to the admin instance with OAuth protection.""" admin.add_view(SourceItemAdmin) admin.add_view(AgentObservationAdmin) admin.add_view(NoteAdmin) diff --git a/src/memory/api/app.py b/src/memory/api/app.py index aaf2b17..2c4556a 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -27,8 +27,12 @@ from memory.common.db.connection import get_engine from memory.common.db.models import User from memory.api.admin import setup_admin from memory.api.search import search, SearchResult +from memory.api.auth import ( + get_current_user, + AuthenticationMiddleware, + router as auth_router, +) from memory.api.MCP.base import mcp -from memory.api.auth import get_current_user logger = logging.getLogger(__name__) @@ -41,19 +45,6 @@ async def lifespan(app: FastAPI): app = FastAPI(title="Knowledge Base API", lifespan=lifespan) - - -# Add request logging middleware -# @app.middleware("http") -# async def log_requests(request, call_next): -# logger.info(f"Main app: {request.method} {request.url.path}") -# if request.url.path.startswith("/mcp"): -# logger.info(f"Request headers: {dict(request.headers)}") -# response = await call_next(request) -# return response - - -# Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify actual origins @@ -62,10 +53,14 @@ app.add_middleware( allow_headers=["*"], ) -# SQLAdmin setup +# SQLAdmin setup with OAuth protection engine = get_engine() admin = Admin(app, engine) + +# Setup admin with OAuth protection using existing OAuth provider setup_admin(admin) +app.include_router(auth_router) +app.add_middleware(AuthenticationMiddleware) # Add health check to MCP server instead of main app diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py index 2b92e17..215d782 100644 --- a/src/memory/api/auth.py +++ b/src/memory/api/auth.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta, timezone -import textwrap from typing import cast import logging +import pathlib from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form -from fastapi.responses import HTMLResponse +from fastapi.templating import Jinja2Templates from starlette.middleware.base import BaseHTTPMiddleware from memory.common import settings from sqlalchemy.orm import Session as DBSession, scoped_session @@ -52,28 +52,35 @@ def create_user_session( return str(session.id) -def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None: - """Get user from session ID if session is valid""" +def get_user_session( + request: Request, db: DBSession | scoped_session +) -> UserSession | None: + """Get session ID from request""" + session_id = request.cookies.get(settings.SESSION_COOKIE_NAME) + + if not session_id: + return None + session = db.query(UserSession).get(session_id) if not session: return None + now = datetime.now(timezone.utc) - if session.expires_at.replace(tzinfo=timezone.utc) > now: + if session.expires_at.replace(tzinfo=timezone.utc) < now: + return None + return session + + +def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None: + """Get user from session ID if session is valid""" + if session := get_user_session(request, db): return session.user return None def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User: """FastAPI dependency to get current authenticated user""" - # Check for session ID in header or cookie - session_id = request.headers.get( - settings.SESSION_HEADER_NAME - ) or request.cookies.get(settings.SESSION_COOKIE_NAME) - - if not session_id: - raise HTTPException(status_code=401, detail="No session provided") - - user = get_session_user(session_id, db) + user = get_session_user(request, db) if not user: raise HTTPException(status_code=401, detail="Invalid or expired session") @@ -117,21 +124,18 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)): @router.get("/login", response_model=LoginResponse) -def login_page(): +def login_page(request: Request): """Login page""" - return HTMLResponse( - content=textwrap.dedent(""" - - -

Login

-
- - - -
- - - """), + template_dir = pathlib.Path(__file__).parent / "templates" + templates = Jinja2Templates(directory=template_dir) + return templates.TemplateResponse( + "login.html", + { + "request": request, + "action": router.url_path_for("login_form"), + "error": None, + "form_data": {}, + }, ) @@ -170,9 +174,17 @@ def login_form( return LoginResponse(session_id=session_id, **user.serialize()) -@router.post("/logout") -def logout(response: Response, user: User = Depends(get_current_user)): +@router.api_route("/logout", methods=["GET", "POST"]) +def logout( + request: Request, + response: Response, + db: DBSession = Depends(get_session), +): """Logout and clear session""" + session = get_user_session(request, db) + if session: + db.delete(session) + db.commit() response.delete_cookie(settings.SESSION_COOKIE_NAME) return {"message": "Logged out successfully"} @@ -192,6 +204,11 @@ class AuthenticationMiddleware(BaseHTTPMiddleware): "/auth/login", "/auth/login-form", "/auth/register", + "/register", + "/token", + "/mcp", + "/oauth/", + "/.well-known/", } async def dispatch(self, request: Request, call_next): @@ -205,10 +222,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware): return await call_next(request) # Check for session ID in header or cookie - session_id = request.headers.get( - settings.SESSION_HEADER_NAME - ) or request.cookies.get(settings.SESSION_COOKIE_NAME) - + session_id = request.cookies.get(settings.SESSION_COOKIE_NAME) if not session_id: return Response( content="Authentication required", @@ -218,7 +232,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware): # Validate session and get user with make_session() as session: - user = get_session_user(session_id, session) + user = get_session_user(request, session) if not user: return Response( content="Invalid or expired session", diff --git a/src/memory/api/MCP/templates/login.html b/src/memory/api/templates/login.html similarity index 98% rename from src/memory/api/MCP/templates/login.html rename to src/memory/api/templates/login.html index 5ec2011..100036a 100644 --- a/src/memory/api/MCP/templates/login.html +++ b/src/memory/api/templates/login.html @@ -136,7 +136,7 @@ {% endif %} -
+ {% for key, value in form_data.items() %} {% endfor %} diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index cd0ed13..ddbfa8e 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -133,7 +133,6 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240 # API settings SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000") HTTPS = boolean_env("HTTPS", False) -SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID") SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id") SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60)) SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))