mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-29 14:16:09 +02:00
fix oauth
This commit is contained in:
parent
beb94375da
commit
862251fedb
@ -1,30 +1,32 @@
|
|||||||
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Any, cast
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||||
|
|
||||||
from memory.common.db.models.users import (
|
|
||||||
User,
|
|
||||||
UserSession,
|
|
||||||
OAuthClientInformation,
|
|
||||||
OAuthState,
|
|
||||||
OAuthRefreshToken,
|
|
||||||
OAuthToken as TokenBase,
|
|
||||||
)
|
|
||||||
from memory.common.db.connection import make_session, scoped_session
|
|
||||||
from memory.common import settings
|
|
||||||
from mcp.server.auth.provider import (
|
from mcp.server.auth.provider import (
|
||||||
AccessToken,
|
AccessToken,
|
||||||
AuthorizationParams,
|
|
||||||
AuthorizationCode,
|
AuthorizationCode,
|
||||||
|
AuthorizationParams,
|
||||||
OAuthAuthorizationServerProvider,
|
OAuthAuthorizationServerProvider,
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
construct_redirect_uri,
|
|
||||||
)
|
)
|
||||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.connection import make_session, scoped_session
|
||||||
|
from memory.common.db.models.users import (
|
||||||
|
OAuthClientInformation,
|
||||||
|
OAuthRefreshToken,
|
||||||
|
OAuthState,
|
||||||
|
User,
|
||||||
|
UserSession,
|
||||||
|
)
|
||||||
|
from memory.common.db.models.users import (
|
||||||
|
OAuthToken as TokenBase,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALLOWED_SCOPES = ["read", "write", "claudeai"]
|
ALLOWED_SCOPES = ["read", "write", "claudeai"]
|
||||||
@ -227,6 +229,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
|
|
||||||
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
|
async def complete_authorization(self, oauth_params: dict, user: User) -> str:
|
||||||
"""Complete authorization after successful login."""
|
"""Complete authorization after successful login."""
|
||||||
|
logger.info(f"Completing authorization with params: {oauth_params}")
|
||||||
if not (state := oauth_params.get("state")):
|
if not (state := oauth_params.get("state")):
|
||||||
logger.error("No state parameter provided")
|
logger.error("No state parameter provided")
|
||||||
raise ValueError("Missing state parameter")
|
raise ValueError("Missing state parameter")
|
||||||
@ -255,16 +258,21 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
session.add(oauth_state)
|
session.add(oauth_state)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return construct_redirect_uri(
|
parsed_uri = urlparse(str(oauth_state.redirect_uri))
|
||||||
cast(str, oauth_state.redirect_uri),
|
query_params = {
|
||||||
code=cast(str, oauth_state.code),
|
k: ",".join(v) for k, v in parse_qs(parsed_uri.query).items()
|
||||||
state=state,
|
}
|
||||||
)
|
query_params |= {
|
||||||
|
"code": oauth_state.code,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
return urlunparse(parsed_uri._replace(query=urlencode(query_params)))
|
||||||
|
|
||||||
async def load_authorization_code(
|
async def load_authorization_code(
|
||||||
self, client: OAuthClientInformationFull, authorization_code: str
|
self, client: OAuthClientInformationFull, authorization_code: str
|
||||||
) -> Optional[AuthorizationCode]:
|
) -> Optional[AuthorizationCode]:
|
||||||
"""Load an authorization code."""
|
"""Load an authorization code."""
|
||||||
|
logger.info(f"Loading authorization code: {authorization_code}")
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
auth_code = (
|
auth_code = (
|
||||||
session.query(OAuthState)
|
session.query(OAuthState)
|
||||||
@ -281,6 +289,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
|
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
|
||||||
) -> OAuthToken:
|
) -> OAuthToken:
|
||||||
"""Exchange authorization code for tokens."""
|
"""Exchange authorization code for tokens."""
|
||||||
|
logger.info(f"Exchanging authorization code: {authorization_code}")
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
auth_code = (
|
auth_code = (
|
||||||
session.query(OAuthState)
|
session.query(OAuthState)
|
||||||
@ -296,7 +305,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
logger.error(f"No user found for auth code: {authorization_code.code}")
|
logger.error(f"No user found for auth code: {authorization_code.code}")
|
||||||
raise ValueError("Invalid authorization code")
|
raise ValueError("Invalid authorization code")
|
||||||
|
|
||||||
return make_token(session, auth_code, authorization_code.scopes)
|
token = make_token(session, auth_code, authorization_code.scopes)
|
||||||
|
logger.info(f"Exchanged authorization code: {token}")
|
||||||
|
return token
|
||||||
|
|
||||||
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
async def load_access_token(self, token: str) -> Optional[AccessToken]:
|
||||||
"""Load and validate an access token."""
|
"""Load and validate an access token."""
|
||||||
|
@ -53,6 +53,7 @@ def section_processor(
|
|||||||
if len(content) >= MIN_SECTION_LENGTH:
|
if len(content) >= MIN_SECTION_LENGTH:
|
||||||
book_section = BookSection(
|
book_section = BookSection(
|
||||||
book_id=book.id,
|
book_id=book.id,
|
||||||
|
book=book,
|
||||||
section_title=section.title,
|
section_title=section.title,
|
||||||
section_number=section.number,
|
section_number=section.number,
|
||||||
section_level=level,
|
section_level=level,
|
||||||
@ -187,6 +188,8 @@ def sync_book(
|
|||||||
logger.info("Creating book and sections with relationships")
|
logger.info("Creating book and sections with relationships")
|
||||||
# Create book and sections with relationships
|
# Create book and sections with relationships
|
||||||
book, all_sections = create_book_and_sections(ebook, session, tags)
|
book, all_sections = create_book_and_sections(ebook, session, tags)
|
||||||
|
for section in all_sections:
|
||||||
|
print(section.section_title, section.book)
|
||||||
|
|
||||||
if title:
|
if title:
|
||||||
book.title = title # type: ignore
|
book.title = title # type: ignore
|
||||||
@ -196,10 +199,16 @@ def sync_book(
|
|||||||
book.publisher = publisher # type: ignore
|
book.publisher = publisher # type: ignore
|
||||||
if published:
|
if published:
|
||||||
book.published = datetime.fromisoformat(published) # type: ignore
|
book.published = datetime.fromisoformat(published) # type: ignore
|
||||||
|
if isinstance(book.published, str):
|
||||||
|
book.published = datetime.fromisoformat(book.published) # type: ignore
|
||||||
|
|
||||||
if language:
|
if language:
|
||||||
book.language = language # type: ignore
|
book.language = language # type: ignore
|
||||||
if edition:
|
if edition:
|
||||||
book.edition = edition # type: ignore
|
book.edition = edition # type: ignore
|
||||||
|
|
||||||
|
if isinstance(series, dict):
|
||||||
|
series = series.get("name")
|
||||||
if series:
|
if series:
|
||||||
book.series = series # type: ignore
|
book.series = series # type: ignore
|
||||||
if series_number:
|
if series_number:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user