fix oauth

This commit is contained in:
EC2 Default User 2025-07-26 14:57:41 +00:00
parent beb94375da
commit 862251fedb
2 changed files with 41 additions and 21 deletions

View File

@ -1,30 +1,32 @@
import logging
import secrets import secrets
import time import time
from typing import Optional, Any, cast
from urllib.parse import urlencode
import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional, cast
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from memory.common.db.models.users import (
User,
UserSession,
OAuthClientInformation,
OAuthState,
OAuthRefreshToken,
OAuthToken as TokenBase,
)
from memory.common.db.connection import make_session, scoped_session
from memory.common import settings
from mcp.server.auth.provider import ( from mcp.server.auth.provider import (
AccessToken, AccessToken,
AuthorizationParams,
AuthorizationCode, AuthorizationCode,
AuthorizationParams,
OAuthAuthorizationServerProvider, OAuthAuthorizationServerProvider,
RefreshToken, RefreshToken,
construct_redirect_uri,
) )
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
from memory.common import settings
from memory.common.db.connection import make_session, scoped_session
from memory.common.db.models.users import (
OAuthClientInformation,
OAuthRefreshToken,
OAuthState,
User,
UserSession,
)
from memory.common.db.models.users import (
OAuthToken as TokenBase,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALLOWED_SCOPES = ["read", "write", "claudeai"] ALLOWED_SCOPES = ["read", "write", "claudeai"]
@ -227,6 +229,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
async def complete_authorization(self, oauth_params: dict, user: User) -> str: async def complete_authorization(self, oauth_params: dict, user: User) -> str:
"""Complete authorization after successful login.""" """Complete authorization after successful login."""
logger.info(f"Completing authorization with params: {oauth_params}")
if not (state := oauth_params.get("state")): if not (state := oauth_params.get("state")):
logger.error("No state parameter provided") logger.error("No state parameter provided")
raise ValueError("Missing state parameter") raise ValueError("Missing state parameter")
@ -255,16 +258,21 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
session.add(oauth_state) session.add(oauth_state)
session.commit() session.commit()
return construct_redirect_uri( parsed_uri = urlparse(str(oauth_state.redirect_uri))
cast(str, oauth_state.redirect_uri), query_params = {
code=cast(str, oauth_state.code), k: ",".join(v) for k, v in parse_qs(parsed_uri.query).items()
state=state, }
) query_params |= {
"code": oauth_state.code,
"state": state,
}
return urlunparse(parsed_uri._replace(query=urlencode(query_params)))
async def load_authorization_code( async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str self, client: OAuthClientInformationFull, authorization_code: str
) -> Optional[AuthorizationCode]: ) -> Optional[AuthorizationCode]:
"""Load an authorization code.""" """Load an authorization code."""
logger.info(f"Loading authorization code: {authorization_code}")
with make_session() as session: with make_session() as session:
auth_code = ( auth_code = (
session.query(OAuthState) session.query(OAuthState)
@ -281,6 +289,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
) -> OAuthToken: ) -> OAuthToken:
"""Exchange authorization code for tokens.""" """Exchange authorization code for tokens."""
logger.info(f"Exchanging authorization code: {authorization_code}")
with make_session() as session: with make_session() as session:
auth_code = ( auth_code = (
session.query(OAuthState) session.query(OAuthState)
@ -296,7 +305,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
logger.error(f"No user found for auth code: {authorization_code.code}") logger.error(f"No user found for auth code: {authorization_code.code}")
raise ValueError("Invalid authorization code") raise ValueError("Invalid authorization code")
return make_token(session, auth_code, authorization_code.scopes) token = make_token(session, auth_code, authorization_code.scopes)
logger.info(f"Exchanged authorization code: {token}")
return token
async def load_access_token(self, token: str) -> Optional[AccessToken]: async def load_access_token(self, token: str) -> Optional[AccessToken]:
"""Load and validate an access token.""" """Load and validate an access token."""

View File

@ -53,6 +53,7 @@ def section_processor(
if len(content) >= MIN_SECTION_LENGTH: if len(content) >= MIN_SECTION_LENGTH:
book_section = BookSection( book_section = BookSection(
book_id=book.id, book_id=book.id,
book=book,
section_title=section.title, section_title=section.title,
section_number=section.number, section_number=section.number,
section_level=level, section_level=level,
@ -187,6 +188,8 @@ def sync_book(
logger.info("Creating book and sections with relationships") logger.info("Creating book and sections with relationships")
# Create book and sections with relationships # Create book and sections with relationships
book, all_sections = create_book_and_sections(ebook, session, tags) book, all_sections = create_book_and_sections(ebook, session, tags)
for section in all_sections:
print(section.section_title, section.book)
if title: if title:
book.title = title # type: ignore book.title = title # type: ignore
@ -196,10 +199,16 @@ def sync_book(
book.publisher = publisher # type: ignore book.publisher = publisher # type: ignore
if published: if published:
book.published = datetime.fromisoformat(published) # type: ignore book.published = datetime.fromisoformat(published) # type: ignore
if isinstance(book.published, str):
book.published = datetime.fromisoformat(book.published) # type: ignore
if language: if language:
book.language = language # type: ignore book.language = language # type: ignore
if edition: if edition:
book.edition = edition # type: ignore book.edition = edition # type: ignore
if isinstance(series, dict):
series = series.get("name")
if series: if series:
book.series = series # type: ignore book.series = series # type: ignore
if series_number: if series_number: