fix oauth

This commit is contained in:
EC2 Default User 2025-07-26 14:57:41 +00:00
parent beb94375da
commit 862251fedb
2 changed files with 41 additions and 21 deletions

View File

@ -1,30 +1,32 @@
import logging
import secrets
import time
from typing import Optional, Any, cast
from urllib.parse import urlencode
import logging
from datetime import datetime, timezone
from typing import Any, Optional, cast
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from memory.common.db.models.users import (
User,
UserSession,
OAuthClientInformation,
OAuthState,
OAuthRefreshToken,
OAuthToken as TokenBase,
)
from memory.common.db.connection import make_session, scoped_session
from memory.common import settings
from mcp.server.auth.provider import (
AccessToken,
AuthorizationParams,
AuthorizationCode,
AuthorizationParams,
OAuthAuthorizationServerProvider,
RefreshToken,
construct_redirect_uri,
)
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
from memory.common import settings
from memory.common.db.connection import make_session, scoped_session
from memory.common.db.models.users import (
OAuthClientInformation,
OAuthRefreshToken,
OAuthState,
User,
UserSession,
)
from memory.common.db.models.users import (
OAuthToken as TokenBase,
)
logger = logging.getLogger(__name__)
ALLOWED_SCOPES = ["read", "write", "claudeai"]
@ -227,6 +229,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
"""Complete authorization after successful login."""
logger.info(f"Completing authorization with params: {oauth_params}")
if not (state := oauth_params.get("state")):
logger.error("No state parameter provided")
raise ValueError("Missing state parameter")
@ -255,16 +258,21 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
session.add(oauth_state)
session.commit()
return construct_redirect_uri(
cast(str, oauth_state.redirect_uri),
code=cast(str, oauth_state.code),
state=state,
)
parsed_uri = urlparse(str(oauth_state.redirect_uri))
query_params = {
k: ",".join(v) for k, v in parse_qs(parsed_uri.query).items()
}
query_params |= {
"code": oauth_state.code,
"state": state,
}
return urlunparse(parsed_uri._replace(query=urlencode(query_params)))
async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str
) -> Optional[AuthorizationCode]:
"""Load an authorization code."""
logger.info(f"Loading authorization code: {authorization_code}")
with make_session() as session:
auth_code = (
session.query(OAuthState)
@ -281,6 +289,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
) -> OAuthToken:
"""Exchange authorization code for tokens."""
logger.info(f"Exchanging authorization code: {authorization_code}")
with make_session() as session:
auth_code = (
session.query(OAuthState)
@ -296,7 +305,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
logger.error(f"No user found for auth code: {authorization_code.code}")
raise ValueError("Invalid authorization code")
return make_token(session, auth_code, authorization_code.scopes)
token = make_token(session, auth_code, authorization_code.scopes)
logger.info(f"Exchanged authorization code: {token}")
return token
async def load_access_token(self, token: str) -> Optional[AccessToken]:
"""Load and validate an access token."""

View File

@ -53,6 +53,7 @@ def section_processor(
if len(content) >= MIN_SECTION_LENGTH:
book_section = BookSection(
book_id=book.id,
book=book,
section_title=section.title,
section_number=section.number,
section_level=level,
@ -187,6 +188,8 @@ def sync_book(
logger.info("Creating book and sections with relationships")
# Create book and sections with relationships
book, all_sections = create_book_and_sections(ebook, session, tags)
for section in all_sections:
print(section.section_title, section.book)
if title:
book.title = title # type: ignore
@ -196,10 +199,16 @@ def sync_book(
book.publisher = publisher # type: ignore
if published:
book.published = datetime.fromisoformat(published) # type: ignore
if isinstance(book.published, str):
book.published = datetime.fromisoformat(book.published) # type: ignore
if language:
book.language = language # type: ignore
if edition:
book.edition = edition # type: ignore
if isinstance(series, dict):
series = series.get("name")
if series:
book.series = series # type: ignore
if series_number: