mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
proper oauth flow
This commit is contained in:
parent
4d057d1ec6
commit
4556ef2c48
86
db/migrations/versions/20250606_123611_oauth_codes.py
Normal file
86
db/migrations/versions/20250606_123611_oauth_codes.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""oauth codes
|
||||
|
||||
Revision ID: 66771d293b27
|
||||
Revises: 58439dd3088b
|
||||
Create Date: 2025-06-06 12:36:11.737507
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "66771d293b27"
|
||||
down_revision: Union[str, None] = "58439dd3088b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"oauth_client",
|
||||
sa.Column("client_id", sa.String(), nullable=False),
|
||||
sa.Column("client_secret", sa.String(), nullable=True),
|
||||
sa.Column("client_id_issued_at", sa.Numeric(), nullable=False),
|
||||
sa.Column("client_secret_expires_at", sa.Numeric(), nullable=True),
|
||||
sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column("token_endpoint_auth_method", sa.String(), nullable=False),
|
||||
sa.Column("grant_types", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column("response_types", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column("scope", sa.String(), nullable=False),
|
||||
sa.Column("client_name", sa.String(), nullable=False),
|
||||
sa.Column("client_uri", sa.String(), nullable=True),
|
||||
sa.Column("logo_uri", sa.String(), nullable=True),
|
||||
sa.Column("contacts", sa.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column("tos_uri", sa.String(), nullable=True),
|
||||
sa.Column("policy_uri", sa.String(), nullable=True),
|
||||
sa.Column("jwks_uri", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("client_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_states",
|
||||
sa.Column("state", sa.String(), nullable=False),
|
||||
sa.Column("client_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("code", sa.String(), nullable=True),
|
||||
sa.Column("redirect_uri", sa.String(), nullable=False),
|
||||
sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False),
|
||||
sa.Column("code_challenge", sa.String(), nullable=True),
|
||||
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||
),
|
||||
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("stale", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"],
|
||||
["oauth_client.client_id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("state"),
|
||||
)
|
||||
op.add_column(
|
||||
"user_sessions", sa.Column("oauth_state_id", sa.String(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_user_sessions_oauth_state_id",
|
||||
"user_sessions",
|
||||
"oauth_states",
|
||||
["oauth_state_id"],
|
||||
["state"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"fk_user_sessions_oauth_state_id", "user_sessions", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("user_sessions", "oauth_state_id")
|
||||
op.drop_table("oauth_states")
|
||||
op.drop_table("oauth_client")
|
124
src/memory/api/MCP/base.py
Normal file
124
src/memory/api/MCP/base.py
Normal file
@ -0,0 +1,124 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
||||
from mcp.server.auth.handlers.token import (
|
||||
AuthorizationCodeRequest,
|
||||
RefreshTokenRequest,
|
||||
TokenRequest,
|
||||
)
|
||||
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
from memory.common.db.models.users import User
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, RedirectResponse
|
||||
from starlette.templating import Jinja2Templates
|
||||
|
||||
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import OAuthState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_metadata(klass: type):
|
||||
orig_validate = klass.model_validate
|
||||
|
||||
def validate(data: dict):
|
||||
data = dict(data)
|
||||
if "redirect_uris" in data:
|
||||
data["redirect_uris"] = [
|
||||
str(uri).replace("cursor://", "http://")
|
||||
for uri in data["redirect_uris"]
|
||||
]
|
||||
if "redirect_uri" in data:
|
||||
data["redirect_uri"] = str(data["redirect_uri"]).replace(
|
||||
"cursor://", "http://"
|
||||
)
|
||||
|
||||
return orig_validate(data)
|
||||
|
||||
klass.model_validate = validate
|
||||
|
||||
|
||||
validate_metadata(OAuthClientMetadata)
|
||||
validate_metadata(AuthorizationRequest)
|
||||
validate_metadata(AuthorizationCodeRequest)
|
||||
validate_metadata(RefreshTokenRequest)
|
||||
validate_metadata(TokenRequest)
|
||||
|
||||
|
||||
# Setup templates
|
||||
template_dir = os.path.join(os.path.dirname(__file__), "templates")
|
||||
templates = Jinja2Templates(directory=template_dir)
|
||||
|
||||
|
||||
oauth_provider = SimpleOAuthProvider()
|
||||
auth_settings = AuthSettings(
|
||||
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
||||
client_registration_options=ClientRegistrationOptions(
|
||||
enabled=True,
|
||||
valid_scopes=["read", "write"],
|
||||
default_scopes=["read"],
|
||||
),
|
||||
required_scopes=["read", "write"],
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
"memory",
|
||||
stateless_http=True,
|
||||
auth_server_provider=oauth_provider,
|
||||
auth=auth_settings,
|
||||
)
|
||||
|
||||
|
||||
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
|
||||
async def oauth_protected_resource(request: Request):
|
||||
"""OAuth 2.0 Protected Resource Metadata."""
|
||||
logger.info("Protected resource metadata requested")
|
||||
return JSONResponse(oauth_provider.get_protected_resource_metadata())
|
||||
|
||||
|
||||
def login_form(request: Request, form_data: dict, error: str | None = None):
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{"request": request, "form_data": form_data, "error": error},
|
||||
)
|
||||
|
||||
|
||||
@mcp.custom_route("/oauth/login", methods=["GET"])
|
||||
async def login_page(request: Request):
|
||||
"""Display the login page."""
|
||||
form_data = dict(request.query_params)
|
||||
|
||||
state = form_data.get("state")
|
||||
with make_session() as session:
|
||||
oauth_state = session.query(OAuthState).get(state)
|
||||
if not oauth_state:
|
||||
logger.error(f"State {state} not found in database")
|
||||
raise ValueError("Invalid state parameter")
|
||||
|
||||
return login_form(request, form_data, None)
|
||||
|
||||
|
||||
@mcp.custom_route("/oauth/login", methods=["POST"])
|
||||
async def handle_login(request: Request):
|
||||
"""Handle login form submission."""
|
||||
form = await request.form()
|
||||
oauth_params = {
|
||||
key: value for key, value in form.items() if key not in ["email", "password"]
|
||||
}
|
||||
with make_session() as session:
|
||||
user = session.query(User).filter(User.email == form.get("email")).first()
|
||||
if not user or not user.is_valid_password(str(form.get("password", ""))):
|
||||
logger.warning("Login failed - invalid credentials")
|
||||
return login_form(request, oauth_params, "Invalid email or password")
|
||||
|
||||
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
|
||||
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
||||
redirect_url = redirect_url.replace("http://", "cursor://")
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
246
src/memory/api/MCP/oauth_provider.py
Normal file
246
src/memory/api/MCP/oauth_provider.py
Normal file
@ -0,0 +1,246 @@
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional, Any, cast
|
||||
from urllib.parse import urlencode
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from memory.common.db.models.users import (
|
||||
User,
|
||||
UserSession,
|
||||
OAuthClientInformation,
|
||||
OAuthState,
|
||||
)
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common import settings
|
||||
from mcp.server.auth.provider import (
|
||||
AccessToken,
|
||||
AuthorizationParams,
|
||||
AuthorizationCode,
|
||||
OAuthAuthorizationServerProvider,
|
||||
RefreshToken,
|
||||
construct_redirect_uri,
|
||||
)
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
"""Get OAuth client information."""
|
||||
with make_session() as session:
|
||||
client = session.query(OAuthClientInformation).get(client_id)
|
||||
return client and OAuthClientInformationFull(**client.serialize())
|
||||
|
||||
async def register_client(self, client_info: OAuthClientInformationFull):
|
||||
"""Register a new OAuth client."""
|
||||
with make_session() as session:
|
||||
client = session.query(OAuthClientInformation).get(client_info.client_id)
|
||||
if not client:
|
||||
client = OAuthClientInformation(client_id=client_info.client_id)
|
||||
|
||||
for key, value in client_info.model_dump().items():
|
||||
if key == "redirect_uris":
|
||||
value = [str(uri) for uri in value]
|
||||
elif value and key in [
|
||||
"client_uri",
|
||||
"logo_uri",
|
||||
"tos_uri",
|
||||
"policy_uri",
|
||||
"jwks_uri",
|
||||
]:
|
||||
value = str(value)
|
||||
setattr(client, key, value)
|
||||
session.add(client)
|
||||
session.commit()
|
||||
|
||||
async def authorize(
|
||||
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
||||
) -> str:
|
||||
"""Redirect to login page for user authentication."""
|
||||
redirect_uri_str = str(params.redirect_uri)
|
||||
registered_uris = [str(uri) for uri in getattr(client, "redirect_uris", [])]
|
||||
if redirect_uri_str not in registered_uris:
|
||||
logger.error(
|
||||
f"Redirect URI {redirect_uri_str} not in registered URIs: {registered_uris}"
|
||||
)
|
||||
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
|
||||
|
||||
# Store the authorization parameters in database
|
||||
with make_session() as session:
|
||||
oauth_state = OAuthState(
|
||||
state=params.state or secrets.token_hex(16),
|
||||
client_id=client.client_id,
|
||||
redirect_uri=str(params.redirect_uri),
|
||||
redirect_uri_provided_explicitly=str(
|
||||
params.redirect_uri_provided_explicitly
|
||||
).lower()
|
||||
== "true",
|
||||
code_challenge=params.code_challenge or "",
|
||||
scopes=["read", "write"], # Default scopes
|
||||
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
|
||||
)
|
||||
session.add(oauth_state)
|
||||
session.commit()
|
||||
|
||||
return f"{settings.SERVER_URL}/oauth/login?" + urlencode(
|
||||
{
|
||||
"state": oauth_state.state,
|
||||
"client_id": client.client_id,
|
||||
"redirect_uri": oauth_state.redirect_uri,
|
||||
"redirect_uri_provided_explicitly": oauth_state.redirect_uri_provided_explicitly,
|
||||
"code_challenge": cast(str, oauth_state.code_challenge),
|
||||
}
|
||||
)
|
||||
|
||||
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
|
||||
"""Complete authorization after successful login."""
|
||||
if not (state := oauth_params.get("state")):
|
||||
logger.error("No state parameter provided")
|
||||
raise ValueError("Missing state parameter")
|
||||
|
||||
with make_session() as session:
|
||||
# Load OAuth state from database
|
||||
oauth_state = session.query(OAuthState).get(state)
|
||||
if not oauth_state:
|
||||
logger.error(f"State {state} not found in database")
|
||||
raise ValueError("Invalid state parameter")
|
||||
|
||||
# Check if state has expired
|
||||
now = datetime.fromtimestamp(time.time())
|
||||
if oauth_state.expires_at < now:
|
||||
logger.error(f"State {state} has expired")
|
||||
oauth_state.stale = True
|
||||
session.commit()
|
||||
raise ValueError("State has expired")
|
||||
|
||||
oauth_state.code = f"code_{secrets.token_hex(16)}"
|
||||
oauth_state.stale = False
|
||||
oauth_state.user_id = user.id
|
||||
|
||||
session.add(oauth_state)
|
||||
session.commit()
|
||||
|
||||
return construct_redirect_uri(
|
||||
oauth_state.redirect_uri, code=oauth_state.code, state=state
|
||||
)
|
||||
|
||||
async def load_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: str
|
||||
) -> Optional[AuthorizationCode]:
|
||||
"""Load an authorization code."""
|
||||
with make_session() as session:
|
||||
auth_code = (
|
||||
session.query(OAuthState)
|
||||
.filter(OAuthState.code == authorization_code)
|
||||
.first()
|
||||
)
|
||||
if not auth_code:
|
||||
logger.error(f"Invalid authorization code: {authorization_code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
return AuthorizationCode(**auth_code.serialize(code=True))
|
||||
|
||||
async def exchange_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
|
||||
) -> OAuthToken:
|
||||
"""Exchange authorization code for tokens."""
|
||||
with make_session() as session:
|
||||
auth_code = (
|
||||
session.query(OAuthState)
|
||||
.filter(OAuthState.code == authorization_code.code)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not auth_code:
|
||||
logger.error(f"Invalid authorization code: {authorization_code.code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
|
||||
# Get the user associated with this auth code
|
||||
if not auth_code.user:
|
||||
logger.error(f"No user found for auth code: {authorization_code.code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
|
||||
# Create a UserSession to serve as access token
|
||||
expires_at = datetime.fromtimestamp(time.time() + 3600)
|
||||
|
||||
auth_code.session = UserSession(
|
||||
user_id=auth_code.user_id,
|
||||
oauth_state_id=auth_code.state,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
auth_code.stale = True # type: ignore
|
||||
session.commit()
|
||||
access_token = str(auth_code.session.id)
|
||||
|
||||
return OAuthToken(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=3600,
|
||||
scope=" ".join(authorization_code.scopes),
|
||||
)
|
||||
|
||||
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
||||
"""Load and validate an access token."""
|
||||
with make_session() as session:
|
||||
now = datetime.now(timezone.utc).replace(
|
||||
tzinfo=None
|
||||
) # Make naive for DB comparison
|
||||
|
||||
# Query for active (non-expired) session
|
||||
user_session = session.query(UserSession).get(token)
|
||||
if not user_session or user_session.expires_at < now:
|
||||
return None
|
||||
|
||||
return AccessToken(
|
||||
token=token,
|
||||
client_id=user_session.oauth_state.client_id,
|
||||
scopes=user_session.oauth_state.scopes,
|
||||
expires_at=int(user_session.expires_at.timestamp()),
|
||||
)
|
||||
|
||||
async def load_refresh_token(
|
||||
self, client: OAuthClientInformationFull, refresh_token: str
|
||||
) -> Optional[RefreshToken]:
|
||||
"""Load a refresh token - not supported in this simple implementation."""
|
||||
return None
|
||||
|
||||
async def exchange_refresh_token(
|
||||
self,
|
||||
client: OAuthClientInformationFull,
|
||||
refresh_token: RefreshToken,
|
||||
scopes: list[str],
|
||||
) -> OAuthToken:
|
||||
"""Exchange refresh token - not supported in this simple implementation."""
|
||||
raise NotImplementedError("Refresh tokens not supported")
|
||||
|
||||
async def revoke_token(
|
||||
self, token: str, token_type_hint: Optional[str] = None
|
||||
) -> None:
|
||||
"""Revoke a token."""
|
||||
with make_session() as session:
|
||||
user_session = session.query(UserSession).get(token)
|
||||
if user_session:
|
||||
session.delete(user_session)
|
||||
session.commit()
|
||||
|
||||
def get_protected_resource_metadata(self) -> dict[str, Any]:
|
||||
"""Return metadata about the protected resource."""
|
||||
return {
|
||||
"resource_server": settings.SERVER_URL,
|
||||
"scopes_supported": ["read", "write"],
|
||||
"bearer_methods_supported": ["header"],
|
||||
"resource_documentation": f"{settings.SERVER_URL}/docs",
|
||||
"protected_resources": [
|
||||
{
|
||||
"resource_uri": f"{settings.SERVER_URL}/mcp",
|
||||
"scopes": ["read", "write"],
|
||||
"http_methods": ["POST", "GET"],
|
||||
},
|
||||
{
|
||||
"resource_uri": f"{settings.SERVER_URL}/mcp/",
|
||||
"scopes": ["read", "write"],
|
||||
"http_methods": ["POST", "GET"],
|
||||
},
|
||||
],
|
||||
}
|
159
src/memory/api/MCP/templates/login.html
Normal file
159
src/memory/api/MCP/templates/login.html
Normal file
@ -0,0 +1,159 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Login - Memory OAuth</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.login-container {
|
||||
background: white;
|
||||
padding: 2rem;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
|
||||
width: 100%;
|
||||
max-width: 400px;
|
||||
}
|
||||
|
||||
.login-header {
|
||||
text-align: center;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.login-header h1 {
|
||||
color: #333;
|
||||
margin: 0 0 0.5rem 0;
|
||||
font-size: 1.8rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.login-header p {
|
||||
color: #666;
|
||||
margin: 0;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.form-group {
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.form-group label {
|
||||
display: block;
|
||||
margin-bottom: 0.5rem;
|
||||
color: #333;
|
||||
font-weight: 500;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.form-group input {
|
||||
width: 100%;
|
||||
padding: 0.75rem;
|
||||
border: 2px solid #e1e5e9;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
transition: border-color 0.2s;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.form-group input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
|
||||
.login-button {
|
||||
width: 100%;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s;
|
||||
}
|
||||
|
||||
.login-button:hover {
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.login-button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.error-message {
|
||||
background: #fee;
|
||||
color: #c33;
|
||||
padding: 0.75rem;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 1rem;
|
||||
font-size: 0.9rem;
|
||||
border-left: 4px solid #c33;
|
||||
}
|
||||
|
||||
.demo-credentials {
|
||||
background: #f8f9fa;
|
||||
padding: 1rem;
|
||||
border-radius: 8px;
|
||||
margin-top: 1.5rem;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.demo-credentials strong {
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.demo-credentials code {
|
||||
background: #e9ecef;
|
||||
padding: 0.2rem 0.4rem;
|
||||
border-radius: 4px;
|
||||
font-family: monospace;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="login-container">
|
||||
<div class="login-header">
|
||||
<h1>Sign In</h1>
|
||||
<p>Access your Memory knowledge base</p>
|
||||
</div>
|
||||
|
||||
{% if error %}
|
||||
<div class="error-message">
|
||||
{{ error }}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<form method="post" action="/oauth/login">
|
||||
{% for key, value in form_data.items() %}
|
||||
<input type="hidden" name="{{ key }}" value="{{ value }}">
|
||||
{% endfor %}
|
||||
|
||||
<div class="form-group">
|
||||
<label for="email">Email</label>
|
||||
<input type="email" id="email" name="email" required>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="password">Password</label>
|
||||
<input type="password" id="password" name="password" required>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="login-button">Sign In</button>
|
||||
</form>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
@ -5,19 +5,21 @@ MCP tools for the epistemic sparring partner system.
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.auth.middleware.auth_context import get_access_token
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.common.db.models import (
|
||||
AgentObservation,
|
||||
SourceItem,
|
||||
UserSession,
|
||||
)
|
||||
from memory.api.MCP.base import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create MCP server instance
|
||||
mcp = FastMCP("memory", stateless_http=True)
|
||||
|
||||
|
||||
def filter_observation_source_ids(
|
||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||
@ -67,4 +69,42 @@ def filter_source_ids(
|
||||
@mcp.tool()
|
||||
async def get_current_time() -> dict:
|
||||
"""Get the current time in UTC."""
|
||||
logger.info("get_current_time tool called")
|
||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_authenticated_user() -> dict:
|
||||
"""Get information about the authenticated user."""
|
||||
logger.info("🔧 get_authenticated_user tool called")
|
||||
access_token = get_access_token()
|
||||
logger.info(f"🔧 Access token from MCP context: {access_token}")
|
||||
|
||||
if not access_token:
|
||||
logger.warning("❌ No access token found in MCP context!")
|
||||
return {"error": "Not authenticated"}
|
||||
|
||||
logger.info(
|
||||
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
|
||||
)
|
||||
|
||||
# Look up the actual user from the session token
|
||||
with make_session() as session:
|
||||
user_session = (
|
||||
session.query(UserSession)
|
||||
.filter(UserSession.id == access_token.token)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_session and user_session.user:
|
||||
user_info = user_session.user.serialize()
|
||||
else:
|
||||
user_info = {"error": "User not found"}
|
||||
|
||||
return {
|
||||
"authenticated": True,
|
||||
"token_type": "Bearer",
|
||||
"scopes": access_token.scopes,
|
||||
"client_id": access_token.client_id,
|
||||
"user": user_info,
|
||||
}
|
||||
|
@ -16,22 +16,19 @@ from fastapi import (
|
||||
Query,
|
||||
Form,
|
||||
Depends,
|
||||
Request,
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqladmin import Admin
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common import extract
|
||||
from memory.common import extract, settings
|
||||
from memory.common.db.connection import get_engine
|
||||
from memory.common.db.models import User
|
||||
from memory.api.admin import setup_admin
|
||||
from memory.api.search import search, SearchResult
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.api.auth import (
|
||||
router as auth_router,
|
||||
get_current_user,
|
||||
AuthenticationMiddleware,
|
||||
)
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.api.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -45,23 +42,45 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||
|
||||
|
||||
# Add request logging middleware
|
||||
# @app.middleware("http")
|
||||
# async def log_requests(request, call_next):
|
||||
# logger.info(f"Main app: {request.method} {request.url.path}")
|
||||
# if request.url.path.startswith("/mcp"):
|
||||
# logger.info(f"Request headers: {dict(request.headers)}")
|
||||
# response = await call_next(request)
|
||||
# return response
|
||||
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify actual origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# SQLAdmin setup
|
||||
engine = get_engine()
|
||||
admin = Admin(app, engine)
|
||||
setup_admin(admin)
|
||||
|
||||
# Include auth router
|
||||
app.add_middleware(AuthenticationMiddleware)
|
||||
app.include_router(auth_router)
|
||||
|
||||
# Add health check to MCP server instead of main app
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request):
|
||||
"""Simple health check endpoint on MCP server"""
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"})
|
||||
|
||||
|
||||
# Mount MCP server at root - OAuth endpoints need to be at root level
|
||||
app.mount("/", mcp.streamable_http_app())
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health_check():
|
||||
"""Simple health check endpoint"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||
if not item:
|
||||
return []
|
||||
|
@ -35,6 +35,8 @@ from memory.common.db.models.sources import (
|
||||
from memory.common.db.models.users import (
|
||||
User,
|
||||
UserSession,
|
||||
OAuthClientInformation,
|
||||
OAuthState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -69,4 +71,6 @@ __all__ = [
|
||||
# Users
|
||||
"User",
|
||||
"UserSession",
|
||||
"OAuthClientInformation",
|
||||
"OAuthState",
|
||||
]
|
||||
|
@ -3,7 +3,16 @@ import secrets
|
||||
from typing import cast
|
||||
import uuid
|
||||
from memory.common.db.models.base import Base
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Boolean,
|
||||
ARRAY,
|
||||
Numeric,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@ -35,6 +44,9 @@ class User(Base):
|
||||
sessions = relationship(
|
||||
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oauth_states = relationship(
|
||||
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
@ -57,9 +69,92 @@ class UserSession(Base):
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
oauth_state_id = Column(String, ForeignKey("oauth_states.state"), nullable=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
|
||||
# Relationship to user
|
||||
user = relationship("User", back_populates="sessions")
|
||||
oauth_state = relationship("OAuthState", back_populates="session")
|
||||
|
||||
|
||||
class OAuthClientInformation(Base):
|
||||
__tablename__ = "oauth_client"
|
||||
|
||||
client_id = Column(String, primary_key=True)
|
||||
client_secret = Column(String, nullable=True)
|
||||
client_id_issued_at = Column(Numeric, nullable=False)
|
||||
client_secret_expires_at = Column(Numeric, nullable=True)
|
||||
|
||||
redirect_uris = Column(ARRAY(String), nullable=False)
|
||||
token_endpoint_auth_method = Column(String, nullable=False)
|
||||
grant_types = Column(ARRAY(String), nullable=False)
|
||||
response_types = Column(ARRAY(String), nullable=False)
|
||||
scope = Column(String, nullable=False)
|
||||
client_name = Column(String, nullable=False)
|
||||
client_uri = Column(String, nullable=True)
|
||||
logo_uri = Column(String, nullable=True)
|
||||
contacts = Column(ARRAY(String), nullable=True)
|
||||
tos_uri = Column(String, nullable=True)
|
||||
policy_uri = Column(String, nullable=True)
|
||||
jwks_uri = Column(String, nullable=True)
|
||||
|
||||
sessions = relationship(
|
||||
"OAuthState", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"client_id_issued_at": self.client_id_issued_at,
|
||||
"client_secret_expires_at": self.client_secret_expires_at,
|
||||
"redirect_uris": self.redirect_uris,
|
||||
"token_endpoint_auth_method": self.token_endpoint_auth_method,
|
||||
"grant_types": self.grant_types,
|
||||
"response_types": self.response_types,
|
||||
"scope": self.scope,
|
||||
"client_name": self.client_name,
|
||||
"client_uri": self.client_uri,
|
||||
"logo_uri": self.logo_uri,
|
||||
"contacts": self.contacts,
|
||||
"tos_uri": self.tos_uri,
|
||||
"policy_uri": self.policy_uri,
|
||||
"jwks_uri": self.jwks_uri,
|
||||
}
|
||||
|
||||
|
||||
class OAuthState(Base):
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
state = Column(String, primary_key=True)
|
||||
client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
code = Column(String, nullable=True)
|
||||
redirect_uri = Column(String, nullable=False)
|
||||
redirect_uri_provided_explicitly = Column(Boolean, nullable=False)
|
||||
code_challenge = Column(String, nullable=True)
|
||||
scopes = Column(ARRAY(String), nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
stale = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
def serialize(self, code: bool = False) -> dict:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly,
|
||||
"code_challenge": self.code_challenge,
|
||||
"scopes": self.scopes,
|
||||
}
|
||||
if code:
|
||||
data |= {
|
||||
"code": self.code,
|
||||
"expires_at": self.expires_at.timestamp(),
|
||||
}
|
||||
return data
|
||||
|
||||
client = relationship("OAuthClientInformation", back_populates="sessions")
|
||||
session = relationship("UserSession", back_populates="oauth_state", uselist=False)
|
||||
user = relationship("User", back_populates="oauth_states")
|
||||
|
@ -131,6 +131,7 @@ else:
|
||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
|
||||
# API settings
|
||||
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
||||
HTTPS = boolean_env("HTTPS", False)
|
||||
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
||||
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
||||
|
Loading…
x
Reference in New Issue
Block a user