mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-30 06:36:07 +02:00
Compare commits
4 Commits
4d057d1ec6
...
6ee46d6215
Author | SHA1 | Date | |
---|---|---|---|
![]() |
6ee46d6215 | ||
![]() |
d17d724631 | ||
![]() |
986d5b9957 | ||
![]() |
4556ef2c48 |
12
README.md
12
README.md
@ -80,18 +80,6 @@ python tools/run_celery_task.py notes setup-git-notes --origin ssh://git@github.
|
|||||||
For this to work you need to make sure you have set up the ssh keys in `secrets` (see the README.md
|
For this to work you need to make sure you have set up the ssh keys in `secrets` (see the README.md
|
||||||
in that folder), and you will need to add the public key that is generated there to your git server.
|
in that folder), and you will need to add the public key that is generated there to your git server.
|
||||||
|
|
||||||
### Authentication
|
|
||||||
|
|
||||||
The API uses session-based authentication. Login via:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:8000/auth/login \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{"email": "user@example.com", "password": "yourpassword"}'
|
|
||||||
```
|
|
||||||
|
|
||||||
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
|
|
||||||
|
|
||||||
## Discord integration
|
## Discord integration
|
||||||
|
|
||||||
If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications).
|
If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications).
|
||||||
|
115
db/migrations/versions/20250606_165333_oauth_codes.py
Normal file
115
db/migrations/versions/20250606_165333_oauth_codes.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
"""oauth codes
|
||||||
|
|
||||||
|
Revision ID: 1d6bc8015ea9
|
||||||
|
Revises: 58439dd3088b
|
||||||
|
Create Date: 2025-06-06 16:53:33.044558
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "1d6bc8015ea9"
|
||||||
|
down_revision: Union[str, None] = "58439dd3088b"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"oauth_client",
|
||||||
|
sa.Column("client_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("client_secret", sa.String(), nullable=True),
|
||||||
|
sa.Column("client_id_issued_at", sa.Numeric(), nullable=False),
|
||||||
|
sa.Column("client_secret_expires_at", sa.Numeric(), nullable=True),
|
||||||
|
sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=False),
|
||||||
|
sa.Column("token_endpoint_auth_method", sa.String(), nullable=False),
|
||||||
|
sa.Column("grant_types", sa.ARRAY(sa.String()), nullable=False),
|
||||||
|
sa.Column("response_types", sa.ARRAY(sa.String()), nullable=False),
|
||||||
|
sa.Column("scope", sa.String(), nullable=False),
|
||||||
|
sa.Column("client_name", sa.String(), nullable=False),
|
||||||
|
sa.Column("client_uri", sa.String(), nullable=True),
|
||||||
|
sa.Column("logo_uri", sa.String(), nullable=True),
|
||||||
|
sa.Column("contacts", sa.ARRAY(sa.String()), nullable=True),
|
||||||
|
sa.Column("tos_uri", sa.String(), nullable=True),
|
||||||
|
sa.Column("policy_uri", sa.String(), nullable=True),
|
||||||
|
sa.Column("jwks_uri", sa.String(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("client_id"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_states",
|
||||||
|
sa.Column("state", sa.String(), nullable=False),
|
||||||
|
sa.Column("code", sa.String(), nullable=True),
|
||||||
|
sa.Column("redirect_uri", sa.String(), nullable=False),
|
||||||
|
sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("code_challenge", sa.String(), nullable=True),
|
||||||
|
sa.Column("stale", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"],
|
||||||
|
["oauth_client.client_id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_refresh_tokens",
|
||||||
|
sa.Column("token", sa.String(), nullable=False),
|
||||||
|
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("access_token_session_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["access_token_session_id"],
|
||||||
|
["user_sessions.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"],
|
||||||
|
["oauth_client.client_id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"user_sessions", sa.Column("oauth_state_id", sa.Integer(), nullable=True)
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"user_sessions_oauth_state_id_fkey",
|
||||||
|
"user_sessions",
|
||||||
|
"oauth_states",
|
||||||
|
["oauth_state_id"],
|
||||||
|
["id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_constraint(
|
||||||
|
"user_sessions_oauth_state_id_fkey", "user_sessions", type_="foreignkey"
|
||||||
|
)
|
||||||
|
op.drop_column("user_sessions", "oauth_state_id")
|
||||||
|
op.drop_table("oauth_refresh_tokens")
|
||||||
|
op.drop_table("oauth_states")
|
||||||
|
op.drop_table("oauth_client")
|
@ -147,6 +147,7 @@ services:
|
|||||||
<<: *env
|
<<: *env
|
||||||
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||||
QDRANT_URL: http://qdrant:6333
|
QDRANT_URL: http://qdrant:6333
|
||||||
|
SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
|
||||||
secrets: [postgres_password]
|
secrets: [postgres_password]
|
||||||
volumes:
|
volumes:
|
||||||
- ./memory_files:/app/memory_files:rw
|
- ./memory_files:/app/memory_files:rw
|
||||||
@ -158,20 +159,6 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
|
|
||||||
proxy:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: docker/api/Dockerfile
|
|
||||||
restart: unless-stopped
|
|
||||||
networks: [kbnet]
|
|
||||||
environment:
|
|
||||||
<<: *env
|
|
||||||
command: ["python", "/app/tools/simple_proxy.py", "--remote-server", "${PROXY_REMOTE_SERVER:-http://api:8000}", "--email", "${PROXY_EMAIL}", "--password", "${PROXY_PASSWORD}", "--port", "8001"]
|
|
||||||
volumes:
|
|
||||||
- ./tools:/app/tools:ro
|
|
||||||
ports:
|
|
||||||
- "8001:8001"
|
|
||||||
|
|
||||||
# ------------------------------------------------------------ Celery workers
|
# ------------------------------------------------------------ Celery workers
|
||||||
worker:
|
worker:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
@ -201,4 +188,4 @@ services:
|
|||||||
# restart: unless-stopped
|
# restart: unless-stopped
|
||||||
# command: [ "--schedule", "0 0 4 * * *", "--cleanup" ]
|
# command: [ "--schedule", "0 0 4 * * *", "--cleanup" ]
|
||||||
# volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ]
|
# volumes: [ "/var/run/docker.sock:/var/run/docker.sock:ro" ]
|
||||||
# networks: [ kbnet ]
|
# networks: [ kbnet ]
|
||||||
|
132
src/memory/api/MCP/base.py
Normal file
132
src/memory/api/MCP/base.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
||||||
|
from mcp.server.auth.handlers.token import (
|
||||||
|
AuthorizationCodeRequest,
|
||||||
|
RefreshTokenRequest,
|
||||||
|
TokenRequest,
|
||||||
|
)
|
||||||
|
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
from mcp.shared.auth import OAuthClientMetadata
|
||||||
|
from memory.common.db.models.users import User
|
||||||
|
from pydantic import AnyHttpUrl
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, RedirectResponse
|
||||||
|
from starlette.templating import Jinja2Templates
|
||||||
|
|
||||||
|
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import OAuthState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_metadata(klass: type):
|
||||||
|
orig_validate = klass.model_validate
|
||||||
|
|
||||||
|
def validate(data: dict):
|
||||||
|
data = dict(data)
|
||||||
|
if "redirect_uris" in data:
|
||||||
|
data["redirect_uris"] = [
|
||||||
|
str(uri).replace("cursor://", "http://")
|
||||||
|
for uri in data["redirect_uris"]
|
||||||
|
]
|
||||||
|
if "redirect_uri" in data:
|
||||||
|
data["redirect_uri"] = str(data["redirect_uri"]).replace(
|
||||||
|
"cursor://", "http://"
|
||||||
|
)
|
||||||
|
|
||||||
|
return orig_validate(data)
|
||||||
|
|
||||||
|
klass.model_validate = validate
|
||||||
|
|
||||||
|
|
||||||
|
validate_metadata(OAuthClientMetadata)
|
||||||
|
validate_metadata(AuthorizationRequest)
|
||||||
|
validate_metadata(AuthorizationCodeRequest)
|
||||||
|
validate_metadata(RefreshTokenRequest)
|
||||||
|
validate_metadata(TokenRequest)
|
||||||
|
|
||||||
|
|
||||||
|
# Setup templates
|
||||||
|
template_dir = pathlib.Path(__file__).parent.parent / "templates"
|
||||||
|
templates = Jinja2Templates(directory=template_dir)
|
||||||
|
|
||||||
|
|
||||||
|
oauth_provider = SimpleOAuthProvider()
|
||||||
|
auth_settings = AuthSettings(
|
||||||
|
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
||||||
|
client_registration_options=ClientRegistrationOptions(
|
||||||
|
enabled=True,
|
||||||
|
valid_scopes=["read", "write"],
|
||||||
|
default_scopes=["read"],
|
||||||
|
),
|
||||||
|
required_scopes=["read", "write"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp = FastMCP(
|
||||||
|
"memory",
|
||||||
|
stateless_http=True,
|
||||||
|
auth_server_provider=oauth_provider,
|
||||||
|
auth=auth_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
|
||||||
|
async def oauth_protected_resource(request: Request):
|
||||||
|
"""OAuth 2.0 Protected Resource Metadata."""
|
||||||
|
logger.info("Protected resource metadata requested")
|
||||||
|
return JSONResponse(oauth_provider.get_protected_resource_metadata())
|
||||||
|
|
||||||
|
|
||||||
|
def login_form(request: Request, form_data: dict, error: str | None = None):
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
"login.html",
|
||||||
|
{
|
||||||
|
"request": request,
|
||||||
|
"form_data": form_data,
|
||||||
|
"error": error,
|
||||||
|
"action": "/oauth/login",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.custom_route("/oauth/login", methods=["GET"])
|
||||||
|
async def login_page(request: Request):
|
||||||
|
"""Display the login page."""
|
||||||
|
form_data = dict(request.query_params)
|
||||||
|
|
||||||
|
state = form_data.get("state")
|
||||||
|
with make_session() as session:
|
||||||
|
oauth_state = (
|
||||||
|
session.query(OAuthState).filter(OAuthState.state == state).first()
|
||||||
|
)
|
||||||
|
if not oauth_state:
|
||||||
|
logger.error(f"State {state} not found in database")
|
||||||
|
raise ValueError("Invalid state parameter")
|
||||||
|
|
||||||
|
return login_form(request, form_data, None)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.custom_route("/oauth/login", methods=["POST"])
|
||||||
|
async def handle_login(request: Request):
|
||||||
|
"""Handle login form submission."""
|
||||||
|
form = await request.form()
|
||||||
|
oauth_params = {
|
||||||
|
key: value for key, value in form.items() if key not in ["email", "password"]
|
||||||
|
}
|
||||||
|
with make_session() as session:
|
||||||
|
user = session.query(User).filter(User.email == form.get("email")).first()
|
||||||
|
if not user or not user.is_valid_password(str(form.get("password", ""))):
|
||||||
|
logger.warning("Login failed - invalid credentials")
|
||||||
|
return login_form(request, oauth_params, "Invalid email or password")
|
||||||
|
|
||||||
|
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
|
||||||
|
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
||||||
|
redirect_url = redirect_url.replace("http://", "cursor://")
|
||||||
|
return RedirectResponse(url=redirect_url, status_code=302)
|
411
src/memory/api/MCP/oauth_provider.py
Normal file
411
src/memory/api/MCP/oauth_provider.py
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from typing import Optional, Any, cast
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from memory.common.db.models.users import (
|
||||||
|
User,
|
||||||
|
UserSession,
|
||||||
|
OAuthClientInformation,
|
||||||
|
OAuthState,
|
||||||
|
OAuthRefreshToken,
|
||||||
|
OAuthToken as TokenBase,
|
||||||
|
)
|
||||||
|
from memory.common.db.connection import make_session, scoped_session
|
||||||
|
from memory.common import settings
|
||||||
|
from mcp.server.auth.provider import (
|
||||||
|
AccessToken,
|
||||||
|
AuthorizationParams,
|
||||||
|
AuthorizationCode,
|
||||||
|
OAuthAuthorizationServerProvider,
|
||||||
|
RefreshToken,
|
||||||
|
construct_redirect_uri,
|
||||||
|
)
|
||||||
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Token configuration constants
|
||||||
|
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
|
||||||
|
REFRESH_TOKEN_LIFETIME = 30 * 24 * 3600 # 30 days
|
||||||
|
|
||||||
|
|
||||||
|
def create_expiration(lifetime_seconds: int) -> datetime:
|
||||||
|
"""Create expiration datetime from lifetime in seconds."""
|
||||||
|
return datetime.fromtimestamp(time.time() + lifetime_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_refresh_token() -> str:
|
||||||
|
"""Generate a new refresh token."""
|
||||||
|
return f"rt_{secrets.token_hex(32)}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token_session(
|
||||||
|
user_id: int, oauth_state_id: str | None = None
|
||||||
|
) -> UserSession:
|
||||||
|
"""Create a new access token session."""
|
||||||
|
return UserSession(
|
||||||
|
user_id=user_id,
|
||||||
|
oauth_state_id=oauth_state_id,
|
||||||
|
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token_record(
|
||||||
|
client_id: str,
|
||||||
|
user_id: int,
|
||||||
|
scopes: list[str],
|
||||||
|
access_token_session_id: Optional[str] = None,
|
||||||
|
) -> OAuthRefreshToken:
|
||||||
|
"""Create a new refresh token record."""
|
||||||
|
return OAuthRefreshToken(
|
||||||
|
token=generate_refresh_token(),
|
||||||
|
client_id=client_id,
|
||||||
|
user_id=user_id,
|
||||||
|
scopes=scopes,
|
||||||
|
expires_at=create_expiration(REFRESH_TOKEN_LIFETIME),
|
||||||
|
access_token_session_id=access_token_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_refresh_token(db_refresh_token: OAuthRefreshToken) -> None:
|
||||||
|
"""Validate a refresh token, raising ValueError if invalid."""
|
||||||
|
now = datetime.now()
|
||||||
|
if db_refresh_token.expires_at < now: # type: ignore
|
||||||
|
logger.error(f"Refresh token expired: {db_refresh_token.token[:20]}...")
|
||||||
|
db_refresh_token.revoked = True # type: ignore
|
||||||
|
raise ValueError("Refresh token expired")
|
||||||
|
|
||||||
|
|
||||||
|
def create_oauth_token(
|
||||||
|
access_token: str, scopes: list[str], refresh_token: Optional[str] = None
|
||||||
|
) -> OAuthToken:
|
||||||
|
"""Create an OAuth token response."""
|
||||||
|
return OAuthToken(
|
||||||
|
access_token=access_token,
|
||||||
|
token_type="bearer",
|
||||||
|
expires_in=ACCESS_TOKEN_LIFETIME,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
scope=" ".join(scopes),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_token(
|
||||||
|
db: scoped_session,
|
||||||
|
auth_state: TokenBase,
|
||||||
|
scopes: list[str],
|
||||||
|
) -> OAuthToken:
|
||||||
|
new_session = UserSession(
|
||||||
|
user_id=auth_state.user_id,
|
||||||
|
oauth_state_id=auth_state.id,
|
||||||
|
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create refresh token
|
||||||
|
refresh_token = create_refresh_token_record(
|
||||||
|
cast(str, auth_state.client_id),
|
||||||
|
cast(int, auth_state.user_id),
|
||||||
|
scopes,
|
||||||
|
cast(str, new_session.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(new_session)
|
||||||
|
db.add(refresh_token)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return create_oauth_token(
|
||||||
|
str(new_session.id),
|
||||||
|
scopes,
|
||||||
|
cast(str, refresh_token.token),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||||
|
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||||
|
"""Get OAuth client information."""
|
||||||
|
with make_session() as session:
|
||||||
|
client = session.query(OAuthClientInformation).get(client_id)
|
||||||
|
return client and OAuthClientInformationFull(**client.serialize())
|
||||||
|
|
||||||
|
async def register_client(self, client_info: OAuthClientInformationFull):
|
||||||
|
"""Register a new OAuth client."""
|
||||||
|
with make_session() as session:
|
||||||
|
client = session.query(OAuthClientInformation).get(client_info.client_id)
|
||||||
|
if not client:
|
||||||
|
client = OAuthClientInformation(client_id=client_info.client_id)
|
||||||
|
|
||||||
|
for key, value in client_info.model_dump().items():
|
||||||
|
if key == "redirect_uris":
|
||||||
|
value = [str(uri) for uri in value]
|
||||||
|
elif value and key in [
|
||||||
|
"client_uri",
|
||||||
|
"logo_uri",
|
||||||
|
"tos_uri",
|
||||||
|
"policy_uri",
|
||||||
|
"jwks_uri",
|
||||||
|
]:
|
||||||
|
value = str(value)
|
||||||
|
setattr(client, key, value)
|
||||||
|
session.add(client)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
async def authorize(
|
||||||
|
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
||||||
|
) -> str:
|
||||||
|
"""Redirect to login page for user authentication."""
|
||||||
|
redirect_uri_str = str(params.redirect_uri)
|
||||||
|
registered_uris = [str(uri) for uri in getattr(client, "redirect_uris", [])]
|
||||||
|
if redirect_uri_str not in registered_uris:
|
||||||
|
logger.error(
|
||||||
|
f"Redirect URI {redirect_uri_str} not in registered URIs: {registered_uris}"
|
||||||
|
)
|
||||||
|
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
|
||||||
|
|
||||||
|
# Store the authorization parameters in database
|
||||||
|
with make_session() as session:
|
||||||
|
oauth_state = OAuthState(
|
||||||
|
state=params.state or secrets.token_hex(16),
|
||||||
|
client_id=client.client_id,
|
||||||
|
redirect_uri=str(params.redirect_uri),
|
||||||
|
redirect_uri_provided_explicitly=str(
|
||||||
|
params.redirect_uri_provided_explicitly
|
||||||
|
).lower()
|
||||||
|
== "true",
|
||||||
|
code_challenge=params.code_challenge or "",
|
||||||
|
scopes=["read", "write"], # Default scopes
|
||||||
|
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
|
||||||
|
)
|
||||||
|
session.add(oauth_state)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return f"{settings.SERVER_URL}/oauth/login?" + urlencode(
|
||||||
|
{
|
||||||
|
"state": oauth_state.state,
|
||||||
|
"client_id": client.client_id,
|
||||||
|
"redirect_uri": oauth_state.redirect_uri,
|
||||||
|
"redirect_uri_provided_explicitly": oauth_state.redirect_uri_provided_explicitly,
|
||||||
|
"code_challenge": cast(str, oauth_state.code_challenge),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
|
||||||
|
"""Complete authorization after successful login."""
|
||||||
|
if not (state := oauth_params.get("state")):
|
||||||
|
logger.error("No state parameter provided")
|
||||||
|
raise ValueError("Missing state parameter")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Load OAuth state from database
|
||||||
|
oauth_state = (
|
||||||
|
session.query(OAuthState).filter(OAuthState.state == state).first()
|
||||||
|
)
|
||||||
|
if not oauth_state:
|
||||||
|
logger.error(f"State {state} not found in database")
|
||||||
|
raise ValueError("Invalid state parameter")
|
||||||
|
|
||||||
|
# Check if state has expired
|
||||||
|
now = datetime.fromtimestamp(time.time())
|
||||||
|
if oauth_state.expires_at < now:
|
||||||
|
logger.error(f"State {state} has expired")
|
||||||
|
oauth_state.stale = True # type: ignore
|
||||||
|
session.commit()
|
||||||
|
raise ValueError("State has expired")
|
||||||
|
|
||||||
|
oauth_state.code = f"code_{secrets.token_hex(16)}" # type: ignore
|
||||||
|
oauth_state.stale = False # type: ignore
|
||||||
|
oauth_state.user_id = user.id
|
||||||
|
|
||||||
|
session.add(oauth_state)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return construct_redirect_uri(
|
||||||
|
cast(str, oauth_state.redirect_uri),
|
||||||
|
code=cast(str, oauth_state.code),
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_authorization_code(
|
||||||
|
self, client: OAuthClientInformationFull, authorization_code: str
|
||||||
|
) -> Optional[AuthorizationCode]:
|
||||||
|
"""Load an authorization code."""
|
||||||
|
with make_session() as session:
|
||||||
|
auth_code = (
|
||||||
|
session.query(OAuthState)
|
||||||
|
.filter(OAuthState.code == authorization_code)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not auth_code:
|
||||||
|
logger.error(f"Invalid authorization code: {authorization_code}")
|
||||||
|
raise ValueError("Invalid authorization code")
|
||||||
|
return AuthorizationCode(**auth_code.serialize(code=True))
|
||||||
|
|
||||||
|
async def exchange_authorization_code(
|
||||||
|
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
|
||||||
|
) -> OAuthToken:
|
||||||
|
"""Exchange authorization code for tokens."""
|
||||||
|
with make_session() as session:
|
||||||
|
auth_code = (
|
||||||
|
session.query(OAuthState)
|
||||||
|
.filter(OAuthState.code == authorization_code.code)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not auth_code:
|
||||||
|
logger.error(f"Invalid authorization code: {authorization_code.code}")
|
||||||
|
raise ValueError("Invalid authorization code")
|
||||||
|
|
||||||
|
if not auth_code.user:
|
||||||
|
logger.error(f"No user found for auth code: {authorization_code.code}")
|
||||||
|
raise ValueError("Invalid authorization code")
|
||||||
|
|
||||||
|
return make_token(session, auth_code, authorization_code.scopes)
|
||||||
|
|
||||||
|
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
||||||
|
"""Load and validate an access token."""
|
||||||
|
with make_session() as session:
|
||||||
|
now = datetime.now(timezone.utc).replace(
|
||||||
|
tzinfo=None
|
||||||
|
) # Make naive for DB comparison
|
||||||
|
|
||||||
|
# Query for active (non-expired) session
|
||||||
|
user_session = session.query(UserSession).get(token)
|
||||||
|
if not user_session or user_session.expires_at < now:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return AccessToken(
|
||||||
|
token=token,
|
||||||
|
client_id=user_session.oauth_state.client_id,
|
||||||
|
scopes=user_session.oauth_state.scopes,
|
||||||
|
expires_at=int(user_session.expires_at.timestamp()),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_refresh_token(
|
||||||
|
self, client: OAuthClientInformationFull, refresh_token: str
|
||||||
|
) -> Optional[RefreshToken]:
|
||||||
|
"""Load and validate a refresh token."""
|
||||||
|
with make_session() as session:
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
# Query for the refresh token
|
||||||
|
db_refresh_token = (
|
||||||
|
session.query(OAuthRefreshToken)
|
||||||
|
.filter(
|
||||||
|
OAuthRefreshToken.token == refresh_token,
|
||||||
|
OAuthRefreshToken.client_id == client.client_id,
|
||||||
|
OAuthRefreshToken.revoked == False, # noqa: E712
|
||||||
|
OAuthRefreshToken.expires_at > now,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not db_refresh_token:
|
||||||
|
logger.error(
|
||||||
|
f"Invalid or expired refresh token: {refresh_token[:20]}..."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return RefreshToken(
|
||||||
|
token=refresh_token,
|
||||||
|
client_id=client.client_id,
|
||||||
|
scopes=cast(list[str], db_refresh_token.scopes),
|
||||||
|
expires_at=int(db_refresh_token.expires_at.timestamp()),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def exchange_refresh_token(
|
||||||
|
self,
|
||||||
|
client: OAuthClientInformationFull,
|
||||||
|
refresh_token: RefreshToken,
|
||||||
|
scopes: list[str],
|
||||||
|
) -> OAuthToken:
|
||||||
|
"""Exchange refresh token for new access token."""
|
||||||
|
with make_session() as session:
|
||||||
|
# Load the refresh token from database
|
||||||
|
db_refresh_token = (
|
||||||
|
session.query(OAuthRefreshToken)
|
||||||
|
.filter(
|
||||||
|
OAuthRefreshToken.token == refresh_token.token,
|
||||||
|
OAuthRefreshToken.client_id == client.client_id,
|
||||||
|
OAuthRefreshToken.revoked == False, # noqa: E712
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not db_refresh_token:
|
||||||
|
logger.error(f"Refresh token not found: {refresh_token.token[:20]}...")
|
||||||
|
raise ValueError("Invalid refresh token")
|
||||||
|
|
||||||
|
# Validate refresh token
|
||||||
|
validate_refresh_token(db_refresh_token)
|
||||||
|
|
||||||
|
# Validate requested scopes are subset of original scopes
|
||||||
|
original_scopes = set(cast(list[str], db_refresh_token.scopes))
|
||||||
|
requested_scopes = set(scopes) if scopes else original_scopes
|
||||||
|
|
||||||
|
if not requested_scopes.issubset(original_scopes):
|
||||||
|
logger.error(
|
||||||
|
f"Requested scopes {requested_scopes} exceed original scopes {original_scopes}"
|
||||||
|
)
|
||||||
|
raise ValueError("Requested scopes exceed original authorization")
|
||||||
|
|
||||||
|
return make_token(session, db_refresh_token, scopes)
|
||||||
|
|
||||||
|
async def revoke_token(
|
||||||
|
self, token: str, token_type_hint: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Revoke a token (access token or refresh token)."""
|
||||||
|
with make_session() as session:
|
||||||
|
revoked = False
|
||||||
|
|
||||||
|
# Try to revoke as access token (UserSession)
|
||||||
|
if not token_type_hint or token_type_hint == "access_token":
|
||||||
|
user_session = session.query(UserSession).get(token)
|
||||||
|
if user_session:
|
||||||
|
session.delete(user_session)
|
||||||
|
revoked = True
|
||||||
|
logger.info(f"Revoked access token: {token[:20]}...")
|
||||||
|
|
||||||
|
# Try to revoke as refresh token
|
||||||
|
if not revoked and (
|
||||||
|
not token_type_hint or token_type_hint == "refresh_token"
|
||||||
|
):
|
||||||
|
refresh_token = (
|
||||||
|
session.query(OAuthRefreshToken)
|
||||||
|
.filter(OAuthRefreshToken.token == token)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if refresh_token:
|
||||||
|
refresh_token.revoked = True # type: ignore
|
||||||
|
revoked = True
|
||||||
|
logger.info(f"Revoked refresh token: {token[:20]}...")
|
||||||
|
|
||||||
|
if revoked:
|
||||||
|
session.commit()
|
||||||
|
else:
|
||||||
|
logger.warning(f"Token not found for revocation: {token[:20]}...")
|
||||||
|
|
||||||
|
def get_protected_resource_metadata(self) -> dict[str, Any]:
|
||||||
|
"""Return metadata about the protected resource."""
|
||||||
|
return {
|
||||||
|
"resource_server": settings.SERVER_URL,
|
||||||
|
"scopes_supported": ["read", "write"],
|
||||||
|
"bearer_methods_supported": ["header"],
|
||||||
|
"resource_documentation": f"{settings.SERVER_URL}/docs",
|
||||||
|
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||||
|
"token_endpoint_auth_methods_supported": ["none", "client_secret_basic"],
|
||||||
|
"refresh_token_rotation_enabled": True,
|
||||||
|
"protected_resources": [
|
||||||
|
{
|
||||||
|
"resource_uri": f"{settings.SERVER_URL}/mcp",
|
||||||
|
"scopes": ["read", "write"],
|
||||||
|
"http_methods": ["POST", "GET"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"resource_uri": f"{settings.SERVER_URL}/mcp/",
|
||||||
|
"scopes": ["read", "write"],
|
||||||
|
"http_methods": ["POST", "GET"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
@ -5,19 +5,21 @@ MCP tools for the epistemic sparring partner system.
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.auth.middleware.auth_context import get_access_token
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import AgentObservation, SourceItem
|
from memory.common.db.models import (
|
||||||
|
AgentObservation,
|
||||||
|
SourceItem,
|
||||||
|
UserSession,
|
||||||
|
)
|
||||||
|
from memory.api.MCP.base import mcp
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Create MCP server instance
|
|
||||||
mcp = FastMCP("memory", stateless_http=True)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_observation_source_ids(
|
def filter_observation_source_ids(
|
||||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||||
@ -67,4 +69,42 @@ def filter_source_ids(
|
|||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def get_current_time() -> dict:
|
async def get_current_time() -> dict:
|
||||||
"""Get the current time in UTC."""
|
"""Get the current time in UTC."""
|
||||||
|
logger.info("get_current_time tool called")
|
||||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_authenticated_user() -> dict:
|
||||||
|
"""Get information about the authenticated user."""
|
||||||
|
logger.info("🔧 get_authenticated_user tool called")
|
||||||
|
access_token = get_access_token()
|
||||||
|
logger.info(f"🔧 Access token from MCP context: {access_token}")
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
logger.warning("❌ No access token found in MCP context!")
|
||||||
|
return {"error": "Not authenticated"}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look up the actual user from the session token
|
||||||
|
with make_session() as session:
|
||||||
|
user_session = (
|
||||||
|
session.query(UserSession)
|
||||||
|
.filter(UserSession.id == access_token.token)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_session and user_session.user:
|
||||||
|
user_info = user_session.user.serialize()
|
||||||
|
else:
|
||||||
|
user_info = {"error": "User not found"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"authenticated": True,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scopes": access_token.scopes,
|
||||||
|
"client_id": access_token.client_id,
|
||||||
|
"user": user_info,
|
||||||
|
}
|
||||||
|
@ -2,8 +2,16 @@
|
|||||||
SQLAdmin views for the knowledge base database models.
|
SQLAdmin views for the knowledge base database models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
from sqladmin import Admin, ModelView
|
from sqladmin import Admin, ModelView
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import RedirectResponse
|
||||||
|
import logging
|
||||||
|
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||||
|
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
Chunk,
|
Chunk,
|
||||||
SourceItem,
|
SourceItem,
|
||||||
@ -21,8 +29,11 @@ from memory.common.db.models import (
|
|||||||
AgentObservation,
|
AgentObservation,
|
||||||
Note,
|
Note,
|
||||||
User,
|
User,
|
||||||
|
UserSession,
|
||||||
|
OAuthState,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_COLUMNS = (
|
DEFAULT_COLUMNS = (
|
||||||
"modality",
|
"modality",
|
||||||
@ -218,7 +229,7 @@ class UserAdmin(ModelView, model=User):
|
|||||||
|
|
||||||
|
|
||||||
def setup_admin(admin: Admin):
|
def setup_admin(admin: Admin):
|
||||||
"""Add all admin views to the admin instance."""
|
"""Add all admin views to the admin instance with OAuth protection."""
|
||||||
admin.add_view(SourceItemAdmin)
|
admin.add_view(SourceItemAdmin)
|
||||||
admin.add_view(AgentObservationAdmin)
|
admin.add_view(AgentObservationAdmin)
|
||||||
admin.add_view(NoteAdmin)
|
admin.add_view(NoteAdmin)
|
||||||
|
@ -16,22 +16,23 @@ from fastapi import (
|
|||||||
Query,
|
Query,
|
||||||
Form,
|
Form,
|
||||||
Depends,
|
Depends,
|
||||||
|
Request,
|
||||||
)
|
)
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from sqladmin import Admin
|
from sqladmin import Admin
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import extract, settings
|
||||||
from memory.common import extract
|
|
||||||
from memory.common.db.connection import get_engine
|
from memory.common.db.connection import get_engine
|
||||||
from memory.common.db.models import User
|
from memory.common.db.models import User
|
||||||
from memory.api.admin import setup_admin
|
from memory.api.admin import setup_admin
|
||||||
from memory.api.search import search, SearchResult
|
from memory.api.search import search, SearchResult
|
||||||
from memory.api.MCP.tools import mcp
|
|
||||||
from memory.api.auth import (
|
from memory.api.auth import (
|
||||||
router as auth_router,
|
|
||||||
get_current_user,
|
get_current_user,
|
||||||
AuthenticationMiddleware,
|
AuthenticationMiddleware,
|
||||||
|
router as auth_router,
|
||||||
)
|
)
|
||||||
|
from memory.api.MCP.base import mcp
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -44,24 +45,37 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # In production, specify actual origins
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
# SQLAdmin setup
|
# SQLAdmin setup with OAuth protection
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
admin = Admin(app, engine)
|
admin = Admin(app, engine)
|
||||||
|
|
||||||
|
# Setup admin with OAuth protection using existing OAuth provider
|
||||||
setup_admin(admin)
|
setup_admin(admin)
|
||||||
|
|
||||||
# Include auth router
|
|
||||||
app.add_middleware(AuthenticationMiddleware)
|
|
||||||
app.include_router(auth_router)
|
app.include_router(auth_router)
|
||||||
|
app.add_middleware(AuthenticationMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
# Add health check to MCP server instead of main app
|
||||||
|
@mcp.custom_route("/health", methods=["GET"])
|
||||||
|
async def health_check(request: Request):
|
||||||
|
"""Simple health check endpoint on MCP server"""
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"})
|
||||||
|
|
||||||
|
|
||||||
|
# Mount MCP server at root - OAuth endpoints need to be at root level
|
||||||
app.mount("/", mcp.streamable_http_app())
|
app.mount("/", mcp.streamable_http_app())
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
def health_check():
|
|
||||||
"""Simple health check endpoint"""
|
|
||||||
return {"status": "healthy"}
|
|
||||||
|
|
||||||
|
|
||||||
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||||
if not item:
|
if not item:
|
||||||
return []
|
return []
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
import textwrap
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
import logging
|
import logging
|
||||||
|
import pathlib
|
||||||
|
|
||||||
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.templating import Jinja2Templates
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from sqlalchemy.orm import Session as DBSession, scoped_session
|
from sqlalchemy.orm import Session as DBSession, scoped_session
|
||||||
@ -52,28 +52,35 @@ def create_user_session(
|
|||||||
return str(session.id)
|
return str(session.id)
|
||||||
|
|
||||||
|
|
||||||
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
|
def get_user_session(
|
||||||
"""Get user from session ID if session is valid"""
|
request: Request, db: DBSession | scoped_session
|
||||||
|
) -> UserSession | None:
|
||||||
|
"""Get session ID from request"""
|
||||||
|
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
return None
|
||||||
|
|
||||||
session = db.query(UserSession).get(session_id)
|
session = db.query(UserSession).get(session_id)
|
||||||
if not session:
|
if not session:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
if session.expires_at.replace(tzinfo=timezone.utc) > now:
|
if session.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
|
return None
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None:
|
||||||
|
"""Get user from session ID if session is valid"""
|
||||||
|
if session := get_user_session(request, db):
|
||||||
return session.user
|
return session.user
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
||||||
"""FastAPI dependency to get current authenticated user"""
|
"""FastAPI dependency to get current authenticated user"""
|
||||||
# Check for session ID in header or cookie
|
user = get_session_user(request, db)
|
||||||
session_id = request.headers.get(
|
|
||||||
settings.SESSION_HEADER_NAME
|
|
||||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
|
||||||
|
|
||||||
if not session_id:
|
|
||||||
raise HTTPException(status_code=401, detail="No session provided")
|
|
||||||
|
|
||||||
user = get_session_user(session_id, db)
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
||||||
|
|
||||||
@ -117,21 +124,18 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/login", response_model=LoginResponse)
|
@router.get("/login", response_model=LoginResponse)
|
||||||
def login_page():
|
def login_page(request: Request):
|
||||||
"""Login page"""
|
"""Login page"""
|
||||||
return HTMLResponse(
|
template_dir = pathlib.Path(__file__).parent / "templates"
|
||||||
content=textwrap.dedent("""
|
templates = Jinja2Templates(directory=template_dir)
|
||||||
<html>
|
return templates.TemplateResponse(
|
||||||
<body>
|
"login.html",
|
||||||
<h1>Login</h1>
|
{
|
||||||
<form method="post" action="/auth/login-form">
|
"request": request,
|
||||||
<input type="email" name="email" placeholder="Email" />
|
"action": router.url_path_for("login_form"),
|
||||||
<input type="password" name="password" placeholder="Password" />
|
"error": None,
|
||||||
<button type="submit">Login</button>
|
"form_data": {},
|
||||||
</form>
|
},
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"""),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -170,9 +174,17 @@ def login_form(
|
|||||||
return LoginResponse(session_id=session_id, **user.serialize())
|
return LoginResponse(session_id=session_id, **user.serialize())
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout")
|
@router.api_route("/logout", methods=["GET", "POST"])
|
||||||
def logout(response: Response, user: User = Depends(get_current_user)):
|
def logout(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
db: DBSession = Depends(get_session),
|
||||||
|
):
|
||||||
"""Logout and clear session"""
|
"""Logout and clear session"""
|
||||||
|
session = get_user_session(request, db)
|
||||||
|
if session:
|
||||||
|
db.delete(session)
|
||||||
|
db.commit()
|
||||||
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
||||||
return {"message": "Logged out successfully"}
|
return {"message": "Logged out successfully"}
|
||||||
|
|
||||||
@ -192,6 +204,11 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
"/auth/login",
|
"/auth/login",
|
||||||
"/auth/login-form",
|
"/auth/login-form",
|
||||||
"/auth/register",
|
"/auth/register",
|
||||||
|
"/register",
|
||||||
|
"/token",
|
||||||
|
"/mcp",
|
||||||
|
"/oauth/",
|
||||||
|
"/.well-known/",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
@ -205,10 +222,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# Check for session ID in header or cookie
|
# Check for session ID in header or cookie
|
||||||
session_id = request.headers.get(
|
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||||
settings.SESSION_HEADER_NAME
|
|
||||||
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
return Response(
|
return Response(
|
||||||
content="Authentication required",
|
content="Authentication required",
|
||||||
@ -218,7 +232,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# Validate session and get user
|
# Validate session and get user
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
user = get_session_user(session_id, session)
|
user = get_session_user(request, session)
|
||||||
if not user:
|
if not user:
|
||||||
return Response(
|
return Response(
|
||||||
content="Invalid or expired session",
|
content="Invalid or expired session",
|
||||||
|
159
src/memory/api/templates/login.html
Normal file
159
src/memory/api/templates/login.html
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Login - Memory OAuth</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
min-height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-container {
|
||||||
|
background: white;
|
||||||
|
padding: 2rem;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
|
||||||
|
width: 100%;
|
||||||
|
max-width: 400px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-header {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-header h1 {
|
||||||
|
color: #333;
|
||||||
|
margin: 0 0 0.5rem 0;
|
||||||
|
font-size: 1.8rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-header p {
|
||||||
|
color: #666;
|
||||||
|
margin: 0;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group {
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group label {
|
||||||
|
display: block;
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
color: #333;
|
||||||
|
font-weight: 500;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group input {
|
||||||
|
width: 100%;
|
||||||
|
padding: 0.75rem;
|
||||||
|
border: 2px solid #e1e5e9;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 1rem;
|
||||||
|
transition: border-color 0.2s;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #667eea;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-button {
|
||||||
|
width: 100%;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 0.875rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 1rem;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: transform 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-button:hover {
|
||||||
|
transform: translateY(-1px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-button:active {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
.error-message {
|
||||||
|
background: #fee;
|
||||||
|
color: #c33;
|
||||||
|
padding: 0.75rem;
|
||||||
|
border-radius: 6px;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
border-left: 4px solid #c33;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-credentials {
|
||||||
|
background: #f8f9fa;
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-top: 1.5rem;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-credentials strong {
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-credentials code {
|
||||||
|
background: #e9ecef;
|
||||||
|
padding: 0.2rem 0.4rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-family: monospace;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div class="login-container">
|
||||||
|
<div class="login-header">
|
||||||
|
<h1>Sign In</h1>
|
||||||
|
<p>Access your Memory knowledge base</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{% if error %}
|
||||||
|
<div class="error-message">
|
||||||
|
{{ error }}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<form method="post" action="{{ action }}">
|
||||||
|
{% for key, value in form_data.items() %}
|
||||||
|
<input type="hidden" name="{{ key }}" value="{{ value }}">
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="email">Email</label>
|
||||||
|
<input type="email" id="email" name="email" required>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="password">Password</label>
|
||||||
|
<input type="password" id="password" name="password" required>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button type="submit" class="login-button">Sign In</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
@ -35,6 +35,9 @@ from memory.common.db.models.sources import (
|
|||||||
from memory.common.db.models.users import (
|
from memory.common.db.models.users import (
|
||||||
User,
|
User,
|
||||||
UserSession,
|
UserSession,
|
||||||
|
OAuthClientInformation,
|
||||||
|
OAuthState,
|
||||||
|
OAuthRefreshToken,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -69,4 +72,7 @@ __all__ = [
|
|||||||
# Users
|
# Users
|
||||||
"User",
|
"User",
|
||||||
"UserSession",
|
"UserSession",
|
||||||
|
"OAuthClientInformation",
|
||||||
|
"OAuthState",
|
||||||
|
"OAuthRefreshToken",
|
||||||
]
|
]
|
||||||
|
@ -3,9 +3,19 @@ import secrets
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
import uuid
|
import uuid
|
||||||
from memory.common.db.models.base import Base
|
from memory.common.db.models.base import Base
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Boolean,
|
||||||
|
ARRAY,
|
||||||
|
Numeric,
|
||||||
|
)
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
@ -35,6 +45,9 @@ class User(Base):
|
|||||||
sessions = relationship(
|
sessions = relationship(
|
||||||
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
oauth_states = relationship(
|
||||||
|
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
def serialize(self) -> dict:
|
def serialize(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@ -57,9 +70,127 @@ class UserSession(Base):
|
|||||||
__tablename__ = "user_sessions"
|
__tablename__ = "user_sessions"
|
||||||
|
|
||||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
oauth_state_id = Column(Integer, ForeignKey("oauth_states.id"), nullable=True)
|
||||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
created_at = Column(DateTime, server_default=func.now())
|
created_at = Column(DateTime, server_default=func.now())
|
||||||
expires_at = Column(DateTime, nullable=False)
|
expires_at = Column(DateTime, nullable=False)
|
||||||
|
|
||||||
# Relationship to user
|
# Relationship to user
|
||||||
user = relationship("User", back_populates="sessions")
|
user = relationship("User", back_populates="sessions")
|
||||||
|
oauth_state = relationship("OAuthState", back_populates="session")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientInformation(Base):
|
||||||
|
__tablename__ = "oauth_client"
|
||||||
|
|
||||||
|
client_id = Column(String, primary_key=True)
|
||||||
|
client_secret = Column(String, nullable=True)
|
||||||
|
client_id_issued_at = Column(Numeric, nullable=False)
|
||||||
|
client_secret_expires_at = Column(Numeric, nullable=True)
|
||||||
|
|
||||||
|
redirect_uris = Column(ARRAY(String), nullable=False)
|
||||||
|
token_endpoint_auth_method = Column(String, nullable=False)
|
||||||
|
grant_types = Column(ARRAY(String), nullable=False)
|
||||||
|
response_types = Column(ARRAY(String), nullable=False)
|
||||||
|
scope = Column(String, nullable=False)
|
||||||
|
client_name = Column(String, nullable=False)
|
||||||
|
client_uri = Column(String, nullable=True)
|
||||||
|
logo_uri = Column(String, nullable=True)
|
||||||
|
contacts = Column(ARRAY(String), nullable=True)
|
||||||
|
tos_uri = Column(String, nullable=True)
|
||||||
|
policy_uri = Column(String, nullable=True)
|
||||||
|
jwks_uri = Column(String, nullable=True)
|
||||||
|
|
||||||
|
sessions = relationship(
|
||||||
|
"OAuthState", back_populates="client", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize(self) -> dict:
|
||||||
|
return {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"client_id_issued_at": self.client_id_issued_at,
|
||||||
|
"client_secret_expires_at": self.client_secret_expires_at,
|
||||||
|
"redirect_uris": self.redirect_uris,
|
||||||
|
"token_endpoint_auth_method": self.token_endpoint_auth_method,
|
||||||
|
"grant_types": self.grant_types,
|
||||||
|
"response_types": self.response_types,
|
||||||
|
"scope": self.scope,
|
||||||
|
"client_name": self.client_name,
|
||||||
|
"client_uri": self.client_uri,
|
||||||
|
"logo_uri": self.logo_uri,
|
||||||
|
"contacts": self.contacts,
|
||||||
|
"tos_uri": self.tos_uri,
|
||||||
|
"policy_uri": self.policy_uri,
|
||||||
|
"jwks_uri": self.jwks_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthToken:
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False)
|
||||||
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||||
|
scopes = Column(ARRAY(String), nullable=False)
|
||||||
|
created_at = Column(DateTime, server_default=func.now())
|
||||||
|
expires_at = Column(DateTime, nullable=False)
|
||||||
|
|
||||||
|
def serialize(self) -> dict:
|
||||||
|
return {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"scopes": self.scopes,
|
||||||
|
"expires_at": self.expires_at.timestamp(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthState(Base, OAuthToken):
|
||||||
|
__tablename__ = "oauth_states"
|
||||||
|
|
||||||
|
state = Column(String, nullable=False)
|
||||||
|
code = Column(String, nullable=True)
|
||||||
|
redirect_uri = Column(String, nullable=False)
|
||||||
|
redirect_uri_provided_explicitly = Column(Boolean, nullable=False)
|
||||||
|
code_challenge = Column(String, nullable=True)
|
||||||
|
stale = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
def serialize(self, code: bool = False) -> dict:
|
||||||
|
data = {
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly,
|
||||||
|
"code_challenge": self.code_challenge,
|
||||||
|
} | super().serialize()
|
||||||
|
if code:
|
||||||
|
data |= {
|
||||||
|
"code": self.code,
|
||||||
|
"expires_at": self.expires_at.timestamp(),
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
client = relationship("OAuthClientInformation", back_populates="sessions")
|
||||||
|
session = relationship("UserSession", back_populates="oauth_state", uselist=False)
|
||||||
|
user = relationship("User", back_populates="oauth_states")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthRefreshToken(Base, OAuthToken):
|
||||||
|
__tablename__ = "oauth_refresh_tokens"
|
||||||
|
|
||||||
|
token = Column(
|
||||||
|
String, nullable=False, default=lambda: f"rt_{secrets.token_hex(32)}"
|
||||||
|
)
|
||||||
|
revoked = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
# Optional: link to the access token session that was created with this refresh token
|
||||||
|
access_token_session_id = Column(
|
||||||
|
String, ForeignKey("user_sessions.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
client = relationship("OAuthClientInformation")
|
||||||
|
user = relationship("User")
|
||||||
|
access_token_session = relationship("UserSession")
|
||||||
|
|
||||||
|
def serialize(self) -> dict:
|
||||||
|
return {
|
||||||
|
"token": self.token,
|
||||||
|
"expires_at": self.expires_at.timestamp(),
|
||||||
|
"revoked": self.revoked,
|
||||||
|
} | super().serialize()
|
||||||
|
@ -131,8 +131,8 @@ else:
|
|||||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
|
||||||
# API settings
|
# API settings
|
||||||
|
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
||||||
HTTPS = boolean_env("HTTPS", False)
|
HTTPS = boolean_env("HTTPS", False)
|
||||||
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
|
||||||
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
||||||
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
||||||
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
||||||
|
@ -1,126 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import uvicorn
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
|
||||||
from fastapi.responses import Response
|
|
||||||
|
|
||||||
|
|
||||||
class State(BaseModel):
|
|
||||||
email: str
|
|
||||||
password: str
|
|
||||||
remote_server: str
|
|
||||||
session_header: str = "X-Session-ID"
|
|
||||||
session_id: str | None = None
|
|
||||||
port: int = 8080
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> State:
|
|
||||||
"""Parse command line arguments"""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Simple HTTP proxy with authentication"
|
|
||||||
)
|
|
||||||
parser.add_argument("--remote-server", required=True, help="Remote server URL")
|
|
||||||
parser.add_argument("--email", required=True, help="Email for authentication")
|
|
||||||
parser.add_argument("--password", required=True, help="Password for authentication")
|
|
||||||
parser.add_argument(
|
|
||||||
"--session-header", default="X-Session-ID", help="Session header name"
|
|
||||||
)
|
|
||||||
parser.add_argument("--port", type=int, default=8080, help="Port to run proxy on")
|
|
||||||
return State(**vars(parser.parse_args()))
|
|
||||||
|
|
||||||
|
|
||||||
state = parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
async def login() -> None:
|
|
||||||
"""Login to remote server and store session ID"""
|
|
||||||
login_url = f"{state.remote_server}/auth/login"
|
|
||||||
login_data = {"email": state.email, "password": state.password}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
try:
|
|
||||||
response = await client.post(login_url, json=login_data)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
login_response = response.json()
|
|
||||||
state.session_id = login_response["session_id"]
|
|
||||||
print(f"Successfully logged in, session ID: {state.session_id}")
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
print(
|
|
||||||
f"Login failed with status {e.response.status_code}: {e.response.text}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Login failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
async def proxy_request(request: Request) -> Response:
|
|
||||||
"""Proxy request to remote server with session header"""
|
|
||||||
if not state.session_id:
|
|
||||||
try:
|
|
||||||
await login()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Login failed: {e}")
|
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
||||||
|
|
||||||
# Build the target URL
|
|
||||||
target_url = f"{state.remote_server}{request.url.path}"
|
|
||||||
if request.url.query:
|
|
||||||
target_url += f"?{request.url.query}"
|
|
||||||
|
|
||||||
# Get request body
|
|
||||||
body = await request.body()
|
|
||||||
headers = dict(request.headers)
|
|
||||||
headers.pop("host", None)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
try:
|
|
||||||
response = await client.request(
|
|
||||||
method=request.method,
|
|
||||||
url=target_url,
|
|
||||||
headers=headers | {state.session_header: state.session_id}, # type: ignore
|
|
||||||
content=body,
|
|
||||||
timeout=30.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Forward response
|
|
||||||
resp = Response(
|
|
||||||
content=response.content,
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers={
|
|
||||||
k: v.replace(state.remote_server, f"http://localhost:{state.port}")
|
|
||||||
for k, v in response.headers.items()
|
|
||||||
},
|
|
||||||
media_type=response.headers.get("content-type"),
|
|
||||||
)
|
|
||||||
print(resp.headers)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
print(f"Request failed: {e}")
|
|
||||||
raise HTTPException(status_code=502, detail=f"Proxy request failed: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# Create FastAPI app
|
|
||||||
app = FastAPI(title="Simple Proxy")
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route(
|
|
||||||
"/{path:path}",
|
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
|
|
||||||
)
|
|
||||||
async def proxy_all(request: Request):
|
|
||||||
"""Proxy all requests to remote server"""
|
|
||||||
return await proxy_request(request)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(f"Starting proxy server on port {state.port}")
|
|
||||||
print(f"Proxying to: {state.remote_server}")
|
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=state.port)
|
|
Loading…
x
Reference in New Issue
Block a user