mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-29 14:16:09 +02:00
Compare commits
4 Commits
4d057d1ec6
...
6ee46d6215
Author | SHA1 | Date | |
---|---|---|---|
![]() |
6ee46d6215 | ||
![]() |
d17d724631 | ||
![]() |
986d5b9957 | ||
![]() |
4556ef2c48 |
12
README.md
12
README.md
@ -80,18 +80,6 @@ python tools/run_celery_task.py notes setup-git-notes --origin ssh://git@github.
|
||||
For this to work you need to make sure you have set up the ssh keys in `secrets` (see the README.md
|
||||
in that folder), and you will need to add the public key that is generated there to your git server.
|
||||
|
||||
### Authentication
|
||||
|
||||
The API uses session-based authentication. Login via:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/auth/login \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"email": "user@example.com", "password": "yourpassword"}'
|
||||
```
|
||||
|
||||
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
|
||||
|
||||
## Discord integration
|
||||
|
||||
If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications).
|
||||
|
115
db/migrations/versions/20250606_165333_oauth_codes.py
Normal file
115
db/migrations/versions/20250606_165333_oauth_codes.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""oauth codes
|
||||
|
||||
Revision ID: 1d6bc8015ea9
|
||||
Revises: 58439dd3088b
|
||||
Create Date: 2025-06-06 16:53:33.044558
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "1d6bc8015ea9"
|
||||
down_revision: Union[str, None] = "58439dd3088b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"oauth_client",
|
||||
sa.Column("client_id", sa.String(), nullable=False),
|
||||
sa.Column("client_secret", sa.String(), nullable=True),
|
||||
sa.Column("client_id_issued_at", sa.Numeric(), nullable=False),
|
||||
sa.Column("client_secret_expires_at", sa.Numeric(), nullable=True),
|
||||
sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column("token_endpoint_auth_method", sa.String(), nullable=False),
|
||||
sa.Column("grant_types", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column("response_types", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column("scope", sa.String(), nullable=False),
|
||||
sa.Column("client_name", sa.String(), nullable=False),
|
||||
sa.Column("client_uri", sa.String(), nullable=True),
|
||||
sa.Column("logo_uri", sa.String(), nullable=True),
|
||||
sa.Column("contacts", sa.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column("tos_uri", sa.String(), nullable=True),
|
||||
sa.Column("policy_uri", sa.String(), nullable=True),
|
||||
sa.Column("jwks_uri", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("client_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_states",
|
||||
sa.Column("state", sa.String(), nullable=False),
|
||||
sa.Column("code", sa.String(), nullable=True),
|
||||
sa.Column("redirect_uri", sa.String(), nullable=False),
|
||||
sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False),
|
||||
sa.Column("code_challenge", sa.String(), nullable=True),
|
||||
sa.Column("stale", sa.Boolean(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("client_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||
),
|
||||
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"],
|
||||
["oauth_client.client_id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_refresh_tokens",
|
||||
sa.Column("token", sa.String(), nullable=False),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||
sa.Column("access_token_session_id", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("client_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||
),
|
||||
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["access_token_session_id"],
|
||||
["user_sessions.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"],
|
||||
["oauth_client.client_id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.add_column(
|
||||
"user_sessions", sa.Column("oauth_state_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"user_sessions_oauth_state_id_fkey",
|
||||
"user_sessions",
|
||||
"oauth_states",
|
||||
["oauth_state_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"user_sessions_oauth_state_id_fkey", "user_sessions", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("user_sessions", "oauth_state_id")
|
||||
op.drop_table("oauth_refresh_tokens")
|
||||
op.drop_table("oauth_states")
|
||||
op.drop_table("oauth_client")
|
@ -147,6 +147,7 @@ services:
|
||||
<<: *env
|
||||
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
|
||||
secrets: [postgres_password]
|
||||
volumes:
|
||||
- ./memory_files:/app/memory_files:rw
|
||||
@ -158,20 +159,6 @@ services:
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
proxy:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/api/Dockerfile
|
||||
restart: unless-stopped
|
||||
networks: [kbnet]
|
||||
environment:
|
||||
<<: *env
|
||||
command: ["python", "/app/tools/simple_proxy.py", "--remote-server", "${PROXY_REMOTE_SERVER:-http://api:8000}", "--email", "${PROXY_EMAIL}", "--password", "${PROXY_PASSWORD}", "--port", "8001"]
|
||||
volumes:
|
||||
- ./tools:/app/tools:ro
|
||||
ports:
|
||||
- "8001:8001"
|
||||
|
||||
# ------------------------------------------------------------ Celery workers
|
||||
worker:
|
||||
<<: *worker-base
|
||||
@ -201,4 +188,4 @@ services:
|
||||
# restart: unless-stopped
|
||||
# command: [ "--schedule", "0 0 4 * * *", "--cleanup" ]
|
||||
# volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ]
|
||||
# networks: [ kbnet ]
|
||||
# networks: [ kbnet ]
|
||||
|
132
src/memory/api/MCP/base.py
Normal file
132
src/memory/api/MCP/base.py
Normal file
@ -0,0 +1,132 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from typing import cast
|
||||
|
||||
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
||||
from mcp.server.auth.handlers.token import (
|
||||
AuthorizationCodeRequest,
|
||||
RefreshTokenRequest,
|
||||
TokenRequest,
|
||||
)
|
||||
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
from memory.common.db.models.users import User
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, RedirectResponse
|
||||
from starlette.templating import Jinja2Templates
|
||||
|
||||
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import OAuthState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_metadata(klass: type):
|
||||
orig_validate = klass.model_validate
|
||||
|
||||
def validate(data: dict):
|
||||
data = dict(data)
|
||||
if "redirect_uris" in data:
|
||||
data["redirect_uris"] = [
|
||||
str(uri).replace("cursor://", "http://")
|
||||
for uri in data["redirect_uris"]
|
||||
]
|
||||
if "redirect_uri" in data:
|
||||
data["redirect_uri"] = str(data["redirect_uri"]).replace(
|
||||
"cursor://", "http://"
|
||||
)
|
||||
|
||||
return orig_validate(data)
|
||||
|
||||
klass.model_validate = validate
|
||||
|
||||
|
||||
validate_metadata(OAuthClientMetadata)
|
||||
validate_metadata(AuthorizationRequest)
|
||||
validate_metadata(AuthorizationCodeRequest)
|
||||
validate_metadata(RefreshTokenRequest)
|
||||
validate_metadata(TokenRequest)
|
||||
|
||||
|
||||
# Setup templates
|
||||
template_dir = pathlib.Path(__file__).parent.parent / "templates"
|
||||
templates = Jinja2Templates(directory=template_dir)
|
||||
|
||||
|
||||
oauth_provider = SimpleOAuthProvider()
|
||||
auth_settings = AuthSettings(
|
||||
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
||||
client_registration_options=ClientRegistrationOptions(
|
||||
enabled=True,
|
||||
valid_scopes=["read", "write"],
|
||||
default_scopes=["read"],
|
||||
),
|
||||
required_scopes=["read", "write"],
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
"memory",
|
||||
stateless_http=True,
|
||||
auth_server_provider=oauth_provider,
|
||||
auth=auth_settings,
|
||||
)
|
||||
|
||||
|
||||
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
|
||||
async def oauth_protected_resource(request: Request):
|
||||
"""OAuth 2.0 Protected Resource Metadata."""
|
||||
logger.info("Protected resource metadata requested")
|
||||
return JSONResponse(oauth_provider.get_protected_resource_metadata())
|
||||
|
||||
|
||||
def login_form(request: Request, form_data: dict, error: str | None = None):
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{
|
||||
"request": request,
|
||||
"form_data": form_data,
|
||||
"error": error,
|
||||
"action": "/oauth/login",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@mcp.custom_route("/oauth/login", methods=["GET"])
|
||||
async def login_page(request: Request):
|
||||
"""Display the login page."""
|
||||
form_data = dict(request.query_params)
|
||||
|
||||
state = form_data.get("state")
|
||||
with make_session() as session:
|
||||
oauth_state = (
|
||||
session.query(OAuthState).filter(OAuthState.state == state).first()
|
||||
)
|
||||
if not oauth_state:
|
||||
logger.error(f"State {state} not found in database")
|
||||
raise ValueError("Invalid state parameter")
|
||||
|
||||
return login_form(request, form_data, None)
|
||||
|
||||
|
||||
@mcp.custom_route("/oauth/login", methods=["POST"])
|
||||
async def handle_login(request: Request):
|
||||
"""Handle login form submission."""
|
||||
form = await request.form()
|
||||
oauth_params = {
|
||||
key: value for key, value in form.items() if key not in ["email", "password"]
|
||||
}
|
||||
with make_session() as session:
|
||||
user = session.query(User).filter(User.email == form.get("email")).first()
|
||||
if not user or not user.is_valid_password(str(form.get("password", ""))):
|
||||
logger.warning("Login failed - invalid credentials")
|
||||
return login_form(request, oauth_params, "Invalid email or password")
|
||||
|
||||
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
|
||||
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
||||
redirect_url = redirect_url.replace("http://", "cursor://")
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
411
src/memory/api/MCP/oauth_provider.py
Normal file
411
src/memory/api/MCP/oauth_provider.py
Normal file
@ -0,0 +1,411 @@
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional, Any, cast
|
||||
from urllib.parse import urlencode
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from memory.common.db.models.users import (
|
||||
User,
|
||||
UserSession,
|
||||
OAuthClientInformation,
|
||||
OAuthState,
|
||||
OAuthRefreshToken,
|
||||
OAuthToken as TokenBase,
|
||||
)
|
||||
from memory.common.db.connection import make_session, scoped_session
|
||||
from memory.common import settings
|
||||
from mcp.server.auth.provider import (
|
||||
AccessToken,
|
||||
AuthorizationParams,
|
||||
AuthorizationCode,
|
||||
OAuthAuthorizationServerProvider,
|
||||
RefreshToken,
|
||||
construct_redirect_uri,
|
||||
)
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Token configuration constants
|
||||
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
|
||||
REFRESH_TOKEN_LIFETIME = 30 * 24 * 3600 # 30 days
|
||||
|
||||
|
||||
def create_expiration(lifetime_seconds: int) -> datetime:
|
||||
"""Create expiration datetime from lifetime in seconds."""
|
||||
return datetime.fromtimestamp(time.time() + lifetime_seconds)
|
||||
|
||||
|
||||
def generate_refresh_token() -> str:
|
||||
"""Generate a new refresh token."""
|
||||
return f"rt_{secrets.token_hex(32)}"
|
||||
|
||||
|
||||
def create_access_token_session(
|
||||
user_id: int, oauth_state_id: str | None = None
|
||||
) -> UserSession:
|
||||
"""Create a new access token session."""
|
||||
return UserSession(
|
||||
user_id=user_id,
|
||||
oauth_state_id=oauth_state_id,
|
||||
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token_record(
|
||||
client_id: str,
|
||||
user_id: int,
|
||||
scopes: list[str],
|
||||
access_token_session_id: Optional[str] = None,
|
||||
) -> OAuthRefreshToken:
|
||||
"""Create a new refresh token record."""
|
||||
return OAuthRefreshToken(
|
||||
token=generate_refresh_token(),
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
expires_at=create_expiration(REFRESH_TOKEN_LIFETIME),
|
||||
access_token_session_id=access_token_session_id,
|
||||
)
|
||||
|
||||
|
||||
def validate_refresh_token(db_refresh_token: OAuthRefreshToken) -> None:
|
||||
"""Validate a refresh token, raising ValueError if invalid."""
|
||||
now = datetime.now()
|
||||
if db_refresh_token.expires_at < now: # type: ignore
|
||||
logger.error(f"Refresh token expired: {db_refresh_token.token[:20]}...")
|
||||
db_refresh_token.revoked = True # type: ignore
|
||||
raise ValueError("Refresh token expired")
|
||||
|
||||
|
||||
def create_oauth_token(
|
||||
access_token: str, scopes: list[str], refresh_token: Optional[str] = None
|
||||
) -> OAuthToken:
|
||||
"""Create an OAuth token response."""
|
||||
return OAuthToken(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=ACCESS_TOKEN_LIFETIME,
|
||||
refresh_token=refresh_token,
|
||||
scope=" ".join(scopes),
|
||||
)
|
||||
|
||||
|
||||
def make_token(
|
||||
db: scoped_session,
|
||||
auth_state: TokenBase,
|
||||
scopes: list[str],
|
||||
) -> OAuthToken:
|
||||
new_session = UserSession(
|
||||
user_id=auth_state.user_id,
|
||||
oauth_state_id=auth_state.id,
|
||||
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
|
||||
)
|
||||
|
||||
# Create refresh token
|
||||
refresh_token = create_refresh_token_record(
|
||||
cast(str, auth_state.client_id),
|
||||
cast(int, auth_state.user_id),
|
||||
scopes,
|
||||
cast(str, new_session.id),
|
||||
)
|
||||
|
||||
db.add(new_session)
|
||||
db.add(refresh_token)
|
||||
db.commit()
|
||||
|
||||
return create_oauth_token(
|
||||
str(new_session.id),
|
||||
scopes,
|
||||
cast(str, refresh_token.token),
|
||||
)
|
||||
|
||||
|
||||
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
"""Get OAuth client information."""
|
||||
with make_session() as session:
|
||||
client = session.query(OAuthClientInformation).get(client_id)
|
||||
return client and OAuthClientInformationFull(**client.serialize())
|
||||
|
||||
async def register_client(self, client_info: OAuthClientInformationFull):
|
||||
"""Register a new OAuth client."""
|
||||
with make_session() as session:
|
||||
client = session.query(OAuthClientInformation).get(client_info.client_id)
|
||||
if not client:
|
||||
client = OAuthClientInformation(client_id=client_info.client_id)
|
||||
|
||||
for key, value in client_info.model_dump().items():
|
||||
if key == "redirect_uris":
|
||||
value = [str(uri) for uri in value]
|
||||
elif value and key in [
|
||||
"client_uri",
|
||||
"logo_uri",
|
||||
"tos_uri",
|
||||
"policy_uri",
|
||||
"jwks_uri",
|
||||
]:
|
||||
value = str(value)
|
||||
setattr(client, key, value)
|
||||
session.add(client)
|
||||
session.commit()
|
||||
|
||||
async def authorize(
|
||||
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
||||
) -> str:
|
||||
"""Redirect to login page for user authentication."""
|
||||
redirect_uri_str = str(params.redirect_uri)
|
||||
registered_uris = [str(uri) for uri in getattr(client, "redirect_uris", [])]
|
||||
if redirect_uri_str not in registered_uris:
|
||||
logger.error(
|
||||
f"Redirect URI {redirect_uri_str} not in registered URIs: {registered_uris}"
|
||||
)
|
||||
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
|
||||
|
||||
# Store the authorization parameters in database
|
||||
with make_session() as session:
|
||||
oauth_state = OAuthState(
|
||||
state=params.state or secrets.token_hex(16),
|
||||
client_id=client.client_id,
|
||||
redirect_uri=str(params.redirect_uri),
|
||||
redirect_uri_provided_explicitly=str(
|
||||
params.redirect_uri_provided_explicitly
|
||||
).lower()
|
||||
== "true",
|
||||
code_challenge=params.code_challenge or "",
|
||||
scopes=["read", "write"], # Default scopes
|
||||
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
|
||||
)
|
||||
session.add(oauth_state)
|
||||
session.commit()
|
||||
|
||||
return f"{settings.SERVER_URL}/oauth/login?" + urlencode(
|
||||
{
|
||||
"state": oauth_state.state,
|
||||
"client_id": client.client_id,
|
||||
"redirect_uri": oauth_state.redirect_uri,
|
||||
"redirect_uri_provided_explicitly": oauth_state.redirect_uri_provided_explicitly,
|
||||
"code_challenge": cast(str, oauth_state.code_challenge),
|
||||
}
|
||||
)
|
||||
|
||||
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
|
||||
"""Complete authorization after successful login."""
|
||||
if not (state := oauth_params.get("state")):
|
||||
logger.error("No state parameter provided")
|
||||
raise ValueError("Missing state parameter")
|
||||
|
||||
with make_session() as session:
|
||||
# Load OAuth state from database
|
||||
oauth_state = (
|
||||
session.query(OAuthState).filter(OAuthState.state == state).first()
|
||||
)
|
||||
if not oauth_state:
|
||||
logger.error(f"State {state} not found in database")
|
||||
raise ValueError("Invalid state parameter")
|
||||
|
||||
# Check if state has expired
|
||||
now = datetime.fromtimestamp(time.time())
|
||||
if oauth_state.expires_at < now:
|
||||
logger.error(f"State {state} has expired")
|
||||
oauth_state.stale = True # type: ignore
|
||||
session.commit()
|
||||
raise ValueError("State has expired")
|
||||
|
||||
oauth_state.code = f"code_{secrets.token_hex(16)}" # type: ignore
|
||||
oauth_state.stale = False # type: ignore
|
||||
oauth_state.user_id = user.id
|
||||
|
||||
session.add(oauth_state)
|
||||
session.commit()
|
||||
|
||||
return construct_redirect_uri(
|
||||
cast(str, oauth_state.redirect_uri),
|
||||
code=cast(str, oauth_state.code),
|
||||
state=state,
|
||||
)
|
||||
|
||||
async def load_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: str
|
||||
) -> Optional[AuthorizationCode]:
|
||||
"""Load an authorization code."""
|
||||
with make_session() as session:
|
||||
auth_code = (
|
||||
session.query(OAuthState)
|
||||
.filter(OAuthState.code == authorization_code)
|
||||
.first()
|
||||
)
|
||||
if not auth_code:
|
||||
logger.error(f"Invalid authorization code: {authorization_code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
return AuthorizationCode(**auth_code.serialize(code=True))
|
||||
|
||||
async def exchange_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
|
||||
) -> OAuthToken:
|
||||
"""Exchange authorization code for tokens."""
|
||||
with make_session() as session:
|
||||
auth_code = (
|
||||
session.query(OAuthState)
|
||||
.filter(OAuthState.code == authorization_code.code)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not auth_code:
|
||||
logger.error(f"Invalid authorization code: {authorization_code.code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
|
||||
if not auth_code.user:
|
||||
logger.error(f"No user found for auth code: {authorization_code.code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
|
||||
return make_token(session, auth_code, authorization_code.scopes)
|
||||
|
||||
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
||||
"""Load and validate an access token."""
|
||||
with make_session() as session:
|
||||
now = datetime.now(timezone.utc).replace(
|
||||
tzinfo=None
|
||||
) # Make naive for DB comparison
|
||||
|
||||
# Query for active (non-expired) session
|
||||
user_session = session.query(UserSession).get(token)
|
||||
if not user_session or user_session.expires_at < now:
|
||||
return None
|
||||
|
||||
return AccessToken(
|
||||
token=token,
|
||||
client_id=user_session.oauth_state.client_id,
|
||||
scopes=user_session.oauth_state.scopes,
|
||||
expires_at=int(user_session.expires_at.timestamp()),
|
||||
)
|
||||
|
||||
async def load_refresh_token(
|
||||
self, client: OAuthClientInformationFull, refresh_token: str
|
||||
) -> Optional[RefreshToken]:
|
||||
"""Load and validate a refresh token."""
|
||||
with make_session() as session:
|
||||
now = datetime.now()
|
||||
|
||||
# Query for the refresh token
|
||||
db_refresh_token = (
|
||||
session.query(OAuthRefreshToken)
|
||||
.filter(
|
||||
OAuthRefreshToken.token == refresh_token,
|
||||
OAuthRefreshToken.client_id == client.client_id,
|
||||
OAuthRefreshToken.revoked == False, # noqa: E712
|
||||
OAuthRefreshToken.expires_at > now,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_refresh_token:
|
||||
logger.error(
|
||||
f"Invalid or expired refresh token: {refresh_token[:20]}..."
|
||||
)
|
||||
return None
|
||||
|
||||
return RefreshToken(
|
||||
token=refresh_token,
|
||||
client_id=client.client_id,
|
||||
scopes=cast(list[str], db_refresh_token.scopes),
|
||||
expires_at=int(db_refresh_token.expires_at.timestamp()),
|
||||
)
|
||||
|
||||
async def exchange_refresh_token(
|
||||
self,
|
||||
client: OAuthClientInformationFull,
|
||||
refresh_token: RefreshToken,
|
||||
scopes: list[str],
|
||||
) -> OAuthToken:
|
||||
"""Exchange refresh token for new access token."""
|
||||
with make_session() as session:
|
||||
# Load the refresh token from database
|
||||
db_refresh_token = (
|
||||
session.query(OAuthRefreshToken)
|
||||
.filter(
|
||||
OAuthRefreshToken.token == refresh_token.token,
|
||||
OAuthRefreshToken.client_id == client.client_id,
|
||||
OAuthRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_refresh_token:
|
||||
logger.error(f"Refresh token not found: {refresh_token.token[:20]}...")
|
||||
raise ValueError("Invalid refresh token")
|
||||
|
||||
# Validate refresh token
|
||||
validate_refresh_token(db_refresh_token)
|
||||
|
||||
# Validate requested scopes are subset of original scopes
|
||||
original_scopes = set(cast(list[str], db_refresh_token.scopes))
|
||||
requested_scopes = set(scopes) if scopes else original_scopes
|
||||
|
||||
if not requested_scopes.issubset(original_scopes):
|
||||
logger.error(
|
||||
f"Requested scopes {requested_scopes} exceed original scopes {original_scopes}"
|
||||
)
|
||||
raise ValueError("Requested scopes exceed original authorization")
|
||||
|
||||
return make_token(session, db_refresh_token, scopes)
|
||||
|
||||
async def revoke_token(
|
||||
self, token: str, token_type_hint: Optional[str] = None
|
||||
) -> None:
|
||||
"""Revoke a token (access token or refresh token)."""
|
||||
with make_session() as session:
|
||||
revoked = False
|
||||
|
||||
# Try to revoke as access token (UserSession)
|
||||
if not token_type_hint or token_type_hint == "access_token":
|
||||
user_session = session.query(UserSession).get(token)
|
||||
if user_session:
|
||||
session.delete(user_session)
|
||||
revoked = True
|
||||
logger.info(f"Revoked access token: {token[:20]}...")
|
||||
|
||||
# Try to revoke as refresh token
|
||||
if not revoked and (
|
||||
not token_type_hint or token_type_hint == "refresh_token"
|
||||
):
|
||||
refresh_token = (
|
||||
session.query(OAuthRefreshToken)
|
||||
.filter(OAuthRefreshToken.token == token)
|
||||
.first()
|
||||
)
|
||||
if refresh_token:
|
||||
refresh_token.revoked = True # type: ignore
|
||||
revoked = True
|
||||
logger.info(f"Revoked refresh token: {token[:20]}...")
|
||||
|
||||
if revoked:
|
||||
session.commit()
|
||||
else:
|
||||
logger.warning(f"Token not found for revocation: {token[:20]}...")
|
||||
|
||||
def get_protected_resource_metadata(self) -> dict[str, Any]:
|
||||
"""Return metadata about the protected resource."""
|
||||
return {
|
||||
"resource_server": settings.SERVER_URL,
|
||||
"scopes_supported": ["read", "write"],
|
||||
"bearer_methods_supported": ["header"],
|
||||
"resource_documentation": f"{settings.SERVER_URL}/docs",
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"token_endpoint_auth_methods_supported": ["none", "client_secret_basic"],
|
||||
"refresh_token_rotation_enabled": True,
|
||||
"protected_resources": [
|
||||
{
|
||||
"resource_uri": f"{settings.SERVER_URL}/mcp",
|
||||
"scopes": ["read", "write"],
|
||||
"http_methods": ["POST", "GET"],
|
||||
},
|
||||
{
|
||||
"resource_uri": f"{settings.SERVER_URL}/mcp/",
|
||||
"scopes": ["read", "write"],
|
||||
"http_methods": ["POST", "GET"],
|
||||
},
|
||||
],
|
||||
}
|
@ -5,19 +5,21 @@ MCP tools for the epistemic sparring partner system.
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.auth.middleware.auth_context import get_access_token
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.common.db.models import (
|
||||
AgentObservation,
|
||||
SourceItem,
|
||||
UserSession,
|
||||
)
|
||||
from memory.api.MCP.base import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create MCP server instance
|
||||
mcp = FastMCP("memory", stateless_http=True)
|
||||
|
||||
|
||||
def filter_observation_source_ids(
|
||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||
@ -67,4 +69,42 @@ def filter_source_ids(
|
||||
@mcp.tool()
|
||||
async def get_current_time() -> dict:
|
||||
"""Get the current time in UTC."""
|
||||
logger.info("get_current_time tool called")
|
||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_authenticated_user() -> dict:
|
||||
"""Get information about the authenticated user."""
|
||||
logger.info("🔧 get_authenticated_user tool called")
|
||||
access_token = get_access_token()
|
||||
logger.info(f"🔧 Access token from MCP context: {access_token}")
|
||||
|
||||
if not access_token:
|
||||
logger.warning("❌ No access token found in MCP context!")
|
||||
return {"error": "Not authenticated"}
|
||||
|
||||
logger.info(
|
||||
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
|
||||
)
|
||||
|
||||
# Look up the actual user from the session token
|
||||
with make_session() as session:
|
||||
user_session = (
|
||||
session.query(UserSession)
|
||||
.filter(UserSession.id == access_token.token)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_session and user_session.user:
|
||||
user_info = user_session.user.serialize()
|
||||
else:
|
||||
user_info = {"error": "User not found"}
|
||||
|
||||
return {
|
||||
"authenticated": True,
|
||||
"token_type": "Bearer",
|
||||
"scopes": access_token.scopes,
|
||||
"client_id": access_token.client_id,
|
||||
"user": user_info,
|
||||
}
|
||||
|
@ -2,8 +2,16 @@
|
||||
SQLAdmin views for the knowledge base database models.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from sqladmin import Admin, ModelView
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
import logging
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import (
|
||||
Chunk,
|
||||
SourceItem,
|
||||
@ -21,8 +29,11 @@ from memory.common.db.models import (
|
||||
AgentObservation,
|
||||
Note,
|
||||
User,
|
||||
UserSession,
|
||||
OAuthState,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_COLUMNS = (
|
||||
"modality",
|
||||
@ -218,7 +229,7 @@ class UserAdmin(ModelView, model=User):
|
||||
|
||||
|
||||
def setup_admin(admin: Admin):
|
||||
"""Add all admin views to the admin instance."""
|
||||
"""Add all admin views to the admin instance with OAuth protection."""
|
||||
admin.add_view(SourceItemAdmin)
|
||||
admin.add_view(AgentObservationAdmin)
|
||||
admin.add_view(NoteAdmin)
|
||||
|
@ -16,22 +16,23 @@ from fastapi import (
|
||||
Query,
|
||||
Form,
|
||||
Depends,
|
||||
Request,
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqladmin import Admin
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common import extract
|
||||
from memory.common import extract, settings
|
||||
from memory.common.db.connection import get_engine
|
||||
from memory.common.db.models import User
|
||||
from memory.api.admin import setup_admin
|
||||
from memory.api.search import search, SearchResult
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.api.auth import (
|
||||
router as auth_router,
|
||||
get_current_user,
|
||||
AuthenticationMiddleware,
|
||||
router as auth_router,
|
||||
)
|
||||
from memory.api.MCP.base import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -44,24 +45,37 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify actual origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# SQLAdmin setup
|
||||
# SQLAdmin setup with OAuth protection
|
||||
engine = get_engine()
|
||||
admin = Admin(app, engine)
|
||||
|
||||
# Setup admin with OAuth protection using existing OAuth provider
|
||||
setup_admin(admin)
|
||||
|
||||
# Include auth router
|
||||
app.add_middleware(AuthenticationMiddleware)
|
||||
app.include_router(auth_router)
|
||||
app.add_middleware(AuthenticationMiddleware)
|
||||
|
||||
|
||||
# Add health check to MCP server instead of main app
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request):
|
||||
"""Simple health check endpoint on MCP server"""
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"})
|
||||
|
||||
|
||||
# Mount MCP server at root - OAuth endpoints need to be at root level
|
||||
app.mount("/", mcp.streamable_http_app())
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health_check():
|
||||
"""Simple health check endpoint"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||
if not item:
|
||||
return []
|
||||
|
@ -1,10 +1,10 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import textwrap
|
||||
from typing import cast
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from memory.common import settings
|
||||
from sqlalchemy.orm import Session as DBSession, scoped_session
|
||||
@ -52,28 +52,35 @@ def create_user_session(
|
||||
return str(session.id)
|
||||
|
||||
|
||||
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
|
||||
"""Get user from session ID if session is valid"""
|
||||
def get_user_session(
|
||||
request: Request, db: DBSession | scoped_session
|
||||
) -> UserSession | None:
|
||||
"""Get session ID from request"""
|
||||
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
session = db.query(UserSession).get(session_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if session.expires_at.replace(tzinfo=timezone.utc) > now:
|
||||
if session.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||
return None
|
||||
return session
|
||||
|
||||
|
||||
def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None:
|
||||
"""Get user from session ID if session is valid"""
|
||||
if session := get_user_session(request, db):
|
||||
return session.user
|
||||
return None
|
||||
|
||||
|
||||
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
||||
"""FastAPI dependency to get current authenticated user"""
|
||||
# Check for session ID in header or cookie
|
||||
session_id = request.headers.get(
|
||||
settings.SESSION_HEADER_NAME
|
||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=401, detail="No session provided")
|
||||
|
||||
user = get_session_user(session_id, db)
|
||||
user = get_session_user(request, db)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
||||
|
||||
@ -117,21 +124,18 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
|
||||
|
||||
|
||||
@router.get("/login", response_model=LoginResponse)
|
||||
def login_page():
|
||||
def login_page(request: Request):
|
||||
"""Login page"""
|
||||
return HTMLResponse(
|
||||
content=textwrap.dedent("""
|
||||
<html>
|
||||
<body>
|
||||
<h1>Login</h1>
|
||||
<form method="post" action="/auth/login-form">
|
||||
<input type="email" name="email" placeholder="Email" />
|
||||
<input type="password" name="password" placeholder="Password" />
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
"""),
|
||||
template_dir = pathlib.Path(__file__).parent / "templates"
|
||||
templates = Jinja2Templates(directory=template_dir)
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{
|
||||
"request": request,
|
||||
"action": router.url_path_for("login_form"),
|
||||
"error": None,
|
||||
"form_data": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -170,9 +174,17 @@ def login_form(
|
||||
return LoginResponse(session_id=session_id, **user.serialize())
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(response: Response, user: User = Depends(get_current_user)):
|
||||
@router.api_route("/logout", methods=["GET", "POST"])
|
||||
def logout(
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: DBSession = Depends(get_session),
|
||||
):
|
||||
"""Logout and clear session"""
|
||||
session = get_user_session(request, db)
|
||||
if session:
|
||||
db.delete(session)
|
||||
db.commit()
|
||||
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
||||
return {"message": "Logged out successfully"}
|
||||
|
||||
@ -192,6 +204,11 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"/auth/login",
|
||||
"/auth/login-form",
|
||||
"/auth/register",
|
||||
"/register",
|
||||
"/token",
|
||||
"/mcp",
|
||||
"/oauth/",
|
||||
"/.well-known/",
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
@ -205,10 +222,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
# Check for session ID in header or cookie
|
||||
session_id = request.headers.get(
|
||||
settings.SESSION_HEADER_NAME
|
||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
|
||||
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
if not session_id:
|
||||
return Response(
|
||||
content="Authentication required",
|
||||
@ -218,7 +232,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Validate session and get user
|
||||
with make_session() as session:
|
||||
user = get_session_user(session_id, session)
|
||||
user = get_session_user(request, session)
|
||||
if not user:
|
||||
return Response(
|
||||
content="Invalid or expired session",
|
||||
|
159
src/memory/api/templates/login.html
Normal file
159
src/memory/api/templates/login.html
Normal file
@ -0,0 +1,159 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Login - Memory OAuth</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.login-container {
|
||||
background: white;
|
||||
padding: 2rem;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
|
||||
width: 100%;
|
||||
max-width: 400px;
|
||||
}
|
||||
|
||||
.login-header {
|
||||
text-align: center;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.login-header h1 {
|
||||
color: #333;
|
||||
margin: 0 0 0.5rem 0;
|
||||
font-size: 1.8rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.login-header p {
|
||||
color: #666;
|
||||
margin: 0;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.form-group {
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.form-group label {
|
||||
display: block;
|
||||
margin-bottom: 0.5rem;
|
||||
color: #333;
|
||||
font-weight: 500;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.form-group input {
|
||||
width: 100%;
|
||||
padding: 0.75rem;
|
||||
border: 2px solid #e1e5e9;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
transition: border-color 0.2s;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.form-group input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
|
||||
.login-button {
|
||||
width: 100%;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s;
|
||||
}
|
||||
|
||||
.login-button:hover {
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.login-button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.error-message {
|
||||
background: #fee;
|
||||
color: #c33;
|
||||
padding: 0.75rem;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 1rem;
|
||||
font-size: 0.9rem;
|
||||
border-left: 4px solid #c33;
|
||||
}
|
||||
|
||||
.demo-credentials {
|
||||
background: #f8f9fa;
|
||||
padding: 1rem;
|
||||
border-radius: 8px;
|
||||
margin-top: 1.5rem;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.demo-credentials strong {
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.demo-credentials code {
|
||||
background: #e9ecef;
|
||||
padding: 0.2rem 0.4rem;
|
||||
border-radius: 4px;
|
||||
font-family: monospace;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="login-container">
|
||||
<div class="login-header">
|
||||
<h1>Sign In</h1>
|
||||
<p>Access your Memory knowledge base</p>
|
||||
</div>
|
||||
|
||||
{% if error %}
|
||||
<div class="error-message">
|
||||
{{ error }}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<form method="post" action="{{ action }}">
|
||||
{% for key, value in form_data.items() %}
|
||||
<input type="hidden" name="{{ key }}" value="{{ value }}">
|
||||
{% endfor %}
|
||||
|
||||
<div class="form-group">
|
||||
<label for="email">Email</label>
|
||||
<input type="email" id="email" name="email" required>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="password">Password</label>
|
||||
<input type="password" id="password" name="password" required>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="login-button">Sign In</button>
|
||||
</form>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
@ -35,6 +35,9 @@ from memory.common.db.models.sources import (
|
||||
from memory.common.db.models.users import (
|
||||
User,
|
||||
UserSession,
|
||||
OAuthClientInformation,
|
||||
OAuthState,
|
||||
OAuthRefreshToken,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -69,4 +72,7 @@ __all__ = [
|
||||
# Users
|
||||
"User",
|
||||
"UserSession",
|
||||
"OAuthClientInformation",
|
||||
"OAuthState",
|
||||
"OAuthRefreshToken",
|
||||
]
|
||||
|
@ -3,9 +3,19 @@ import secrets
|
||||
from typing import cast
|
||||
import uuid
|
||||
from memory.common.db.models.base import Base
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Boolean,
|
||||
ARRAY,
|
||||
Numeric,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
@ -35,6 +45,9 @@ class User(Base):
|
||||
sessions = relationship(
|
||||
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oauth_states = relationship(
|
||||
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
@ -57,9 +70,127 @@ class UserSession(Base):
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
oauth_state_id = Column(Integer, ForeignKey("oauth_states.id"), nullable=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
|
||||
# Relationship to user
|
||||
user = relationship("User", back_populates="sessions")
|
||||
oauth_state = relationship("OAuthState", back_populates="session")
|
||||
|
||||
|
||||
class OAuthClientInformation(Base):
|
||||
__tablename__ = "oauth_client"
|
||||
|
||||
client_id = Column(String, primary_key=True)
|
||||
client_secret = Column(String, nullable=True)
|
||||
client_id_issued_at = Column(Numeric, nullable=False)
|
||||
client_secret_expires_at = Column(Numeric, nullable=True)
|
||||
|
||||
redirect_uris = Column(ARRAY(String), nullable=False)
|
||||
token_endpoint_auth_method = Column(String, nullable=False)
|
||||
grant_types = Column(ARRAY(String), nullable=False)
|
||||
response_types = Column(ARRAY(String), nullable=False)
|
||||
scope = Column(String, nullable=False)
|
||||
client_name = Column(String, nullable=False)
|
||||
client_uri = Column(String, nullable=True)
|
||||
logo_uri = Column(String, nullable=True)
|
||||
contacts = Column(ARRAY(String), nullable=True)
|
||||
tos_uri = Column(String, nullable=True)
|
||||
policy_uri = Column(String, nullable=True)
|
||||
jwks_uri = Column(String, nullable=True)
|
||||
|
||||
sessions = relationship(
|
||||
"OAuthState", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"client_id_issued_at": self.client_id_issued_at,
|
||||
"client_secret_expires_at": self.client_secret_expires_at,
|
||||
"redirect_uris": self.redirect_uris,
|
||||
"token_endpoint_auth_method": self.token_endpoint_auth_method,
|
||||
"grant_types": self.grant_types,
|
||||
"response_types": self.response_types,
|
||||
"scope": self.scope,
|
||||
"client_name": self.client_name,
|
||||
"client_uri": self.client_uri,
|
||||
"logo_uri": self.logo_uri,
|
||||
"contacts": self.contacts,
|
||||
"tos_uri": self.tos_uri,
|
||||
"policy_uri": self.policy_uri,
|
||||
"jwks_uri": self.jwks_uri,
|
||||
}
|
||||
|
||||
|
||||
class OAuthToken:
|
||||
id = Column(Integer, primary_key=True)
|
||||
client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
scopes = Column(ARRAY(String), nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
"client_id": self.client_id,
|
||||
"scopes": self.scopes,
|
||||
"expires_at": self.expires_at.timestamp(),
|
||||
}
|
||||
|
||||
|
||||
class OAuthState(Base, OAuthToken):
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
state = Column(String, nullable=False)
|
||||
code = Column(String, nullable=True)
|
||||
redirect_uri = Column(String, nullable=False)
|
||||
redirect_uri_provided_explicitly = Column(Boolean, nullable=False)
|
||||
code_challenge = Column(String, nullable=True)
|
||||
stale = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
def serialize(self, code: bool = False) -> dict:
|
||||
data = {
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly,
|
||||
"code_challenge": self.code_challenge,
|
||||
} | super().serialize()
|
||||
if code:
|
||||
data |= {
|
||||
"code": self.code,
|
||||
"expires_at": self.expires_at.timestamp(),
|
||||
}
|
||||
return data
|
||||
|
||||
client = relationship("OAuthClientInformation", back_populates="sessions")
|
||||
session = relationship("UserSession", back_populates="oauth_state", uselist=False)
|
||||
user = relationship("User", back_populates="oauth_states")
|
||||
|
||||
|
||||
class OAuthRefreshToken(Base, OAuthToken):
|
||||
__tablename__ = "oauth_refresh_tokens"
|
||||
|
||||
token = Column(
|
||||
String, nullable=False, default=lambda: f"rt_{secrets.token_hex(32)}"
|
||||
)
|
||||
revoked = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Optional: link to the access token session that was created with this refresh token
|
||||
access_token_session_id = Column(
|
||||
String, ForeignKey("user_sessions.id"), nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClientInformation")
|
||||
user = relationship("User")
|
||||
access_token_session = relationship("UserSession")
|
||||
|
||||
def serialize(self) -> dict:
|
||||
return {
|
||||
"token": self.token,
|
||||
"expires_at": self.expires_at.timestamp(),
|
||||
"revoked": self.revoked,
|
||||
} | super().serialize()
|
||||
|
@ -131,8 +131,8 @@ else:
|
||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
|
||||
# API settings
|
||||
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
||||
HTTPS = boolean_env("HTTPS", False)
|
||||
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
||||
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
||||
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
||||
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
||||
|
@ -1,126 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
remote_server: str
|
||||
session_header: str = "X-Session-ID"
|
||||
session_id: str | None = None
|
||||
port: int = 8080
|
||||
|
||||
|
||||
def parse_args() -> State:
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Simple HTTP proxy with authentication"
|
||||
)
|
||||
parser.add_argument("--remote-server", required=True, help="Remote server URL")
|
||||
parser.add_argument("--email", required=True, help="Email for authentication")
|
||||
parser.add_argument("--password", required=True, help="Password for authentication")
|
||||
parser.add_argument(
|
||||
"--session-header", default="X-Session-ID", help="Session header name"
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8080, help="Port to run proxy on")
|
||||
return State(**vars(parser.parse_args()))
|
||||
|
||||
|
||||
state = parse_args()
|
||||
|
||||
|
||||
async def login() -> None:
|
||||
"""Login to remote server and store session ID"""
|
||||
login_url = f"{state.remote_server}/auth/login"
|
||||
login_data = {"email": state.email, "password": state.password}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(login_url, json=login_data)
|
||||
response.raise_for_status()
|
||||
|
||||
login_response = response.json()
|
||||
state.session_id = login_response["session_id"]
|
||||
print(f"Successfully logged in, session ID: {state.session_id}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
print(
|
||||
f"Login failed with status {e.response.status_code}: {e.response.text}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Login failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def proxy_request(request: Request) -> Response:
|
||||
"""Proxy request to remote server with session header"""
|
||||
if not state.session_id:
|
||||
try:
|
||||
await login()
|
||||
except Exception as e:
|
||||
print(f"Login failed: {e}")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
# Build the target URL
|
||||
target_url = f"{state.remote_server}{request.url.path}"
|
||||
if request.url.query:
|
||||
target_url += f"?{request.url.query}"
|
||||
|
||||
# Get request body
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=headers | {state.session_header: state.session_id}, # type: ignore
|
||||
content=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
# Forward response
|
||||
resp = Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers={
|
||||
k: v.replace(state.remote_server, f"http://localhost:{state.port}")
|
||||
for k, v in response.headers.items()
|
||||
},
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
print(resp.headers)
|
||||
return resp
|
||||
|
||||
except httpx.RequestError as e:
|
||||
print(f"Request failed: {e}")
|
||||
raise HTTPException(status_code=502, detail=f"Proxy request failed: {e}")
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(title="Simple Proxy")
|
||||
|
||||
|
||||
@app.api_route(
|
||||
"/{path:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
|
||||
)
|
||||
async def proxy_all(request: Request):
|
||||
"""Proxy all requests to remote server"""
|
||||
return await proxy_request(request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Starting proxy server on port {state.port}")
|
||||
print(f"Proxying to: {state.remote_server}")
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=state.port)
|
Loading…
x
Reference in New Issue
Block a user