mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-30 06:36:07 +02:00
proper oauth flow
This commit is contained in:
parent
4d057d1ec6
commit
4556ef2c48
86
db/migrations/versions/20250606_123611_oauth_codes.py
Normal file
86
db/migrations/versions/20250606_123611_oauth_codes.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
"""oauth codes
|
||||||
|
|
||||||
|
Revision ID: 66771d293b27
|
||||||
|
Revises: 58439dd3088b
|
||||||
|
Create Date: 2025-06-06 12:36:11.737507
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "66771d293b27"
|
||||||
|
down_revision: Union[str, None] = "58439dd3088b"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"oauth_client",
|
||||||
|
sa.Column("client_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("client_secret", sa.String(), nullable=True),
|
||||||
|
sa.Column("client_id_issued_at", sa.Numeric(), nullable=False),
|
||||||
|
sa.Column("client_secret_expires_at", sa.Numeric(), nullable=True),
|
||||||
|
sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=False),
|
||||||
|
sa.Column("token_endpoint_auth_method", sa.String(), nullable=False),
|
||||||
|
sa.Column("grant_types", sa.ARRAY(sa.String()), nullable=False),
|
||||||
|
sa.Column("response_types", sa.ARRAY(sa.String()), nullable=False),
|
||||||
|
sa.Column("scope", sa.String(), nullable=False),
|
||||||
|
sa.Column("client_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("client_uri", sa.String(), nullable=True),
|
||||||
|
sa.Column("logo_uri", sa.String(), nullable=True),
|
||||||
|
sa.Column("contacts", sa.ARRAY(sa.String()), nullable=True),
|
||||||
|
sa.Column("tos_uri", sa.String(), nullable=True),
|
||||||
|
sa.Column("policy_uri", sa.String(), nullable=True),
|
||||||
|
sa.Column("jwks_uri", sa.String(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("client_id"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_states",
|
||||||
|
sa.Column("state", sa.String(), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("code", sa.String(), nullable=True),
|
||||||
|
sa.Column("redirect_uri", sa.String(), nullable=False),
|
||||||
|
sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("code_challenge", sa.String(), nullable=True),
|
||||||
|
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||||
|
sa.Column("stale", sa.Boolean(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"],
|
||||||
|
["oauth_client.client_id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("state"),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"user_sessions", sa.Column("oauth_state_id", sa.String(), nullable=True)
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"fk_user_sessions_oauth_state_id",
|
||||||
|
"user_sessions",
|
||||||
|
"oauth_states",
|
||||||
|
["oauth_state_id"],
|
||||||
|
["state"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_constraint(
|
||||||
|
"fk_user_sessions_oauth_state_id", "user_sessions", type_="foreignkey"
|
||||||
|
)
|
||||||
|
op.drop_column("user_sessions", "oauth_state_id")
|
||||||
|
op.drop_table("oauth_states")
|
||||||
|
op.drop_table("oauth_client")
|
124
src/memory/api/MCP/base.py
Normal file
124
src/memory/api/MCP/base.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
||||||
|
from mcp.server.auth.handlers.token import (
|
||||||
|
AuthorizationCodeRequest,
|
||||||
|
RefreshTokenRequest,
|
||||||
|
TokenRequest,
|
||||||
|
)
|
||||||
|
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
from mcp.shared.auth import OAuthClientMetadata
|
||||||
|
from memory.common.db.models.users import User
|
||||||
|
from pydantic import AnyHttpUrl
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, RedirectResponse
|
||||||
|
from starlette.templating import Jinja2Templates
|
||||||
|
|
||||||
|
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import OAuthState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_metadata(klass: type):
|
||||||
|
orig_validate = klass.model_validate
|
||||||
|
|
||||||
|
def validate(data: dict):
|
||||||
|
data = dict(data)
|
||||||
|
if "redirect_uris" in data:
|
||||||
|
data["redirect_uris"] = [
|
||||||
|
str(uri).replace("cursor://", "http://")
|
||||||
|
for uri in data["redirect_uris"]
|
||||||
|
]
|
||||||
|
if "redirect_uri" in data:
|
||||||
|
data["redirect_uri"] = str(data["redirect_uri"]).replace(
|
||||||
|
"cursor://", "http://"
|
||||||
|
)
|
||||||
|
|
||||||
|
return orig_validate(data)
|
||||||
|
|
||||||
|
klass.model_validate = validate
|
||||||
|
|
||||||
|
|
||||||
|
validate_metadata(OAuthClientMetadata)
|
||||||
|
validate_metadata(AuthorizationRequest)
|
||||||
|
validate_metadata(AuthorizationCodeRequest)
|
||||||
|
validate_metadata(RefreshTokenRequest)
|
||||||
|
validate_metadata(TokenRequest)
|
||||||
|
|
||||||
|
|
||||||
|
# Setup templates
|
||||||
|
template_dir = os.path.join(os.path.dirname(__file__), "templates")
|
||||||
|
templates = Jinja2Templates(directory=template_dir)
|
||||||
|
|
||||||
|
|
||||||
|
oauth_provider = SimpleOAuthProvider()
|
||||||
|
auth_settings = AuthSettings(
|
||||||
|
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
||||||
|
client_registration_options=ClientRegistrationOptions(
|
||||||
|
enabled=True,
|
||||||
|
valid_scopes=["read", "write"],
|
||||||
|
default_scopes=["read"],
|
||||||
|
),
|
||||||
|
required_scopes=["read", "write"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp = FastMCP(
|
||||||
|
"memory",
|
||||||
|
stateless_http=True,
|
||||||
|
auth_server_provider=oauth_provider,
|
||||||
|
auth=auth_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
|
||||||
|
async def oauth_protected_resource(request: Request):
|
||||||
|
"""OAuth 2.0 Protected Resource Metadata."""
|
||||||
|
logger.info("Protected resource metadata requested")
|
||||||
|
return JSONResponse(oauth_provider.get_protected_resource_metadata())
|
||||||
|
|
||||||
|
|
||||||
|
def login_form(request: Request, form_data: dict, error: str | None = None):
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
"login.html",
|
||||||
|
{"request": request, "form_data": form_data, "error": error},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.custom_route("/oauth/login", methods=["GET"])
|
||||||
|
async def login_page(request: Request):
|
||||||
|
"""Display the login page."""
|
||||||
|
form_data = dict(request.query_params)
|
||||||
|
|
||||||
|
state = form_data.get("state")
|
||||||
|
with make_session() as session:
|
||||||
|
oauth_state = session.query(OAuthState).get(state)
|
||||||
|
if not oauth_state:
|
||||||
|
logger.error(f"State {state} not found in database")
|
||||||
|
raise ValueError("Invalid state parameter")
|
||||||
|
|
||||||
|
return login_form(request, form_data, None)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.custom_route("/oauth/login", methods=["POST"])
|
||||||
|
async def handle_login(request: Request):
|
||||||
|
"""Handle login form submission."""
|
||||||
|
form = await request.form()
|
||||||
|
oauth_params = {
|
||||||
|
key: value for key, value in form.items() if key not in ["email", "password"]
|
||||||
|
}
|
||||||
|
with make_session() as session:
|
||||||
|
user = session.query(User).filter(User.email == form.get("email")).first()
|
||||||
|
if not user or not user.is_valid_password(str(form.get("password", ""))):
|
||||||
|
logger.warning("Login failed - invalid credentials")
|
||||||
|
return login_form(request, oauth_params, "Invalid email or password")
|
||||||
|
|
||||||
|
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
|
||||||
|
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
||||||
|
redirect_url = redirect_url.replace("http://", "cursor://")
|
||||||
|
return RedirectResponse(url=redirect_url, status_code=302)
|
246
src/memory/api/MCP/oauth_provider.py
Normal file
246
src/memory/api/MCP/oauth_provider.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from typing import Optional, Any, cast
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from memory.common.db.models.users import (
|
||||||
|
User,
|
||||||
|
UserSession,
|
||||||
|
OAuthClientInformation,
|
||||||
|
OAuthState,
|
||||||
|
)
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common import settings
|
||||||
|
from mcp.server.auth.provider import (
|
||||||
|
AccessToken,
|
||||||
|
AuthorizationParams,
|
||||||
|
AuthorizationCode,
|
||||||
|
OAuthAuthorizationServerProvider,
|
||||||
|
RefreshToken,
|
||||||
|
construct_redirect_uri,
|
||||||
|
)
|
||||||
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||||
|
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||||
|
"""Get OAuth client information."""
|
||||||
|
with make_session() as session:
|
||||||
|
client = session.query(OAuthClientInformation).get(client_id)
|
||||||
|
return client and OAuthClientInformationFull(**client.serialize())
|
||||||
|
|
||||||
|
async def register_client(self, client_info: OAuthClientInformationFull):
|
||||||
|
"""Register a new OAuth client."""
|
||||||
|
with make_session() as session:
|
||||||
|
client = session.query(OAuthClientInformation).get(client_info.client_id)
|
||||||
|
if not client:
|
||||||
|
client = OAuthClientInformation(client_id=client_info.client_id)
|
||||||
|
|
||||||
|
for key, value in client_info.model_dump().items():
|
||||||
|
if key == "redirect_uris":
|
||||||
|
value = [str(uri) for uri in value]
|
||||||
|
elif value and key in [
|
||||||
|
"client_uri",
|
||||||
|
"logo_uri",
|
||||||
|
"tos_uri",
|
||||||
|
"policy_uri",
|
||||||
|
"jwks_uri",
|
||||||
|
]:
|
||||||
|
value = str(value)
|
||||||
|
setattr(client, key, value)
|
||||||
|
session.add(client)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
async def authorize(
|
||||||
|
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
||||||
|
) -> str:
|
||||||
|
"""Redirect to login page for user authentication."""
|
||||||
|
redirect_uri_str = str(params.redirect_uri)
|
||||||
|
registered_uris = [str(uri) for uri in getattr(client, "redirect_uris", [])]
|
||||||
|
if redirect_uri_str not in registered_uris:
|
||||||
|
logger.error(
|
||||||
|
f"Redirect URI {redirect_uri_str} not in registered URIs: {registered_uris}"
|
||||||
|
)
|
||||||
|
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
|
||||||
|
|
||||||
|
# Store the authorization parameters in database
|
||||||
|
with make_session() as session:
|
||||||
|
oauth_state = OAuthState(
|
||||||
|
state=params.state or secrets.token_hex(16),
|
||||||
|
client_id=client.client_id,
|
||||||
|
redirect_uri=str(params.redirect_uri),
|
||||||
|
redirect_uri_provided_explicitly=str(
|
||||||
|
params.redirect_uri_provided_explicitly
|
||||||
|
).lower()
|
||||||
|
== "true",
|
||||||
|
code_challenge=params.code_challenge or "",
|
||||||
|
scopes=["read", "write"], # Default scopes
|
||||||
|
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
|
||||||
|
)
|
||||||
|
session.add(oauth_state)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return f"{settings.SERVER_URL}/oauth/login?" + urlencode(
|
||||||
|
{
|
||||||
|
"state": oauth_state.state,
|
||||||
|
"client_id": client.client_id,
|
||||||
|
"redirect_uri": oauth_state.redirect_uri,
|
||||||
|
"redirect_uri_provided_explicitly": oauth_state.redirect_uri_provided_explicitly,
|
||||||
|
"code_challenge": cast(str, oauth_state.code_challenge),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
|
||||||
|
"""Complete authorization after successful login."""
|
||||||
|
if not (state := oauth_params.get("state")):
|
||||||
|
logger.error("No state parameter provided")
|
||||||
|
raise ValueError("Missing state parameter")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Load OAuth state from database
|
||||||
|
oauth_state = session.query(OAuthState).get(state)
|
||||||
|
if not oauth_state:
|
||||||
|
logger.error(f"State {state} not found in database")
|
||||||
|
raise ValueError("Invalid state parameter")
|
||||||
|
|
||||||
|
# Check if state has expired
|
||||||
|
now = datetime.fromtimestamp(time.time())
|
||||||
|
if oauth_state.expires_at < now:
|
||||||
|
logger.error(f"State {state} has expired")
|
||||||
|
oauth_state.stale = True
|
||||||
|
session.commit()
|
||||||
|
raise ValueError("State has expired")
|
||||||
|
|
||||||
|
oauth_state.code = f"code_{secrets.token_hex(16)}"
|
||||||
|
oauth_state.stale = False
|
||||||
|
oauth_state.user_id = user.id
|
||||||
|
|
||||||
|
session.add(oauth_state)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return construct_redirect_uri(
|
||||||
|
oauth_state.redirect_uri, code=oauth_state.code, state=state
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_authorization_code(
|
||||||
|
self, client: OAuthClientInformationFull, authorization_code: str
|
||||||
|
) -> Optional[AuthorizationCode]:
|
||||||
|
"""Load an authorization code."""
|
||||||
|
with make_session() as session:
|
||||||
|
auth_code = (
|
||||||
|
session.query(OAuthState)
|
||||||
|
.filter(OAuthState.code == authorization_code)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not auth_code:
|
||||||
|
logger.error(f"Invalid authorization code: {authorization_code}")
|
||||||
|
raise ValueError("Invalid authorization code")
|
||||||
|
return AuthorizationCode(**auth_code.serialize(code=True))
|
||||||
|
|
||||||
|
async def exchange_authorization_code(
|
||||||
|
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
|
||||||
|
) -> OAuthToken:
|
||||||
|
"""Exchange authorization code for tokens."""
|
||||||
|
with make_session() as session:
|
||||||
|
auth_code = (
|
||||||
|
session.query(OAuthState)
|
||||||
|
.filter(OAuthState.code == authorization_code.code)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not auth_code:
|
||||||
|
logger.error(f"Invalid authorization code: {authorization_code.code}")
|
||||||
|
raise ValueError("Invalid authorization code")
|
||||||
|
|
||||||
|
# Get the user associated with this auth code
|
||||||
|
if not auth_code.user:
|
||||||
|
logger.error(f"No user found for auth code: {authorization_code.code}")
|
||||||
|
raise ValueError("Invalid authorization code")
|
||||||
|
|
||||||
|
# Create a UserSession to serve as access token
|
||||||
|
expires_at = datetime.fromtimestamp(time.time() + 3600)
|
||||||
|
|
||||||
|
auth_code.session = UserSession(
|
||||||
|
user_id=auth_code.user_id,
|
||||||
|
oauth_state_id=auth_code.state,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
auth_code.stale = True # type: ignore
|
||||||
|
session.commit()
|
||||||
|
access_token = str(auth_code.session.id)
|
||||||
|
|
||||||
|
return OAuthToken(
|
||||||
|
access_token=access_token,
|
||||||
|
token_type="bearer",
|
||||||
|
expires_in=3600,
|
||||||
|
scope=" ".join(authorization_code.scopes),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
||||||
|
"""Load and validate an access token."""
|
||||||
|
with make_session() as session:
|
||||||
|
now = datetime.now(timezone.utc).replace(
|
||||||
|
tzinfo=None
|
||||||
|
) # Make naive for DB comparison
|
||||||
|
|
||||||
|
# Query for active (non-expired) session
|
||||||
|
user_session = session.query(UserSession).get(token)
|
||||||
|
if not user_session or user_session.expires_at < now:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return AccessToken(
|
||||||
|
token=token,
|
||||||
|
client_id=user_session.oauth_state.client_id,
|
||||||
|
scopes=user_session.oauth_state.scopes,
|
||||||
|
expires_at=int(user_session.expires_at.timestamp()),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_refresh_token(
|
||||||
|
self, client: OAuthClientInformationFull, refresh_token: str
|
||||||
|
) -> Optional[RefreshToken]:
|
||||||
|
"""Load a refresh token - not supported in this simple implementation."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def exchange_refresh_token(
|
||||||
|
self,
|
||||||
|
client: OAuthClientInformationFull,
|
||||||
|
refresh_token: RefreshToken,
|
||||||
|
scopes: list[str],
|
||||||
|
) -> OAuthToken:
|
||||||
|
"""Exchange refresh token - not supported in this simple implementation."""
|
||||||
|
raise NotImplementedError("Refresh tokens not supported")
|
||||||
|
|
||||||
|
async def revoke_token(
|
||||||
|
self, token: str, token_type_hint: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Revoke a token."""
|
||||||
|
with make_session() as session:
|
||||||
|
user_session = session.query(UserSession).get(token)
|
||||||
|
if user_session:
|
||||||
|
session.delete(user_session)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def get_protected_resource_metadata(self) -> dict[str, Any]:
|
||||||
|
"""Return metadata about the protected resource."""
|
||||||
|
return {
|
||||||
|
"resource_server": settings.SERVER_URL,
|
||||||
|
"scopes_supported": ["read", "write"],
|
||||||
|
"bearer_methods_supported": ["header"],
|
||||||
|
"resource_documentation": f"{settings.SERVER_URL}/docs",
|
||||||
|
"protected_resources": [
|
||||||
|
{
|
||||||
|
"resource_uri": f"{settings.SERVER_URL}/mcp",
|
||||||
|
"scopes": ["read", "write"],
|
||||||
|
"http_methods": ["POST", "GET"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"resource_uri": f"{settings.SERVER_URL}/mcp/",
|
||||||
|
"scopes": ["read", "write"],
|
||||||
|
"http_methods": ["POST", "GET"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
159
src/memory/api/MCP/templates/login.html
Normal file
159
src/memory/api/MCP/templates/login.html
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Login - Memory OAuth</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
min-height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-container {
|
||||||
|
background: white;
|
||||||
|
padding: 2rem;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
|
||||||
|
width: 100%;
|
||||||
|
max-width: 400px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-header {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-header h1 {
|
||||||
|
color: #333;
|
||||||
|
margin: 0 0 0.5rem 0;
|
||||||
|
font-size: 1.8rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-header p {
|
||||||
|
color: #666;
|
||||||
|
margin: 0;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group {
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group label {
|
||||||
|
display: block;
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
color: #333;
|
||||||
|
font-weight: 500;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group input {
|
||||||
|
width: 100%;
|
||||||
|
padding: 0.75rem;
|
||||||
|
border: 2px solid #e1e5e9;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 1rem;
|
||||||
|
transition: border-color 0.2s;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #667eea;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-button {
|
||||||
|
width: 100%;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 0.875rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 1rem;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: transform 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-button:hover {
|
||||||
|
transform: translateY(-1px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-button:active {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
.error-message {
|
||||||
|
background: #fee;
|
||||||
|
color: #c33;
|
||||||
|
padding: 0.75rem;
|
||||||
|
border-radius: 6px;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
border-left: 4px solid #c33;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-credentials {
|
||||||
|
background: #f8f9fa;
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-top: 1.5rem;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-credentials strong {
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-credentials code {
|
||||||
|
background: #e9ecef;
|
||||||
|
padding: 0.2rem 0.4rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-family: monospace;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div class="login-container">
|
||||||
|
<div class="login-header">
|
||||||
|
<h1>Sign In</h1>
|
||||||
|
<p>Access your Memory knowledge base</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{% if error %}
|
||||||
|
<div class="error-message">
|
||||||
|
{{ error }}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<form method="post" action="/oauth/login">
|
||||||
|
{% for key, value in form_data.items() %}
|
||||||
|
<input type="hidden" name="{{ key }}" value="{{ value }}">
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="email">Email</label>
|
||||||
|
<input type="email" id="email" name="email" required>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="password">Password</label>
|
||||||
|
<input type="password" id="password" name="password" required>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button type="submit" class="login-button">Sign In</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
@ -5,19 +5,21 @@ MCP tools for the epistemic sparring partner system.
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.auth.middleware.auth_context import get_access_token
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import AgentObservation, SourceItem
|
from memory.common.db.models import (
|
||||||
|
AgentObservation,
|
||||||
|
SourceItem,
|
||||||
|
UserSession,
|
||||||
|
)
|
||||||
|
from memory.api.MCP.base import mcp
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Create MCP server instance
|
|
||||||
mcp = FastMCP("memory", stateless_http=True)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_observation_source_ids(
|
def filter_observation_source_ids(
|
||||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||||
@ -67,4 +69,42 @@ def filter_source_ids(
|
|||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def get_current_time() -> dict:
|
async def get_current_time() -> dict:
|
||||||
"""Get the current time in UTC."""
|
"""Get the current time in UTC."""
|
||||||
|
logger.info("get_current_time tool called")
|
||||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_authenticated_user() -> dict:
|
||||||
|
"""Get information about the authenticated user."""
|
||||||
|
logger.info("🔧 get_authenticated_user tool called")
|
||||||
|
access_token = get_access_token()
|
||||||
|
logger.info(f"🔧 Access token from MCP context: {access_token}")
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
logger.warning("❌ No access token found in MCP context!")
|
||||||
|
return {"error": "Not authenticated"}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look up the actual user from the session token
|
||||||
|
with make_session() as session:
|
||||||
|
user_session = (
|
||||||
|
session.query(UserSession)
|
||||||
|
.filter(UserSession.id == access_token.token)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_session and user_session.user:
|
||||||
|
user_info = user_session.user.serialize()
|
||||||
|
else:
|
||||||
|
user_info = {"error": "User not found"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"authenticated": True,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scopes": access_token.scopes,
|
||||||
|
"client_id": access_token.client_id,
|
||||||
|
"user": user_info,
|
||||||
|
}
|
||||||
|
@ -16,22 +16,19 @@ from fastapi import (
|
|||||||
Query,
|
Query,
|
||||||
Form,
|
Form,
|
||||||
Depends,
|
Depends,
|
||||||
|
Request,
|
||||||
)
|
)
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from sqladmin import Admin
|
from sqladmin import Admin
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import extract, settings
|
||||||
from memory.common import extract
|
|
||||||
from memory.common.db.connection import get_engine
|
from memory.common.db.connection import get_engine
|
||||||
from memory.common.db.models import User
|
from memory.common.db.models import User
|
||||||
from memory.api.admin import setup_admin
|
from memory.api.admin import setup_admin
|
||||||
from memory.api.search import search, SearchResult
|
from memory.api.search import search, SearchResult
|
||||||
from memory.api.MCP.tools import mcp
|
from memory.api.MCP.base import mcp
|
||||||
from memory.api.auth import (
|
from memory.api.auth import get_current_user
|
||||||
router as auth_router,
|
|
||||||
get_current_user,
|
|
||||||
AuthenticationMiddleware,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -45,23 +42,45 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
# Add request logging middleware
|
||||||
|
# @app.middleware("http")
|
||||||
|
# async def log_requests(request, call_next):
|
||||||
|
# logger.info(f"Main app: {request.method} {request.url.path}")
|
||||||
|
# if request.url.path.startswith("/mcp"):
|
||||||
|
# logger.info(f"Request headers: {dict(request.headers)}")
|
||||||
|
# response = await call_next(request)
|
||||||
|
# return response
|
||||||
|
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # In production, specify actual origins
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
# SQLAdmin setup
|
# SQLAdmin setup
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
admin = Admin(app, engine)
|
admin = Admin(app, engine)
|
||||||
setup_admin(admin)
|
setup_admin(admin)
|
||||||
|
|
||||||
# Include auth router
|
|
||||||
app.add_middleware(AuthenticationMiddleware)
|
# Add health check to MCP server instead of main app
|
||||||
app.include_router(auth_router)
|
@mcp.custom_route("/health", methods=["GET"])
|
||||||
|
async def health_check(request: Request):
|
||||||
|
"""Simple health check endpoint on MCP server"""
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"})
|
||||||
|
|
||||||
|
|
||||||
|
# Mount MCP server at root - OAuth endpoints need to be at root level
|
||||||
app.mount("/", mcp.streamable_http_app())
|
app.mount("/", mcp.streamable_http_app())
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
def health_check():
|
|
||||||
"""Simple health check endpoint"""
|
|
||||||
return {"status": "healthy"}
|
|
||||||
|
|
||||||
|
|
||||||
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||||
if not item:
|
if not item:
|
||||||
return []
|
return []
|
||||||
|
@ -35,6 +35,8 @@ from memory.common.db.models.sources import (
|
|||||||
from memory.common.db.models.users import (
|
from memory.common.db.models.users import (
|
||||||
User,
|
User,
|
||||||
UserSession,
|
UserSession,
|
||||||
|
OAuthClientInformation,
|
||||||
|
OAuthState,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -69,4 +71,6 @@ __all__ = [
|
|||||||
# Users
|
# Users
|
||||||
"User",
|
"User",
|
||||||
"UserSession",
|
"UserSession",
|
||||||
|
"OAuthClientInformation",
|
||||||
|
"OAuthState",
|
||||||
]
|
]
|
||||||
|
@ -3,7 +3,16 @@ import secrets
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
import uuid
|
import uuid
|
||||||
from memory.common.db.models.base import Base
|
from memory.common.db.models.base import Base
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Boolean,
|
||||||
|
ARRAY,
|
||||||
|
Numeric,
|
||||||
|
)
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
@ -35,6 +44,9 @@ class User(Base):
|
|||||||
sessions = relationship(
|
sessions = relationship(
|
||||||
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
oauth_states = relationship(
|
||||||
|
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
def serialize(self) -> dict:
|
def serialize(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@ -57,9 +69,92 @@ class UserSession(Base):
|
|||||||
__tablename__ = "user_sessions"
|
__tablename__ = "user_sessions"
|
||||||
|
|
||||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
oauth_state_id = Column(String, ForeignKey("oauth_states.state"), nullable=True)
|
||||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
created_at = Column(DateTime, server_default=func.now())
|
created_at = Column(DateTime, server_default=func.now())
|
||||||
expires_at = Column(DateTime, nullable=False)
|
expires_at = Column(DateTime, nullable=False)
|
||||||
|
|
||||||
# Relationship to user
|
# Relationship to user
|
||||||
user = relationship("User", back_populates="sessions")
|
user = relationship("User", back_populates="sessions")
|
||||||
|
oauth_state = relationship("OAuthState", back_populates="session")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientInformation(Base):
|
||||||
|
__tablename__ = "oauth_client"
|
||||||
|
|
||||||
|
client_id = Column(String, primary_key=True)
|
||||||
|
client_secret = Column(String, nullable=True)
|
||||||
|
client_id_issued_at = Column(Numeric, nullable=False)
|
||||||
|
client_secret_expires_at = Column(Numeric, nullable=True)
|
||||||
|
|
||||||
|
redirect_uris = Column(ARRAY(String), nullable=False)
|
||||||
|
token_endpoint_auth_method = Column(String, nullable=False)
|
||||||
|
grant_types = Column(ARRAY(String), nullable=False)
|
||||||
|
response_types = Column(ARRAY(String), nullable=False)
|
||||||
|
scope = Column(String, nullable=False)
|
||||||
|
client_name = Column(String, nullable=False)
|
||||||
|
client_uri = Column(String, nullable=True)
|
||||||
|
logo_uri = Column(String, nullable=True)
|
||||||
|
contacts = Column(ARRAY(String), nullable=True)
|
||||||
|
tos_uri = Column(String, nullable=True)
|
||||||
|
policy_uri = Column(String, nullable=True)
|
||||||
|
jwks_uri = Column(String, nullable=True)
|
||||||
|
|
||||||
|
sessions = relationship(
|
||||||
|
"OAuthState", back_populates="client", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize(self) -> dict:
|
||||||
|
return {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"client_id_issued_at": self.client_id_issued_at,
|
||||||
|
"client_secret_expires_at": self.client_secret_expires_at,
|
||||||
|
"redirect_uris": self.redirect_uris,
|
||||||
|
"token_endpoint_auth_method": self.token_endpoint_auth_method,
|
||||||
|
"grant_types": self.grant_types,
|
||||||
|
"response_types": self.response_types,
|
||||||
|
"scope": self.scope,
|
||||||
|
"client_name": self.client_name,
|
||||||
|
"client_uri": self.client_uri,
|
||||||
|
"logo_uri": self.logo_uri,
|
||||||
|
"contacts": self.contacts,
|
||||||
|
"tos_uri": self.tos_uri,
|
||||||
|
"policy_uri": self.policy_uri,
|
||||||
|
"jwks_uri": self.jwks_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthState(Base):
|
||||||
|
__tablename__ = "oauth_states"
|
||||||
|
|
||||||
|
state = Column(String, primary_key=True)
|
||||||
|
client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False)
|
||||||
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||||
|
code = Column(String, nullable=True)
|
||||||
|
redirect_uri = Column(String, nullable=False)
|
||||||
|
redirect_uri_provided_explicitly = Column(Boolean, nullable=False)
|
||||||
|
code_challenge = Column(String, nullable=True)
|
||||||
|
scopes = Column(ARRAY(String), nullable=False)
|
||||||
|
created_at = Column(DateTime, server_default=func.now())
|
||||||
|
expires_at = Column(DateTime, nullable=False)
|
||||||
|
stale = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
def serialize(self, code: bool = False) -> dict:
|
||||||
|
data = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly,
|
||||||
|
"code_challenge": self.code_challenge,
|
||||||
|
"scopes": self.scopes,
|
||||||
|
}
|
||||||
|
if code:
|
||||||
|
data |= {
|
||||||
|
"code": self.code,
|
||||||
|
"expires_at": self.expires_at.timestamp(),
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
client = relationship("OAuthClientInformation", back_populates="sessions")
|
||||||
|
session = relationship("UserSession", back_populates="oauth_state", uselist=False)
|
||||||
|
user = relationship("User", back_populates="oauth_states")
|
||||||
|
@ -131,6 +131,7 @@ else:
|
|||||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
|
||||||
# API settings
|
# API settings
|
||||||
|
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
||||||
HTTPS = boolean_env("HTTPS", False)
|
HTTPS = boolean_env("HTTPS", False)
|
||||||
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
||||||
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user