diff --git a/db/migrations/versions/20251102_220426_discord_mcp_servers.py b/db/migrations/versions/20251102_220426_discord_mcp_servers.py
new file mode 100644
index 0000000..deb35d5
--- /dev/null
+++ b/db/migrations/versions/20251102_220426_discord_mcp_servers.py
@@ -0,0 +1,67 @@
+"""discord mcp servers
+
+Revision ID: 9b887449ea92
+Revises: 1954477b25f4
+Create Date: 2025-11-02 22:04:26.259323
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "9b887449ea92"
+down_revision: Union[str, None] = "1954477b25f4"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ op.create_table(
+ "discord_mcp_servers",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
+ sa.Column("mcp_server_url", sa.Text(), nullable=False),
+ sa.Column("client_id", sa.Text(), nullable=False),
+ sa.Column("state", sa.Text(), nullable=True),
+ sa.Column("code_verifier", sa.Text(), nullable=True),
+ sa.Column("access_token", sa.Text(), nullable=True),
+ sa.Column("refresh_token", sa.Text(), nullable=True),
+ sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column(
+ "created_at",
+ sa.DateTime(timezone=True),
+ server_default=sa.text("now()"),
+ nullable=True,
+ ),
+ sa.Column(
+ "updated_at",
+ sa.DateTime(timezone=True),
+ server_default=sa.text("now()"),
+ nullable=True,
+ ),
+ sa.ForeignKeyConstraint(
+ ["discord_bot_user_id"],
+ ["discord_users.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("state"),
+ )
+ op.create_index(
+ "discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
+ )
+ op.create_index(
+ "discord_mcp_user_url_idx",
+ "discord_mcp_servers",
+ ["discord_bot_user_id", "mcp_server_url"],
+ unique=False,
+ )
+
+
+def downgrade() -> None:
+ op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
+ op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
+ op.drop_table("discord_mcp_servers")
diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py
index 140fa62..1198d8c 100644
--- a/src/memory/common/db/models/__init__.py
+++ b/src/memory/common/db/models/__init__.py
@@ -34,6 +34,7 @@ from memory.common.db.models.discord import (
DiscordServer,
DiscordChannel,
DiscordUser,
+ DiscordMCPServer,
)
from memory.common.db.models.observations import (
ObservationContradiction,
@@ -106,6 +107,7 @@ __all__ = [
"DiscordServer",
"DiscordChannel",
"DiscordUser",
+ "DiscordMCPServer",
# Users
"User",
"HumanUser",
diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py
index 3775ca1..638deec 100644
--- a/src/memory/common/db/models/discord.py
+++ b/src/memory/common/db/models/discord.py
@@ -127,5 +127,44 @@ class DiscordUser(Base, MessageProcessor):
updated_at = Column(DateTime(timezone=True), server_default=func.now())
system_user = relationship("User", back_populates="discord_users")
+ mcp_servers = relationship(
+ "DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
+ )
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
+
+
+class DiscordMCPServer(Base):
+ """MCP server configuration and OAuth state for Discord users."""
+
+ __tablename__ = "discord_mcp_servers"
+
+ id = Column(Integer, primary_key=True)
+ discord_bot_user_id = Column(
+ BigInteger, ForeignKey("discord_users.id"), nullable=False
+ )
+
+ # MCP server info
+ mcp_server_url = Column(Text, nullable=False)
+ client_id = Column(Text, nullable=False)
+
+ # OAuth flow state (temporary, cleared after token exchange)
+ state = Column(Text, nullable=True, unique=True)
+ code_verifier = Column(Text, nullable=True)
+
+ # OAuth tokens (set after successful authorization)
+ access_token = Column(Text, nullable=True)
+ refresh_token = Column(Text, nullable=True)
+ token_expires_at = Column(DateTime(timezone=True), nullable=True)
+
+ # Timestamps
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
+ updated_at = Column(DateTime(timezone=True), server_default=func.now())
+
+ # Relationships
+ discord_user = relationship("DiscordUser", back_populates="mcp_servers")
+
+ __table_args__ = (
+ Index("discord_mcp_state_idx", "state"),
+ Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"),
+ )
diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py
index 3c011a8..827b1be 100644
--- a/src/memory/common/db/models/source_items.py
+++ b/src/memory/common/db/models/source_items.py
@@ -363,7 +363,19 @@ class DiscordMessage(SourceItem):
@property
def title(self) -> str:
- return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
+ return textwrap.dedent("""
+
+ {message_id}
+ {from_user}
+ {sent_at}
+ {content}
+
+ """).format(
+ message_id=self.message_id,
+ from_user=self.from_user.username,
+ sent_at=self.sent_at.isoformat()[:19],
+ content=self.content,
+ )
def as_content(self) -> dict[str, Any]:
"""Return message content ready for LLM (text + images from disk)."""
diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py
index 6eab35b..a324276 100644
--- a/src/memory/common/discord.py
+++ b/src/memory/common/discord.py
@@ -94,6 +94,30 @@ def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
return False
+def add_reaction(bot_id: int, channel: int | str, message_id: int, emoji: str) -> bool:
+ """Add a reaction to a message in a channel"""
+ try:
+ response = requests.post(
+ f"{get_api_url()}/add_reaction",
+ json={
+ "bot_id": bot_id,
+ "channel": channel,
+ "message_id": message_id,
+ "emoji": emoji,
+ },
+ timeout=10,
+ )
+ response.raise_for_status()
+ result = response.json()
+ return result.get("success", False)
+
+ except requests.RequestException as e:
+ logger.error(
+ f"Failed to add reaction {emoji} to message {message_id} in channel {channel}: {e}"
+ )
+ return False
+
+
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
"""Send a message to a channel by name or ID (ID supports threads)"""
try:
diff --git a/src/memory/common/llms/tools/discord.py b/src/memory/common/llms/tools/discord.py
index 8492be3..d334bb5 100644
--- a/src/memory/common/llms/tools/discord.py
+++ b/src/memory/common/llms/tools/discord.py
@@ -17,6 +17,7 @@ from memory.common.db.models import (
BotUser,
)
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
+from memory.common.discord import add_reaction
UpdateSummaryType = Literal["server", "channel", "user"]
@@ -209,6 +210,50 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
)
+def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition:
+ bot_id = cast(int, bot.id)
+ channel_id = channel and channel.id
+
+ def handler(input: ToolInput) -> str:
+ if not isinstance(input, dict):
+ raise ValueError("Input must be a dictionary")
+ try:
+ emoji = input.get("emoji")
+ except ValueError:
+ raise ValueError("Emoji is required")
+ if not emoji:
+ raise ValueError("Emoji is required")
+
+ try:
+ message_id = int(input.get("message_id") or "no id")
+ except ValueError:
+ raise ValueError("Message ID is required")
+
+ success = add_reaction(bot_id, channel_id, message_id, emoji)
+ if not success:
+ return "Failed to add reaction"
+ return "Reaction added"
+
+ return ToolDefinition(
+ name="add_reaction",
+ description="Add a reaction to a message in a channel",
+ input_schema={
+ "type": "object",
+ "properties": {
+ "message_id": {
+ "type": "number",
+ "description": "The ID of the message to add the reaction to",
+ },
+ "emoji": {
+ "type": "string",
+ "description": "The emoji to add to the message",
+ },
+ },
+ },
+ function=handler,
+ )
+
+
def make_discord_tools(
bot: BotUser,
author: DiscordUser | None,
@@ -227,5 +272,6 @@ def make_discord_tools(
if channel and channel.server:
tools += [
make_summary_tool("server", cast(BigInteger, channel.server_id)),
+ make_add_reaction_tool(bot, channel),
]
return {tool.name: tool for tool in tools}
diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py
index bb43f0f..451b1ac 100644
--- a/src/memory/discord/api.py
+++ b/src/memory/discord/api.py
@@ -12,13 +12,15 @@ from contextlib import asynccontextmanager
from typing import cast
import uvicorn
-from fastapi import FastAPI, HTTPException
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models.users import DiscordBotUser
from memory.discord.collector import MessageCollector
+from memory.discord.oauth import complete_oauth_flow
logger = logging.getLogger(__name__)
@@ -45,6 +47,13 @@ class TypingChannelRequest(BaseModel):
channel: int | str # Channel name or ID (ID supports threads)
+class AddReactionRequest(BaseModel):
+ bot_id: int
+ channel: int | str # Channel name or ID (ID supports threads)
+ message_id: int
+ emoji: str
+
+
class Collector:
collector: MessageCollector
collector_task: asyncio.Task
@@ -53,6 +62,7 @@ class Collector:
bot_name: str
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
+ logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}")
self.collector = collector
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
self.bot_id = cast(int, bot.id)
@@ -72,7 +82,9 @@ async def lifespan(app: FastAPI):
bots = session.query(DiscordBotUser).all()
app.bots = {bot.id: make_collector(bot) for bot in bots}
- logger.info(f"Discord collectors started for {len(app.bots)} bots")
+ logger.error(
+ f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}"
+ )
yield
@@ -216,6 +228,36 @@ async def health_check():
}
+@app.post("/add_reaction")
+async def add_reaction_endpoint(request: AddReactionRequest):
+ """Add a reaction to a message via the collector's Discord client"""
+ collector = app.bots.get(request.bot_id)
+ if not collector:
+ raise HTTPException(status_code=404, detail="Bot not found")
+
+ try:
+ success = await collector.collector.add_reaction(
+ request.channel, request.message_id, request.emoji
+ )
+ except Exception as e:
+ logger.error(f"Failed to add reaction: {e}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+ if not success:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to add reaction to message {request.message_id}",
+ )
+
+ return {
+ "success": True,
+ "channel": request.channel,
+ "message_id": request.message_id,
+ "emoji": request.emoji,
+ "message": f"Added reaction {request.emoji} to message {request.message_id}",
+ }
+
+
@app.post("/refresh_metadata")
async def refresh_metadata():
"""Refresh Discord server/channel/user metadata from Discord API"""
@@ -237,6 +279,51 @@ async def refresh_metadata():
raise HTTPException(status_code=500, detail=str(e))
+@app.get("/oauth/callback/discord", response_class=HTMLResponse)
+async def oauth_callback(request: Request):
+ """Handle OAuth callback from MCP server after user authorization."""
+ code = request.query_params.get("code")
+ state = request.query_params.get("state")
+ error = request.query_params.get("error")
+
+ logger.info(
+ f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
+ )
+
+ message, title, close, status_code = "", "", "", 200
+ if error:
+ logger.error(f"OAuth error: {error}")
+ message = f"Error: {error}"
+ title = "❌ Authorization Failed"
+ status_code = 400
+ elif not code or not state:
+ message = "Missing authorization code or state parameter."
+ title = "❌ Invalid Request"
+ status_code = 400
+ else:
+ # Complete the OAuth flow (exchange code for token)
+ with make_session() as session:
+ status_code, message = await complete_oauth_flow(session, code, state)
+ if 200 <= status_code < 300:
+ title = "✅ Authorization Successful!"
+ close = "You can close this window and return to the MCP server."
+ else:
+ title = "❌ Authorization Failed"
+
+ return HTMLResponse(
+ content=f"""
+
+
+ {title}
+ {message}
+ {close}
+
+
+ """,
+ status_code=status_code,
+ )
+
+
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
"""Run the Discord API server"""
uvicorn.run(app, host=host, port=port, log_level="debug")
diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py
index bb5b6c6..852a892 100644
--- a/src/memory/discord/collector.py
+++ b/src/memory/discord/collector.py
@@ -229,11 +229,21 @@ class MessageCollector(commands.Bot):
intents=intents,
help_command=None, # Disable default help
)
+ logger.info(f"Initialized collector for {self.user}")
async def setup_hook(self):
"""Register slash commands when the bot is ready."""
- register_slash_commands(self, name=self.user.name)
+ if not (name := self.user.name):
+ logger.error(f"Failed to get user name for {self.user}")
+ return
+
+ name = name.replace("-", "_").lower()
+ try:
+ register_slash_commands(self, name=name)
+ except Exception as e:
+ logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
+ logger.error(f"Registered slash commands for {self.user.name}")
async def on_ready(self):
"""Called when bot connects to Discord"""
@@ -313,8 +323,6 @@ class MessageCollector(commands.Bot):
async def refresh_metadata(self) -> dict[str, int]:
"""Refresh server and channel metadata from Discord and update database"""
- print("🔄 Refreshing Discord metadata...")
-
servers_updated = 0
channels_updated = 0
users_updated = 0
@@ -454,25 +462,33 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False
+ async def _get_channel(
+ self, channel_identifier: int | str, check_notifications: bool = True
+ ):
+ """Get channel by ID or name with standard checks"""
+ if check_notifications and not settings.DISCORD_NOTIFICATIONS_ENABLED:
+ logger.debug("Discord notifications disabled")
+ return None
+
+ if isinstance(channel_identifier, int):
+ channel = self.get_channel(channel_identifier)
+ else:
+ channel = await self.get_channel_by_name(channel_identifier)
+
+ if not channel:
+ logger.error(f"Channel {channel_identifier} not found")
+
+ return channel
+
async def send_to_channel(
self, channel_identifier: int | str, message: str
) -> bool:
"""Send a message to a channel by name or ID (supports threads)"""
- if not settings.DISCORD_NOTIFICATIONS_ENABLED:
- return False
-
try:
- # Get channel by ID or name
- if isinstance(channel_identifier, int):
- channel = self.get_channel(channel_identifier)
- else:
- channel = await self.get_channel_by_name(channel_identifier)
-
+ channel = await self._get_channel(channel_identifier)
if not channel:
- logger.error(f"Channel {channel_identifier} not found")
return False
- # Post-process mentions to convert usernames to IDs
with make_session() as session:
processed_message = process_mentions(session, message)
@@ -486,18 +502,9 @@ class MessageCollector(commands.Bot):
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
- if not settings.DISCORD_NOTIFICATIONS_ENABLED:
- return False
-
try:
- # Get channel by ID or name
- if isinstance(channel_identifier, int):
- channel = self.get_channel(channel_identifier)
- else:
- channel = await self.get_channel_by_name(channel_identifier)
-
+ channel = await self._get_channel(channel_identifier)
if not channel:
- logger.error(f"Channel {channel_identifier} not found")
return False
async with channel.typing():
@@ -509,3 +516,21 @@ class MessageCollector(commands.Bot):
f"Failed to trigger typing for channel {channel_identifier}: {e}"
)
return False
+
+ async def add_reaction(
+ self, channel_identifier: int | str, message_id: int, emoji: str
+ ) -> bool:
+ """Add a reaction to a message in a channel"""
+ try:
+ channel = await self._get_channel(channel_identifier)
+ if not channel:
+ return False
+
+ message = await channel.fetch_message(message_id)
+ await message.add_reaction(emoji)
+ logger.info(f"Added reaction {emoji} to message {message_id}")
+ return True
+
+ except Exception as e:
+ logger.error(f"Failed to add reaction: {e}")
+ return False
diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py
index bd3a234..82c15ce 100644
--- a/src/memory/discord/commands.py
+++ b/src/memory/discord/commands.py
@@ -1,7 +1,7 @@
"""Lightweight slash-command helpers for the Discord collector."""
-from __future__ import annotations
-
+from calendar import c
+import logging
from dataclasses import dataclass
from typing import Callable, Literal
@@ -10,6 +10,9 @@ from sqlalchemy.orm import Session
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
+from memory.discord.mcp import run_mcp_server_command
+
+logger = logging.getLogger(__name__)
ScopeLiteral = Literal["server", "channel", "user"]
@@ -50,6 +53,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
"""
if getattr(bot, "_memory_commands_registered", False):
+ logger.error(f"Slash commands already registered for {name}")
return
setattr(bot, "_memory_commands_registered", True)
@@ -167,6 +171,21 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
target_user=user,
)
+ @tree.command(
+ name=f"{name}_mcp_servers",
+ description="Manage MCP servers for your account",
+ )
+ @discord.app_commands.describe(
+ action="Action to perform",
+ url="MCP server URL (required for add, delete, connect, tools)",
+ )
+ async def mcp_servers_command(
+ interaction: discord.Interaction,
+ action: Literal["list", "add", "delete", "connect", "tools"] = "list",
+ url: str | None = None,
+ ) -> None:
+ await run_mcp_server_command(interaction, action, url and url.strip(), name)
+
async def _run_interaction_command(
interaction: discord.Interaction,
@@ -177,17 +196,10 @@ async def _run_interaction_command(
**handler_kwargs,
) -> None:
"""Shared coroutine used by the registered slash commands."""
-
try:
with make_session() as session:
- response = run_command(
- session,
- interaction,
- scope,
- handler=handler,
- target_user=target_user,
- **handler_kwargs,
- )
+ context = _build_context(session, interaction, scope, target_user)
+ response = handler(context, **handler_kwargs)
session.commit()
except CommandError as exc: # pragma: no cover - passthrough
await interaction.response.send_message(str(exc), ephemeral=True)
@@ -199,21 +211,6 @@ async def _run_interaction_command(
)
-def run_command(
- session: Session,
- interaction: discord.Interaction,
- scope: ScopeLiteral,
- *,
- handler: CommandHandler,
- target_user: discord.User | None = None,
- **handler_kwargs,
-) -> CommandResponse:
- """Create a :class:`CommandContext` and execute the handler."""
-
- context = _build_context(session, interaction, scope, target_user)
- return handler(context, **handler_kwargs)
-
-
def _build_context(
session: Session,
interaction: discord.Interaction,
diff --git a/src/memory/discord/mcp.py b/src/memory/discord/mcp.py
new file mode 100644
index 0000000..e2f7e9d
--- /dev/null
+++ b/src/memory/discord/mcp.py
@@ -0,0 +1,294 @@
+"""Lightweight slash-command helpers for the Discord collector."""
+
+import json
+import logging
+import time
+from typing import Any, AsyncGenerator, Literal, cast
+
+import aiohttp
+import discord
+from sqlalchemy.orm import Session, scoped_session
+
+from memory.common.db.connection import make_session
+from memory.common.db.models.discord import DiscordMCPServer
+from memory.discord.oauth import get_endpoints, issue_challenge, register_oauth_client
+
+logger = logging.getLogger(__name__)
+
+
+def find_mcp_server(
+ session: Session | scoped_session, user_id: int, url: str
+) -> DiscordMCPServer | None:
+ return (
+ session.query(DiscordMCPServer)
+ .filter(
+ DiscordMCPServer.discord_bot_user_id == user_id,
+ DiscordMCPServer.mcp_server_url == url,
+ )
+ .first()
+ )
+
+
+async def call_mcp_server(
+ url: str, access_token: str, method: str, params: dict = {}
+) -> AsyncGenerator[Any, None]:
+ headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json, text/event-stream",
+ "Authorization": f"Bearer {access_token}",
+ }
+
+ payload = {
+ "jsonrpc": "2.0",
+ "id": int(time.time() * 1000),
+ "method": method,
+ "params": params,
+ }
+
+ async with aiohttp.ClientSession() as http_session:
+ async with http_session.post(
+ url,
+ json=payload,
+ headers=headers,
+ timeout=aiohttp.ClientTimeout(total=10),
+ ) as resp:
+ if resp.status != 200:
+ error_text = await resp.text()
+ logger.error(f"Tools list failed: {resp.status} - {error_text}")
+ raise ValueError(
+ f"Failed to call MCP server: {resp.status} - {error_text}"
+ )
+
+ # Parse SSE stream
+ async for line in resp.content:
+ line_str = line.decode("utf-8").strip()
+
+ # SSE format: "data: {json}"
+ if line_str.startswith("data: "):
+ json_str = line_str[6:] # Remove "data: " prefix
+ try:
+ yield json.loads(json_str)
+ except json.JSONDecodeError:
+ continue # Skip invalid JSON lines
+
+
+async def handle_mcp_list(interaction: discord.Interaction) -> str:
+ """List all MCP servers for the user."""
+ with make_session() as session:
+ servers = (
+ session.query(DiscordMCPServer)
+ .filter(
+ DiscordMCPServer.discord_bot_user_id == interaction.user.id,
+ )
+ .all()
+ )
+
+ if not servers:
+ return (
+ "📋 **Your MCP Servers**\n\n"
+ "You don't have any MCP servers configured yet.\n"
+ "Use `/memory_mcp_servers add ` to add one."
+ )
+
+ def format_server(server: DiscordMCPServer) -> str:
+ con = "🟢" if cast(str | None, server.access_token) else "🔴"
+ return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
+
+ server_list = "\n".join(format_server(s) for s in servers)
+
+ return f"📋 **Your MCP Servers**\n\n{server_list}"
+
+
+async def handle_mcp_add(
+ interaction: discord.Interaction, url: str, name: str = "memory"
+) -> str:
+ """Add a new MCP server via OAuth."""
+ with make_session() as session:
+ if find_mcp_server(session, interaction.user.id, url):
+ return (
+ f"**MCP Server Already Exists**\n\n"
+ f"You already have an MCP server configured at `{url}`.\n"
+ f"Use `/memory_mcp_servers connect {url}` to reconnect."
+ )
+
+ endpoints = await get_endpoints(url)
+ client_id = await register_oauth_client(
+ endpoints,
+ url,
+ f"Discord Bot - {name} ({interaction.user.name})",
+ )
+ mcp_server = DiscordMCPServer(
+ discord_bot_user_id=interaction.user.id,
+ mcp_server_url=url,
+ client_id=client_id,
+ )
+ session.add(mcp_server)
+ session.flush()
+
+ auth_url = await issue_challenge(mcp_server, endpoints)
+ session.commit()
+
+ logger.info(
+ f"Created MCP server record: id={mcp_server.id}, "
+ f"user={interaction.user.id}, url={url}"
+ )
+
+ return (
+ f"🔐 **Add MCP Server**\n\n"
+ f"Server: `{url}`\n"
+ f"Click the link below to authorize:\n{auth_url}\n\n"
+ f"⚠️ Keep this link private!\n"
+ f"💡 You'll be redirected to login and grant access to the MCP server."
+ )
+
+
+async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
+ """Delete an MCP server."""
+ with make_session() as session:
+ mcp_server = find_mcp_server(session, interaction.user.id, url)
+ if not mcp_server:
+ return (
+ f"**MCP Server Not Found**\n\n"
+ f"You don't have an MCP server configured at `{url}`.\n"
+ )
+ session.delete(mcp_server)
+ session.commit()
+
+ return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
+
+
+async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
+ """Reconnect to an existing MCP server (redo OAuth)."""
+ with make_session() as session:
+ mcp_server = find_mcp_server(session, interaction.user.id, url)
+ if not mcp_server:
+ raise ValueError(
+ f"**MCP Server Not Found**\n\n"
+ f"You don't have an MCP server configured at `{url}`.\n"
+ f"Use `/memory_mcp_servers add {url}` to add it first."
+ )
+
+ if not mcp_server:
+ raise ValueError(
+ f"**MCP Server Not Found**\n\n"
+ f"You don't have an MCP server configured at `{url}`.\n"
+ f"Use `/memory_mcp_servers add {url}` to add it first."
+ )
+
+ endpoints = await get_endpoints(url)
+ auth_url = await issue_challenge(mcp_server, endpoints)
+
+ session.commit()
+
+ logger.info(
+ f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}"
+ )
+
+ return (
+ f"🔄 **Reconnect to MCP Server**\n\n"
+ f"Server: `{url}`\n"
+ f"Click the link below to reauthorize:\n{auth_url}\n\n"
+ f"⚠️ Keep this link private!\n"
+ f"💡 You'll be redirected to login and grant access to the MCP server again."
+ )
+
+
+async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
+ """List tools available on an MCP server."""
+ with make_session() as session:
+ mcp_server = find_mcp_server(session, interaction.user.id, url)
+
+ if not mcp_server:
+ raise ValueError(
+ f"**MCP Server Not Found**\n\n"
+ f"You don't have an MCP server configured at `{url}`.\n"
+ f"Use `/memory_mcp_servers add {url}` to add it first."
+ )
+
+ if not cast(str | None, mcp_server.access_token):
+ raise ValueError(
+ f"**Not Authorized**\n\n"
+ f"You haven't authorized access to `{url}` yet.\n"
+ f"Use `/memory_mcp_servers connect {url}` to authorize."
+ )
+
+ access_token = cast(str, mcp_server.access_token)
+
+ # Make JSON-RPC request to MCP server
+ tools = None
+ try:
+ async for data in call_mcp_server(url, access_token, "tools/list"):
+ if "result" in data and "tools" in data["result"]:
+ tools = data["result"]["tools"]
+ break
+ except aiohttp.ClientError as exc:
+ logger.exception(f"Failed to connect to MCP server: {exc}")
+ raise ValueError(
+ f"**Connection failed**\n\n"
+ f"Server: `{url}`\n"
+ f"Could not connect to the MCP server: {str(exc)}"
+ )
+ except Exception as exc:
+ logger.exception(f"Failed to list tools: {exc}")
+ raise ValueError(
+ f"**Error**\n\nServer: `{url}`\nFailed to list tools: {str(exc)}"
+ )
+
+ if tools is None:
+ raise ValueError(
+ f"**Unexpected response format**\n\n"
+ f"Server: `{url}`\n"
+ f"The server returned an unexpected response format."
+ )
+
+ if not tools:
+ return (
+ f"🔧 **MCP Server Tools**\n\n"
+ f"Server: `{url}`\n\n"
+ f"No tools available on this server."
+ )
+
+ # Format tools list
+ tools_list = "\n".join(
+ f"• **{t.get('name', 'unknown')}**: {t.get('description', 'No description')}"
+ for t in tools
+ )
+
+ return (
+ f"🔧 **MCP Server Tools**\n\n"
+ f"Server: `{url}`\n"
+ f"Found {len(tools)} tool(s):\n\n"
+ f"{tools_list}"
+ )
+
+
+async def run_mcp_server_command(
+ interaction: discord.Interaction,
+ action: Literal["list", "add", "delete", "connect", "tools"],
+ url: str | None,
+ name: str = "memory",
+) -> None:
+ """Handle MCP server management commands."""
+ if action not in ["list", "add", "delete", "connect", "tools"]:
+ await interaction.response.send_message("❌ Invalid action", ephemeral=True)
+ return
+ if action != "list" and not url:
+ await interaction.response.send_message(
+ "❌ URL is required for this action", ephemeral=True
+ )
+ return
+
+ try:
+ if action == "list" or not url:
+ result = await handle_mcp_list(interaction)
+ elif action == "add":
+ result = await handle_mcp_add(interaction, url, name)
+ elif action == "delete":
+ result = await handle_mcp_delete(interaction, url)
+ elif action == "connect":
+ result = await handle_mcp_connect(interaction, url)
+ elif action == "tools":
+ result = await handle_mcp_tools(interaction, url)
+ except Exception as exc:
+ result = f"❌ Error: {exc}"
+ await interaction.response.send_message(result, ephemeral=True)
diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py
index 75968be..fe3416b 100644
--- a/src/memory/discord/messages.py
+++ b/src/memory/discord/messages.py
@@ -141,8 +141,10 @@ def previous_messages(
) -> list[DiscordMessage]:
messages = session.query(DiscordMessage)
if user_id:
+ print(f"user_id: {user_id}")
messages = messages.filter(DiscordMessage.recipient_id == user_id)
if channel_id:
+ print(f"channel_id: {channel_id}")
messages = messages.filter(DiscordMessage.channel_id == channel_id)
return list(
reversed(
diff --git a/src/memory/discord/oauth.py b/src/memory/discord/oauth.py
new file mode 100644
index 0000000..a876598
--- /dev/null
+++ b/src/memory/discord/oauth.py
@@ -0,0 +1,268 @@
+import hashlib
+import logging
+import secrets
+from base64 import urlsafe_b64encode
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+from urllib.parse import urlencode, urljoin
+
+import aiohttp
+from sqlalchemy.orm import Session, scoped_session
+from memory.common import settings
+from memory.common.db.connection import make_session
+from memory.common.db.models.discord import DiscordMCPServer
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class OAuthEndpoints:
+ authorization_endpoint: str
+ registration_endpoint: str
+ token_endpoint: str
+ redirect_uri: str
+
+
+def generate_pkce_pair() -> tuple[str, str]:
+ """Generate PKCE code verifier and challenge.
+
+ Returns:
+ Tuple of (code_verifier, code_challenge)
+ """
+ # Generate a random code verifier
+ code_verifier = (
+ urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
+ )
+
+ # Create code challenge using S256 method
+ challenge_bytes = hashlib.sha256(code_verifier.encode("utf-8")).digest()
+ code_challenge = urlsafe_b64encode(challenge_bytes).decode("utf-8").rstrip("=")
+
+ return code_verifier, code_challenge
+
+
+async def discover_oauth_metadata(server_url: str) -> dict | None:
+ """Discover OAuth metadata from an MCP server."""
+ # Try the standard OAuth discovery endpoint
+ discovery_url = urljoin(server_url, "/.well-known/oauth-authorization-server")
+
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ discovery_url, timeout=aiohttp.ClientTimeout(total=5)
+ ) as resp:
+ if resp.status == 200:
+ return await resp.json()
+ except Exception as exc:
+ logger.warning(f"Failed to discover OAuth metadata from {discovery_url}: {exc}")
+
+ return None
+
+
+async def get_endpoints(url: str) -> OAuthEndpoints:
+ # Discover OAuth endpoints from the target server
+ oauth_metadata = await discover_oauth_metadata(url)
+
+ if not oauth_metadata:
+ raise ValueError(
+ "**Failed to connect to MCP server**\n\n"
+ f"Could not discover OAuth endpoints at `{url}`\n"
+ "Make sure the server is running and supports OAuth 2.0.",
+ )
+
+ authorization_endpoint = oauth_metadata.get("authorization_endpoint")
+ registration_endpoint = oauth_metadata.get("registration_endpoint")
+ token_endpoint = oauth_metadata.get("token_endpoint")
+
+ if not authorization_endpoint:
+ raise ValueError(
+ "**Invalid OAuth configuration**\n\n"
+ f"Server `{url}` did not provide an authorization endpoint.",
+ )
+
+ if not registration_endpoint:
+ raise ValueError(
+ "**Invalid OAuth configuration**\n\n"
+ f"Server `{url}` does not support dynamic client registration.",
+ )
+
+ if not token_endpoint:
+ raise ValueError(
+ "**Invalid OAuth configuration**\n\n"
+ f"Server `{url}` does not provide a token endpoint.",
+ )
+
+ logger.info(f"Authorization endpoint: {authorization_endpoint}")
+ logger.info(f"Registration endpoint: {registration_endpoint}")
+
+ return OAuthEndpoints(
+ authorization_endpoint=authorization_endpoint,
+ registration_endpoint=registration_endpoint,
+ token_endpoint=token_endpoint,
+ redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord",
+ )
+
+
+async def register_oauth_client(
+ endpoints: OAuthEndpoints,
+ url: str,
+ client_name: str,
+) -> None:
+ """Register OAuth client and store client_id in the mcp_server object."""
+ client_metadata = {
+ "client_name": client_name,
+ "redirect_uris": [endpoints.redirect_uri],
+ "grant_types": ["authorization_code", "refresh_token"],
+ "response_types": ["code"],
+ "scope": "read write",
+ "token_endpoint_auth_method": "none",
+ }
+
+ logger.error(f"Registration metadata: {client_metadata}")
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ endpoints.registration_endpoint,
+ json=client_metadata,
+ timeout=aiohttp.ClientTimeout(total=5),
+ ) as resp:
+ logger.error(
+ f"Registration response: {resp.status} {await resp.text()}"
+ )
+ resp.raise_for_status()
+ client_info = await resp.json()
+ except Exception as exc:
+ raise ValueError(
+ f"Failed to register OAuth client at {endpoints.registration_endpoint}: {exc}"
+ )
+
+ if not client_info or "client_id" not in client_info:
+ raise ValueError(
+ "**Failed to register OAuth client**\n\n"
+ f"Could not register with the MCP server at `{url}`\n"
+ f"Check the server logs for more details.",
+ )
+
+ client_id = client_info["client_id"]
+
+ logger.info(f"Registered OAuth client: {client_id}")
+ return client_id
+
+
+async def issue_challenge(
+ mcp_server: DiscordMCPServer,
+ endpoints: OAuthEndpoints,
+) -> str:
+ """Generate OAuth challenge and store state in mcp_server object."""
+ code_verifier, code_challenge = generate_pkce_pair()
+ state = secrets.token_urlsafe(32)
+
+ # Store in mcp_server object
+ mcp_server.state = state # type: ignore
+ mcp_server.code_verifier = code_verifier # type: ignore
+
+ logger.info(
+ f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: "
+ f"state={state[:20]}..., verifier={code_verifier[:20]}..."
+ )
+
+ # Build authorization URL pointing to the target server
+ auth_params = {
+ "client_id": mcp_server.client_id,
+ "redirect_uri": endpoints.redirect_uri,
+ "response_type": "code",
+ "state": state,
+ "code_challenge": code_challenge,
+ "code_challenge_method": "S256",
+ "scope": "read write",
+ }
+
+ return f"{endpoints.authorization_endpoint}?{urlencode(auth_params)}"
+
+
+async def complete_oauth_flow(
+ session: Session | scoped_session, code: str, state: str
+) -> tuple[int, str]:
+ """Complete OAuth flow by exchanging code for token.
+
+ Args:
+ code: Authorization code from OAuth callback
+ state: State parameter from OAuth callback
+
+ Returns:
+ Tuple of (status_code, html_message) for the callback response
+ """
+ try:
+ mcp_server = (
+ session.query(DiscordMCPServer)
+ .filter(DiscordMCPServer.state == state)
+ .first()
+ )
+
+ if not mcp_server:
+ logger.error(f"Invalid or expired state: {state[:20]}...")
+ return 400, "Invalid or expired OAuth state"
+
+ logger.info(
+ f"Found MCP server config: user={mcp_server.discord_bot_user_id}, "
+ f"url={mcp_server.mcp_server_url}"
+ )
+
+ # Get OAuth endpoints
+ try:
+ endpoints = await get_endpoints(str(mcp_server.mcp_server_url))
+ except Exception as exc:
+ return 500, f"Failed to get OAuth endpoints: {str(exc)}"
+
+ # Exchange authorization code for access token
+ token_data = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": endpoints.redirect_uri,
+ "client_id": mcp_server.client_id,
+ "code_verifier": mcp_server.code_verifier,
+ }
+
+ async with aiohttp.ClientSession() as http_session:
+ async with http_session.post(
+ endpoints.token_endpoint,
+ data=token_data,
+ timeout=aiohttp.ClientTimeout(total=10),
+ ) as resp:
+ if resp.status != 200:
+ error_text = await resp.text()
+ logger.error(f"Token exchange failed: {resp.status} - {error_text}")
+ return 500, f"Token exchange failed: {error_text}"
+
+ tokens = await resp.json()
+
+ access_token = tokens.get("access_token")
+ refresh_token = tokens.get("refresh_token")
+ expires_in = tokens.get("expires_in", 3600)
+
+ if not access_token:
+ return 500, "Token response did not include access_token"
+
+ logger.info(f"Successfully obtained access token: {access_token[:20]}...")
+
+ # Store tokens and clear temporary OAuth state
+ mcp_server.access_token = access_token # type: ignore
+ mcp_server.refresh_token = refresh_token # type: ignore
+ mcp_server.token_expires_at = datetime.now() + timedelta(seconds=expires_in) # type: ignore
+
+ # Clear temporary OAuth flow data
+ mcp_server.state = None # type: ignore
+ mcp_server.code_verifier = None # type: ignore
+
+ session.commit()
+
+ logger.info(
+ f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
+ f"server {mcp_server.mcp_server_url}"
+ )
+
+ return 200, "✅ Authorization successful! You can now use this MCP server."
+
+ except Exception as exc:
+ logger.exception(f"Failed to complete OAuth flow: {exc}")
+ return 500, f"Failed to complete OAuth flow: {str(exc)}"