mcp servers for discord bots

This commit is contained in:
Daniel O'Connell 2025-11-02 23:04:03 +01:00
parent 6250586d1f
commit 64bb926eba
12 changed files with 916 additions and 53 deletions

View File

@ -0,0 +1,67 @@
"""discord mcp servers
Revision ID: 9b887449ea92
Revises: 1954477b25f4
Create Date: 2025-11-02 22:04:26.259323
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "9b887449ea92"
down_revision: Union[str, None] = "1954477b25f4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"discord_mcp_servers",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
sa.Column("mcp_server_url", sa.Text(), nullable=False),
sa.Column("client_id", sa.Text(), nullable=False),
sa.Column("state", sa.Text(), nullable=True),
sa.Column("code_verifier", sa.Text(), nullable=True),
sa.Column("access_token", sa.Text(), nullable=True),
sa.Column("refresh_token", sa.Text(), nullable=True),
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(
["discord_bot_user_id"],
["discord_users.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("state"),
)
op.create_index(
"discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
)
op.create_index(
"discord_mcp_user_url_idx",
"discord_mcp_servers",
["discord_bot_user_id", "mcp_server_url"],
unique=False,
)
def downgrade() -> None:
op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
op.drop_table("discord_mcp_servers")

View File

@ -34,6 +34,7 @@ from memory.common.db.models.discord import (
DiscordServer,
DiscordChannel,
DiscordUser,
DiscordMCPServer,
)
from memory.common.db.models.observations import (
ObservationContradiction,
@ -106,6 +107,7 @@ __all__ = [
"DiscordServer",
"DiscordChannel",
"DiscordUser",
"DiscordMCPServer",
# Users
"User",
"HumanUser",

View File

@ -127,5 +127,44 @@ class DiscordUser(Base, MessageProcessor):
updated_at = Column(DateTime(timezone=True), server_default=func.now())
system_user = relationship("User", back_populates="discord_users")
mcp_servers = relationship(
"DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
)
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
class DiscordMCPServer(Base):
"""MCP server configuration and OAuth state for Discord users."""
__tablename__ = "discord_mcp_servers"
id = Column(Integer, primary_key=True)
discord_bot_user_id = Column(
BigInteger, ForeignKey("discord_users.id"), nullable=False
)
# MCP server info
mcp_server_url = Column(Text, nullable=False)
client_id = Column(Text, nullable=False)
# OAuth flow state (temporary, cleared after token exchange)
state = Column(Text, nullable=True, unique=True)
code_verifier = Column(Text, nullable=True)
# OAuth tokens (set after successful authorization)
access_token = Column(Text, nullable=True)
refresh_token = Column(Text, nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
discord_user = relationship("DiscordUser", back_populates="mcp_servers")
__table_args__ = (
Index("discord_mcp_state_idx", "state"),
Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"),
)

View File

@ -363,7 +363,19 @@ class DiscordMessage(SourceItem):
@property
def title(self) -> str:
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
return textwrap.dedent("""
<message>
<id>{message_id}</id>
<from>{from_user}</from>
<sent_at>{sent_at}</sent_at>
<content>{content}</content>
</message>
""").format(
message_id=self.message_id,
from_user=self.from_user.username,
sent_at=self.sent_at.isoformat()[:19],
content=self.content,
)
def as_content(self) -> dict[str, Any]:
"""Return message content ready for LLM (text + images from disk)."""

View File

@ -94,6 +94,30 @@ def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
return False
def add_reaction(bot_id: int, channel: int | str, message_id: int, emoji: str) -> bool:
"""Add a reaction to a message in a channel"""
try:
response = requests.post(
f"{get_api_url()}/add_reaction",
json={
"bot_id": bot_id,
"channel": channel,
"message_id": message_id,
"emoji": emoji,
},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(
f"Failed to add reaction {emoji} to message {message_id} in channel {channel}: {e}"
)
return False
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
"""Send a message to a channel by name or ID (ID supports threads)"""
try:

View File

@ -17,6 +17,7 @@ from memory.common.db.models import (
BotUser,
)
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
from memory.common.discord import add_reaction
UpdateSummaryType = Literal["server", "channel", "user"]
@ -209,6 +210,50 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
)
def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition:
bot_id = cast(int, bot.id)
channel_id = channel and channel.id
def handler(input: ToolInput) -> str:
if not isinstance(input, dict):
raise ValueError("Input must be a dictionary")
try:
emoji = input.get("emoji")
except ValueError:
raise ValueError("Emoji is required")
if not emoji:
raise ValueError("Emoji is required")
try:
message_id = int(input.get("message_id") or "no id")
except ValueError:
raise ValueError("Message ID is required")
success = add_reaction(bot_id, channel_id, message_id, emoji)
if not success:
return "Failed to add reaction"
return "Reaction added"
return ToolDefinition(
name="add_reaction",
description="Add a reaction to a message in a channel",
input_schema={
"type": "object",
"properties": {
"message_id": {
"type": "number",
"description": "The ID of the message to add the reaction to",
},
"emoji": {
"type": "string",
"description": "The emoji to add to the message",
},
},
},
function=handler,
)
def make_discord_tools(
bot: BotUser,
author: DiscordUser | None,
@ -227,5 +272,6 @@ def make_discord_tools(
if channel and channel.server:
tools += [
make_summary_tool("server", cast(BigInteger, channel.server_id)),
make_add_reaction_tool(bot, channel),
]
return {tool.name: tool for tool in tools}

View File

@ -12,13 +12,15 @@ from contextlib import asynccontextmanager
from typing import cast
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models.users import DiscordBotUser
from memory.discord.collector import MessageCollector
from memory.discord.oauth import complete_oauth_flow
logger = logging.getLogger(__name__)
@ -45,6 +47,13 @@ class TypingChannelRequest(BaseModel):
channel: int | str # Channel name or ID (ID supports threads)
class AddReactionRequest(BaseModel):
bot_id: int
channel: int | str # Channel name or ID (ID supports threads)
message_id: int
emoji: str
class Collector:
collector: MessageCollector
collector_task: asyncio.Task
@ -53,6 +62,7 @@ class Collector:
bot_name: str
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}")
self.collector = collector
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
self.bot_id = cast(int, bot.id)
@ -72,7 +82,9 @@ async def lifespan(app: FastAPI):
bots = session.query(DiscordBotUser).all()
app.bots = {bot.id: make_collector(bot) for bot in bots}
logger.info(f"Discord collectors started for {len(app.bots)} bots")
logger.error(
f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}"
)
yield
@ -216,6 +228,36 @@ async def health_check():
}
@app.post("/add_reaction")
async def add_reaction_endpoint(request: AddReactionRequest):
"""Add a reaction to a message via the collector's Discord client"""
collector = app.bots.get(request.bot_id)
if not collector:
raise HTTPException(status_code=404, detail="Bot not found")
try:
success = await collector.collector.add_reaction(
request.channel, request.message_id, request.emoji
)
except Exception as e:
logger.error(f"Failed to add reaction: {e}")
raise HTTPException(status_code=500, detail=str(e))
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to add reaction to message {request.message_id}",
)
return {
"success": True,
"channel": request.channel,
"message_id": request.message_id,
"emoji": request.emoji,
"message": f"Added reaction {request.emoji} to message {request.message_id}",
}
@app.post("/refresh_metadata")
async def refresh_metadata():
"""Refresh Discord server/channel/user metadata from Discord API"""
@ -237,6 +279,51 @@ async def refresh_metadata():
raise HTTPException(status_code=500, detail=str(e))
@app.get("/oauth/callback/discord", response_class=HTMLResponse)
async def oauth_callback(request: Request):
"""Handle OAuth callback from MCP server after user authorization."""
code = request.query_params.get("code")
state = request.query_params.get("state")
error = request.query_params.get("error")
logger.info(
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
)
message, title, close, status_code = "", "", "", 200
if error:
logger.error(f"OAuth error: {error}")
message = f"Error: {error}"
title = "❌ Authorization Failed"
status_code = 400
elif not code or not state:
message = "Missing authorization code or state parameter."
title = "❌ Invalid Request"
status_code = 400
else:
# Complete the OAuth flow (exchange code for token)
with make_session() as session:
status_code, message = await complete_oauth_flow(session, code, state)
if 200 <= status_code < 300:
title = "✅ Authorization Successful!"
close = "You can close this window and return to the MCP server."
else:
title = "❌ Authorization Failed"
return HTMLResponse(
content=f"""
<html>
<body>
<h1>{title}</h1>
<p>{message}</p>
<p>{close}</p>
</body>
</html>
""",
status_code=status_code,
)
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
"""Run the Discord API server"""
uvicorn.run(app, host=host, port=port, log_level="debug")

View File

@ -229,11 +229,21 @@ class MessageCollector(commands.Bot):
intents=intents,
help_command=None, # Disable default help
)
logger.info(f"Initialized collector for {self.user}")
async def setup_hook(self):
"""Register slash commands when the bot is ready."""
register_slash_commands(self, name=self.user.name)
if not (name := self.user.name):
logger.error(f"Failed to get user name for {self.user}")
return
name = name.replace("-", "_").lower()
try:
register_slash_commands(self, name=name)
except Exception as e:
logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
logger.error(f"Registered slash commands for {self.user.name}")
async def on_ready(self):
"""Called when bot connects to Discord"""
@ -313,8 +323,6 @@ class MessageCollector(commands.Bot):
async def refresh_metadata(self) -> dict[str, int]:
"""Refresh server and channel metadata from Discord and update database"""
print("🔄 Refreshing Discord metadata...")
servers_updated = 0
channels_updated = 0
users_updated = 0
@ -454,25 +462,33 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False
async def _get_channel(
self, channel_identifier: int | str, check_notifications: bool = True
):
"""Get channel by ID or name with standard checks"""
if check_notifications and not settings.DISCORD_NOTIFICATIONS_ENABLED:
logger.debug("Discord notifications disabled")
return None
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
if not channel:
logger.error(f"Channel {channel_identifier} not found")
return channel
async def send_to_channel(
self, channel_identifier: int | str, message: str
) -> bool:
"""Send a message to a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False
try:
# Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
channel = await self._get_channel(channel_identifier)
if not channel:
logger.error(f"Channel {channel_identifier} not found")
return False
# Post-process mentions to convert usernames to IDs
with make_session() as session:
processed_message = process_mentions(session, message)
@ -486,18 +502,9 @@ class MessageCollector(commands.Bot):
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False
try:
# Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
channel = await self._get_channel(channel_identifier)
if not channel:
logger.error(f"Channel {channel_identifier} not found")
return False
async with channel.typing():
@ -509,3 +516,21 @@ class MessageCollector(commands.Bot):
f"Failed to trigger typing for channel {channel_identifier}: {e}"
)
return False
async def add_reaction(
self, channel_identifier: int | str, message_id: int, emoji: str
) -> bool:
"""Add a reaction to a message in a channel"""
try:
channel = await self._get_channel(channel_identifier)
if not channel:
return False
message = await channel.fetch_message(message_id)
await message.add_reaction(emoji)
logger.info(f"Added reaction {emoji} to message {message_id}")
return True
except Exception as e:
logger.error(f"Failed to add reaction: {e}")
return False

View File

@ -1,7 +1,7 @@
"""Lightweight slash-command helpers for the Discord collector."""
from __future__ import annotations
from calendar import c
import logging
from dataclasses import dataclass
from typing import Callable, Literal
@ -10,6 +10,9 @@ from sqlalchemy.orm import Session
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
from memory.discord.mcp import run_mcp_server_command
logger = logging.getLogger(__name__)
ScopeLiteral = Literal["server", "channel", "user"]
@ -50,6 +53,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
"""
if getattr(bot, "_memory_commands_registered", False):
logger.error(f"Slash commands already registered for {name}")
return
setattr(bot, "_memory_commands_registered", True)
@ -167,6 +171,21 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
target_user=user,
)
@tree.command(
name=f"{name}_mcp_servers",
description="Manage MCP servers for your account",
)
@discord.app_commands.describe(
action="Action to perform",
url="MCP server URL (required for add, delete, connect, tools)",
)
async def mcp_servers_command(
interaction: discord.Interaction,
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
url: str | None = None,
) -> None:
await run_mcp_server_command(interaction, action, url and url.strip(), name)
async def _run_interaction_command(
interaction: discord.Interaction,
@ -177,17 +196,10 @@ async def _run_interaction_command(
**handler_kwargs,
) -> None:
"""Shared coroutine used by the registered slash commands."""
try:
with make_session() as session:
response = run_command(
session,
interaction,
scope,
handler=handler,
target_user=target_user,
**handler_kwargs,
)
context = _build_context(session, interaction, scope, target_user)
response = handler(context, **handler_kwargs)
session.commit()
except CommandError as exc: # pragma: no cover - passthrough
await interaction.response.send_message(str(exc), ephemeral=True)
@ -199,21 +211,6 @@ async def _run_interaction_command(
)
def run_command(
session: Session,
interaction: discord.Interaction,
scope: ScopeLiteral,
*,
handler: CommandHandler,
target_user: discord.User | None = None,
**handler_kwargs,
) -> CommandResponse:
"""Create a :class:`CommandContext` and execute the handler."""
context = _build_context(session, interaction, scope, target_user)
return handler(context, **handler_kwargs)
def _build_context(
session: Session,
interaction: discord.Interaction,

294
src/memory/discord/mcp.py Normal file
View File

@ -0,0 +1,294 @@
"""Lightweight slash-command helpers for the Discord collector."""
import json
import logging
import time
from typing import Any, AsyncGenerator, Literal, cast
import aiohttp
import discord
from sqlalchemy.orm import Session, scoped_session
from memory.common.db.connection import make_session
from memory.common.db.models.discord import DiscordMCPServer
from memory.discord.oauth import get_endpoints, issue_challenge, register_oauth_client
logger = logging.getLogger(__name__)
def find_mcp_server(
session: Session | scoped_session, user_id: int, url: str
) -> DiscordMCPServer | None:
return (
session.query(DiscordMCPServer)
.filter(
DiscordMCPServer.discord_bot_user_id == user_id,
DiscordMCPServer.mcp_server_url == url,
)
.first()
)
async def call_mcp_server(
url: str, access_token: str, method: str, params: dict = {}
) -> AsyncGenerator[Any, None]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Authorization": f"Bearer {access_token}",
}
payload = {
"jsonrpc": "2.0",
"id": int(time.time() * 1000),
"method": method,
"params": params,
}
async with aiohttp.ClientSession() as http_session:
async with http_session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Tools list failed: {resp.status} - {error_text}")
raise ValueError(
f"Failed to call MCP server: {resp.status} - {error_text}"
)
# Parse SSE stream
async for line in resp.content:
line_str = line.decode("utf-8").strip()
# SSE format: "data: {json}"
if line_str.startswith("data: "):
json_str = line_str[6:] # Remove "data: " prefix
try:
yield json.loads(json_str)
except json.JSONDecodeError:
continue # Skip invalid JSON lines
async def handle_mcp_list(interaction: discord.Interaction) -> str:
"""List all MCP servers for the user."""
with make_session() as session:
servers = (
session.query(DiscordMCPServer)
.filter(
DiscordMCPServer.discord_bot_user_id == interaction.user.id,
)
.all()
)
if not servers:
return (
"📋 **Your MCP Servers**\n\n"
"You don't have any MCP servers configured yet.\n"
"Use `/memory_mcp_servers add <url>` to add one."
)
def format_server(server: DiscordMCPServer) -> str:
con = "🟢" if cast(str | None, server.access_token) else "🔴"
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
server_list = "\n".join(format_server(s) for s in servers)
return f"📋 **Your MCP Servers**\n\n{server_list}"
async def handle_mcp_add(
interaction: discord.Interaction, url: str, name: str = "memory"
) -> str:
"""Add a new MCP server via OAuth."""
with make_session() as session:
if find_mcp_server(session, interaction.user.id, url):
return (
f"**MCP Server Already Exists**\n\n"
f"You already have an MCP server configured at `{url}`.\n"
f"Use `/memory_mcp_servers connect {url}` to reconnect."
)
endpoints = await get_endpoints(url)
client_id = await register_oauth_client(
endpoints,
url,
f"Discord Bot - {name} ({interaction.user.name})",
)
mcp_server = DiscordMCPServer(
discord_bot_user_id=interaction.user.id,
mcp_server_url=url,
client_id=client_id,
)
session.add(mcp_server)
session.flush()
auth_url = await issue_challenge(mcp_server, endpoints)
session.commit()
logger.info(
f"Created MCP server record: id={mcp_server.id}, "
f"user={interaction.user.id}, url={url}"
)
return (
f"🔐 **Add MCP Server**\n\n"
f"Server: `{url}`\n"
f"Click the link below to authorize:\n{auth_url}\n\n"
f"⚠️ Keep this link private!\n"
f"💡 You'll be redirected to login and grant access to the MCP server."
)
async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
"""Delete an MCP server."""
with make_session() as session:
mcp_server = find_mcp_server(session, interaction.user.id, url)
if not mcp_server:
return (
f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n"
)
session.delete(mcp_server)
session.commit()
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
"""Reconnect to an existing MCP server (redo OAuth)."""
with make_session() as session:
mcp_server = find_mcp_server(session, interaction.user.id, url)
if not mcp_server:
raise ValueError(
f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n"
f"Use `/memory_mcp_servers add {url}` to add it first."
)
if not mcp_server:
raise ValueError(
f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n"
f"Use `/memory_mcp_servers add {url}` to add it first."
)
endpoints = await get_endpoints(url)
auth_url = await issue_challenge(mcp_server, endpoints)
session.commit()
logger.info(
f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}"
)
return (
f"🔄 **Reconnect to MCP Server**\n\n"
f"Server: `{url}`\n"
f"Click the link below to reauthorize:\n{auth_url}\n\n"
f"⚠️ Keep this link private!\n"
f"💡 You'll be redirected to login and grant access to the MCP server again."
)
async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
"""List tools available on an MCP server."""
with make_session() as session:
mcp_server = find_mcp_server(session, interaction.user.id, url)
if not mcp_server:
raise ValueError(
f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n"
f"Use `/memory_mcp_servers add {url}` to add it first."
)
if not cast(str | None, mcp_server.access_token):
raise ValueError(
f"**Not Authorized**\n\n"
f"You haven't authorized access to `{url}` yet.\n"
f"Use `/memory_mcp_servers connect {url}` to authorize."
)
access_token = cast(str, mcp_server.access_token)
# Make JSON-RPC request to MCP server
tools = None
try:
async for data in call_mcp_server(url, access_token, "tools/list"):
if "result" in data and "tools" in data["result"]:
tools = data["result"]["tools"]
break
except aiohttp.ClientError as exc:
logger.exception(f"Failed to connect to MCP server: {exc}")
raise ValueError(
f"**Connection failed**\n\n"
f"Server: `{url}`\n"
f"Could not connect to the MCP server: {str(exc)}"
)
except Exception as exc:
logger.exception(f"Failed to list tools: {exc}")
raise ValueError(
f"**Error**\n\nServer: `{url}`\nFailed to list tools: {str(exc)}"
)
if tools is None:
raise ValueError(
f"**Unexpected response format**\n\n"
f"Server: `{url}`\n"
f"The server returned an unexpected response format."
)
if not tools:
return (
f"🔧 **MCP Server Tools**\n\n"
f"Server: `{url}`\n\n"
f"No tools available on this server."
)
# Format tools list
tools_list = "\n".join(
f"• **{t.get('name', 'unknown')}**: {t.get('description', 'No description')}"
for t in tools
)
return (
f"🔧 **MCP Server Tools**\n\n"
f"Server: `{url}`\n"
f"Found {len(tools)} tool(s):\n\n"
f"{tools_list}"
)
async def run_mcp_server_command(
interaction: discord.Interaction,
action: Literal["list", "add", "delete", "connect", "tools"],
url: str | None,
name: str = "memory",
) -> None:
"""Handle MCP server management commands."""
if action not in ["list", "add", "delete", "connect", "tools"]:
await interaction.response.send_message("❌ Invalid action", ephemeral=True)
return
if action != "list" and not url:
await interaction.response.send_message(
"❌ URL is required for this action", ephemeral=True
)
return
try:
if action == "list" or not url:
result = await handle_mcp_list(interaction)
elif action == "add":
result = await handle_mcp_add(interaction, url, name)
elif action == "delete":
result = await handle_mcp_delete(interaction, url)
elif action == "connect":
result = await handle_mcp_connect(interaction, url)
elif action == "tools":
result = await handle_mcp_tools(interaction, url)
except Exception as exc:
result = f"❌ Error: {exc}"
await interaction.response.send_message(result, ephemeral=True)

View File

@ -141,8 +141,10 @@ def previous_messages(
) -> list[DiscordMessage]:
messages = session.query(DiscordMessage)
if user_id:
print(f"user_id: {user_id}")
messages = messages.filter(DiscordMessage.recipient_id == user_id)
if channel_id:
print(f"channel_id: {channel_id}")
messages = messages.filter(DiscordMessage.channel_id == channel_id)
return list(
reversed(

268
src/memory/discord/oauth.py Normal file
View File

@ -0,0 +1,268 @@
import hashlib
import logging
import secrets
from base64 import urlsafe_b64encode
from dataclasses import dataclass
from datetime import datetime, timedelta
from urllib.parse import urlencode, urljoin
import aiohttp
from sqlalchemy.orm import Session, scoped_session
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models.discord import DiscordMCPServer
logger = logging.getLogger(__name__)
@dataclass
class OAuthEndpoints:
authorization_endpoint: str
registration_endpoint: str
token_endpoint: str
redirect_uri: str
def generate_pkce_pair() -> tuple[str, str]:
"""Generate PKCE code verifier and challenge.
Returns:
Tuple of (code_verifier, code_challenge)
"""
# Generate a random code verifier
code_verifier = (
urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
)
# Create code challenge using S256 method
challenge_bytes = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = urlsafe_b64encode(challenge_bytes).decode("utf-8").rstrip("=")
return code_verifier, code_challenge
async def discover_oauth_metadata(server_url: str) -> dict | None:
"""Discover OAuth metadata from an MCP server."""
# Try the standard OAuth discovery endpoint
discovery_url = urljoin(server_url, "/.well-known/oauth-authorization-server")
try:
async with aiohttp.ClientSession() as session:
async with session.get(
discovery_url, timeout=aiohttp.ClientTimeout(total=5)
) as resp:
if resp.status == 200:
return await resp.json()
except Exception as exc:
logger.warning(f"Failed to discover OAuth metadata from {discovery_url}: {exc}")
return None
async def get_endpoints(url: str) -> OAuthEndpoints:
# Discover OAuth endpoints from the target server
oauth_metadata = await discover_oauth_metadata(url)
if not oauth_metadata:
raise ValueError(
"**Failed to connect to MCP server**\n\n"
f"Could not discover OAuth endpoints at `{url}`\n"
"Make sure the server is running and supports OAuth 2.0.",
)
authorization_endpoint = oauth_metadata.get("authorization_endpoint")
registration_endpoint = oauth_metadata.get("registration_endpoint")
token_endpoint = oauth_metadata.get("token_endpoint")
if not authorization_endpoint:
raise ValueError(
"**Invalid OAuth configuration**\n\n"
f"Server `{url}` did not provide an authorization endpoint.",
)
if not registration_endpoint:
raise ValueError(
"**Invalid OAuth configuration**\n\n"
f"Server `{url}` does not support dynamic client registration.",
)
if not token_endpoint:
raise ValueError(
"**Invalid OAuth configuration**\n\n"
f"Server `{url}` does not provide a token endpoint.",
)
logger.info(f"Authorization endpoint: {authorization_endpoint}")
logger.info(f"Registration endpoint: {registration_endpoint}")
return OAuthEndpoints(
authorization_endpoint=authorization_endpoint,
registration_endpoint=registration_endpoint,
token_endpoint=token_endpoint,
redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord",
)
async def register_oauth_client(
endpoints: OAuthEndpoints,
url: str,
client_name: str,
) -> None:
"""Register OAuth client and store client_id in the mcp_server object."""
client_metadata = {
"client_name": client_name,
"redirect_uris": [endpoints.redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"scope": "read write",
"token_endpoint_auth_method": "none",
}
logger.error(f"Registration metadata: {client_metadata}")
try:
async with aiohttp.ClientSession() as session:
async with session.post(
endpoints.registration_endpoint,
json=client_metadata,
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
logger.error(
f"Registration response: {resp.status} {await resp.text()}"
)
resp.raise_for_status()
client_info = await resp.json()
except Exception as exc:
raise ValueError(
f"Failed to register OAuth client at {endpoints.registration_endpoint}: {exc}"
)
if not client_info or "client_id" not in client_info:
raise ValueError(
"**Failed to register OAuth client**\n\n"
f"Could not register with the MCP server at `{url}`\n"
f"Check the server logs for more details.",
)
client_id = client_info["client_id"]
logger.info(f"Registered OAuth client: {client_id}")
return client_id
async def issue_challenge(
mcp_server: DiscordMCPServer,
endpoints: OAuthEndpoints,
) -> str:
"""Generate OAuth challenge and store state in mcp_server object."""
code_verifier, code_challenge = generate_pkce_pair()
state = secrets.token_urlsafe(32)
# Store in mcp_server object
mcp_server.state = state # type: ignore
mcp_server.code_verifier = code_verifier # type: ignore
logger.info(
f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: "
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
)
# Build authorization URL pointing to the target server
auth_params = {
"client_id": mcp_server.client_id,
"redirect_uri": endpoints.redirect_uri,
"response_type": "code",
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"scope": "read write",
}
return f"{endpoints.authorization_endpoint}?{urlencode(auth_params)}"
async def complete_oauth_flow(
session: Session | scoped_session, code: str, state: str
) -> tuple[int, str]:
"""Complete OAuth flow by exchanging code for token.
Args:
code: Authorization code from OAuth callback
state: State parameter from OAuth callback
Returns:
Tuple of (status_code, html_message) for the callback response
"""
try:
mcp_server = (
session.query(DiscordMCPServer)
.filter(DiscordMCPServer.state == state)
.first()
)
if not mcp_server:
logger.error(f"Invalid or expired state: {state[:20]}...")
return 400, "Invalid or expired OAuth state"
logger.info(
f"Found MCP server config: user={mcp_server.discord_bot_user_id}, "
f"url={mcp_server.mcp_server_url}"
)
# Get OAuth endpoints
try:
endpoints = await get_endpoints(str(mcp_server.mcp_server_url))
except Exception as exc:
return 500, f"Failed to get OAuth endpoints: {str(exc)}"
# Exchange authorization code for access token
token_data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": endpoints.redirect_uri,
"client_id": mcp_server.client_id,
"code_verifier": mcp_server.code_verifier,
}
async with aiohttp.ClientSession() as http_session:
async with http_session.post(
endpoints.token_endpoint,
data=token_data,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Token exchange failed: {resp.status} - {error_text}")
return 500, f"Token exchange failed: {error_text}"
tokens = await resp.json()
access_token = tokens.get("access_token")
refresh_token = tokens.get("refresh_token")
expires_in = tokens.get("expires_in", 3600)
if not access_token:
return 500, "Token response did not include access_token"
logger.info(f"Successfully obtained access token: {access_token[:20]}...")
# Store tokens and clear temporary OAuth state
mcp_server.access_token = access_token # type: ignore
mcp_server.refresh_token = refresh_token # type: ignore
mcp_server.token_expires_at = datetime.now() + timedelta(seconds=expires_in) # type: ignore
# Clear temporary OAuth flow data
mcp_server.state = None # type: ignore
mcp_server.code_verifier = None # type: ignore
session.commit()
logger.info(
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
f"server {mcp_server.mcp_server_url}"
)
return 200, "✅ Authorization successful! You can now use this MCP server."
except Exception as exc:
logger.exception(f"Failed to complete OAuth flow: {exc}")
return 500, f"Failed to complete OAuth flow: {str(exc)}"