Compare commits

...

4 Commits

Author SHA1 Message Date
Daniel O'Connell
6ee46d6215 remove proxy 2025-06-06 18:32:37 +02:00
Daniel O'Connell
d17d724631 protect admin 2025-06-06 18:24:09 +02:00
Daniel O'Connell
986d5b9957 oauth refresh + revoke 2025-06-06 17:07:25 +02:00
Daniel O'Connell
4556ef2c48 proper oauth flow 2025-06-06 12:55:48 +02:00
14 changed files with 1093 additions and 211 deletions

View File

@ -80,18 +80,6 @@ python tools/run_celery_task.py notes setup-git-notes --origin ssh://git@github.
For this to work you need to make sure you have set up the ssh keys in `secrets` (see the README.md
in that folder), and you will need to add the public key that is generated there to your git server.
### Authentication
The API uses session-based authentication. Login via:
```bash
curl -X POST http://localhost:8000/auth/login \
-H "Content-Type: application/json" \
-d '{"email": "user@example.com", "password": "yourpassword"}'
```
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
## Discord integration
If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications).

View File

@ -0,0 +1,115 @@
"""oauth codes
Revision ID: 1d6bc8015ea9
Revises: 58439dd3088b
Create Date: 2025-06-06 16:53:33.044558
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "1d6bc8015ea9"
down_revision: Union[str, None] = "58439dd3088b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"oauth_client",
sa.Column("client_id", sa.String(), nullable=False),
sa.Column("client_secret", sa.String(), nullable=True),
sa.Column("client_id_issued_at", sa.Numeric(), nullable=False),
sa.Column("client_secret_expires_at", sa.Numeric(), nullable=True),
sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=False),
sa.Column("token_endpoint_auth_method", sa.String(), nullable=False),
sa.Column("grant_types", sa.ARRAY(sa.String()), nullable=False),
sa.Column("response_types", sa.ARRAY(sa.String()), nullable=False),
sa.Column("scope", sa.String(), nullable=False),
sa.Column("client_name", sa.String(), nullable=False),
sa.Column("client_uri", sa.String(), nullable=True),
sa.Column("logo_uri", sa.String(), nullable=True),
sa.Column("contacts", sa.ARRAY(sa.String()), nullable=True),
sa.Column("tos_uri", sa.String(), nullable=True),
sa.Column("policy_uri", sa.String(), nullable=True),
sa.Column("jwks_uri", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("client_id"),
)
op.create_table(
"oauth_states",
sa.Column("state", sa.String(), nullable=False),
sa.Column("code", sa.String(), nullable=True),
sa.Column("redirect_uri", sa.String(), nullable=False),
sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False),
sa.Column("code_challenge", sa.String(), nullable=True),
sa.Column("stale", sa.Boolean(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("client_id", sa.String(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column("expires_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_client.client_id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"oauth_refresh_tokens",
sa.Column("token", sa.String(), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False),
sa.Column("access_token_session_id", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("client_id", sa.String(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column("expires_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["access_token_session_id"],
["user_sessions.id"],
),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_client.client_id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.add_column(
"user_sessions", sa.Column("oauth_state_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"user_sessions_oauth_state_id_fkey",
"user_sessions",
"oauth_states",
["oauth_state_id"],
["id"],
)
def downgrade() -> None:
op.drop_constraint(
"user_sessions_oauth_state_id_fkey", "user_sessions", type_="foreignkey"
)
op.drop_column("user_sessions", "oauth_state_id")
op.drop_table("oauth_refresh_tokens")
op.drop_table("oauth_states")
op.drop_table("oauth_client")

View File

@ -147,6 +147,7 @@ services:
<<: *env
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
QDRANT_URL: http://qdrant:6333
SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
secrets: [postgres_password]
volumes:
- ./memory_files:/app/memory_files:rw
@ -158,20 +159,6 @@ services:
ports:
- "8000:8000"
proxy:
build:
context: .
dockerfile: docker/api/Dockerfile
restart: unless-stopped
networks: [kbnet]
environment:
<<: *env
command: ["python", "/app/tools/simple_proxy.py", "--remote-server", "${PROXY_REMOTE_SERVER:-http://api:8000}", "--email", "${PROXY_EMAIL}", "--password", "${PROXY_PASSWORD}", "--port", "8001"]
volumes:
- ./tools:/app/tools:ro
ports:
- "8001:8001"
# ------------------------------------------------------------ Celery workers
worker:
<<: *worker-base
@ -201,4 +188,4 @@ services:
# restart: unless-stopped
# command: [ "--schedule", "0 0 4 * * *", "--cleanup" ]
# volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ]
# networks: [ kbnet ]
# networks: [ kbnet ]

132
src/memory/api/MCP/base.py Normal file
View File

@ -0,0 +1,132 @@
import logging
import os
import pathlib
from typing import cast
from mcp.server.auth.handlers.authorize import AuthorizationRequest
from mcp.server.auth.handlers.token import (
AuthorizationCodeRequest,
RefreshTokenRequest,
TokenRequest,
)
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
from mcp.server.fastmcp import FastMCP
from mcp.shared.auth import OAuthClientMetadata
from memory.common.db.models.users import User
from pydantic import AnyHttpUrl
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.templating import Jinja2Templates
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import OAuthState
logger = logging.getLogger(__name__)
def validate_metadata(klass: type):
orig_validate = klass.model_validate
def validate(data: dict):
data = dict(data)
if "redirect_uris" in data:
data["redirect_uris"] = [
str(uri).replace("cursor://", "http://")
for uri in data["redirect_uris"]
]
if "redirect_uri" in data:
data["redirect_uri"] = str(data["redirect_uri"]).replace(
"cursor://", "http://"
)
return orig_validate(data)
klass.model_validate = validate
validate_metadata(OAuthClientMetadata)
validate_metadata(AuthorizationRequest)
validate_metadata(AuthorizationCodeRequest)
validate_metadata(RefreshTokenRequest)
validate_metadata(TokenRequest)
# Setup templates
template_dir = pathlib.Path(__file__).parent.parent / "templates"
templates = Jinja2Templates(directory=template_dir)
oauth_provider = SimpleOAuthProvider()
auth_settings = AuthSettings(
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
client_registration_options=ClientRegistrationOptions(
enabled=True,
valid_scopes=["read", "write"],
default_scopes=["read"],
),
required_scopes=["read", "write"],
)
mcp = FastMCP(
"memory",
stateless_http=True,
auth_server_provider=oauth_provider,
auth=auth_settings,
)
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
async def oauth_protected_resource(request: Request):
"""OAuth 2.0 Protected Resource Metadata."""
logger.info("Protected resource metadata requested")
return JSONResponse(oauth_provider.get_protected_resource_metadata())
def login_form(request: Request, form_data: dict, error: str | None = None):
return templates.TemplateResponse(
"login.html",
{
"request": request,
"form_data": form_data,
"error": error,
"action": "/oauth/login",
},
)
@mcp.custom_route("/oauth/login", methods=["GET"])
async def login_page(request: Request):
"""Display the login page."""
form_data = dict(request.query_params)
state = form_data.get("state")
with make_session() as session:
oauth_state = (
session.query(OAuthState).filter(OAuthState.state == state).first()
)
if not oauth_state:
logger.error(f"State {state} not found in database")
raise ValueError("Invalid state parameter")
return login_form(request, form_data, None)
@mcp.custom_route("/oauth/login", methods=["POST"])
async def handle_login(request: Request):
"""Handle login form submission."""
form = await request.form()
oauth_params = {
key: value for key, value in form.items() if key not in ["email", "password"]
}
with make_session() as session:
user = session.query(User).filter(User.email == form.get("email")).first()
if not user or not user.is_valid_password(str(form.get("password", ""))):
logger.warning("Login failed - invalid credentials")
return login_form(request, oauth_params, "Invalid email or password")
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
redirect_url = redirect_url.replace("http://", "cursor://")
return RedirectResponse(url=redirect_url, status_code=302)

View File

@ -0,0 +1,411 @@
import secrets
import time
from typing import Optional, Any, cast
from urllib.parse import urlencode
import logging
from datetime import datetime, timezone
from memory.common.db.models.users import (
User,
UserSession,
OAuthClientInformation,
OAuthState,
OAuthRefreshToken,
OAuthToken as TokenBase,
)
from memory.common.db.connection import make_session, scoped_session
from memory.common import settings
from mcp.server.auth.provider import (
AccessToken,
AuthorizationParams,
AuthorizationCode,
OAuthAuthorizationServerProvider,
RefreshToken,
construct_redirect_uri,
)
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
logger = logging.getLogger(__name__)
# Token configuration constants
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
REFRESH_TOKEN_LIFETIME = 30 * 24 * 3600 # 30 days
def create_expiration(lifetime_seconds: int) -> datetime:
"""Create expiration datetime from lifetime in seconds."""
return datetime.fromtimestamp(time.time() + lifetime_seconds)
def generate_refresh_token() -> str:
"""Generate a new refresh token."""
return f"rt_{secrets.token_hex(32)}"
def create_access_token_session(
user_id: int, oauth_state_id: str | None = None
) -> UserSession:
"""Create a new access token session."""
return UserSession(
user_id=user_id,
oauth_state_id=oauth_state_id,
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
)
def create_refresh_token_record(
client_id: str,
user_id: int,
scopes: list[str],
access_token_session_id: Optional[str] = None,
) -> OAuthRefreshToken:
"""Create a new refresh token record."""
return OAuthRefreshToken(
token=generate_refresh_token(),
client_id=client_id,
user_id=user_id,
scopes=scopes,
expires_at=create_expiration(REFRESH_TOKEN_LIFETIME),
access_token_session_id=access_token_session_id,
)
def validate_refresh_token(db_refresh_token: OAuthRefreshToken) -> None:
"""Validate a refresh token, raising ValueError if invalid."""
now = datetime.now()
if db_refresh_token.expires_at < now: # type: ignore
logger.error(f"Refresh token expired: {db_refresh_token.token[:20]}...")
db_refresh_token.revoked = True # type: ignore
raise ValueError("Refresh token expired")
def create_oauth_token(
access_token: str, scopes: list[str], refresh_token: Optional[str] = None
) -> OAuthToken:
"""Create an OAuth token response."""
return OAuthToken(
access_token=access_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_LIFETIME,
refresh_token=refresh_token,
scope=" ".join(scopes),
)
def make_token(
db: scoped_session,
auth_state: TokenBase,
scopes: list[str],
) -> OAuthToken:
new_session = UserSession(
user_id=auth_state.user_id,
oauth_state_id=auth_state.id,
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
)
# Create refresh token
refresh_token = create_refresh_token_record(
cast(str, auth_state.client_id),
cast(int, auth_state.user_id),
scopes,
cast(str, new_session.id),
)
db.add(new_session)
db.add(refresh_token)
db.commit()
return create_oauth_token(
str(new_session.id),
scopes,
cast(str, refresh_token.token),
)
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""Get OAuth client information."""
with make_session() as session:
client = session.query(OAuthClientInformation).get(client_id)
return client and OAuthClientInformationFull(**client.serialize())
async def register_client(self, client_info: OAuthClientInformationFull):
"""Register a new OAuth client."""
with make_session() as session:
client = session.query(OAuthClientInformation).get(client_info.client_id)
if not client:
client = OAuthClientInformation(client_id=client_info.client_id)
for key, value in client_info.model_dump().items():
if key == "redirect_uris":
value = [str(uri) for uri in value]
elif value and key in [
"client_uri",
"logo_uri",
"tos_uri",
"policy_uri",
"jwks_uri",
]:
value = str(value)
setattr(client, key, value)
session.add(client)
session.commit()
async def authorize(
self, client: OAuthClientInformationFull, params: AuthorizationParams
) -> str:
"""Redirect to login page for user authentication."""
redirect_uri_str = str(params.redirect_uri)
registered_uris = [str(uri) for uri in getattr(client, "redirect_uris", [])]
if redirect_uri_str not in registered_uris:
logger.error(
f"Redirect URI {redirect_uri_str} not in registered URIs: {registered_uris}"
)
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
# Store the authorization parameters in database
with make_session() as session:
oauth_state = OAuthState(
state=params.state or secrets.token_hex(16),
client_id=client.client_id,
redirect_uri=str(params.redirect_uri),
redirect_uri_provided_explicitly=str(
params.redirect_uri_provided_explicitly
).lower()
== "true",
code_challenge=params.code_challenge or "",
scopes=["read", "write"], # Default scopes
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
)
session.add(oauth_state)
session.commit()
return f"{settings.SERVER_URL}/oauth/login?" + urlencode(
{
"state": oauth_state.state,
"client_id": client.client_id,
"redirect_uri": oauth_state.redirect_uri,
"redirect_uri_provided_explicitly": oauth_state.redirect_uri_provided_explicitly,
"code_challenge": cast(str, oauth_state.code_challenge),
}
)
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
"""Complete authorization after successful login."""
if not (state := oauth_params.get("state")):
logger.error("No state parameter provided")
raise ValueError("Missing state parameter")
with make_session() as session:
# Load OAuth state from database
oauth_state = (
session.query(OAuthState).filter(OAuthState.state == state).first()
)
if not oauth_state:
logger.error(f"State {state} not found in database")
raise ValueError("Invalid state parameter")
# Check if state has expired
now = datetime.fromtimestamp(time.time())
if oauth_state.expires_at < now:
logger.error(f"State {state} has expired")
oauth_state.stale = True # type: ignore
session.commit()
raise ValueError("State has expired")
oauth_state.code = f"code_{secrets.token_hex(16)}" # type: ignore
oauth_state.stale = False # type: ignore
oauth_state.user_id = user.id
session.add(oauth_state)
session.commit()
return construct_redirect_uri(
cast(str, oauth_state.redirect_uri),
code=cast(str, oauth_state.code),
state=state,
)
async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str
) -> Optional[AuthorizationCode]:
"""Load an authorization code."""
with make_session() as session:
auth_code = (
session.query(OAuthState)
.filter(OAuthState.code == authorization_code)
.first()
)
if not auth_code:
logger.error(f"Invalid authorization code: {authorization_code}")
raise ValueError("Invalid authorization code")
return AuthorizationCode(**auth_code.serialize(code=True))
async def exchange_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
) -> OAuthToken:
"""Exchange authorization code for tokens."""
with make_session() as session:
auth_code = (
session.query(OAuthState)
.filter(OAuthState.code == authorization_code.code)
.first()
)
if not auth_code:
logger.error(f"Invalid authorization code: {authorization_code.code}")
raise ValueError("Invalid authorization code")
if not auth_code.user:
logger.error(f"No user found for auth code: {authorization_code.code}")
raise ValueError("Invalid authorization code")
return make_token(session, auth_code, authorization_code.scopes)
async def load_access_token(self, token: str) -> Optional[AccessToken]:
"""Load and validate an access token."""
with make_session() as session:
now = datetime.now(timezone.utc).replace(
tzinfo=None
) # Make naive for DB comparison
# Query for active (non-expired) session
user_session = session.query(UserSession).get(token)
if not user_session or user_session.expires_at < now:
return None
return AccessToken(
token=token,
client_id=user_session.oauth_state.client_id,
scopes=user_session.oauth_state.scopes,
expires_at=int(user_session.expires_at.timestamp()),
)
async def load_refresh_token(
self, client: OAuthClientInformationFull, refresh_token: str
) -> Optional[RefreshToken]:
"""Load and validate a refresh token."""
with make_session() as session:
now = datetime.now()
# Query for the refresh token
db_refresh_token = (
session.query(OAuthRefreshToken)
.filter(
OAuthRefreshToken.token == refresh_token,
OAuthRefreshToken.client_id == client.client_id,
OAuthRefreshToken.revoked == False, # noqa: E712
OAuthRefreshToken.expires_at > now,
)
.first()
)
if not db_refresh_token:
logger.error(
f"Invalid or expired refresh token: {refresh_token[:20]}..."
)
return None
return RefreshToken(
token=refresh_token,
client_id=client.client_id,
scopes=cast(list[str], db_refresh_token.scopes),
expires_at=int(db_refresh_token.expires_at.timestamp()),
)
async def exchange_refresh_token(
self,
client: OAuthClientInformationFull,
refresh_token: RefreshToken,
scopes: list[str],
) -> OAuthToken:
"""Exchange refresh token for new access token."""
with make_session() as session:
# Load the refresh token from database
db_refresh_token = (
session.query(OAuthRefreshToken)
.filter(
OAuthRefreshToken.token == refresh_token.token,
OAuthRefreshToken.client_id == client.client_id,
OAuthRefreshToken.revoked == False, # noqa: E712
)
.first()
)
if not db_refresh_token:
logger.error(f"Refresh token not found: {refresh_token.token[:20]}...")
raise ValueError("Invalid refresh token")
# Validate refresh token
validate_refresh_token(db_refresh_token)
# Validate requested scopes are subset of original scopes
original_scopes = set(cast(list[str], db_refresh_token.scopes))
requested_scopes = set(scopes) if scopes else original_scopes
if not requested_scopes.issubset(original_scopes):
logger.error(
f"Requested scopes {requested_scopes} exceed original scopes {original_scopes}"
)
raise ValueError("Requested scopes exceed original authorization")
return make_token(session, db_refresh_token, scopes)
async def revoke_token(
self, token: str, token_type_hint: Optional[str] = None
) -> None:
"""Revoke a token (access token or refresh token)."""
with make_session() as session:
revoked = False
# Try to revoke as access token (UserSession)
if not token_type_hint or token_type_hint == "access_token":
user_session = session.query(UserSession).get(token)
if user_session:
session.delete(user_session)
revoked = True
logger.info(f"Revoked access token: {token[:20]}...")
# Try to revoke as refresh token
if not revoked and (
not token_type_hint or token_type_hint == "refresh_token"
):
refresh_token = (
session.query(OAuthRefreshToken)
.filter(OAuthRefreshToken.token == token)
.first()
)
if refresh_token:
refresh_token.revoked = True # type: ignore
revoked = True
logger.info(f"Revoked refresh token: {token[:20]}...")
if revoked:
session.commit()
else:
logger.warning(f"Token not found for revocation: {token[:20]}...")
def get_protected_resource_metadata(self) -> dict[str, Any]:
"""Return metadata about the protected resource."""
return {
"resource_server": settings.SERVER_URL,
"scopes_supported": ["read", "write"],
"bearer_methods_supported": ["header"],
"resource_documentation": f"{settings.SERVER_URL}/docs",
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["none", "client_secret_basic"],
"refresh_token_rotation_enabled": True,
"protected_resources": [
{
"resource_uri": f"{settings.SERVER_URL}/mcp",
"scopes": ["read", "write"],
"http_methods": ["POST", "GET"],
},
{
"resource_uri": f"{settings.SERVER_URL}/mcp/",
"scopes": ["read", "write"],
"http_methods": ["POST", "GET"],
},
],
}

View File

@ -5,19 +5,21 @@ MCP tools for the epistemic sparring partner system.
import logging
from datetime import datetime, timezone
from mcp.server.fastmcp import FastMCP
from mcp.server.auth.middleware.auth_context import get_access_token
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.common.db.connection import make_session
from memory.common.db.models import AgentObservation, SourceItem
from memory.common.db.models import (
AgentObservation,
SourceItem,
UserSession,
)
from memory.api.MCP.base import mcp
logger = logging.getLogger(__name__)
# Create MCP server instance
mcp = FastMCP("memory", stateless_http=True)
def filter_observation_source_ids(
tags: list[str] | None = None, observation_types: list[str] | None = None
@ -67,4 +69,42 @@ def filter_source_ids(
@mcp.tool()
async def get_current_time() -> dict:
"""Get the current time in UTC."""
logger.info("get_current_time tool called")
return {"current_time": datetime.now(timezone.utc).isoformat()}
@mcp.tool()
async def get_authenticated_user() -> dict:
"""Get information about the authenticated user."""
logger.info("🔧 get_authenticated_user tool called")
access_token = get_access_token()
logger.info(f"🔧 Access token from MCP context: {access_token}")
if not access_token:
logger.warning("❌ No access token found in MCP context!")
return {"error": "Not authenticated"}
logger.info(
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
)
# Look up the actual user from the session token
with make_session() as session:
user_session = (
session.query(UserSession)
.filter(UserSession.id == access_token.token)
.first()
)
if user_session and user_session.user:
user_info = user_session.user.serialize()
else:
user_info = {"error": "User not found"}
return {
"authenticated": True,
"token_type": "Bearer",
"scopes": access_token.scopes,
"client_id": access_token.client_id,
"user": user_info,
}

View File

@ -2,8 +2,16 @@
SQLAdmin views for the knowledge base database models.
"""
import uuid
from sqladmin import Admin, ModelView
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import RedirectResponse
import logging
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import (
Chunk,
SourceItem,
@ -21,8 +29,11 @@ from memory.common.db.models import (
AgentObservation,
Note,
User,
UserSession,
OAuthState,
)
logger = logging.getLogger(__name__)
DEFAULT_COLUMNS = (
"modality",
@ -218,7 +229,7 @@ class UserAdmin(ModelView, model=User):
def setup_admin(admin: Admin):
"""Add all admin views to the admin instance."""
"""Add all admin views to the admin instance with OAuth protection."""
admin.add_view(SourceItemAdmin)
admin.add_view(AgentObservationAdmin)
admin.add_view(NoteAdmin)

View File

@ -16,22 +16,23 @@ from fastapi import (
Query,
Form,
Depends,
Request,
)
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from sqladmin import Admin
from memory.common import settings
from memory.common import extract
from memory.common import extract, settings
from memory.common.db.connection import get_engine
from memory.common.db.models import User
from memory.api.admin import setup_admin
from memory.api.search import search, SearchResult
from memory.api.MCP.tools import mcp
from memory.api.auth import (
router as auth_router,
get_current_user,
AuthenticationMiddleware,
router as auth_router,
)
from memory.api.MCP.base import mcp
logger = logging.getLogger(__name__)
@ -44,24 +45,37 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify actual origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# SQLAdmin setup
# SQLAdmin setup with OAuth protection
engine = get_engine()
admin = Admin(app, engine)
# Setup admin with OAuth protection using existing OAuth provider
setup_admin(admin)
# Include auth router
app.add_middleware(AuthenticationMiddleware)
app.include_router(auth_router)
app.add_middleware(AuthenticationMiddleware)
# Add health check to MCP server instead of main app
@mcp.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
"""Simple health check endpoint on MCP server"""
from fastapi.responses import JSONResponse
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"})
# Mount MCP server at root - OAuth endpoints need to be at root level
app.mount("/", mcp.streamable_http_app())
@app.get("/health")
def health_check():
"""Simple health check endpoint"""
return {"status": "healthy"}
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
if not item:
return []

View File

@ -1,10 +1,10 @@
from datetime import datetime, timedelta, timezone
import textwrap
from typing import cast
import logging
import pathlib
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from starlette.middleware.base import BaseHTTPMiddleware
from memory.common import settings
from sqlalchemy.orm import Session as DBSession, scoped_session
@ -52,28 +52,35 @@ def create_user_session(
return str(session.id)
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
"""Get user from session ID if session is valid"""
def get_user_session(
request: Request, db: DBSession | scoped_session
) -> UserSession | None:
"""Get session ID from request"""
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
return None
session = db.query(UserSession).get(session_id)
if not session:
return None
now = datetime.now(timezone.utc)
if session.expires_at.replace(tzinfo=timezone.utc) > now:
if session.expires_at.replace(tzinfo=timezone.utc) < now:
return None
return session
def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None:
"""Get user from session ID if session is valid"""
if session := get_user_session(request, db):
return session.user
return None
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
"""FastAPI dependency to get current authenticated user"""
# Check for session ID in header or cookie
session_id = request.headers.get(
settings.SESSION_HEADER_NAME
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
raise HTTPException(status_code=401, detail="No session provided")
user = get_session_user(session_id, db)
user = get_session_user(request, db)
if not user:
raise HTTPException(status_code=401, detail="Invalid or expired session")
@ -117,21 +124,18 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
@router.get("/login", response_model=LoginResponse)
def login_page():
def login_page(request: Request):
"""Login page"""
return HTMLResponse(
content=textwrap.dedent("""
<html>
<body>
<h1>Login</h1>
<form method="post" action="/auth/login-form">
<input type="email" name="email" placeholder="Email" />
<input type="password" name="password" placeholder="Password" />
<button type="submit">Login</button>
</form>
</body>
</html>
"""),
template_dir = pathlib.Path(__file__).parent / "templates"
templates = Jinja2Templates(directory=template_dir)
return templates.TemplateResponse(
"login.html",
{
"request": request,
"action": router.url_path_for("login_form"),
"error": None,
"form_data": {},
},
)
@ -170,9 +174,17 @@ def login_form(
return LoginResponse(session_id=session_id, **user.serialize())
@router.post("/logout")
def logout(response: Response, user: User = Depends(get_current_user)):
@router.api_route("/logout", methods=["GET", "POST"])
def logout(
request: Request,
response: Response,
db: DBSession = Depends(get_session),
):
"""Logout and clear session"""
session = get_user_session(request, db)
if session:
db.delete(session)
db.commit()
response.delete_cookie(settings.SESSION_COOKIE_NAME)
return {"message": "Logged out successfully"}
@ -192,6 +204,11 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
"/auth/login",
"/auth/login-form",
"/auth/register",
"/register",
"/token",
"/mcp",
"/oauth/",
"/.well-known/",
}
async def dispatch(self, request: Request, call_next):
@ -205,10 +222,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# Check for session ID in header or cookie
session_id = request.headers.get(
settings.SESSION_HEADER_NAME
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
if not session_id:
return Response(
content="Authentication required",
@ -218,7 +232,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
# Validate session and get user
with make_session() as session:
user = get_session_user(session_id, session)
user = get_session_user(request, session)
if not user:
return Response(
content="Invalid or expired session",

View File

@ -0,0 +1,159 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Login - Memory OAuth</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
margin: 0;
padding: 0;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
}
.login-container {
background: white;
padding: 2rem;
border-radius: 12px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
width: 100%;
max-width: 400px;
}
.login-header {
text-align: center;
margin-bottom: 2rem;
}
.login-header h1 {
color: #333;
margin: 0 0 0.5rem 0;
font-size: 1.8rem;
font-weight: 600;
}
.login-header p {
color: #666;
margin: 0;
font-size: 0.9rem;
}
.form-group {
margin-bottom: 1.5rem;
}
.form-group label {
display: block;
margin-bottom: 0.5rem;
color: #333;
font-weight: 500;
font-size: 0.9rem;
}
.form-group input {
width: 100%;
padding: 0.75rem;
border: 2px solid #e1e5e9;
border-radius: 8px;
font-size: 1rem;
transition: border-color 0.2s;
box-sizing: border-box;
}
.form-group input:focus {
outline: none;
border-color: #667eea;
}
.login-button {
width: 100%;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 0.875rem;
border-radius: 8px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s;
}
.login-button:hover {
transform: translateY(-1px);
}
.login-button:active {
transform: translateY(0);
}
.error-message {
background: #fee;
color: #c33;
padding: 0.75rem;
border-radius: 6px;
margin-bottom: 1rem;
font-size: 0.9rem;
border-left: 4px solid #c33;
}
.demo-credentials {
background: #f8f9fa;
padding: 1rem;
border-radius: 8px;
margin-top: 1.5rem;
font-size: 0.85rem;
}
.demo-credentials strong {
color: #333;
}
.demo-credentials code {
background: #e9ecef;
padding: 0.2rem 0.4rem;
border-radius: 4px;
font-family: monospace;
}
</style>
</head>
<body>
<div class="login-container">
<div class="login-header">
<h1>Sign In</h1>
<p>Access your Memory knowledge base</p>
</div>
{% if error %}
<div class="error-message">
{{ error }}
</div>
{% endif %}
<form method="post" action="{{ action }}">
{% for key, value in form_data.items() %}
<input type="hidden" name="{{ key }}" value="{{ value }}">
{% endfor %}
<div class="form-group">
<label for="email">Email</label>
<input type="email" id="email" name="email" required>
</div>
<div class="form-group">
<label for="password">Password</label>
<input type="password" id="password" name="password" required>
</div>
<button type="submit" class="login-button">Sign In</button>
</form>
</div>
</body>
</html>

View File

@ -35,6 +35,9 @@ from memory.common.db.models.sources import (
from memory.common.db.models.users import (
User,
UserSession,
OAuthClientInformation,
OAuthState,
OAuthRefreshToken,
)
__all__ = [
@ -69,4 +72,7 @@ __all__ = [
# Users
"User",
"UserSession",
"OAuthClientInformation",
"OAuthState",
"OAuthRefreshToken",
]

View File

@ -3,9 +3,19 @@ import secrets
from typing import cast
import uuid
from memory.common.db.models.base import Base
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
ForeignKey,
Boolean,
ARRAY,
Numeric,
)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from datetime import datetime
def hash_password(password: str) -> str:
@ -35,6 +45,9 @@ class User(Base):
sessions = relationship(
"UserSession", back_populates="user", cascade="all, delete-orphan"
)
oauth_states = relationship(
"OAuthState", back_populates="user", cascade="all, delete-orphan"
)
def serialize(self) -> dict:
return {
@ -57,9 +70,127 @@ class UserSession(Base):
__tablename__ = "user_sessions"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
oauth_state_id = Column(Integer, ForeignKey("oauth_states.id"), nullable=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime, server_default=func.now())
expires_at = Column(DateTime, nullable=False)
# Relationship to user
user = relationship("User", back_populates="sessions")
oauth_state = relationship("OAuthState", back_populates="session")
class OAuthClientInformation(Base):
__tablename__ = "oauth_client"
client_id = Column(String, primary_key=True)
client_secret = Column(String, nullable=True)
client_id_issued_at = Column(Numeric, nullable=False)
client_secret_expires_at = Column(Numeric, nullable=True)
redirect_uris = Column(ARRAY(String), nullable=False)
token_endpoint_auth_method = Column(String, nullable=False)
grant_types = Column(ARRAY(String), nullable=False)
response_types = Column(ARRAY(String), nullable=False)
scope = Column(String, nullable=False)
client_name = Column(String, nullable=False)
client_uri = Column(String, nullable=True)
logo_uri = Column(String, nullable=True)
contacts = Column(ARRAY(String), nullable=True)
tos_uri = Column(String, nullable=True)
policy_uri = Column(String, nullable=True)
jwks_uri = Column(String, nullable=True)
sessions = relationship(
"OAuthState", back_populates="client", cascade="all, delete-orphan"
)
def serialize(self) -> dict:
return {
"client_id": self.client_id,
"client_secret": self.client_secret,
"client_id_issued_at": self.client_id_issued_at,
"client_secret_expires_at": self.client_secret_expires_at,
"redirect_uris": self.redirect_uris,
"token_endpoint_auth_method": self.token_endpoint_auth_method,
"grant_types": self.grant_types,
"response_types": self.response_types,
"scope": self.scope,
"client_name": self.client_name,
"client_uri": self.client_uri,
"logo_uri": self.logo_uri,
"contacts": self.contacts,
"tos_uri": self.tos_uri,
"policy_uri": self.policy_uri,
"jwks_uri": self.jwks_uri,
}
class OAuthToken:
id = Column(Integer, primary_key=True)
client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
scopes = Column(ARRAY(String), nullable=False)
created_at = Column(DateTime, server_default=func.now())
expires_at = Column(DateTime, nullable=False)
def serialize(self) -> dict:
return {
"client_id": self.client_id,
"scopes": self.scopes,
"expires_at": self.expires_at.timestamp(),
}
class OAuthState(Base, OAuthToken):
__tablename__ = "oauth_states"
state = Column(String, nullable=False)
code = Column(String, nullable=True)
redirect_uri = Column(String, nullable=False)
redirect_uri_provided_explicitly = Column(Boolean, nullable=False)
code_challenge = Column(String, nullable=True)
stale = Column(Boolean, nullable=False, default=False)
def serialize(self, code: bool = False) -> dict:
data = {
"redirect_uri": self.redirect_uri,
"redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly,
"code_challenge": self.code_challenge,
} | super().serialize()
if code:
data |= {
"code": self.code,
"expires_at": self.expires_at.timestamp(),
}
return data
client = relationship("OAuthClientInformation", back_populates="sessions")
session = relationship("UserSession", back_populates="oauth_state", uselist=False)
user = relationship("User", back_populates="oauth_states")
class OAuthRefreshToken(Base, OAuthToken):
__tablename__ = "oauth_refresh_tokens"
token = Column(
String, nullable=False, default=lambda: f"rt_{secrets.token_hex(32)}"
)
revoked = Column(Boolean, nullable=False, default=False)
# Optional: link to the access token session that was created with this refresh token
access_token_session_id = Column(
String, ForeignKey("user_sessions.id"), nullable=True
)
# Relationships
client = relationship("OAuthClientInformation")
user = relationship("User")
access_token_session = relationship("UserSession")
def serialize(self) -> dict:
return {
"token": self.token,
"expires_at": self.expires_at.timestamp(),
"revoked": self.revoked,
} | super().serialize()

View File

@ -131,8 +131,8 @@ else:
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
# API settings
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
HTTPS = boolean_env("HTTPS", False)
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))

View File

@ -1,126 +0,0 @@
#!/usr/bin/env python3
import argparse
import httpx
import uvicorn
from pydantic import BaseModel
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import Response
class State(BaseModel):
email: str
password: str
remote_server: str
session_header: str = "X-Session-ID"
session_id: str | None = None
port: int = 8080
def parse_args() -> State:
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Simple HTTP proxy with authentication"
)
parser.add_argument("--remote-server", required=True, help="Remote server URL")
parser.add_argument("--email", required=True, help="Email for authentication")
parser.add_argument("--password", required=True, help="Password for authentication")
parser.add_argument(
"--session-header", default="X-Session-ID", help="Session header name"
)
parser.add_argument("--port", type=int, default=8080, help="Port to run proxy on")
return State(**vars(parser.parse_args()))
state = parse_args()
async def login() -> None:
"""Login to remote server and store session ID"""
login_url = f"{state.remote_server}/auth/login"
login_data = {"email": state.email, "password": state.password}
async with httpx.AsyncClient() as client:
try:
response = await client.post(login_url, json=login_data)
response.raise_for_status()
login_response = response.json()
state.session_id = login_response["session_id"]
print(f"Successfully logged in, session ID: {state.session_id}")
except httpx.HTTPStatusError as e:
print(
f"Login failed with status {e.response.status_code}: {e.response.text}"
)
raise
except Exception as e:
print(f"Login failed: {e}")
raise
async def proxy_request(request: Request) -> Response:
"""Proxy request to remote server with session header"""
if not state.session_id:
try:
await login()
except Exception as e:
print(f"Login failed: {e}")
raise HTTPException(status_code=401, detail="Unauthorized")
# Build the target URL
target_url = f"{state.remote_server}{request.url.path}"
if request.url.query:
target_url += f"?{request.url.query}"
# Get request body
body = await request.body()
headers = dict(request.headers)
headers.pop("host", None)
async with httpx.AsyncClient() as client:
try:
response = await client.request(
method=request.method,
url=target_url,
headers=headers | {state.session_header: state.session_id}, # type: ignore
content=body,
timeout=30.0,
)
# Forward response
resp = Response(
content=response.content,
status_code=response.status_code,
headers={
k: v.replace(state.remote_server, f"http://localhost:{state.port}")
for k, v in response.headers.items()
},
media_type=response.headers.get("content-type"),
)
print(resp.headers)
return resp
except httpx.RequestError as e:
print(f"Request failed: {e}")
raise HTTPException(status_code=502, detail=f"Proxy request failed: {e}")
# Create FastAPI app
app = FastAPI(title="Simple Proxy")
@app.api_route(
"/{path:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
)
async def proxy_all(request: Request):
"""Proxy all requests to remote server"""
return await proxy_request(request)
if __name__ == "__main__":
print(f"Starting proxy server on port {state.port}")
print(f"Proxying to: {state.remote_server}")
uvicorn.run(app, host="0.0.0.0", port=state.port)