properly handle mcp redirects

This commit is contained in:
mruwnik 2025-11-03 00:00:02 +00:00
parent 0d9f8beec3
commit 2944a0bce1
9 changed files with 132 additions and 78 deletions

View File

@ -14,6 +14,7 @@ from memory.common.db.models import (
BookSection, BookSection,
Chunk, Chunk,
Comic, Comic,
DiscordMCPServer,
DiscordMessage, DiscordMessage,
EmailAccount, EmailAccount,
EmailAttachment, EmailAttachment,
@ -166,6 +167,37 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage):
column_sortable_list = ["sent_at"] column_sortable_list = ["sent_at"]
class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
column_list = [
"id",
"mcp_server_url",
"client_id",
"discord_bot_user_id",
"state",
"code_verifier",
"access_token",
"refresh_token",
"token_expires_at",
"created_at",
"updated_at",
]
column_searchable_list = [
"mcp_server_url",
"client_id",
"state",
"id",
"discord_bot_user_id",
]
column_sortable_list = [
"created_at",
"updated_at",
"mcp_server_url",
"client_id",
"state",
"id",
]
class ArticleFeedAdmin(ModelView, model=ArticleFeed): class ArticleFeedAdmin(ModelView, model=ArticleFeed):
column_list = [ column_list = [
"id", "id",
@ -328,4 +360,5 @@ def setup_admin(admin: Admin):
admin.add_view(DiscordUserAdmin) admin.add_view(DiscordUserAdmin)
admin.add_view(DiscordServerAdmin) admin.add_view(DiscordServerAdmin)
admin.add_view(DiscordChannelAdmin) admin.add_view(DiscordChannelAdmin)
admin.add_view(DiscordMCPServerAdmin)
admin.add_view(ScheduledLLMCallAdmin) admin.add_view(ScheduledLLMCallAdmin)

View File

@ -1,13 +1,21 @@
from datetime import datetime, timedelta, timezone
import logging import logging
from datetime import datetime, timedelta, timezone
from fastapi import HTTPException, Depends, Request, Response, APIRouter from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sqlalchemy.orm import Session as DBSession
from sqlalchemy.orm import scoped_session
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from memory.common import settings
from sqlalchemy.orm import Session as DBSession, scoped_session
from memory.common import settings
from memory.common.db.connection import get_session, make_session from memory.common.db.connection import get_session, make_session
from memory.common.db.models.users import User, HumanUser, BotUser, UserSession from memory.common.db.models import (
BotUser,
DiscordMCPServer,
HumanUser,
User,
UserSession,
)
from memory.common.oauth import complete_oauth_flow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -136,6 +144,58 @@ def get_me(user: User = Depends(get_current_user)):
return user.serialize() return user.serialize()
@router.get("/callback/discord")
async def oauth_callback_discord(request: Request):
"""Get current user info"""
code = request.query_params.get("code")
state = request.query_params.get("state")
error = request.query_params.get("error")
logger.info(
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
)
message, title, close, status_code = "", "", "", 200
if error:
logger.error(f"OAuth error: {error}")
message = f"Error: {error}"
title = "❌ Authorization Failed"
status_code = 400
elif not code or not state:
message = "Missing authorization code or state parameter."
title = "❌ Invalid Request"
status_code = 400
else:
# Complete the OAuth flow (exchange code for token)
with make_session() as session:
mcp_server = (
session.query(DiscordMCPServer)
.filter(DiscordMCPServer.state == state)
.first()
)
status_code, message = await complete_oauth_flow(mcp_server, code, state)
session.commit()
if 200 <= status_code < 300:
title = "✅ Authorization Successful!"
close = "You can close this window and return to the MCP server."
else:
title = "❌ Authorization Failed"
return Response(
content=f"""
<html>
<body>
<h1>{title}</h1>
<p>{message}</p>
<p>{close}</p>
</body>
</html>
""",
status_code=status_code,
)
class AuthenticationMiddleware(BaseHTTPMiddleware): class AuthenticationMiddleware(BaseHTTPMiddleware):
"""Middleware to require authentication for all endpoints except whitelisted ones.""" """Middleware to require authentication for all endpoints except whitelisted ones."""

View File

@ -1,22 +1,23 @@
import hashlib import hashlib
import secrets import secrets
from typing import cast
import uuid import uuid
from sqlalchemy.orm import Session from typing import cast
from memory.common.db.models.base import Base
from sqlalchemy import ( from sqlalchemy import (
ARRAY,
Boolean,
CheckConstraint,
Column, Column,
Integer,
String,
DateTime, DateTime,
ForeignKey, ForeignKey,
Boolean, Integer,
ARRAY,
Numeric, Numeric,
CheckConstraint, String,
) )
from sqlalchemy.orm import Session, relationship
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
def hash_password(password: str) -> str: def hash_password(password: str) -> str:

View File

@ -164,9 +164,6 @@ class AnthropicProvider(BaseLLMProvider):
kwargs["temperature"] = 1.0 kwargs["temperature"] = 1.0
kwargs.pop("top_p", None) kwargs.pop("top_p", None)
for k, v in kwargs.items():
print(f"{k}: {v}")
return kwargs return kwargs
def _handle_stream_event( def _handle_stream_event(

View File

@ -7,9 +7,7 @@ from datetime import datetime, timedelta
from urllib.parse import urlencode, urljoin from urllib.parse import urlencode, urljoin
import aiohttp import aiohttp
from sqlalchemy.orm import Session, scoped_session
from memory.common import settings from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models.discord import DiscordMCPServer from memory.common.db.models.discord import DiscordMCPServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -99,7 +97,7 @@ async def get_endpoints(url: str) -> OAuthEndpoints:
authorization_endpoint=authorization_endpoint, authorization_endpoint=authorization_endpoint,
registration_endpoint=registration_endpoint, registration_endpoint=registration_endpoint,
token_endpoint=token_endpoint, token_endpoint=token_endpoint,
redirect_uri=f"{settings.SERVER_URL}/oauth/callback/discord", redirect_uri=f"{settings.SERVER_URL}/auth/callback/discord",
) )
@ -181,7 +179,7 @@ async def issue_challenge(
async def complete_oauth_flow( async def complete_oauth_flow(
session: Session | scoped_session, code: str, state: str mcp_server: DiscordMCPServer, code: str, state: str
) -> tuple[int, str]: ) -> tuple[int, str]:
"""Complete OAuth flow by exchanging code for token. """Complete OAuth flow by exchanging code for token.
@ -193,12 +191,6 @@ async def complete_oauth_flow(
Tuple of (status_code, html_message) for the callback response Tuple of (status_code, html_message) for the callback response
""" """
try: try:
mcp_server = (
session.query(DiscordMCPServer)
.filter(DiscordMCPServer.state == state)
.first()
)
if not mcp_server: if not mcp_server:
logger.error(f"Invalid or expired state: {state[:20]}...") logger.error(f"Invalid or expired state: {state[:20]}...")
return 400, "Invalid or expired OAuth state" return 400, "Invalid or expired OAuth state"
@ -254,8 +246,6 @@ async def complete_oauth_flow(
mcp_server.state = None # type: ignore mcp_server.state = None # type: ignore
mcp_server.code_verifier = None # type: ignore mcp_server.code_verifier = None # type: ignore
session.commit()
logger.info( logger.info(
f"Stored tokens for user {mcp_server.discord_bot_user_id}, " f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
f"server {mcp_server.mcp_server_url}" f"server {mcp_server.mcp_server_url}"

View File

@ -18,9 +18,9 @@ from pydantic import BaseModel
from memory.common import settings from memory.common import settings
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models.users import DiscordBotUser from memory.common.db.models import DiscordMCPServer, DiscordBotUser
from memory.common.oauth import complete_oauth_flow
from memory.discord.collector import MessageCollector from memory.discord.collector import MessageCollector
from memory.discord.oauth import complete_oauth_flow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -279,51 +279,6 @@ async def refresh_metadata():
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/oauth/callback/discord", response_class=HTMLResponse)
async def oauth_callback(request: Request):
"""Handle OAuth callback from MCP server after user authorization."""
code = request.query_params.get("code")
state = request.query_params.get("state")
error = request.query_params.get("error")
logger.info(
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
)
message, title, close, status_code = "", "", "", 200
if error:
logger.error(f"OAuth error: {error}")
message = f"Error: {error}"
title = "❌ Authorization Failed"
status_code = 400
elif not code or not state:
message = "Missing authorization code or state parameter."
title = "❌ Invalid Request"
status_code = 400
else:
# Complete the OAuth flow (exchange code for token)
with make_session() as session:
status_code, message = await complete_oauth_flow(session, code, state)
if 200 <= status_code < 300:
title = "✅ Authorization Successful!"
close = "You can close this window and return to the MCP server."
else:
title = "❌ Authorization Failed"
return HTMLResponse(
content=f"""
<html>
<body>
<h1>{title}</h1>
<p>{message}</p>
<p>{close}</p>
</body>
</html>
""",
status_code=status_code,
)
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001): def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
"""Run the Discord API server""" """Run the Discord API server"""
uvicorn.run(app, host=host, port=port, log_level="debug") uvicorn.run(app, host=host, port=port, log_level="debug")

View File

@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, scoped_session
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models.discord import DiscordMCPServer from memory.common.db.models.discord import DiscordMCPServer
from memory.discord.oauth import get_endpoints, issue_challenge, register_oauth_client from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -13,8 +13,8 @@ from memory.common.db.models import (
DiscordUser, DiscordUser,
ScheduledLLMCall, ScheduledLLMCall,
) )
from memory.common.db.models.users import BotUser
from memory.common.llms.base import create_provider from memory.common.llms.base import create_provider
from memory.common.llms.tools import MCPServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -216,6 +216,7 @@ def call_llm(
system_prompt: str = "", system_prompt: str = "",
messages: list[str | dict[str, Any]] = [], messages: list[str | dict[str, Any]] = [],
allowed_tools: Collection[str] | None = None, allowed_tools: Collection[str] | None = None,
mcp_servers: list[MCPServer] | None = None,
num_previous_messages: int = 10, num_previous_messages: int = 10,
) -> str | None: ) -> str | None:
""" """
@ -269,6 +270,7 @@ def call_llm(
messages=provider.as_messages(message_content), messages=provider.as_messages(message_content),
tools=tools, tools=tools,
system_prompt=bot_user.system_prompt + "\n\n" + system_prompt, system_prompt=bot_user.system_prompt + "\n\n" + system_prompt,
mcp_servers=mcp_servers,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS, max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response ).response

View File

@ -10,6 +10,7 @@ import textwrap
from datetime import datetime from datetime import datetime
from typing import Any, cast from typing import Any, cast
from memory.common.llms.tools import MCPServer
import requests import requests
from sqlalchemy import exc as sqlalchemy_exc from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy.orm import Session, scoped_session from sqlalchemy.orm import Session, scoped_session
@ -189,6 +190,20 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id, "message_id": message_id,
} }
mcp_servers = None
if (
discord_message.recipient_user
and discord_message.recipient_user.mcp_servers
):
mcp_servers = [
MCPServer(
name=server.mcp_server_url,
url=server.mcp_server_url,
token=server.access_token,
)
for server in discord_message.recipient_user.mcp_servers
]
try: try:
response = call_llm( response = call_llm(
session, session,
@ -196,6 +211,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
from_user=discord_message.from_user, from_user=discord_message.from_user,
channel=discord_message.channel, channel=discord_message.channel,
model=settings.DISCORD_MODEL, model=settings.DISCORD_MODEL,
mcp_servers=mcp_servers,
system_prompt=discord_message.system_prompt, system_prompt=discord_message.system_prompt,
) )
except Exception: except Exception: