From 4556ef2c4888616026e72ea56c26c509381661d8 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Fri, 6 Jun 2025 12:55:48 +0200 Subject: [PATCH] proper oauth flow --- .../versions/20250606_123611_oauth_codes.py | 86 ++++++ src/memory/api/MCP/base.py | 124 +++++++++ src/memory/api/MCP/oauth_provider.py | 246 ++++++++++++++++++ src/memory/api/MCP/templates/login.html | 159 +++++++++++ src/memory/api/MCP/tools.py | 50 +++- src/memory/api/app.py | 53 ++-- src/memory/common/db/models/__init__.py | 4 + src/memory/common/db/models/users.py | 97 ++++++- src/memory/common/settings.py | 1 + 9 files changed, 797 insertions(+), 23 deletions(-) create mode 100644 db/migrations/versions/20250606_123611_oauth_codes.py create mode 100644 src/memory/api/MCP/base.py create mode 100644 src/memory/api/MCP/oauth_provider.py create mode 100644 src/memory/api/MCP/templates/login.html diff --git a/db/migrations/versions/20250606_123611_oauth_codes.py b/db/migrations/versions/20250606_123611_oauth_codes.py new file mode 100644 index 0000000..b2a91df --- /dev/null +++ b/db/migrations/versions/20250606_123611_oauth_codes.py @@ -0,0 +1,86 @@ +"""oauth codes + +Revision ID: 66771d293b27 +Revises: 58439dd3088b +Create Date: 2025-06-06 12:36:11.737507 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "66771d293b27" +down_revision: Union[str, None] = "58439dd3088b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "oauth_client", + sa.Column("client_id", sa.String(), nullable=False), + sa.Column("client_secret", sa.String(), nullable=True), + sa.Column("client_id_issued_at", sa.Numeric(), nullable=False), + sa.Column("client_secret_expires_at", sa.Numeric(), nullable=True), + sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=False), + sa.Column("token_endpoint_auth_method", sa.String(), nullable=False), + sa.Column("grant_types", sa.ARRAY(sa.String()), nullable=False), + sa.Column("response_types", sa.ARRAY(sa.String()), nullable=False), + sa.Column("scope", sa.String(), nullable=False), + sa.Column("client_name", sa.String(), nullable=False), + sa.Column("client_uri", sa.String(), nullable=True), + sa.Column("logo_uri", sa.String(), nullable=True), + sa.Column("contacts", sa.ARRAY(sa.String()), nullable=True), + sa.Column("tos_uri", sa.String(), nullable=True), + sa.Column("policy_uri", sa.String(), nullable=True), + sa.Column("jwks_uri", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("client_id"), + ) + op.create_table( + "oauth_states", + sa.Column("state", sa.String(), nullable=False), + sa.Column("client_id", sa.String(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("code", sa.String(), nullable=True), + sa.Column("redirect_uri", sa.String(), nullable=False), + sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False), + sa.Column("code_challenge", sa.String(), nullable=True), + sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False), + sa.Column( + "created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True + ), + sa.Column("expires_at", sa.DateTime(), nullable=False), + sa.Column("stale", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["client_id"], + ["oauth_client.client_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("state"), + ) + op.add_column( + "user_sessions", sa.Column("oauth_state_id", sa.String(), nullable=True) + ) + op.create_foreign_key( + "fk_user_sessions_oauth_state_id", + "user_sessions", + "oauth_states", + ["oauth_state_id"], + ["state"], + ) + + +def downgrade() -> None: + op.drop_constraint( + "fk_user_sessions_oauth_state_id", "user_sessions", type_="foreignkey" + ) + op.drop_column("user_sessions", "oauth_state_id") + op.drop_table("oauth_states") + op.drop_table("oauth_client") diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py new file mode 100644 index 0000000..755c0c2 --- /dev/null +++ b/src/memory/api/MCP/base.py @@ -0,0 +1,124 @@ +import logging +import os +from typing import cast + +from mcp.server.auth.handlers.authorize import AuthorizationRequest +from mcp.server.auth.handlers.token import ( + AuthorizationCodeRequest, + RefreshTokenRequest, + TokenRequest, +) +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.fastmcp import FastMCP +from mcp.shared.auth import OAuthClientMetadata +from memory.common.db.models.users import User +from pydantic import AnyHttpUrl +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse +from starlette.templating import Jinja2Templates + +from memory.api.MCP.oauth_provider import SimpleOAuthProvider +from memory.common import settings +from memory.common.db.connection import make_session +from memory.common.db.models import OAuthState + +logger = logging.getLogger(__name__) + + +def validate_metadata(klass: type): + orig_validate = klass.model_validate + + def validate(data: dict): + data = dict(data) + if "redirect_uris" in data: + data["redirect_uris"] = [ + str(uri).replace("cursor://", "http://") + for uri in data["redirect_uris"] + ] + if "redirect_uri" in data: + data["redirect_uri"] = str(data["redirect_uri"]).replace( + "cursor://", "http://" + ) + + return orig_validate(data) + + klass.model_validate = validate + + +validate_metadata(OAuthClientMetadata) +validate_metadata(AuthorizationRequest) +validate_metadata(AuthorizationCodeRequest) +validate_metadata(RefreshTokenRequest) +validate_metadata(TokenRequest) + + +# Setup templates +template_dir = os.path.join(os.path.dirname(__file__), "templates") +templates = Jinja2Templates(directory=template_dir) + + +oauth_provider = SimpleOAuthProvider() +auth_settings = AuthSettings( + issuer_url=cast(AnyHttpUrl, settings.SERVER_URL), + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=["read", "write"], + default_scopes=["read"], + ), + required_scopes=["read", "write"], +) + +mcp = FastMCP( + "memory", + stateless_http=True, + auth_server_provider=oauth_provider, + auth=auth_settings, +) + + +@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"]) +async def oauth_protected_resource(request: Request): + """OAuth 2.0 Protected Resource Metadata.""" + logger.info("Protected resource metadata requested") + return JSONResponse(oauth_provider.get_protected_resource_metadata()) + + +def login_form(request: Request, form_data: dict, error: str | None = None): + return templates.TemplateResponse( + "login.html", + {"request": request, "form_data": form_data, "error": error}, + ) + + +@mcp.custom_route("/oauth/login", methods=["GET"]) +async def login_page(request: Request): + """Display the login page.""" + form_data = dict(request.query_params) + + state = form_data.get("state") + with make_session() as session: + oauth_state = session.query(OAuthState).get(state) + if not oauth_state: + logger.error(f"State {state} not found in database") + raise ValueError("Invalid state parameter") + + return login_form(request, form_data, None) + + +@mcp.custom_route("/oauth/login", methods=["POST"]) +async def handle_login(request: Request): + """Handle login form submission.""" + form = await request.form() + oauth_params = { + key: value for key, value in form.items() if key not in ["email", "password"] + } + with make_session() as session: + user = session.query(User).filter(User.email == form.get("email")).first() + if not user or not user.is_valid_password(str(form.get("password", ""))): + logger.warning("Login failed - invalid credentials") + return login_form(request, oauth_params, "Invalid email or password") + + redirect_url = await oauth_provider.complete_authorization(oauth_params, user) + if redirect_url.startswith("http://anysphere.cursor-retrieval"): + redirect_url = redirect_url.replace("http://", "cursor://") + return RedirectResponse(url=redirect_url, status_code=302) diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py new file mode 100644 index 0000000..e5b52d1 --- /dev/null +++ b/src/memory/api/MCP/oauth_provider.py @@ -0,0 +1,246 @@ +import secrets +import time +from typing import Optional, Any, cast +from urllib.parse import urlencode +import logging +from datetime import datetime, timezone + +from memory.common.db.models.users import ( + User, + UserSession, + OAuthClientInformation, + OAuthState, +) +from memory.common.db.connection import make_session +from memory.common import settings +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationParams, + AuthorizationCode, + OAuthAuthorizationServerProvider, + RefreshToken, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +logger = logging.getLogger(__name__) + + +class SimpleOAuthProvider(OAuthAuthorizationServerProvider): + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + """Get OAuth client information.""" + with make_session() as session: + client = session.query(OAuthClientInformation).get(client_id) + return client and OAuthClientInformationFull(**client.serialize()) + + async def register_client(self, client_info: OAuthClientInformationFull): + """Register a new OAuth client.""" + with make_session() as session: + client = session.query(OAuthClientInformation).get(client_info.client_id) + if not client: + client = OAuthClientInformation(client_id=client_info.client_id) + + for key, value in client_info.model_dump().items(): + if key == "redirect_uris": + value = [str(uri) for uri in value] + elif value and key in [ + "client_uri", + "logo_uri", + "tos_uri", + "policy_uri", + "jwks_uri", + ]: + value = str(value) + setattr(client, key, value) + session.add(client) + session.commit() + + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + """Redirect to login page for user authentication.""" + redirect_uri_str = str(params.redirect_uri) + registered_uris = [str(uri) for uri in getattr(client, "redirect_uris", [])] + if redirect_uri_str not in registered_uris: + logger.error( + f"Redirect URI {redirect_uri_str} not in registered URIs: {registered_uris}" + ) + raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}") + + # Store the authorization parameters in database + with make_session() as session: + oauth_state = OAuthState( + state=params.state or secrets.token_hex(16), + client_id=client.client_id, + redirect_uri=str(params.redirect_uri), + redirect_uri_provided_explicitly=str( + params.redirect_uri_provided_explicitly + ).lower() + == "true", + code_challenge=params.code_challenge or "", + scopes=["read", "write"], # Default scopes + expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry + ) + session.add(oauth_state) + session.commit() + + return f"{settings.SERVER_URL}/oauth/login?" + urlencode( + { + "state": oauth_state.state, + "client_id": client.client_id, + "redirect_uri": oauth_state.redirect_uri, + "redirect_uri_provided_explicitly": oauth_state.redirect_uri_provided_explicitly, + "code_challenge": cast(str, oauth_state.code_challenge), + } + ) + + async def complete_authorization(self, oauth_params: dict, user: User) -> str: + """Complete authorization after successful login.""" + if not (state := oauth_params.get("state")): + logger.error("No state parameter provided") + raise ValueError("Missing state parameter") + + with make_session() as session: + # Load OAuth state from database + oauth_state = session.query(OAuthState).get(state) + if not oauth_state: + logger.error(f"State {state} not found in database") + raise ValueError("Invalid state parameter") + + # Check if state has expired + now = datetime.fromtimestamp(time.time()) + if oauth_state.expires_at < now: + logger.error(f"State {state} has expired") + oauth_state.stale = True + session.commit() + raise ValueError("State has expired") + + oauth_state.code = f"code_{secrets.token_hex(16)}" + oauth_state.stale = False + oauth_state.user_id = user.id + + session.add(oauth_state) + session.commit() + + return construct_redirect_uri( + oauth_state.redirect_uri, code=oauth_state.code, state=state + ) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> Optional[AuthorizationCode]: + """Load an authorization code.""" + with make_session() as session: + auth_code = ( + session.query(OAuthState) + .filter(OAuthState.code == authorization_code) + .first() + ) + if not auth_code: + logger.error(f"Invalid authorization code: {authorization_code}") + raise ValueError("Invalid authorization code") + return AuthorizationCode(**auth_code.serialize(code=True)) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Exchange authorization code for tokens.""" + with make_session() as session: + auth_code = ( + session.query(OAuthState) + .filter(OAuthState.code == authorization_code.code) + .first() + ) + + if not auth_code: + logger.error(f"Invalid authorization code: {authorization_code.code}") + raise ValueError("Invalid authorization code") + + # Get the user associated with this auth code + if not auth_code.user: + logger.error(f"No user found for auth code: {authorization_code.code}") + raise ValueError("Invalid authorization code") + + # Create a UserSession to serve as access token + expires_at = datetime.fromtimestamp(time.time() + 3600) + + auth_code.session = UserSession( + user_id=auth_code.user_id, + oauth_state_id=auth_code.state, + expires_at=expires_at, + ) + auth_code.stale = True # type: ignore + session.commit() + access_token = str(auth_code.session.id) + + return OAuthToken( + access_token=access_token, + token_type="bearer", + expires_in=3600, + scope=" ".join(authorization_code.scopes), + ) + + async def load_access_token(self, token: str) -> Optional[AccessToken]: + """Load and validate an access token.""" + with make_session() as session: + now = datetime.now(timezone.utc).replace( + tzinfo=None + ) # Make naive for DB comparison + + # Query for active (non-expired) session + user_session = session.query(UserSession).get(token) + if not user_session or user_session.expires_at < now: + return None + + return AccessToken( + token=token, + client_id=user_session.oauth_state.client_id, + scopes=user_session.oauth_state.scopes, + expires_at=int(user_session.expires_at.timestamp()), + ) + + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> Optional[RefreshToken]: + """Load a refresh token - not supported in this simple implementation.""" + return None + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + """Exchange refresh token - not supported in this simple implementation.""" + raise NotImplementedError("Refresh tokens not supported") + + async def revoke_token( + self, token: str, token_type_hint: Optional[str] = None + ) -> None: + """Revoke a token.""" + with make_session() as session: + user_session = session.query(UserSession).get(token) + if user_session: + session.delete(user_session) + session.commit() + + def get_protected_resource_metadata(self) -> dict[str, Any]: + """Return metadata about the protected resource.""" + return { + "resource_server": settings.SERVER_URL, + "scopes_supported": ["read", "write"], + "bearer_methods_supported": ["header"], + "resource_documentation": f"{settings.SERVER_URL}/docs", + "protected_resources": [ + { + "resource_uri": f"{settings.SERVER_URL}/mcp", + "scopes": ["read", "write"], + "http_methods": ["POST", "GET"], + }, + { + "resource_uri": f"{settings.SERVER_URL}/mcp/", + "scopes": ["read", "write"], + "http_methods": ["POST", "GET"], + }, + ], + } diff --git a/src/memory/api/MCP/templates/login.html b/src/memory/api/MCP/templates/login.html new file mode 100644 index 0000000..5ec2011 --- /dev/null +++ b/src/memory/api/MCP/templates/login.html @@ -0,0 +1,159 @@ + + + + + + + Login - Memory OAuth + + + + +
+ + + {% if error %} +
+ {{ error }} +
+ {% endif %} + +
+ {% for key, value in form_data.items() %} + + {% endfor %} + +
+ + +
+ +
+ + +
+ + +
+
+ + + \ No newline at end of file diff --git a/src/memory/api/MCP/tools.py b/src/memory/api/MCP/tools.py index a0409b8..d0d8b82 100644 --- a/src/memory/api/MCP/tools.py +++ b/src/memory/api/MCP/tools.py @@ -5,19 +5,21 @@ MCP tools for the epistemic sparring partner system. import logging from datetime import datetime, timezone -from mcp.server.fastmcp import FastMCP +from mcp.server.auth.middleware.auth_context import get_access_token from sqlalchemy import Text from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY from memory.common.db.connection import make_session -from memory.common.db.models import AgentObservation, SourceItem +from memory.common.db.models import ( + AgentObservation, + SourceItem, + UserSession, +) +from memory.api.MCP.base import mcp logger = logging.getLogger(__name__) -# Create MCP server instance -mcp = FastMCP("memory", stateless_http=True) - def filter_observation_source_ids( tags: list[str] | None = None, observation_types: list[str] | None = None @@ -67,4 +69,42 @@ def filter_source_ids( @mcp.tool() async def get_current_time() -> dict: """Get the current time in UTC.""" + logger.info("get_current_time tool called") return {"current_time": datetime.now(timezone.utc).isoformat()} + + +@mcp.tool() +async def get_authenticated_user() -> dict: + """Get information about the authenticated user.""" + logger.info("🔧 get_authenticated_user tool called") + access_token = get_access_token() + logger.info(f"🔧 Access token from MCP context: {access_token}") + + if not access_token: + logger.warning("❌ No access token found in MCP context!") + return {"error": "Not authenticated"} + + logger.info( + f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..." + ) + + # Look up the actual user from the session token + with make_session() as session: + user_session = ( + session.query(UserSession) + .filter(UserSession.id == access_token.token) + .first() + ) + + if user_session and user_session.user: + user_info = user_session.user.serialize() + else: + user_info = {"error": "User not found"} + + return { + "authenticated": True, + "token_type": "Bearer", + "scopes": access_token.scopes, + "client_id": access_token.client_id, + "user": user_info, + } diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 700427e..aaf2b17 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -16,22 +16,19 @@ from fastapi import ( Query, Form, Depends, + Request, ) from fastapi.responses import FileResponse +from fastapi.middleware.cors import CORSMiddleware from sqladmin import Admin -from memory.common import settings -from memory.common import extract +from memory.common import extract, settings from memory.common.db.connection import get_engine from memory.common.db.models import User from memory.api.admin import setup_admin from memory.api.search import search, SearchResult -from memory.api.MCP.tools import mcp -from memory.api.auth import ( - router as auth_router, - get_current_user, - AuthenticationMiddleware, -) +from memory.api.MCP.base import mcp +from memory.api.auth import get_current_user logger = logging.getLogger(__name__) @@ -45,23 +42,45 @@ async def lifespan(app: FastAPI): app = FastAPI(title="Knowledge Base API", lifespan=lifespan) + +# Add request logging middleware +# @app.middleware("http") +# async def log_requests(request, call_next): +# logger.info(f"Main app: {request.method} {request.url.path}") +# if request.url.path.startswith("/mcp"): +# logger.info(f"Request headers: {dict(request.headers)}") +# response = await call_next(request) +# return response + + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # In production, specify actual origins + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + # SQLAdmin setup engine = get_engine() admin = Admin(app, engine) setup_admin(admin) -# Include auth router -app.add_middleware(AuthenticationMiddleware) -app.include_router(auth_router) + +# Add health check to MCP server instead of main app +@mcp.custom_route("/health", methods=["GET"]) +async def health_check(request: Request): + """Simple health check endpoint on MCP server""" + from fastapi.responses import JSONResponse + + return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"}) + + +# Mount MCP server at root - OAuth endpoints need to be at root level app.mount("/", mcp.streamable_http_app()) -@app.get("/health") -def health_check(): - """Simple health check endpoint""" - return {"status": "healthy"} - - async def input_type(item: str | UploadFile) -> list[extract.DataChunk]: if not item: return [] diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py index fedb34b..352454d 100644 --- a/src/memory/common/db/models/__init__.py +++ b/src/memory/common/db/models/__init__.py @@ -35,6 +35,8 @@ from memory.common.db.models.sources import ( from memory.common.db.models.users import ( User, UserSession, + OAuthClientInformation, + OAuthState, ) __all__ = [ @@ -69,4 +71,6 @@ __all__ = [ # Users "User", "UserSession", + "OAuthClientInformation", + "OAuthState", ] diff --git a/src/memory/common/db/models/users.py b/src/memory/common/db/models/users.py index c42a940..ef48be5 100644 --- a/src/memory/common/db/models/users.py +++ b/src/memory/common/db/models/users.py @@ -3,7 +3,16 @@ import secrets from typing import cast import uuid from memory.common.db.models.base import Base -from sqlalchemy import Column, Integer, String, DateTime, ForeignKey +from sqlalchemy import ( + Column, + Integer, + String, + DateTime, + ForeignKey, + Boolean, + ARRAY, + Numeric, +) from sqlalchemy.sql import func from sqlalchemy.orm import relationship @@ -35,6 +44,9 @@ class User(Base): sessions = relationship( "UserSession", back_populates="user", cascade="all, delete-orphan" ) + oauth_states = relationship( + "OAuthState", back_populates="user", cascade="all, delete-orphan" + ) def serialize(self) -> dict: return { @@ -57,9 +69,92 @@ class UserSession(Base): __tablename__ = "user_sessions" id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + oauth_state_id = Column(String, ForeignKey("oauth_states.state"), nullable=True) user_id = Column(Integer, ForeignKey("users.id"), nullable=False) created_at = Column(DateTime, server_default=func.now()) expires_at = Column(DateTime, nullable=False) # Relationship to user user = relationship("User", back_populates="sessions") + oauth_state = relationship("OAuthState", back_populates="session") + + +class OAuthClientInformation(Base): + __tablename__ = "oauth_client" + + client_id = Column(String, primary_key=True) + client_secret = Column(String, nullable=True) + client_id_issued_at = Column(Numeric, nullable=False) + client_secret_expires_at = Column(Numeric, nullable=True) + + redirect_uris = Column(ARRAY(String), nullable=False) + token_endpoint_auth_method = Column(String, nullable=False) + grant_types = Column(ARRAY(String), nullable=False) + response_types = Column(ARRAY(String), nullable=False) + scope = Column(String, nullable=False) + client_name = Column(String, nullable=False) + client_uri = Column(String, nullable=True) + logo_uri = Column(String, nullable=True) + contacts = Column(ARRAY(String), nullable=True) + tos_uri = Column(String, nullable=True) + policy_uri = Column(String, nullable=True) + jwks_uri = Column(String, nullable=True) + + sessions = relationship( + "OAuthState", back_populates="client", cascade="all, delete-orphan" + ) + + def serialize(self) -> dict: + return { + "client_id": self.client_id, + "client_secret": self.client_secret, + "client_id_issued_at": self.client_id_issued_at, + "client_secret_expires_at": self.client_secret_expires_at, + "redirect_uris": self.redirect_uris, + "token_endpoint_auth_method": self.token_endpoint_auth_method, + "grant_types": self.grant_types, + "response_types": self.response_types, + "scope": self.scope, + "client_name": self.client_name, + "client_uri": self.client_uri, + "logo_uri": self.logo_uri, + "contacts": self.contacts, + "tos_uri": self.tos_uri, + "policy_uri": self.policy_uri, + "jwks_uri": self.jwks_uri, + } + + +class OAuthState(Base): + __tablename__ = "oauth_states" + + state = Column(String, primary_key=True) + client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=True) + code = Column(String, nullable=True) + redirect_uri = Column(String, nullable=False) + redirect_uri_provided_explicitly = Column(Boolean, nullable=False) + code_challenge = Column(String, nullable=True) + scopes = Column(ARRAY(String), nullable=False) + created_at = Column(DateTime, server_default=func.now()) + expires_at = Column(DateTime, nullable=False) + stale = Column(Boolean, nullable=False, default=False) + + def serialize(self, code: bool = False) -> dict: + data = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly, + "code_challenge": self.code_challenge, + "scopes": self.scopes, + } + if code: + data |= { + "code": self.code, + "expires_at": self.expires_at.timestamp(), + } + return data + + client = relationship("OAuthClientInformation", back_populates="sessions") + session = relationship("UserSession", back_populates="oauth_state", uselist=False) + user = relationship("User", back_populates="oauth_states") diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 633aa88..cd0ed13 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -131,6 +131,7 @@ else: SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307") # API settings +SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000") HTTPS = boolean_env("HTTPS", False) SESSION_HEADER_NAME = os.getenv("SESSION_HEADER_NAME", "X-Session-ID") SESSION_COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "session_id")