mirror of
https://github.com/mruwnik/memory.git
synced 2025-08-01 15:36:55 +02:00
Compare commits
No commits in common. "6ee46d62159db12d602170f9ea44156db174bf30" and "4d057d1ec61b49bf156c156a90073e7b01d60356" have entirely different histories.
6ee46d6215
...
4d057d1ec6
12
README.md
12
README.md
@ -80,6 +80,18 @@ python tools/run_celery_task.py notes setup-git-notes --origin ssh://git@github.
|
|||||||
For this to work you need to make sure you have set up the ssh keys in `secrets` (see the README.md
|
For this to work you need to make sure you have set up the ssh keys in `secrets` (see the README.md
|
||||||
in that folder), and you will need to add the public key that is generated there to your git server.
|
in that folder), and you will need to add the public key that is generated there to your git server.
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
|
||||||
|
The API uses session-based authentication. Login via:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/auth/login \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"email": "user@example.com", "password": "yourpassword"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
|
||||||
|
|
||||||
## Discord integration
|
## Discord integration
|
||||||
|
|
||||||
If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications).
|
If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications).
|
||||||
|
@ -1,115 +0,0 @@
|
|||||||
"""oauth codes
|
|
||||||
|
|
||||||
Revision ID: 1d6bc8015ea9
|
|
||||||
Revises: 58439dd3088b
|
|
||||||
Create Date: 2025-06-06 16:53:33.044558
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "1d6bc8015ea9"
|
|
||||||
down_revision: Union[str, None] = "58439dd3088b"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"oauth_client",
|
|
||||||
sa.Column("client_id", sa.String(), nullable=False),
|
|
||||||
sa.Column("client_secret", sa.String(), nullable=True),
|
|
||||||
sa.Column("client_id_issued_at", sa.Numeric(), nullable=False),
|
|
||||||
sa.Column("client_secret_expires_at", sa.Numeric(), nullable=True),
|
|
||||||
sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=False),
|
|
||||||
sa.Column("token_endpoint_auth_method", sa.String(), nullable=False),
|
|
||||||
sa.Column("grant_types", sa.ARRAY(sa.String()), nullable=False),
|
|
||||||
sa.Column("response_types", sa.ARRAY(sa.String()), nullable=False),
|
|
||||||
sa.Column("scope", sa.String(), nullable=False),
|
|
||||||
sa.Column("client_name", sa.String(), nullable=False),
|
|
||||||
sa.Column("client_uri", sa.String(), nullable=True),
|
|
||||||
sa.Column("logo_uri", sa.String(), nullable=True),
|
|
||||||
sa.Column("contacts", sa.ARRAY(sa.String()), nullable=True),
|
|
||||||
sa.Column("tos_uri", sa.String(), nullable=True),
|
|
||||||
sa.Column("policy_uri", sa.String(), nullable=True),
|
|
||||||
sa.Column("jwks_uri", sa.String(), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint("client_id"),
|
|
||||||
)
|
|
||||||
op.create_table(
|
|
||||||
"oauth_states",
|
|
||||||
sa.Column("state", sa.String(), nullable=False),
|
|
||||||
sa.Column("code", sa.String(), nullable=True),
|
|
||||||
sa.Column("redirect_uri", sa.String(), nullable=False),
|
|
||||||
sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False),
|
|
||||||
sa.Column("code_challenge", sa.String(), nullable=True),
|
|
||||||
sa.Column("stale", sa.Boolean(), nullable=False),
|
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("client_id", sa.String(), nullable=False),
|
|
||||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
|
||||||
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
|
||||||
),
|
|
||||||
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["client_id"],
|
|
||||||
["oauth_client.client_id"],
|
|
||||||
),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["user_id"],
|
|
||||||
["users.id"],
|
|
||||||
),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
)
|
|
||||||
op.create_table(
|
|
||||||
"oauth_refresh_tokens",
|
|
||||||
sa.Column("token", sa.String(), nullable=False),
|
|
||||||
sa.Column("revoked", sa.Boolean(), nullable=False),
|
|
||||||
sa.Column("access_token_session_id", sa.String(), nullable=True),
|
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("client_id", sa.String(), nullable=False),
|
|
||||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
|
||||||
sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
|
||||||
),
|
|
||||||
sa.Column("expires_at", sa.DateTime(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["access_token_session_id"],
|
|
||||||
["user_sessions.id"],
|
|
||||||
),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["client_id"],
|
|
||||||
["oauth_client.client_id"],
|
|
||||||
),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["user_id"],
|
|
||||||
["users.id"],
|
|
||||||
),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
)
|
|
||||||
op.add_column(
|
|
||||||
"user_sessions", sa.Column("oauth_state_id", sa.Integer(), nullable=True)
|
|
||||||
)
|
|
||||||
op.create_foreign_key(
|
|
||||||
"user_sessions_oauth_state_id_fkey",
|
|
||||||
"user_sessions",
|
|
||||||
"oauth_states",
|
|
||||||
["oauth_state_id"],
|
|
||||||
["id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_constraint(
|
|
||||||
"user_sessions_oauth_state_id_fkey", "user_sessions", type_="foreignkey"
|
|
||||||
)
|
|
||||||
op.drop_column("user_sessions", "oauth_state_id")
|
|
||||||
op.drop_table("oauth_refresh_tokens")
|
|
||||||
op.drop_table("oauth_states")
|
|
||||||
op.drop_table("oauth_client")
|
|
@ -147,7 +147,6 @@ services:
|
|||||||
<<: *env
|
<<: *env
|
||||||
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password
|
||||||
QDRANT_URL: http://qdrant:6333
|
QDRANT_URL: http://qdrant:6333
|
||||||
SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
|
|
||||||
secrets: [postgres_password]
|
secrets: [postgres_password]
|
||||||
volumes:
|
volumes:
|
||||||
- ./memory_files:/app/memory_files:rw
|
- ./memory_files:/app/memory_files:rw
|
||||||
@ -159,6 +158,20 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
|
|
||||||
|
proxy:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: docker/api/Dockerfile
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [kbnet]
|
||||||
|
environment:
|
||||||
|
<<: *env
|
||||||
|
command: ["python", "/app/tools/simple_proxy.py", "--remote-server", "${PROXY_REMOTE_SERVER:-http://api:8000}", "--email", "${PROXY_EMAIL}", "--password", "${PROXY_PASSWORD}", "--port", "8001"]
|
||||||
|
volumes:
|
||||||
|
- ./tools:/app/tools:ro
|
||||||
|
ports:
|
||||||
|
- "8001:8001"
|
||||||
|
|
||||||
# ------------------------------------------------------------ Celery workers
|
# ------------------------------------------------------------ Celery workers
|
||||||
worker:
|
worker:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
|
@ -1,132 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import pathlib
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
|
||||||
from mcp.server.auth.handlers.token import (
|
|
||||||
AuthorizationCodeRequest,
|
|
||||||
RefreshTokenRequest,
|
|
||||||
TokenRequest,
|
|
||||||
)
|
|
||||||
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
|
||||||
from mcp.shared.auth import OAuthClientMetadata
|
|
||||||
from memory.common.db.models.users import User
|
|
||||||
from pydantic import AnyHttpUrl
|
|
||||||
from starlette.requests import Request
|
|
||||||
from starlette.responses import JSONResponse, RedirectResponse
|
|
||||||
from starlette.templating import Jinja2Templates
|
|
||||||
|
|
||||||
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
|
||||||
from memory.common import settings
|
|
||||||
from memory.common.db.connection import make_session
|
|
||||||
from memory.common.db.models import OAuthState
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_metadata(klass: type):
|
|
||||||
orig_validate = klass.model_validate
|
|
||||||
|
|
||||||
def validate(data: dict):
|
|
||||||
data = dict(data)
|
|
||||||
if "redirect_uris" in data:
|
|
||||||
data["redirect_uris"] = [
|
|
||||||
str(uri).replace("cursor://", "http://")
|
|
||||||
for uri in data["redirect_uris"]
|
|
||||||
]
|
|
||||||
if "redirect_uri" in data:
|
|
||||||
data["redirect_uri"] = str(data["redirect_uri"]).replace(
|
|
||||||
"cursor://", "http://"
|
|
||||||
)
|
|
||||||
|
|
||||||
return orig_validate(data)
|
|
||||||
|
|
||||||
klass.model_validate = validate
|
|
||||||
|
|
||||||
|
|
||||||
validate_metadata(OAuthClientMetadata)
|
|
||||||
validate_metadata(AuthorizationRequest)
|
|
||||||
validate_metadata(AuthorizationCodeRequest)
|
|
||||||
validate_metadata(RefreshTokenRequest)
|
|
||||||
validate_metadata(TokenRequest)
|
|
||||||
|
|
||||||
|
|
||||||
# Setup templates
|
|
||||||
template_dir = pathlib.Path(__file__).parent.parent / "templates"
|
|
||||||
templates = Jinja2Templates(directory=template_dir)
|
|
||||||
|
|
||||||
|
|
||||||
oauth_provider = SimpleOAuthProvider()
|
|
||||||
auth_settings = AuthSettings(
|
|
||||||
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
|
||||||
client_registration_options=ClientRegistrationOptions(
|
|
||||||
enabled=True,
|
|
||||||
valid_scopes=["read", "write"],
|
|
||||||
default_scopes=["read"],
|
|
||||||
),
|
|
||||||
required_scopes=["read", "write"],
|
|
||||||
)
|
|
||||||
|
|
||||||
mcp = FastMCP(
|
|
||||||
"memory",
|
|
||||||
stateless_http=True,
|
|
||||||
auth_server_provider=oauth_provider,
|
|
||||||
auth=auth_settings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
|
|
||||||
async def oauth_protected_resource(request: Request):
|
|
||||||
"""OAuth 2.0 Protected Resource Metadata."""
|
|
||||||
logger.info("Protected resource metadata requested")
|
|
||||||
return JSONResponse(oauth_provider.get_protected_resource_metadata())
|
|
||||||
|
|
||||||
|
|
||||||
def login_form(request: Request, form_data: dict, error: str | None = None):
|
|
||||||
return templates.TemplateResponse(
|
|
||||||
"login.html",
|
|
||||||
{
|
|
||||||
"request": request,
|
|
||||||
"form_data": form_data,
|
|
||||||
"error": error,
|
|
||||||
"action": "/oauth/login",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.custom_route("/oauth/login", methods=["GET"])
|
|
||||||
async def login_page(request: Request):
|
|
||||||
"""Display the login page."""
|
|
||||||
form_data = dict(request.query_params)
|
|
||||||
|
|
||||||
state = form_data.get("state")
|
|
||||||
with make_session() as session:
|
|
||||||
oauth_state = (
|
|
||||||
session.query(OAuthState).filter(OAuthState.state == state).first()
|
|
||||||
)
|
|
||||||
if not oauth_state:
|
|
||||||
logger.error(f"State {state} not found in database")
|
|
||||||
raise ValueError("Invalid state parameter")
|
|
||||||
|
|
||||||
return login_form(request, form_data, None)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.custom_route("/oauth/login", methods=["POST"])
|
|
||||||
async def handle_login(request: Request):
|
|
||||||
"""Handle login form submission."""
|
|
||||||
form = await request.form()
|
|
||||||
oauth_params = {
|
|
||||||
key: value for key, value in form.items() if key not in ["email", "password"]
|
|
||||||
}
|
|
||||||
with make_session() as session:
|
|
||||||
user = session.query(User).filter(User.email == form.get("email")).first()
|
|
||||||
if not user or not user.is_valid_password(str(form.get("password", ""))):
|
|
||||||
logger.warning("Login failed - invalid credentials")
|
|
||||||
return login_form(request, oauth_params, "Invalid email or password")
|
|
||||||
|
|
||||||
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
|
|
||||||
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
|
||||||
redirect_url = redirect_url.replace("http://", "cursor://")
|
|
||||||
return RedirectResponse(url=redirect_url, status_code=302)
|
|
@ -1,411 +0,0 @@
|
|||||||
import secrets
|
|
||||||
import time
|
|
||||||
from typing import Optional, Any, cast
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from memory.common.db.models.users import (
|
|
||||||
User,
|
|
||||||
UserSession,
|
|
||||||
OAuthClientInformation,
|
|
||||||
OAuthState,
|
|
||||||
OAuthRefreshToken,
|
|
||||||
OAuthToken as TokenBase,
|
|
||||||
)
|
|
||||||
from memory.common.db.connection import make_session, scoped_session
|
|
||||||
from memory.common import settings
|
|
||||||
from mcp.server.auth.provider import (
|
|
||||||
AccessToken,
|
|
||||||
AuthorizationParams,
|
|
||||||
AuthorizationCode,
|
|
||||||
OAuthAuthorizationServerProvider,
|
|
||||||
RefreshToken,
|
|
||||||
construct_redirect_uri,
|
|
||||||
)
|
|
||||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Token configuration constants
|
|
||||||
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
|
|
||||||
REFRESH_TOKEN_LIFETIME = 30 * 24 * 3600 # 30 days
|
|
||||||
|
|
||||||
|
|
||||||
def create_expiration(lifetime_seconds: int) -> datetime:
|
|
||||||
"""Create expiration datetime from lifetime in seconds."""
|
|
||||||
return datetime.fromtimestamp(time.time() + lifetime_seconds)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_refresh_token() -> str:
|
|
||||||
"""Generate a new refresh token."""
|
|
||||||
return f"rt_{secrets.token_hex(32)}"
|
|
||||||
|
|
||||||
|
|
||||||
def create_access_token_session(
|
|
||||||
user_id: int, oauth_state_id: str | None = None
|
|
||||||
) -> UserSession:
|
|
||||||
"""Create a new access token session."""
|
|
||||||
return UserSession(
|
|
||||||
user_id=user_id,
|
|
||||||
oauth_state_id=oauth_state_id,
|
|
||||||
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_token_record(
|
|
||||||
client_id: str,
|
|
||||||
user_id: int,
|
|
||||||
scopes: list[str],
|
|
||||||
access_token_session_id: Optional[str] = None,
|
|
||||||
) -> OAuthRefreshToken:
|
|
||||||
"""Create a new refresh token record."""
|
|
||||||
return OAuthRefreshToken(
|
|
||||||
token=generate_refresh_token(),
|
|
||||||
client_id=client_id,
|
|
||||||
user_id=user_id,
|
|
||||||
scopes=scopes,
|
|
||||||
expires_at=create_expiration(REFRESH_TOKEN_LIFETIME),
|
|
||||||
access_token_session_id=access_token_session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_refresh_token(db_refresh_token: OAuthRefreshToken) -> None:
|
|
||||||
"""Validate a refresh token, raising ValueError if invalid."""
|
|
||||||
now = datetime.now()
|
|
||||||
if db_refresh_token.expires_at < now: # type: ignore
|
|
||||||
logger.error(f"Refresh token expired: {db_refresh_token.token[:20]}...")
|
|
||||||
db_refresh_token.revoked = True # type: ignore
|
|
||||||
raise ValueError("Refresh token expired")
|
|
||||||
|
|
||||||
|
|
||||||
def create_oauth_token(
|
|
||||||
access_token: str, scopes: list[str], refresh_token: Optional[str] = None
|
|
||||||
) -> OAuthToken:
|
|
||||||
"""Create an OAuth token response."""
|
|
||||||
return OAuthToken(
|
|
||||||
access_token=access_token,
|
|
||||||
token_type="bearer",
|
|
||||||
expires_in=ACCESS_TOKEN_LIFETIME,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
scope=" ".join(scopes),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_token(
|
|
||||||
db: scoped_session,
|
|
||||||
auth_state: TokenBase,
|
|
||||||
scopes: list[str],
|
|
||||||
) -> OAuthToken:
|
|
||||||
new_session = UserSession(
|
|
||||||
user_id=auth_state.user_id,
|
|
||||||
oauth_state_id=auth_state.id,
|
|
||||||
expires_at=create_expiration(ACCESS_TOKEN_LIFETIME),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create refresh token
|
|
||||||
refresh_token = create_refresh_token_record(
|
|
||||||
cast(str, auth_state.client_id),
|
|
||||||
cast(int, auth_state.user_id),
|
|
||||||
scopes,
|
|
||||||
cast(str, new_session.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(new_session)
|
|
||||||
db.add(refresh_token)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return create_oauth_token(
|
|
||||||
str(new_session.id),
|
|
||||||
scopes,
|
|
||||||
cast(str, refresh_token.token),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|
||||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
|
||||||
"""Get OAuth client information."""
|
|
||||||
with make_session() as session:
|
|
||||||
client = session.query(OAuthClientInformation).get(client_id)
|
|
||||||
return client and OAuthClientInformationFull(**client.serialize())
|
|
||||||
|
|
||||||
async def register_client(self, client_info: OAuthClientInformationFull):
|
|
||||||
"""Register a new OAuth client."""
|
|
||||||
with make_session() as session:
|
|
||||||
client = session.query(OAuthClientInformation).get(client_info.client_id)
|
|
||||||
if not client:
|
|
||||||
client = OAuthClientInformation(client_id=client_info.client_id)
|
|
||||||
|
|
||||||
for key, value in client_info.model_dump().items():
|
|
||||||
if key == "redirect_uris":
|
|
||||||
value = [str(uri) for uri in value]
|
|
||||||
elif value and key in [
|
|
||||||
"client_uri",
|
|
||||||
"logo_uri",
|
|
||||||
"tos_uri",
|
|
||||||
"policy_uri",
|
|
||||||
"jwks_uri",
|
|
||||||
]:
|
|
||||||
value = str(value)
|
|
||||||
setattr(client, key, value)
|
|
||||||
session.add(client)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
async def authorize(
|
|
||||||
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
|
||||||
) -> str:
|
|
||||||
"""Redirect to login page for user authentication."""
|
|
||||||
redirect_uri_str = str(params.redirect_uri)
|
|
||||||
registered_uris = [str(uri) for uri in getattr(client, "redirect_uris", [])]
|
|
||||||
if redirect_uri_str not in registered_uris:
|
|
||||||
logger.error(
|
|
||||||
f"Redirect URI {redirect_uri_str} not in registered URIs: {registered_uris}"
|
|
||||||
)
|
|
||||||
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
|
|
||||||
|
|
||||||
# Store the authorization parameters in database
|
|
||||||
with make_session() as session:
|
|
||||||
oauth_state = OAuthState(
|
|
||||||
state=params.state or secrets.token_hex(16),
|
|
||||||
client_id=client.client_id,
|
|
||||||
redirect_uri=str(params.redirect_uri),
|
|
||||||
redirect_uri_provided_explicitly=str(
|
|
||||||
params.redirect_uri_provided_explicitly
|
|
||||||
).lower()
|
|
||||||
== "true",
|
|
||||||
code_challenge=params.code_challenge or "",
|
|
||||||
scopes=["read", "write"], # Default scopes
|
|
||||||
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
|
|
||||||
)
|
|
||||||
session.add(oauth_state)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
return f"{settings.SERVER_URL}/oauth/login?" + urlencode(
|
|
||||||
{
|
|
||||||
"state": oauth_state.state,
|
|
||||||
"client_id": client.client_id,
|
|
||||||
"redirect_uri": oauth_state.redirect_uri,
|
|
||||||
"redirect_uri_provided_explicitly": oauth_state.redirect_uri_provided_explicitly,
|
|
||||||
"code_challenge": cast(str, oauth_state.code_challenge),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
|
|
||||||
"""Complete authorization after successful login."""
|
|
||||||
if not (state := oauth_params.get("state")):
|
|
||||||
logger.error("No state parameter provided")
|
|
||||||
raise ValueError("Missing state parameter")
|
|
||||||
|
|
||||||
with make_session() as session:
|
|
||||||
# Load OAuth state from database
|
|
||||||
oauth_state = (
|
|
||||||
session.query(OAuthState).filter(OAuthState.state == state).first()
|
|
||||||
)
|
|
||||||
if not oauth_state:
|
|
||||||
logger.error(f"State {state} not found in database")
|
|
||||||
raise ValueError("Invalid state parameter")
|
|
||||||
|
|
||||||
# Check if state has expired
|
|
||||||
now = datetime.fromtimestamp(time.time())
|
|
||||||
if oauth_state.expires_at < now:
|
|
||||||
logger.error(f"State {state} has expired")
|
|
||||||
oauth_state.stale = True # type: ignore
|
|
||||||
session.commit()
|
|
||||||
raise ValueError("State has expired")
|
|
||||||
|
|
||||||
oauth_state.code = f"code_{secrets.token_hex(16)}" # type: ignore
|
|
||||||
oauth_state.stale = False # type: ignore
|
|
||||||
oauth_state.user_id = user.id
|
|
||||||
|
|
||||||
session.add(oauth_state)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
return construct_redirect_uri(
|
|
||||||
cast(str, oauth_state.redirect_uri),
|
|
||||||
code=cast(str, oauth_state.code),
|
|
||||||
state=state,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_authorization_code(
|
|
||||||
self, client: OAuthClientInformationFull, authorization_code: str
|
|
||||||
) -> Optional[AuthorizationCode]:
|
|
||||||
"""Load an authorization code."""
|
|
||||||
with make_session() as session:
|
|
||||||
auth_code = (
|
|
||||||
session.query(OAuthState)
|
|
||||||
.filter(OAuthState.code == authorization_code)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not auth_code:
|
|
||||||
logger.error(f"Invalid authorization code: {authorization_code}")
|
|
||||||
raise ValueError("Invalid authorization code")
|
|
||||||
return AuthorizationCode(**auth_code.serialize(code=True))
|
|
||||||
|
|
||||||
async def exchange_authorization_code(
|
|
||||||
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
|
|
||||||
) -> OAuthToken:
|
|
||||||
"""Exchange authorization code for tokens."""
|
|
||||||
with make_session() as session:
|
|
||||||
auth_code = (
|
|
||||||
session.query(OAuthState)
|
|
||||||
.filter(OAuthState.code == authorization_code.code)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not auth_code:
|
|
||||||
logger.error(f"Invalid authorization code: {authorization_code.code}")
|
|
||||||
raise ValueError("Invalid authorization code")
|
|
||||||
|
|
||||||
if not auth_code.user:
|
|
||||||
logger.error(f"No user found for auth code: {authorization_code.code}")
|
|
||||||
raise ValueError("Invalid authorization code")
|
|
||||||
|
|
||||||
return make_token(session, auth_code, authorization_code.scopes)
|
|
||||||
|
|
||||||
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
|
||||||
"""Load and validate an access token."""
|
|
||||||
with make_session() as session:
|
|
||||||
now = datetime.now(timezone.utc).replace(
|
|
||||||
tzinfo=None
|
|
||||||
) # Make naive for DB comparison
|
|
||||||
|
|
||||||
# Query for active (non-expired) session
|
|
||||||
user_session = session.query(UserSession).get(token)
|
|
||||||
if not user_session or user_session.expires_at < now:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return AccessToken(
|
|
||||||
token=token,
|
|
||||||
client_id=user_session.oauth_state.client_id,
|
|
||||||
scopes=user_session.oauth_state.scopes,
|
|
||||||
expires_at=int(user_session.expires_at.timestamp()),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_refresh_token(
|
|
||||||
self, client: OAuthClientInformationFull, refresh_token: str
|
|
||||||
) -> Optional[RefreshToken]:
|
|
||||||
"""Load and validate a refresh token."""
|
|
||||||
with make_session() as session:
|
|
||||||
now = datetime.now()
|
|
||||||
|
|
||||||
# Query for the refresh token
|
|
||||||
db_refresh_token = (
|
|
||||||
session.query(OAuthRefreshToken)
|
|
||||||
.filter(
|
|
||||||
OAuthRefreshToken.token == refresh_token,
|
|
||||||
OAuthRefreshToken.client_id == client.client_id,
|
|
||||||
OAuthRefreshToken.revoked == False, # noqa: E712
|
|
||||||
OAuthRefreshToken.expires_at > now,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not db_refresh_token:
|
|
||||||
logger.error(
|
|
||||||
f"Invalid or expired refresh token: {refresh_token[:20]}..."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return RefreshToken(
|
|
||||||
token=refresh_token,
|
|
||||||
client_id=client.client_id,
|
|
||||||
scopes=cast(list[str], db_refresh_token.scopes),
|
|
||||||
expires_at=int(db_refresh_token.expires_at.timestamp()),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def exchange_refresh_token(
|
|
||||||
self,
|
|
||||||
client: OAuthClientInformationFull,
|
|
||||||
refresh_token: RefreshToken,
|
|
||||||
scopes: list[str],
|
|
||||||
) -> OAuthToken:
|
|
||||||
"""Exchange refresh token for new access token."""
|
|
||||||
with make_session() as session:
|
|
||||||
# Load the refresh token from database
|
|
||||||
db_refresh_token = (
|
|
||||||
session.query(OAuthRefreshToken)
|
|
||||||
.filter(
|
|
||||||
OAuthRefreshToken.token == refresh_token.token,
|
|
||||||
OAuthRefreshToken.client_id == client.client_id,
|
|
||||||
OAuthRefreshToken.revoked == False, # noqa: E712
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not db_refresh_token:
|
|
||||||
logger.error(f"Refresh token not found: {refresh_token.token[:20]}...")
|
|
||||||
raise ValueError("Invalid refresh token")
|
|
||||||
|
|
||||||
# Validate refresh token
|
|
||||||
validate_refresh_token(db_refresh_token)
|
|
||||||
|
|
||||||
# Validate requested scopes are subset of original scopes
|
|
||||||
original_scopes = set(cast(list[str], db_refresh_token.scopes))
|
|
||||||
requested_scopes = set(scopes) if scopes else original_scopes
|
|
||||||
|
|
||||||
if not requested_scopes.issubset(original_scopes):
|
|
||||||
logger.error(
|
|
||||||
f"Requested scopes {requested_scopes} exceed original scopes {original_scopes}"
|
|
||||||
)
|
|
||||||
raise ValueError("Requested scopes exceed original authorization")
|
|
||||||
|
|
||||||
return make_token(session, db_refresh_token, scopes)
|
|
||||||
|
|
||||||
async def revoke_token(
|
|
||||||
self, token: str, token_type_hint: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""Revoke a token (access token or refresh token)."""
|
|
||||||
with make_session() as session:
|
|
||||||
revoked = False
|
|
||||||
|
|
||||||
# Try to revoke as access token (UserSession)
|
|
||||||
if not token_type_hint or token_type_hint == "access_token":
|
|
||||||
user_session = session.query(UserSession).get(token)
|
|
||||||
if user_session:
|
|
||||||
session.delete(user_session)
|
|
||||||
revoked = True
|
|
||||||
logger.info(f"Revoked access token: {token[:20]}...")
|
|
||||||
|
|
||||||
# Try to revoke as refresh token
|
|
||||||
if not revoked and (
|
|
||||||
not token_type_hint or token_type_hint == "refresh_token"
|
|
||||||
):
|
|
||||||
refresh_token = (
|
|
||||||
session.query(OAuthRefreshToken)
|
|
||||||
.filter(OAuthRefreshToken.token == token)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if refresh_token:
|
|
||||||
refresh_token.revoked = True # type: ignore
|
|
||||||
revoked = True
|
|
||||||
logger.info(f"Revoked refresh token: {token[:20]}...")
|
|
||||||
|
|
||||||
if revoked:
|
|
||||||
session.commit()
|
|
||||||
else:
|
|
||||||
logger.warning(f"Token not found for revocation: {token[:20]}...")
|
|
||||||
|
|
||||||
def get_protected_resource_metadata(self) -> dict[str, Any]:
|
|
||||||
"""Return metadata about the protected resource."""
|
|
||||||
return {
|
|
||||||
"resource_server": settings.SERVER_URL,
|
|
||||||
"scopes_supported": ["read", "write"],
|
|
||||||
"bearer_methods_supported": ["header"],
|
|
||||||
"resource_documentation": f"{settings.SERVER_URL}/docs",
|
|
||||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
|
||||||
"token_endpoint_auth_methods_supported": ["none", "client_secret_basic"],
|
|
||||||
"refresh_token_rotation_enabled": True,
|
|
||||||
"protected_resources": [
|
|
||||||
{
|
|
||||||
"resource_uri": f"{settings.SERVER_URL}/mcp",
|
|
||||||
"scopes": ["read", "write"],
|
|
||||||
"http_methods": ["POST", "GET"],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"resource_uri": f"{settings.SERVER_URL}/mcp/",
|
|
||||||
"scopes": ["read", "write"],
|
|
||||||
"http_methods": ["POST", "GET"],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
@ -5,21 +5,19 @@ MCP tools for the epistemic sparring partner system.
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from mcp.server.auth.middleware.auth_context import get_access_token
|
from mcp.server.fastmcp import FastMCP
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import AgentObservation, SourceItem
|
||||||
AgentObservation,
|
|
||||||
SourceItem,
|
|
||||||
UserSession,
|
|
||||||
)
|
|
||||||
from memory.api.MCP.base import mcp
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create MCP server instance
|
||||||
|
mcp = FastMCP("memory", stateless_http=True)
|
||||||
|
|
||||||
|
|
||||||
def filter_observation_source_ids(
|
def filter_observation_source_ids(
|
||||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||||
@ -69,42 +67,4 @@ def filter_source_ids(
|
|||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def get_current_time() -> dict:
|
async def get_current_time() -> dict:
|
||||||
"""Get the current time in UTC."""
|
"""Get the current time in UTC."""
|
||||||
logger.info("get_current_time tool called")
|
|
||||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_authenticated_user() -> dict:
|
|
||||||
"""Get information about the authenticated user."""
|
|
||||||
logger.info("🔧 get_authenticated_user tool called")
|
|
||||||
access_token = get_access_token()
|
|
||||||
logger.info(f"🔧 Access token from MCP context: {access_token}")
|
|
||||||
|
|
||||||
if not access_token:
|
|
||||||
logger.warning("❌ No access token found in MCP context!")
|
|
||||||
return {"error": "Not authenticated"}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Look up the actual user from the session token
|
|
||||||
with make_session() as session:
|
|
||||||
user_session = (
|
|
||||||
session.query(UserSession)
|
|
||||||
.filter(UserSession.id == access_token.token)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_session and user_session.user:
|
|
||||||
user_info = user_session.user.serialize()
|
|
||||||
else:
|
|
||||||
user_info = {"error": "User not found"}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"authenticated": True,
|
|
||||||
"token_type": "Bearer",
|
|
||||||
"scopes": access_token.scopes,
|
|
||||||
"client_id": access_token.client_id,
|
|
||||||
"user": user_info,
|
|
||||||
}
|
|
||||||
|
@ -2,16 +2,8 @@
|
|||||||
SQLAdmin views for the knowledge base database models.
|
SQLAdmin views for the knowledge base database models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
|
||||||
from sqladmin import Admin, ModelView
|
from sqladmin import Admin, ModelView
|
||||||
from fastapi import Request
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import RedirectResponse
|
|
||||||
import logging
|
|
||||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
|
||||||
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
|
|
||||||
from memory.common import settings
|
|
||||||
from memory.common.db.connection import make_session
|
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
Chunk,
|
Chunk,
|
||||||
SourceItem,
|
SourceItem,
|
||||||
@ -29,11 +21,8 @@ from memory.common.db.models import (
|
|||||||
AgentObservation,
|
AgentObservation,
|
||||||
Note,
|
Note,
|
||||||
User,
|
User,
|
||||||
UserSession,
|
|
||||||
OAuthState,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
DEFAULT_COLUMNS = (
|
DEFAULT_COLUMNS = (
|
||||||
"modality",
|
"modality",
|
||||||
@ -229,7 +218,7 @@ class UserAdmin(ModelView, model=User):
|
|||||||
|
|
||||||
|
|
||||||
def setup_admin(admin: Admin):
|
def setup_admin(admin: Admin):
|
||||||
"""Add all admin views to the admin instance with OAuth protection."""
|
"""Add all admin views to the admin instance."""
|
||||||
admin.add_view(SourceItemAdmin)
|
admin.add_view(SourceItemAdmin)
|
||||||
admin.add_view(AgentObservationAdmin)
|
admin.add_view(AgentObservationAdmin)
|
||||||
admin.add_view(NoteAdmin)
|
admin.add_view(NoteAdmin)
|
||||||
|
@ -16,23 +16,22 @@ from fastapi import (
|
|||||||
Query,
|
Query,
|
||||||
Form,
|
Form,
|
||||||
Depends,
|
Depends,
|
||||||
Request,
|
|
||||||
)
|
)
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from sqladmin import Admin
|
from sqladmin import Admin
|
||||||
|
|
||||||
from memory.common import extract, settings
|
from memory.common import settings
|
||||||
|
from memory.common import extract
|
||||||
from memory.common.db.connection import get_engine
|
from memory.common.db.connection import get_engine
|
||||||
from memory.common.db.models import User
|
from memory.common.db.models import User
|
||||||
from memory.api.admin import setup_admin
|
from memory.api.admin import setup_admin
|
||||||
from memory.api.search import search, SearchResult
|
from memory.api.search import search, SearchResult
|
||||||
|
from memory.api.MCP.tools import mcp
|
||||||
from memory.api.auth import (
|
from memory.api.auth import (
|
||||||
|
router as auth_router,
|
||||||
get_current_user,
|
get_current_user,
|
||||||
AuthenticationMiddleware,
|
AuthenticationMiddleware,
|
||||||
router as auth_router,
|
|
||||||
)
|
)
|
||||||
from memory.api.MCP.base import mcp
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -45,37 +44,24 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"], # In production, specify actual origins
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# SQLAdmin setup with OAuth protection
|
# SQLAdmin setup
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
admin = Admin(app, engine)
|
admin = Admin(app, engine)
|
||||||
|
|
||||||
# Setup admin with OAuth protection using existing OAuth provider
|
|
||||||
setup_admin(admin)
|
setup_admin(admin)
|
||||||
app.include_router(auth_router)
|
|
||||||
|
# Include auth router
|
||||||
app.add_middleware(AuthenticationMiddleware)
|
app.add_middleware(AuthenticationMiddleware)
|
||||||
|
app.include_router(auth_router)
|
||||||
|
|
||||||
# Add health check to MCP server instead of main app
|
|
||||||
@mcp.custom_route("/health", methods=["GET"])
|
|
||||||
async def health_check(request: Request):
|
|
||||||
"""Simple health check endpoint on MCP server"""
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
|
|
||||||
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"})
|
|
||||||
|
|
||||||
|
|
||||||
# Mount MCP server at root - OAuth endpoints need to be at root level
|
|
||||||
app.mount("/", mcp.streamable_http_app())
|
app.mount("/", mcp.streamable_http_app())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health_check():
|
||||||
|
"""Simple health check endpoint"""
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||||
if not item:
|
if not item:
|
||||||
return []
|
return []
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import textwrap
|
||||||
from typing import cast
|
from typing import cast
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
from fastapi import HTTPException, Depends, Request, Response, APIRouter, Form
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.responses import HTMLResponse
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from sqlalchemy.orm import Session as DBSession, scoped_session
|
from sqlalchemy.orm import Session as DBSession, scoped_session
|
||||||
@ -52,35 +52,28 @@ def create_user_session(
|
|||||||
return str(session.id)
|
return str(session.id)
|
||||||
|
|
||||||
|
|
||||||
def get_user_session(
|
def get_session_user(session_id: str, db: DBSession | scoped_session) -> User | None:
|
||||||
request: Request, db: DBSession | scoped_session
|
"""Get user from session ID if session is valid"""
|
||||||
) -> UserSession | None:
|
|
||||||
"""Get session ID from request"""
|
|
||||||
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
|
||||||
|
|
||||||
if not session_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
session = db.query(UserSession).get(session_id)
|
session = db.query(UserSession).get(session_id)
|
||||||
if not session:
|
if not session:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
if session.expires_at.replace(tzinfo=timezone.utc) < now:
|
if session.expires_at.replace(tzinfo=timezone.utc) > now:
|
||||||
return None
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
def get_session_user(request: Request, db: DBSession | scoped_session) -> User | None:
|
|
||||||
"""Get user from session ID if session is valid"""
|
|
||||||
if session := get_user_session(request, db):
|
|
||||||
return session.user
|
return session.user
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
def get_current_user(request: Request, db: DBSession = Depends(get_session)) -> User:
|
||||||
"""FastAPI dependency to get current authenticated user"""
|
"""FastAPI dependency to get current authenticated user"""
|
||||||
user = get_session_user(request, db)
|
# Check for session ID in header or cookie
|
||||||
|
session_id = request.headers.get(
|
||||||
|
settings.SESSION_HEADER_NAME
|
||||||
|
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
raise HTTPException(status_code=401, detail="No session provided")
|
||||||
|
|
||||||
|
user = get_session_user(session_id, db)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
||||||
|
|
||||||
@ -124,18 +117,21 @@ def register(request: RegisterRequest, db: DBSession = Depends(get_session)):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/login", response_model=LoginResponse)
|
@router.get("/login", response_model=LoginResponse)
|
||||||
def login_page(request: Request):
|
def login_page():
|
||||||
"""Login page"""
|
"""Login page"""
|
||||||
template_dir = pathlib.Path(__file__).parent / "templates"
|
return HTMLResponse(
|
||||||
templates = Jinja2Templates(directory=template_dir)
|
content=textwrap.dedent("""
|
||||||
return templates.TemplateResponse(
|
<html>
|
||||||
"login.html",
|
<body>
|
||||||
{
|
<h1>Login</h1>
|
||||||
"request": request,
|
<form method="post" action="/auth/login-form">
|
||||||
"action": router.url_path_for("login_form"),
|
<input type="email" name="email" placeholder="Email" />
|
||||||
"error": None,
|
<input type="password" name="password" placeholder="Password" />
|
||||||
"form_data": {},
|
<button type="submit">Login</button>
|
||||||
},
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -174,17 +170,9 @@ def login_form(
|
|||||||
return LoginResponse(session_id=session_id, **user.serialize())
|
return LoginResponse(session_id=session_id, **user.serialize())
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/logout", methods=["GET", "POST"])
|
@router.post("/logout")
|
||||||
def logout(
|
def logout(response: Response, user: User = Depends(get_current_user)):
|
||||||
request: Request,
|
|
||||||
response: Response,
|
|
||||||
db: DBSession = Depends(get_session),
|
|
||||||
):
|
|
||||||
"""Logout and clear session"""
|
"""Logout and clear session"""
|
||||||
session = get_user_session(request, db)
|
|
||||||
if session:
|
|
||||||
db.delete(session)
|
|
||||||
db.commit()
|
|
||||||
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
response.delete_cookie(settings.SESSION_COOKIE_NAME)
|
||||||
return {"message": "Logged out successfully"}
|
return {"message": "Logged out successfully"}
|
||||||
|
|
||||||
@ -204,11 +192,6 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
"/auth/login",
|
"/auth/login",
|
||||||
"/auth/login-form",
|
"/auth/login-form",
|
||||||
"/auth/register",
|
"/auth/register",
|
||||||
"/register",
|
|
||||||
"/token",
|
|
||||||
"/mcp",
|
|
||||||
"/oauth/",
|
|
||||||
"/.well-known/",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
@ -222,7 +205,10 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# Check for session ID in header or cookie
|
# Check for session ID in header or cookie
|
||||||
session_id = request.cookies.get(settings.SESSION_COOKIE_NAME)
|
session_id = request.headers.get(
|
||||||
|
settings.SESSION_HEADER_NAME
|
||||||
|
) or request.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
return Response(
|
return Response(
|
||||||
content="Authentication required",
|
content="Authentication required",
|
||||||
@ -232,7 +218,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# Validate session and get user
|
# Validate session and get user
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
user = get_session_user(request, session)
|
user = get_session_user(session_id, session)
|
||||||
if not user:
|
if not user:
|
||||||
return Response(
|
return Response(
|
||||||
content="Invalid or expired session",
|
content="Invalid or expired session",
|
||||||
|
@ -1,159 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Login - Memory OAuth</title>
|
|
||||||
<style>
|
|
||||||
body {
|
|
||||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
|
||||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
|
||||||
margin: 0;
|
|
||||||
padding: 0;
|
|
||||||
min-height: 100vh;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
.login-container {
|
|
||||||
background: white;
|
|
||||||
padding: 2rem;
|
|
||||||
border-radius: 12px;
|
|
||||||
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
|
|
||||||
width: 100%;
|
|
||||||
max-width: 400px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.login-header {
|
|
||||||
text-align: center;
|
|
||||||
margin-bottom: 2rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.login-header h1 {
|
|
||||||
color: #333;
|
|
||||||
margin: 0 0 0.5rem 0;
|
|
||||||
font-size: 1.8rem;
|
|
||||||
font-weight: 600;
|
|
||||||
}
|
|
||||||
|
|
||||||
.login-header p {
|
|
||||||
color: #666;
|
|
||||||
margin: 0;
|
|
||||||
font-size: 0.9rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.form-group {
|
|
||||||
margin-bottom: 1.5rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.form-group label {
|
|
||||||
display: block;
|
|
||||||
margin-bottom: 0.5rem;
|
|
||||||
color: #333;
|
|
||||||
font-weight: 500;
|
|
||||||
font-size: 0.9rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.form-group input {
|
|
||||||
width: 100%;
|
|
||||||
padding: 0.75rem;
|
|
||||||
border: 2px solid #e1e5e9;
|
|
||||||
border-radius: 8px;
|
|
||||||
font-size: 1rem;
|
|
||||||
transition: border-color 0.2s;
|
|
||||||
box-sizing: border-box;
|
|
||||||
}
|
|
||||||
|
|
||||||
.form-group input:focus {
|
|
||||||
outline: none;
|
|
||||||
border-color: #667eea;
|
|
||||||
}
|
|
||||||
|
|
||||||
.login-button {
|
|
||||||
width: 100%;
|
|
||||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
|
||||||
color: white;
|
|
||||||
border: none;
|
|
||||||
padding: 0.875rem;
|
|
||||||
border-radius: 8px;
|
|
||||||
font-size: 1rem;
|
|
||||||
font-weight: 600;
|
|
||||||
cursor: pointer;
|
|
||||||
transition: transform 0.2s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.login-button:hover {
|
|
||||||
transform: translateY(-1px);
|
|
||||||
}
|
|
||||||
|
|
||||||
.login-button:active {
|
|
||||||
transform: translateY(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
.error-message {
|
|
||||||
background: #fee;
|
|
||||||
color: #c33;
|
|
||||||
padding: 0.75rem;
|
|
||||||
border-radius: 6px;
|
|
||||||
margin-bottom: 1rem;
|
|
||||||
font-size: 0.9rem;
|
|
||||||
border-left: 4px solid #c33;
|
|
||||||
}
|
|
||||||
|
|
||||||
.demo-credentials {
|
|
||||||
background: #f8f9fa;
|
|
||||||
padding: 1rem;
|
|
||||||
border-radius: 8px;
|
|
||||||
margin-top: 1.5rem;
|
|
||||||
font-size: 0.85rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.demo-credentials strong {
|
|
||||||
color: #333;
|
|
||||||
}
|
|
||||||
|
|
||||||
.demo-credentials code {
|
|
||||||
background: #e9ecef;
|
|
||||||
padding: 0.2rem 0.4rem;
|
|
||||||
border-radius: 4px;
|
|
||||||
font-family: monospace;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
|
|
||||||
<body>
|
|
||||||
<div class="login-container">
|
|
||||||
<div class="login-header">
|
|
||||||
<h1>Sign In</h1>
|
|
||||||
<p>Access your Memory knowledge base</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{% if error %}
|
|
||||||
<div class="error-message">
|
|
||||||
{{ error }}
|
|
||||||
</div>
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
<form method="post" action="{{ action }}">
|
|
||||||
{% for key, value in form_data.items() %}
|
|
||||||
<input type="hidden" name="{{ key }}" value="{{ value }}">
|
|
||||||
{% endfor %}
|
|
||||||
|
|
||||||
<div class="form-group">
|
|
||||||
<label for="email">Email</label>
|
|
||||||
<input type="email" id="email" name="email" required>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="form-group">
|
|
||||||
<label for="password">Password</label>
|
|
||||||
<input type="password" id="password" name="password" required>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<button type="submit" class="login-button">Sign In</button>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
|
|
||||||
</html>
|
|
@ -35,9 +35,6 @@ from memory.common.db.models.sources import (
|
|||||||
from memory.common.db.models.users import (
|
from memory.common.db.models.users import (
|
||||||
User,
|
User,
|
||||||
UserSession,
|
UserSession,
|
||||||
OAuthClientInformation,
|
|
||||||
OAuthState,
|
|
||||||
OAuthRefreshToken,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -72,7 +69,4 @@ __all__ = [
|
|||||||
# Users
|
# Users
|
||||||
"User",
|
"User",
|
||||||
"UserSession",
|
"UserSession",
|
||||||
"OAuthClientInformation",
|
|
||||||
"OAuthState",
|
|
||||||
"OAuthRefreshToken",
|
|
||||||
]
|
]
|
||||||
|
@ -3,19 +3,9 @@ import secrets
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
import uuid
|
import uuid
|
||||||
from memory.common.db.models.base import Base
|
from memory.common.db.models.base import Base
|
||||||
from sqlalchemy import (
|
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||||
Column,
|
|
||||||
Integer,
|
|
||||||
String,
|
|
||||||
DateTime,
|
|
||||||
ForeignKey,
|
|
||||||
Boolean,
|
|
||||||
ARRAY,
|
|
||||||
Numeric,
|
|
||||||
)
|
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
@ -45,9 +35,6 @@ class User(Base):
|
|||||||
sessions = relationship(
|
sessions = relationship(
|
||||||
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
"UserSession", back_populates="user", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
oauth_states = relationship(
|
|
||||||
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
def serialize(self) -> dict:
|
def serialize(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@ -70,127 +57,9 @@ class UserSession(Base):
|
|||||||
__tablename__ = "user_sessions"
|
__tablename__ = "user_sessions"
|
||||||
|
|
||||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
oauth_state_id = Column(Integer, ForeignKey("oauth_states.id"), nullable=True)
|
|
||||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
created_at = Column(DateTime, server_default=func.now())
|
created_at = Column(DateTime, server_default=func.now())
|
||||||
expires_at = Column(DateTime, nullable=False)
|
expires_at = Column(DateTime, nullable=False)
|
||||||
|
|
||||||
# Relationship to user
|
# Relationship to user
|
||||||
user = relationship("User", back_populates="sessions")
|
user = relationship("User", back_populates="sessions")
|
||||||
oauth_state = relationship("OAuthState", back_populates="session")
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthClientInformation(Base):
|
|
||||||
__tablename__ = "oauth_client"
|
|
||||||
|
|
||||||
client_id = Column(String, primary_key=True)
|
|
||||||
client_secret = Column(String, nullable=True)
|
|
||||||
client_id_issued_at = Column(Numeric, nullable=False)
|
|
||||||
client_secret_expires_at = Column(Numeric, nullable=True)
|
|
||||||
|
|
||||||
redirect_uris = Column(ARRAY(String), nullable=False)
|
|
||||||
token_endpoint_auth_method = Column(String, nullable=False)
|
|
||||||
grant_types = Column(ARRAY(String), nullable=False)
|
|
||||||
response_types = Column(ARRAY(String), nullable=False)
|
|
||||||
scope = Column(String, nullable=False)
|
|
||||||
client_name = Column(String, nullable=False)
|
|
||||||
client_uri = Column(String, nullable=True)
|
|
||||||
logo_uri = Column(String, nullable=True)
|
|
||||||
contacts = Column(ARRAY(String), nullable=True)
|
|
||||||
tos_uri = Column(String, nullable=True)
|
|
||||||
policy_uri = Column(String, nullable=True)
|
|
||||||
jwks_uri = Column(String, nullable=True)
|
|
||||||
|
|
||||||
sessions = relationship(
|
|
||||||
"OAuthState", back_populates="client", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
def serialize(self) -> dict:
|
|
||||||
return {
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"client_secret": self.client_secret,
|
|
||||||
"client_id_issued_at": self.client_id_issued_at,
|
|
||||||
"client_secret_expires_at": self.client_secret_expires_at,
|
|
||||||
"redirect_uris": self.redirect_uris,
|
|
||||||
"token_endpoint_auth_method": self.token_endpoint_auth_method,
|
|
||||||
"grant_types": self.grant_types,
|
|
||||||
"response_types": self.response_types,
|
|
||||||
"scope": self.scope,
|
|
||||||
"client_name": self.client_name,
|
|
||||||
"client_uri": self.client_uri,
|
|
||||||
"logo_uri": self.logo_uri,
|
|
||||||
"contacts": self.contacts,
|
|
||||||
"tos_uri": self.tos_uri,
|
|
||||||
"policy_uri": self.policy_uri,
|
|
||||||
"jwks_uri": self.jwks_uri,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthToken:
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False)
|
|
||||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
|
||||||
scopes = Column(ARRAY(String), nullable=False)
|
|
||||||
created_at = Column(DateTime, server_default=func.now())
|
|
||||||
expires_at = Column(DateTime, nullable=False)
|
|
||||||
|
|
||||||
def serialize(self) -> dict:
|
|
||||||
return {
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"scopes": self.scopes,
|
|
||||||
"expires_at": self.expires_at.timestamp(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthState(Base, OAuthToken):
|
|
||||||
__tablename__ = "oauth_states"
|
|
||||||
|
|
||||||
state = Column(String, nullable=False)
|
|
||||||
code = Column(String, nullable=True)
|
|
||||||
redirect_uri = Column(String, nullable=False)
|
|
||||||
redirect_uri_provided_explicitly = Column(Boolean, nullable=False)
|
|
||||||
code_challenge = Column(String, nullable=True)
|
|
||||||
stale = Column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
def serialize(self, code: bool = False) -> dict:
|
|
||||||
data = {
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly,
|
|
||||||
"code_challenge": self.code_challenge,
|
|
||||||
} | super().serialize()
|
|
||||||
if code:
|
|
||||||
data |= {
|
|
||||||
"code": self.code,
|
|
||||||
"expires_at": self.expires_at.timestamp(),
|
|
||||||
}
|
|
||||||
return data
|
|
||||||
|
|
||||||
client = relationship("OAuthClientInformation", back_populates="sessions")
|
|
||||||
session = relationship("UserSession", back_populates="oauth_state", uselist=False)
|
|
||||||
user = relationship("User", back_populates="oauth_states")
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthRefreshToken(Base, OAuthToken):
|
|
||||||
__tablename__ = "oauth_refresh_tokens"
|
|
||||||
|
|
||||||
token = Column(
|
|
||||||
String, nullable=False, default=lambda: f"rt_{secrets.token_hex(32)}"
|
|
||||||
)
|
|
||||||
revoked = Column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
# Optional: link to the access token session that was created with this refresh token
|
|
||||||
access_token_session_id = Column(
|
|
||||||
String, ForeignKey("user_sessions.id"), nullable=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
client = relationship("OAuthClientInformation")
|
|
||||||
user = relationship("User")
|
|
||||||
access_token_session = relationship("UserSession")
|
|
||||||
|
|
||||||
def serialize(self) -> dict:
|
|
||||||
return {
|
|
||||||
"token": self.token,
|
|
||||||
"expires_at": self.expires_at.timestamp(),
|
|
||||||
"revoked": self.revoked,
|
|
||||||
} | super().serialize()
|
|
||||||
|
@ -131,8 +131,8 @@ else:
|
|||||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
|
||||||
# API settings
|
# API settings
|
||||||
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
|
||||||
HTTPS = boolean_env("HTTPS", False)
|
HTTPS = boolean_env("HTTPS", False)
|
||||||
|
SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID")
|
||||||
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")
|
||||||
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * 60))
|
||||||
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
||||||
|
126
tools/simple_proxy.py
Normal file
126
tools/simple_proxy.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import uvicorn
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
|
||||||
|
class State(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
remote_server: str
|
||||||
|
session_header: str = "X-Session-ID"
|
||||||
|
session_id: str | None = None
|
||||||
|
port: int = 8080
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> State:
|
||||||
|
"""Parse command line arguments"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Simple HTTP proxy with authentication"
|
||||||
|
)
|
||||||
|
parser.add_argument("--remote-server", required=True, help="Remote server URL")
|
||||||
|
parser.add_argument("--email", required=True, help="Email for authentication")
|
||||||
|
parser.add_argument("--password", required=True, help="Password for authentication")
|
||||||
|
parser.add_argument(
|
||||||
|
"--session-header", default="X-Session-ID", help="Session header name"
|
||||||
|
)
|
||||||
|
parser.add_argument("--port", type=int, default=8080, help="Port to run proxy on")
|
||||||
|
return State(**vars(parser.parse_args()))
|
||||||
|
|
||||||
|
|
||||||
|
state = parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
async def login() -> None:
|
||||||
|
"""Login to remote server and store session ID"""
|
||||||
|
login_url = f"{state.remote_server}/auth/login"
|
||||||
|
login_data = {"email": state.email, "password": state.password}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(login_url, json=login_data)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
login_response = response.json()
|
||||||
|
state.session_id = login_response["session_id"]
|
||||||
|
print(f"Successfully logged in, session ID: {state.session_id}")
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
print(
|
||||||
|
f"Login failed with status {e.response.status_code}: {e.response.text}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Login failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def proxy_request(request: Request) -> Response:
|
||||||
|
"""Proxy request to remote server with session header"""
|
||||||
|
if not state.session_id:
|
||||||
|
try:
|
||||||
|
await login()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Login failed: {e}")
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
# Build the target URL
|
||||||
|
target_url = f"{state.remote_server}{request.url.path}"
|
||||||
|
if request.url.query:
|
||||||
|
target_url += f"?{request.url.query}"
|
||||||
|
|
||||||
|
# Get request body
|
||||||
|
body = await request.body()
|
||||||
|
headers = dict(request.headers)
|
||||||
|
headers.pop("host", None)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=target_url,
|
||||||
|
headers=headers | {state.session_header: state.session_id}, # type: ignore
|
||||||
|
content=body,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward response
|
||||||
|
resp = Response(
|
||||||
|
content=response.content,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers={
|
||||||
|
k: v.replace(state.remote_server, f"http://localhost:{state.port}")
|
||||||
|
for k, v in response.headers.items()
|
||||||
|
},
|
||||||
|
media_type=response.headers.get("content-type"),
|
||||||
|
)
|
||||||
|
print(resp.headers)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
print(f"Request failed: {e}")
|
||||||
|
raise HTTPException(status_code=502, detail=f"Proxy request failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI(title="Simple Proxy")
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route(
|
||||||
|
"/{path:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
|
||||||
|
)
|
||||||
|
async def proxy_all(request: Request):
|
||||||
|
"""Proxy all requests to remote server"""
|
||||||
|
return await proxy_request(request)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(f"Starting proxy server on port {state.port}")
|
||||||
|
print(f"Proxying to: {state.remote_server}")
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=state.port)
|
Loading…
x
Reference in New Issue
Block a user