From 2944a0bce1be46cb5ea27b8a02d550c51b3a3b25 Mon Sep 17 00:00:00 2001 From: mruwnik Date: Mon, 3 Nov 2025 00:00:02 +0000 Subject: [PATCH] properly handle mcp redirects --- src/memory/api/admin.py | 33 +++++++++ src/memory/api/auth.py | 70 ++++++++++++++++++-- src/memory/common/db/models/users.py | 19 +++--- src/memory/common/llms/anthropic_provider.py | 3 - src/memory/{discord => common}/oauth.py | 14 +--- src/memory/discord/api.py | 49 +------------- src/memory/discord/mcp.py | 2 +- src/memory/discord/messages.py | 4 +- src/memory/workers/tasks/discord.py | 16 +++++ 9 files changed, 132 insertions(+), 78 deletions(-) rename src/memory/{discord => common}/oauth.py (95%) diff --git a/src/memory/api/admin.py b/src/memory/api/admin.py index 7c075a4..6e76412 100644 --- a/src/memory/api/admin.py +++ b/src/memory/api/admin.py @@ -14,6 +14,7 @@ from memory.common.db.models import ( BookSection, Chunk, Comic, + DiscordMCPServer, DiscordMessage, EmailAccount, EmailAttachment, @@ -166,6 +167,37 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage): column_sortable_list = ["sent_at"] +class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer): + column_list = [ + "id", + "mcp_server_url", + "client_id", + "discord_bot_user_id", + "state", + "code_verifier", + "access_token", + "refresh_token", + "token_expires_at", + "created_at", + "updated_at", + ] + column_searchable_list = [ + "mcp_server_url", + "client_id", + "state", + "id", + "discord_bot_user_id", + ] + column_sortable_list = [ + "created_at", + "updated_at", + "mcp_server_url", + "client_id", + "state", + "id", + ] + + class ArticleFeedAdmin(ModelView, model=ArticleFeed): column_list = [ "id", @@ -328,4 +360,5 @@ def setup_admin(admin: Admin): admin.add_view(DiscordUserAdmin) admin.add_view(DiscordServerAdmin) admin.add_view(DiscordChannelAdmin) + admin.add_view(DiscordMCPServerAdmin) admin.add_view(ScheduledLLMCallAdmin) diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py index 515491c..adafcfc 100644 --- a/src/memory/api/auth.py +++ b/src/memory/api/auth.py @@ -1,13 +1,21 @@ -from datetime import datetime, timedelta, timezone import logging +from datetime import datetime, timedelta, timezone -from fastapi import HTTPException, Depends, Request, Response, APIRouter +from fastapi import APIRouter, Depends, HTTPException, Request, Response +from sqlalchemy.orm import Session as DBSession +from sqlalchemy.orm import scoped_session from starlette.middleware.base import BaseHTTPMiddleware -from memory.common import settings -from sqlalchemy.orm import Session as DBSession, scoped_session +from memory.common import settings from memory.common.db.connection import get_session, make_session -from memory.common.db.models.users import User, HumanUser, BotUser, UserSession +from memory.common.db.models import ( + BotUser, + DiscordMCPServer, + HumanUser, + User, + UserSession, +) +from memory.common.oauth import complete_oauth_flow logger = logging.getLogger(__name__) @@ -136,6 +144,58 @@ def get_me(user: User = Depends(get_current_user)): return user.serialize() +@router.get("/callback/discord") +async def oauth_callback_discord(request: Request): + """Get current user info""" + code = request.query_params.get("code") + state = request.query_params.get("state") + error = request.query_params.get("error") + + logger.info( + f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..." + ) + + message, title, close, status_code = "", "", "", 200 + if error: + logger.error(f"OAuth error: {error}") + message = f"Error: {error}" + title = "❌ Authorization Failed" + status_code = 400 + elif not code or not state: + message = "Missing authorization code or state parameter." + title = "❌ Invalid Request" + status_code = 400 + else: + # Complete the OAuth flow (exchange code for token) + with make_session() as session: + mcp_server = ( + session.query(DiscordMCPServer) + .filter(DiscordMCPServer.state == state) + .first() + ) + status_code, message = await complete_oauth_flow(mcp_server, code, state) + session.commit() + + if 200 <= status_code < 300: + title = "✅ Authorization Successful!" + close = "You can close this window and return to the MCP server." + else: + title = "❌ Authorization Failed" + + return Response( + content=f""" + + +

{title}

+

{message}

+

{close}

+ + + """, + status_code=status_code, + ) + + class AuthenticationMiddleware(BaseHTTPMiddleware): """Middleware to require authentication for all endpoints except whitelisted ones.""" diff --git a/src/memory/common/db/models/users.py b/src/memory/common/db/models/users.py index a86cf03..af75f8c 100644 --- a/src/memory/common/db/models/users.py +++ b/src/memory/common/db/models/users.py @@ -1,22 +1,23 @@ import hashlib import secrets -from typing import cast import uuid -from sqlalchemy.orm import Session -from memory.common.db.models.base import Base +from typing import cast + from sqlalchemy import ( + ARRAY, + Boolean, + CheckConstraint, Column, - Integer, - String, DateTime, ForeignKey, - Boolean, - ARRAY, + Integer, Numeric, - CheckConstraint, + String, ) +from sqlalchemy.orm import Session, relationship from sqlalchemy.sql import func -from sqlalchemy.orm import relationship + +from memory.common.db.models.base import Base def hash_password(password: str) -> str: diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py index a74a7e8..f7bac60 100644 --- a/src/memory/common/llms/anthropic_provider.py +++ b/src/memory/common/llms/anthropic_provider.py @@ -164,9 +164,6 @@ class AnthropicProvider(BaseLLMProvider): kwargs["temperature"] = 1.0 kwargs.pop("top_p", None) - for k, v in kwargs.items(): - print(f"{k}: {v}") - return kwargs def _handle_stream_event( diff --git a/src/memory/discord/oauth.py b/src/memory/common/oauth.py similarity index 95% rename from src/memory/discord/oauth.py rename to src/memory/common/oauth.py index ce65c7c..117b468 100644 --- a/src/memory/discord/oauth.py +++ b/src/memory/common/oauth.py @@ -7,9 +7,7 @@ from datetime import datetime, timedelta from urllib.parse import urlencode, urljoin import aiohttp -from sqlalchemy.orm import Session, scoped_session from memory.common import settings -from memory.common.db.connection import make_session from memory.common.db.models.discord import DiscordMCPServer logger = logging.getLogger(__name__) @@ -99,7 +97,7 @@ async def get_endpoints(url: str) -> OAuthEndpoints: authorization_endpoint=authorization_endpoint, registration_endpoint=registration_endpoint, token_endpoint=token_endpoint, - redirect_uri=f"{settings.SERVER_URL}/oauth/callback/discord", + redirect_uri=f"{settings.SERVER_URL}/auth/callback/discord", ) @@ -181,7 +179,7 @@ async def issue_challenge( async def complete_oauth_flow( - session: Session | scoped_session, code: str, state: str + mcp_server: DiscordMCPServer, code: str, state: str ) -> tuple[int, str]: """Complete OAuth flow by exchanging code for token. @@ -193,12 +191,6 @@ async def complete_oauth_flow( Tuple of (status_code, html_message) for the callback response """ try: - mcp_server = ( - session.query(DiscordMCPServer) - .filter(DiscordMCPServer.state == state) - .first() - ) - if not mcp_server: logger.error(f"Invalid or expired state: {state[:20]}...") return 400, "Invalid or expired OAuth state" @@ -254,8 +246,6 @@ async def complete_oauth_flow( mcp_server.state = None # type: ignore mcp_server.code_verifier = None # type: ignore - session.commit() - logger.info( f"Stored tokens for user {mcp_server.discord_bot_user_id}, " f"server {mcp_server.mcp_server_url}" diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py index 451b1ac..7aa7560 100644 --- a/src/memory/discord/api.py +++ b/src/memory/discord/api.py @@ -18,9 +18,9 @@ from pydantic import BaseModel from memory.common import settings from memory.common.db.connection import make_session -from memory.common.db.models.users import DiscordBotUser +from memory.common.db.models import DiscordMCPServer, DiscordBotUser +from memory.common.oauth import complete_oauth_flow from memory.discord.collector import MessageCollector -from memory.discord.oauth import complete_oauth_flow logger = logging.getLogger(__name__) @@ -279,51 +279,6 @@ async def refresh_metadata(): raise HTTPException(status_code=500, detail=str(e)) -@app.get("/oauth/callback/discord", response_class=HTMLResponse) -async def oauth_callback(request: Request): - """Handle OAuth callback from MCP server after user authorization.""" - code = request.query_params.get("code") - state = request.query_params.get("state") - error = request.query_params.get("error") - - logger.info( - f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..." - ) - - message, title, close, status_code = "", "", "", 200 - if error: - logger.error(f"OAuth error: {error}") - message = f"Error: {error}" - title = "❌ Authorization Failed" - status_code = 400 - elif not code or not state: - message = "Missing authorization code or state parameter." - title = "❌ Invalid Request" - status_code = 400 - else: - # Complete the OAuth flow (exchange code for token) - with make_session() as session: - status_code, message = await complete_oauth_flow(session, code, state) - if 200 <= status_code < 300: - title = "✅ Authorization Successful!" - close = "You can close this window and return to the MCP server." - else: - title = "❌ Authorization Failed" - - return HTMLResponse( - content=f""" - - -

{title}

-

{message}

-

{close}

- - - """, - status_code=status_code, - ) - - def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001): """Run the Discord API server""" uvicorn.run(app, host=host, port=port, log_level="debug") diff --git a/src/memory/discord/mcp.py b/src/memory/discord/mcp.py index accfb6e..4507f20 100644 --- a/src/memory/discord/mcp.py +++ b/src/memory/discord/mcp.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, scoped_session from memory.common.db.connection import make_session from memory.common.db.models.discord import DiscordMCPServer -from memory.discord.oauth import get_endpoints, issue_challenge, register_oauth_client +from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client logger = logging.getLogger(__name__) diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py index fe3416b..42a3230 100644 --- a/src/memory/discord/messages.py +++ b/src/memory/discord/messages.py @@ -13,8 +13,8 @@ from memory.common.db.models import ( DiscordUser, ScheduledLLMCall, ) -from memory.common.db.models.users import BotUser from memory.common.llms.base import create_provider +from memory.common.llms.tools import MCPServer logger = logging.getLogger(__name__) @@ -216,6 +216,7 @@ def call_llm( system_prompt: str = "", messages: list[str | dict[str, Any]] = [], allowed_tools: Collection[str] | None = None, + mcp_servers: list[MCPServer] | None = None, num_previous_messages: int = 10, ) -> str | None: """ @@ -269,6 +270,7 @@ def call_llm( messages=provider.as_messages(message_content), tools=tools, system_prompt=bot_user.system_prompt + "\n\n" + system_prompt, + mcp_servers=mcp_servers, max_iterations=settings.DISCORD_MAX_TOOL_CALLS, ).response diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index aa34a0e..d5c47ec 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -10,6 +10,7 @@ import textwrap from datetime import datetime from typing import Any, cast +from memory.common.llms.tools import MCPServer import requests from sqlalchemy import exc as sqlalchemy_exc from sqlalchemy.orm import Session, scoped_session @@ -189,6 +190,20 @@ def process_discord_message(message_id: int) -> dict[str, Any]: "message_id": message_id, } + mcp_servers = None + if ( + discord_message.recipient_user + and discord_message.recipient_user.mcp_servers + ): + mcp_servers = [ + MCPServer( + name=server.mcp_server_url, + url=server.mcp_server_url, + token=server.access_token, + ) + for server in discord_message.recipient_user.mcp_servers + ] + try: response = call_llm( session, @@ -196,6 +211,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]: from_user=discord_message.from_user, channel=discord_message.channel, model=settings.DISCORD_MODEL, + mcp_servers=mcp_servers, system_prompt=discord_message.system_prompt, ) except Exception: