From 862251fedbb3bd3f45389ac42906abd6b2c168f0 Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Sat, 26 Jul 2025 14:57:41 +0000 Subject: [PATCH] fix oauth --- src/memory/api/MCP/oauth_provider.py | 53 +++++++++++++++++----------- src/memory/workers/tasks/ebook.py | 9 +++++ 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py index f7b8dfd..ca0fd39 100644 --- a/src/memory/api/MCP/oauth_provider.py +++ b/src/memory/api/MCP/oauth_provider.py @@ -1,30 +1,32 @@ +import logging import secrets import time -from typing import Optional, Any, cast -from urllib.parse import urlencode -import logging from datetime import datetime, timezone +from typing import Any, Optional, cast +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse -from memory.common.db.models.users import ( - User, - UserSession, - OAuthClientInformation, - OAuthState, - OAuthRefreshToken, - OAuthToken as TokenBase, -) -from memory.common.db.connection import make_session, scoped_session -from memory.common import settings from mcp.server.auth.provider import ( AccessToken, - AuthorizationParams, AuthorizationCode, + AuthorizationParams, OAuthAuthorizationServerProvider, RefreshToken, - construct_redirect_uri, ) from mcp.shared.auth import OAuthClientInformationFull, OAuthToken +from memory.common import settings +from memory.common.db.connection import make_session, scoped_session +from memory.common.db.models.users import ( + OAuthClientInformation, + OAuthRefreshToken, + OAuthState, + User, + UserSession, +) +from memory.common.db.models.users import ( + OAuthToken as TokenBase, +) + logger = logging.getLogger(__name__) ALLOWED_SCOPES = ["read", "write", "claudeai"] @@ -227,6 +229,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): async def complete_authorization(self, oauth_params: dict, user: User) -> str: """Complete authorization after successful login.""" + logger.info(f"Completing authorization with params: {oauth_params}") if not (state := oauth_params.get("state")): logger.error("No state parameter provided") raise ValueError("Missing state parameter") @@ -255,16 +258,21 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): session.add(oauth_state) session.commit() - return construct_redirect_uri( - cast(str, oauth_state.redirect_uri), - code=cast(str, oauth_state.code), - state=state, - ) + parsed_uri = urlparse(str(oauth_state.redirect_uri)) + query_params = { + k: ",".join(v) for k, v in parse_qs(parsed_uri.query).items() + } + query_params |= { + "code": oauth_state.code, + "state": state, + } + return urlunparse(parsed_uri._replace(query=urlencode(query_params))) async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str ) -> Optional[AuthorizationCode]: """Load an authorization code.""" + logger.info(f"Loading authorization code: {authorization_code}") with make_session() as session: auth_code = ( session.query(OAuthState) @@ -281,6 +289,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode ) -> OAuthToken: """Exchange authorization code for tokens.""" + logger.info(f"Exchanging authorization code: {authorization_code}") with make_session() as session: auth_code = ( session.query(OAuthState) @@ -296,7 +305,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): logger.error(f"No user found for auth code: {authorization_code.code}") raise ValueError("Invalid authorization code") - return make_token(session, auth_code, authorization_code.scopes) + token = make_token(session, auth_code, authorization_code.scopes) + logger.info(f"Exchanged authorization code: {token}") + return token async def load_access_token(self, token: str) -> Optional[AccessToken]: """Load and validate an access token.""" diff --git a/src/memory/workers/tasks/ebook.py b/src/memory/workers/tasks/ebook.py index 20a62ef..aedaeb2 100644 --- a/src/memory/workers/tasks/ebook.py +++ b/src/memory/workers/tasks/ebook.py @@ -53,6 +53,7 @@ def section_processor( if len(content) >= MIN_SECTION_LENGTH: book_section = BookSection( book_id=book.id, + book=book, section_title=section.title, section_number=section.number, section_level=level, @@ -187,6 +188,8 @@ def sync_book( logger.info("Creating book and sections with relationships") # Create book and sections with relationships book, all_sections = create_book_and_sections(ebook, session, tags) + for section in all_sections: + print(section.section_title, section.book) if title: book.title = title # type: ignore @@ -196,10 +199,16 @@ def sync_book( book.publisher = publisher # type: ignore if published: book.published = datetime.fromisoformat(published) # type: ignore + if isinstance(book.published, str): + book.published = datetime.fromisoformat(book.published) # type: ignore + if language: book.language = language # type: ignore if edition: book.edition = edition # type: ignore + + if isinstance(series, dict): + series = series.get("name") if series: book.series = series # type: ignore if series_number: