diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py index c3b7369..440d65f 100644 --- a/src/memory/api/MCP/base.py +++ b/src/memory/api/MCP/base.py @@ -1,5 +1,6 @@ import logging import os +import pathlib from typing import cast from mcp.server.auth.handlers.authorize import AuthorizationRequest @@ -53,7 +54,7 @@ validate_metadata(TokenRequest) # Setup templates -template_dir = os.path.join(os.path.dirname(__file__), "templates") +template_dir = pathlib.Path(__file__).parent.parent / "templates" templates = Jinja2Templates(directory=template_dir) @@ -86,7 +87,12 @@ async def oauth_protected_resource(request: Request): def login_form(request: Request, form_data: dict, error: str | None = None): return templates.TemplateResponse( "login.html", - {"request": request, "form_data": form_data, "error": error}, + { + "request": request, + "form_data": form_data, + "error": error, + "action": "/oauth/login", + }, ) diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py index 2405890..ad96000 100644 --- a/src/memory/api/MCP/oauth_provider.py +++ b/src/memory/api/MCP/oauth_provider.py @@ -275,9 +275,6 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): if not user_session or user_session.expires_at < now: return None - logger.info( - f"Loading access token: {token}, state: {user_session.oauth_state}" - ) return AccessToken( token=token, client_id=user_session.oauth_state.client_id, diff --git a/src/memory/api/admin.py b/src/memory/api/admin.py index 540a599..9d9bc08 100644 --- a/src/memory/api/admin.py +++ b/src/memory/api/admin.py @@ -2,8 +2,16 @@ SQLAdmin views for the knowledge base database models. """ +import uuid from sqladmin import Admin, ModelView - +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import RedirectResponse +import logging +from mcp.server.auth.provider import OAuthAuthorizationServerProvider +from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME +from memory.common import settings +from memory.common.db.connection import make_session from memory.common.db.models import ( Chunk, SourceItem, @@ -21,8 +29,11 @@ from memory.common.db.models import ( AgentObservation, Note, User, + UserSession, + OAuthState, ) +logger = logging.getLogger(__name__) DEFAULT_COLUMNS = ( "modality", @@ -218,7 +229,7 @@ class UserAdmin(ModelView, model=User): def setup_admin(admin: Admin): - """Add all admin views to the admin instance.""" + """Add all admin views to the admin instance with OAuth protection.""" admin.add_view(SourceItemAdmin) admin.add_view(AgentObservationAdmin) admin.add_view(NoteAdmin) diff --git a/src/memory/api/app.py b/src/memory/api/app.py index aaf2b17..2c4556a 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -27,8 +27,12 @@ from memory.common.db.connection import get_engine from memory.common.db.models import User from memory.api.admin import setup_admin from memory.api.search import search, SearchResult +from memory.api.auth import ( + get_current_user, + AuthenticationMiddleware, + router as auth_router, +) from memory.api.MCP.base import mcp -from memory.api.auth import get_current_user logger = logging.getLogger(__name__) @@ -41,19 +45,6 @@ async def lifespan(app: FastAPI): app = FastAPI(title="Knowledge Base API", lifespan=lifespan) - - -# Add request logging middleware -# @app.middleware("http") -# async def log_requests(request, call_next): -# logger.info(f"Main app: {request.method} {request.url.path}") -# if request.url.path.startswith("/mcp"): -# logger.info(f"Request headers: {dict(request.headers)}") -# response = await call_next(request) -# return response - - -# Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify actual origins @@ -62,10 +53,14 @@ app.add_middleware( allow_headers=["*"], ) -# SQLAdmin setup +# SQLAdmin setup with OAuth protection engine = get_engine() admin = Admin(app, engine) + +# Setup admin with OAuth protection using existing OAuth provider setup_admin(admin) +app.include_router(auth_router) +app.add_middleware(AuthenticationMiddleware) # Add health check to MCP server instead of main app diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py index 2b92e17..215d782 100644 --- a/src/memory/api/auth.py +++ b/src/memory/api/auth.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta, timezone -import textwrap from typing import cast import logging +import pathlib from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form -from fastapi.responses import HTMLResponse +from fastapi.templating import Jinja2Templates from starlette.middleware.base import BaseHTTPMiddleware from memory.common import settings from sqlalchemy.orm import Session as DBSession, scoped_session @@ -52,28 +52,35 @@ def create_user_session( return str(session.id) -def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None: - """Get user from session ID if session is valid""" +def get_user_session( + request: Request, db: DBSession | scoped_session +) -> UserSession | None: + """Get session ID from request""" + session_id = request.cookies.get(settings.SESSION_COOKIE_NAME) + + if not session_id: + return None + session = db.query(UserSession).get(session_id) if not session: return None + now = datetime.now(timezone.utc) - if session.expires_at.replace(tzinfo=timezone.utc) > now: + if session.expires_at.replace(tzinfo=timezone.utc) < now: + return None + return session + + +def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None: + """Get user from session ID if session is valid""" + if session := get_user_session(request, db): return session.user return None def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User: """FastAPI dependency to get current authenticated user""" - # Check for session ID in header or cookie - session_id = request.headers.get( - settings.SESSION_HEADER_NAME - ) or request.cookies.get(settings.SESSION_COOKIE_NAME) - - if not session_id: - raise HTTPException(status_code=401, detail="No session provided") - - user = get_session_user(session_id, db) + user = get_session_user(request, db) if not user: raise HTTPException(status_code=401, detail="Invalid or expired session") @@ -117,21 +124,18 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)): @router.get("/login", response_model=LoginResponse) -def login_page(): +def login_page(request: Request): """Login page""" - return HTMLResponse( - content=textwrap.dedent(""" - -
-