mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-29 14:16:09 +02:00
fix oauth
This commit is contained in:
parent
beb94375da
commit
862251fedb
@ -1,30 +1,32 @@
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional, Any, cast
|
||||
from urllib.parse import urlencode
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from memory.common.db.models.users import (
|
||||
User,
|
||||
UserSession,
|
||||
OAuthClientInformation,
|
||||
OAuthState,
|
||||
OAuthRefreshToken,
|
||||
OAuthToken as TokenBase,
|
||||
)
|
||||
from memory.common.db.connection import make_session, scoped_session
|
||||
from memory.common import settings
|
||||
from mcp.server.auth.provider import (
|
||||
AccessToken,
|
||||
AuthorizationParams,
|
||||
AuthorizationCode,
|
||||
AuthorizationParams,
|
||||
OAuthAuthorizationServerProvider,
|
||||
RefreshToken,
|
||||
construct_redirect_uri,
|
||||
)
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session, scoped_session
|
||||
from memory.common.db.models.users import (
|
||||
OAuthClientInformation,
|
||||
OAuthRefreshToken,
|
||||
OAuthState,
|
||||
User,
|
||||
UserSession,
|
||||
)
|
||||
from memory.common.db.models.users import (
|
||||
OAuthToken as TokenBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_SCOPES = ["read", "write", "claudeai"]
|
||||
@ -227,6 +229,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
|
||||
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
|
||||
"""Complete authorization after successful login."""
|
||||
logger.info(f"Completing authorization with params: {oauth_params}")
|
||||
if not (state := oauth_params.get("state")):
|
||||
logger.error("No state parameter provided")
|
||||
raise ValueError("Missing state parameter")
|
||||
@ -255,16 +258,21 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
session.add(oauth_state)
|
||||
session.commit()
|
||||
|
||||
return construct_redirect_uri(
|
||||
cast(str, oauth_state.redirect_uri),
|
||||
code=cast(str, oauth_state.code),
|
||||
state=state,
|
||||
)
|
||||
parsed_uri = urlparse(str(oauth_state.redirect_uri))
|
||||
query_params = {
|
||||
k: ",".join(v) for k, v in parse_qs(parsed_uri.query).items()
|
||||
}
|
||||
query_params |= {
|
||||
"code": oauth_state.code,
|
||||
"state": state,
|
||||
}
|
||||
return urlunparse(parsed_uri._replace(query=urlencode(query_params)))
|
||||
|
||||
async def load_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: str
|
||||
) -> Optional[AuthorizationCode]:
|
||||
"""Load an authorization code."""
|
||||
logger.info(f"Loading authorization code: {authorization_code}")
|
||||
with make_session() as session:
|
||||
auth_code = (
|
||||
session.query(OAuthState)
|
||||
@ -281,6 +289,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
|
||||
) -> OAuthToken:
|
||||
"""Exchange authorization code for tokens."""
|
||||
logger.info(f"Exchanging authorization code: {authorization_code}")
|
||||
with make_session() as session:
|
||||
auth_code = (
|
||||
session.query(OAuthState)
|
||||
@ -296,7 +305,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
logger.error(f"No user found for auth code: {authorization_code.code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
|
||||
return make_token(session, auth_code, authorization_code.scopes)
|
||||
token = make_token(session, auth_code, authorization_code.scopes)
|
||||
logger.info(f"Exchanged authorization code: {token}")
|
||||
return token
|
||||
|
||||
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
||||
"""Load and validate an access token."""
|
||||
|
@ -53,6 +53,7 @@ def section_processor(
|
||||
if len(content) >= MIN_SECTION_LENGTH:
|
||||
book_section = BookSection(
|
||||
book_id=book.id,
|
||||
book=book,
|
||||
section_title=section.title,
|
||||
section_number=section.number,
|
||||
section_level=level,
|
||||
@ -187,6 +188,8 @@ def sync_book(
|
||||
logger.info("Creating book and sections with relationships")
|
||||
# Create book and sections with relationships
|
||||
book, all_sections = create_book_and_sections(ebook, session, tags)
|
||||
for section in all_sections:
|
||||
print(section.section_title, section.book)
|
||||
|
||||
if title:
|
||||
book.title = title # type: ignore
|
||||
@ -196,10 +199,16 @@ def sync_book(
|
||||
book.publisher = publisher # type: ignore
|
||||
if published:
|
||||
book.published = datetime.fromisoformat(published) # type: ignore
|
||||
if isinstance(book.published, str):
|
||||
book.published = datetime.fromisoformat(book.published) # type: ignore
|
||||
|
||||
if language:
|
||||
book.language = language # type: ignore
|
||||
if edition:
|
||||
book.edition = edition # type: ignore
|
||||
|
||||
if isinstance(series, dict):
|
||||
series = series.get("name")
|
||||
if series:
|
||||
book.series = series # type: ignore
|
||||
if series_number:
|
||||
|
Loading…
x
Reference in New Issue
Block a user