diff --git a/db/migrations/versions/20251102_220426_discord_mcp_servers.py b/db/migrations/versions/20251102_220426_discord_mcp_servers.py new file mode 100644 index 0000000..deb35d5 --- /dev/null +++ b/db/migrations/versions/20251102_220426_discord_mcp_servers.py @@ -0,0 +1,67 @@ +"""discord mcp servers + +Revision ID: 9b887449ea92 +Revises: 1954477b25f4 +Create Date: 2025-11-02 22:04:26.259323 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "9b887449ea92" +down_revision: Union[str, None] = "1954477b25f4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "discord_mcp_servers", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False), + sa.Column("mcp_server_url", sa.Text(), nullable=False), + sa.Column("client_id", sa.Text(), nullable=False), + sa.Column("state", sa.Text(), nullable=True), + sa.Column("code_verifier", sa.Text(), nullable=True), + sa.Column("access_token", sa.Text(), nullable=True), + sa.Column("refresh_token", sa.Text(), nullable=True), + sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["discord_bot_user_id"], + ["discord_users.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("state"), + ) + op.create_index( + "discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False + ) + op.create_index( + "discord_mcp_user_url_idx", + "discord_mcp_servers", + ["discord_bot_user_id", "mcp_server_url"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers") + op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers") + op.drop_table("discord_mcp_servers") diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py index 140fa62..1198d8c 100644 --- a/src/memory/common/db/models/__init__.py +++ b/src/memory/common/db/models/__init__.py @@ -34,6 +34,7 @@ from memory.common.db.models.discord import ( DiscordServer, DiscordChannel, DiscordUser, + DiscordMCPServer, ) from memory.common.db.models.observations import ( ObservationContradiction, @@ -106,6 +107,7 @@ __all__ = [ "DiscordServer", "DiscordChannel", "DiscordUser", + "DiscordMCPServer", # Users "User", "HumanUser", diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py index 3775ca1..638deec 100644 --- a/src/memory/common/db/models/discord.py +++ b/src/memory/common/db/models/discord.py @@ -127,5 +127,44 @@ class DiscordUser(Base, MessageProcessor): updated_at = Column(DateTime(timezone=True), server_default=func.now()) system_user = relationship("User", back_populates="discord_users") + mcp_servers = relationship( + "DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan" + ) __table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),) + + +class DiscordMCPServer(Base): + """MCP server configuration and OAuth state for Discord users.""" + + __tablename__ = "discord_mcp_servers" + + id = Column(Integer, primary_key=True) + discord_bot_user_id = Column( + BigInteger, ForeignKey("discord_users.id"), nullable=False + ) + + # MCP server info + mcp_server_url = Column(Text, nullable=False) + client_id = Column(Text, nullable=False) + + # OAuth flow state (temporary, cleared after token exchange) + state = Column(Text, nullable=True, unique=True) + code_verifier = Column(Text, nullable=True) + + # OAuth tokens (set after successful authorization) + access_token = Column(Text, nullable=True) + refresh_token = Column(Text, nullable=True) + token_expires_at = Column(DateTime(timezone=True), nullable=True) + + # Timestamps + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + # Relationships + discord_user = relationship("DiscordUser", back_populates="mcp_servers") + + __table_args__ = ( + Index("discord_mcp_state_idx", "state"), + Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"), + ) diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index 3c011a8..827b1be 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -363,7 +363,19 @@ class DiscordMessage(SourceItem): @property def title(self) -> str: - return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}" + return textwrap.dedent(""" + + {message_id} + {from_user} + {sent_at} + {content} + + """).format( + message_id=self.message_id, + from_user=self.from_user.username, + sent_at=self.sent_at.isoformat()[:19], + content=self.content, + ) def as_content(self) -> dict[str, Any]: """Return message content ready for LLM (text + images from disk).""" diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py index 6eab35b..a324276 100644 --- a/src/memory/common/discord.py +++ b/src/memory/common/discord.py @@ -94,6 +94,30 @@ def trigger_typing_channel(bot_id: int, channel: int | str) -> bool: return False +def add_reaction(bot_id: int, channel: int | str, message_id: int, emoji: str) -> bool: + """Add a reaction to a message in a channel""" + try: + response = requests.post( + f"{get_api_url()}/add_reaction", + json={ + "bot_id": bot_id, + "channel": channel, + "message_id": message_id, + "emoji": emoji, + }, + timeout=10, + ) + response.raise_for_status() + result = response.json() + return result.get("success", False) + + except requests.RequestException as e: + logger.error( + f"Failed to add reaction {emoji} to message {message_id} in channel {channel}: {e}" + ) + return False + + def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool: """Send a message to a channel by name or ID (ID supports threads)""" try: diff --git a/src/memory/common/llms/tools/discord.py b/src/memory/common/llms/tools/discord.py index 8492be3..d334bb5 100644 --- a/src/memory/common/llms/tools/discord.py +++ b/src/memory/common/llms/tools/discord.py @@ -17,6 +17,7 @@ from memory.common.db.models import ( BotUser, ) from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler +from memory.common.discord import add_reaction UpdateSummaryType = Literal["server", "channel", "user"] @@ -209,6 +210,50 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini ) +def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition: + bot_id = cast(int, bot.id) + channel_id = channel and channel.id + + def handler(input: ToolInput) -> str: + if not isinstance(input, dict): + raise ValueError("Input must be a dictionary") + try: + emoji = input.get("emoji") + except ValueError: + raise ValueError("Emoji is required") + if not emoji: + raise ValueError("Emoji is required") + + try: + message_id = int(input.get("message_id") or "no id") + except ValueError: + raise ValueError("Message ID is required") + + success = add_reaction(bot_id, channel_id, message_id, emoji) + if not success: + return "Failed to add reaction" + return "Reaction added" + + return ToolDefinition( + name="add_reaction", + description="Add a reaction to a message in a channel", + input_schema={ + "type": "object", + "properties": { + "message_id": { + "type": "number", + "description": "The ID of the message to add the reaction to", + }, + "emoji": { + "type": "string", + "description": "The emoji to add to the message", + }, + }, + }, + function=handler, + ) + + def make_discord_tools( bot: BotUser, author: DiscordUser | None, @@ -227,5 +272,6 @@ def make_discord_tools( if channel and channel.server: tools += [ make_summary_tool("server", cast(BigInteger, channel.server_id)), + make_add_reaction_tool(bot, channel), ] return {tool.name: tool for tool in tools} diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py index bb43f0f..451b1ac 100644 --- a/src/memory/discord/api.py +++ b/src/memory/discord/api.py @@ -12,13 +12,15 @@ from contextlib import asynccontextmanager from typing import cast import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import HTMLResponse from pydantic import BaseModel from memory.common import settings from memory.common.db.connection import make_session from memory.common.db.models.users import DiscordBotUser from memory.discord.collector import MessageCollector +from memory.discord.oauth import complete_oauth_flow logger = logging.getLogger(__name__) @@ -45,6 +47,13 @@ class TypingChannelRequest(BaseModel): channel: int | str # Channel name or ID (ID supports threads) +class AddReactionRequest(BaseModel): + bot_id: int + channel: int | str # Channel name or ID (ID supports threads) + message_id: int + emoji: str + + class Collector: collector: MessageCollector collector_task: asyncio.Task @@ -53,6 +62,7 @@ class Collector: bot_name: str def __init__(self, collector: MessageCollector, bot: DiscordBotUser): + logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}") self.collector = collector self.collector_task = asyncio.create_task(collector.start(str(bot.api_key))) self.bot_id = cast(int, bot.id) @@ -72,7 +82,9 @@ async def lifespan(app: FastAPI): bots = session.query(DiscordBotUser).all() app.bots = {bot.id: make_collector(bot) for bot in bots} - logger.info(f"Discord collectors started for {len(app.bots)} bots") + logger.error( + f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}" + ) yield @@ -216,6 +228,36 @@ async def health_check(): } +@app.post("/add_reaction") +async def add_reaction_endpoint(request: AddReactionRequest): + """Add a reaction to a message via the collector's Discord client""" + collector = app.bots.get(request.bot_id) + if not collector: + raise HTTPException(status_code=404, detail="Bot not found") + + try: + success = await collector.collector.add_reaction( + request.channel, request.message_id, request.emoji + ) + except Exception as e: + logger.error(f"Failed to add reaction: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + if not success: + raise HTTPException( + status_code=400, + detail=f"Failed to add reaction to message {request.message_id}", + ) + + return { + "success": True, + "channel": request.channel, + "message_id": request.message_id, + "emoji": request.emoji, + "message": f"Added reaction {request.emoji} to message {request.message_id}", + } + + @app.post("/refresh_metadata") async def refresh_metadata(): """Refresh Discord server/channel/user metadata from Discord API""" @@ -237,6 +279,51 @@ async def refresh_metadata(): raise HTTPException(status_code=500, detail=str(e)) +@app.get("/oauth/callback/discord", response_class=HTMLResponse) +async def oauth_callback(request: Request): + """Handle OAuth callback from MCP server after user authorization.""" + code = request.query_params.get("code") + state = request.query_params.get("state") + error = request.query_params.get("error") + + logger.info( + f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..." + ) + + message, title, close, status_code = "", "", "", 200 + if error: + logger.error(f"OAuth error: {error}") + message = f"Error: {error}" + title = "❌ Authorization Failed" + status_code = 400 + elif not code or not state: + message = "Missing authorization code or state parameter." + title = "❌ Invalid Request" + status_code = 400 + else: + # Complete the OAuth flow (exchange code for token) + with make_session() as session: + status_code, message = await complete_oauth_flow(session, code, state) + if 200 <= status_code < 300: + title = "✅ Authorization Successful!" + close = "You can close this window and return to the MCP server." + else: + title = "❌ Authorization Failed" + + return HTMLResponse( + content=f""" + + +

{title}

+

{message}

+

{close}

+ + + """, + status_code=status_code, + ) + + def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001): """Run the Discord API server""" uvicorn.run(app, host=host, port=port, log_level="debug") diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py index bb5b6c6..852a892 100644 --- a/src/memory/discord/collector.py +++ b/src/memory/discord/collector.py @@ -229,11 +229,21 @@ class MessageCollector(commands.Bot): intents=intents, help_command=None, # Disable default help ) + logger.info(f"Initialized collector for {self.user}") async def setup_hook(self): """Register slash commands when the bot is ready.""" - register_slash_commands(self, name=self.user.name) + if not (name := self.user.name): + logger.error(f"Failed to get user name for {self.user}") + return + + name = name.replace("-", "_").lower() + try: + register_slash_commands(self, name=name) + except Exception as e: + logger.error(f"Failed to register slash commands for {self.user.name}: {e}") + logger.error(f"Registered slash commands for {self.user.name}") async def on_ready(self): """Called when bot connects to Discord""" @@ -313,8 +323,6 @@ class MessageCollector(commands.Bot): async def refresh_metadata(self) -> dict[str, int]: """Refresh server and channel metadata from Discord and update database""" - print("🔄 Refreshing Discord metadata...") - servers_updated = 0 channels_updated = 0 users_updated = 0 @@ -454,25 +462,33 @@ class MessageCollector(commands.Bot): logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}") return False + async def _get_channel( + self, channel_identifier: int | str, check_notifications: bool = True + ): + """Get channel by ID or name with standard checks""" + if check_notifications and not settings.DISCORD_NOTIFICATIONS_ENABLED: + logger.debug("Discord notifications disabled") + return None + + if isinstance(channel_identifier, int): + channel = self.get_channel(channel_identifier) + else: + channel = await self.get_channel_by_name(channel_identifier) + + if not channel: + logger.error(f"Channel {channel_identifier} not found") + + return channel + async def send_to_channel( self, channel_identifier: int | str, message: str ) -> bool: """Send a message to a channel by name or ID (supports threads)""" - if not settings.DISCORD_NOTIFICATIONS_ENABLED: - return False - try: - # Get channel by ID or name - if isinstance(channel_identifier, int): - channel = self.get_channel(channel_identifier) - else: - channel = await self.get_channel_by_name(channel_identifier) - + channel = await self._get_channel(channel_identifier) if not channel: - logger.error(f"Channel {channel_identifier} not found") return False - # Post-process mentions to convert usernames to IDs with make_session() as session: processed_message = process_mentions(session, message) @@ -486,18 +502,9 @@ class MessageCollector(commands.Bot): async def trigger_typing_channel(self, channel_identifier: int | str) -> bool: """Trigger typing indicator in a channel by name or ID (supports threads)""" - if not settings.DISCORD_NOTIFICATIONS_ENABLED: - return False - try: - # Get channel by ID or name - if isinstance(channel_identifier, int): - channel = self.get_channel(channel_identifier) - else: - channel = await self.get_channel_by_name(channel_identifier) - + channel = await self._get_channel(channel_identifier) if not channel: - logger.error(f"Channel {channel_identifier} not found") return False async with channel.typing(): @@ -509,3 +516,21 @@ class MessageCollector(commands.Bot): f"Failed to trigger typing for channel {channel_identifier}: {e}" ) return False + + async def add_reaction( + self, channel_identifier: int | str, message_id: int, emoji: str + ) -> bool: + """Add a reaction to a message in a channel""" + try: + channel = await self._get_channel(channel_identifier) + if not channel: + return False + + message = await channel.fetch_message(message_id) + await message.add_reaction(emoji) + logger.info(f"Added reaction {emoji} to message {message_id}") + return True + + except Exception as e: + logger.error(f"Failed to add reaction: {e}") + return False diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py index bd3a234..82c15ce 100644 --- a/src/memory/discord/commands.py +++ b/src/memory/discord/commands.py @@ -1,7 +1,7 @@ """Lightweight slash-command helpers for the Discord collector.""" -from __future__ import annotations - +from calendar import c +import logging from dataclasses import dataclass from typing import Callable, Literal @@ -10,6 +10,9 @@ from sqlalchemy.orm import Session from memory.common.db.connection import make_session from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser +from memory.discord.mcp import run_mcp_server_command + +logger = logging.getLogger(__name__) ScopeLiteral = Literal["server", "channel", "user"] @@ -50,6 +53,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None: """ if getattr(bot, "_memory_commands_registered", False): + logger.error(f"Slash commands already registered for {name}") return setattr(bot, "_memory_commands_registered", True) @@ -167,6 +171,21 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None: target_user=user, ) + @tree.command( + name=f"{name}_mcp_servers", + description="Manage MCP servers for your account", + ) + @discord.app_commands.describe( + action="Action to perform", + url="MCP server URL (required for add, delete, connect, tools)", + ) + async def mcp_servers_command( + interaction: discord.Interaction, + action: Literal["list", "add", "delete", "connect", "tools"] = "list", + url: str | None = None, + ) -> None: + await run_mcp_server_command(interaction, action, url and url.strip(), name) + async def _run_interaction_command( interaction: discord.Interaction, @@ -177,17 +196,10 @@ async def _run_interaction_command( **handler_kwargs, ) -> None: """Shared coroutine used by the registered slash commands.""" - try: with make_session() as session: - response = run_command( - session, - interaction, - scope, - handler=handler, - target_user=target_user, - **handler_kwargs, - ) + context = _build_context(session, interaction, scope, target_user) + response = handler(context, **handler_kwargs) session.commit() except CommandError as exc: # pragma: no cover - passthrough await interaction.response.send_message(str(exc), ephemeral=True) @@ -199,21 +211,6 @@ async def _run_interaction_command( ) -def run_command( - session: Session, - interaction: discord.Interaction, - scope: ScopeLiteral, - *, - handler: CommandHandler, - target_user: discord.User | None = None, - **handler_kwargs, -) -> CommandResponse: - """Create a :class:`CommandContext` and execute the handler.""" - - context = _build_context(session, interaction, scope, target_user) - return handler(context, **handler_kwargs) - - def _build_context( session: Session, interaction: discord.Interaction, diff --git a/src/memory/discord/mcp.py b/src/memory/discord/mcp.py new file mode 100644 index 0000000..e2f7e9d --- /dev/null +++ b/src/memory/discord/mcp.py @@ -0,0 +1,294 @@ +"""Lightweight slash-command helpers for the Discord collector.""" + +import json +import logging +import time +from typing import Any, AsyncGenerator, Literal, cast + +import aiohttp +import discord +from sqlalchemy.orm import Session, scoped_session + +from memory.common.db.connection import make_session +from memory.common.db.models.discord import DiscordMCPServer +from memory.discord.oauth import get_endpoints, issue_challenge, register_oauth_client + +logger = logging.getLogger(__name__) + + +def find_mcp_server( + session: Session | scoped_session, user_id: int, url: str +) -> DiscordMCPServer | None: + return ( + session.query(DiscordMCPServer) + .filter( + DiscordMCPServer.discord_bot_user_id == user_id, + DiscordMCPServer.mcp_server_url == url, + ) + .first() + ) + + +async def call_mcp_server( + url: str, access_token: str, method: str, params: dict = {} +) -> AsyncGenerator[Any, None]: + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "Authorization": f"Bearer {access_token}", + } + + payload = { + "jsonrpc": "2.0", + "id": int(time.time() * 1000), + "method": method, + "params": params, + } + + async with aiohttp.ClientSession() as http_session: + async with http_session.post( + url, + json=payload, + headers=headers, + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + if resp.status != 200: + error_text = await resp.text() + logger.error(f"Tools list failed: {resp.status} - {error_text}") + raise ValueError( + f"Failed to call MCP server: {resp.status} - {error_text}" + ) + + # Parse SSE stream + async for line in resp.content: + line_str = line.decode("utf-8").strip() + + # SSE format: "data: {json}" + if line_str.startswith("data: "): + json_str = line_str[6:] # Remove "data: " prefix + try: + yield json.loads(json_str) + except json.JSONDecodeError: + continue # Skip invalid JSON lines + + +async def handle_mcp_list(interaction: discord.Interaction) -> str: + """List all MCP servers for the user.""" + with make_session() as session: + servers = ( + session.query(DiscordMCPServer) + .filter( + DiscordMCPServer.discord_bot_user_id == interaction.user.id, + ) + .all() + ) + + if not servers: + return ( + "📋 **Your MCP Servers**\n\n" + "You don't have any MCP servers configured yet.\n" + "Use `/memory_mcp_servers add ` to add one." + ) + + def format_server(server: DiscordMCPServer) -> str: + con = "🟢" if cast(str | None, server.access_token) else "🔴" + return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`" + + server_list = "\n".join(format_server(s) for s in servers) + + return f"📋 **Your MCP Servers**\n\n{server_list}" + + +async def handle_mcp_add( + interaction: discord.Interaction, url: str, name: str = "memory" +) -> str: + """Add a new MCP server via OAuth.""" + with make_session() as session: + if find_mcp_server(session, interaction.user.id, url): + return ( + f"**MCP Server Already Exists**\n\n" + f"You already have an MCP server configured at `{url}`.\n" + f"Use `/memory_mcp_servers connect {url}` to reconnect." + ) + + endpoints = await get_endpoints(url) + client_id = await register_oauth_client( + endpoints, + url, + f"Discord Bot - {name} ({interaction.user.name})", + ) + mcp_server = DiscordMCPServer( + discord_bot_user_id=interaction.user.id, + mcp_server_url=url, + client_id=client_id, + ) + session.add(mcp_server) + session.flush() + + auth_url = await issue_challenge(mcp_server, endpoints) + session.commit() + + logger.info( + f"Created MCP server record: id={mcp_server.id}, " + f"user={interaction.user.id}, url={url}" + ) + + return ( + f"🔐 **Add MCP Server**\n\n" + f"Server: `{url}`\n" + f"Click the link below to authorize:\n{auth_url}\n\n" + f"⚠️ Keep this link private!\n" + f"💡 You'll be redirected to login and grant access to the MCP server." + ) + + +async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str: + """Delete an MCP server.""" + with make_session() as session: + mcp_server = find_mcp_server(session, interaction.user.id, url) + if not mcp_server: + return ( + f"**MCP Server Not Found**\n\n" + f"You don't have an MCP server configured at `{url}`.\n" + ) + session.delete(mcp_server) + session.commit() + + return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed." + + +async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str: + """Reconnect to an existing MCP server (redo OAuth).""" + with make_session() as session: + mcp_server = find_mcp_server(session, interaction.user.id, url) + if not mcp_server: + raise ValueError( + f"**MCP Server Not Found**\n\n" + f"You don't have an MCP server configured at `{url}`.\n" + f"Use `/memory_mcp_servers add {url}` to add it first." + ) + + if not mcp_server: + raise ValueError( + f"**MCP Server Not Found**\n\n" + f"You don't have an MCP server configured at `{url}`.\n" + f"Use `/memory_mcp_servers add {url}` to add it first." + ) + + endpoints = await get_endpoints(url) + auth_url = await issue_challenge(mcp_server, endpoints) + + session.commit() + + logger.info( + f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}" + ) + + return ( + f"🔄 **Reconnect to MCP Server**\n\n" + f"Server: `{url}`\n" + f"Click the link below to reauthorize:\n{auth_url}\n\n" + f"⚠️ Keep this link private!\n" + f"💡 You'll be redirected to login and grant access to the MCP server again." + ) + + +async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str: + """List tools available on an MCP server.""" + with make_session() as session: + mcp_server = find_mcp_server(session, interaction.user.id, url) + + if not mcp_server: + raise ValueError( + f"**MCP Server Not Found**\n\n" + f"You don't have an MCP server configured at `{url}`.\n" + f"Use `/memory_mcp_servers add {url}` to add it first." + ) + + if not cast(str | None, mcp_server.access_token): + raise ValueError( + f"**Not Authorized**\n\n" + f"You haven't authorized access to `{url}` yet.\n" + f"Use `/memory_mcp_servers connect {url}` to authorize." + ) + + access_token = cast(str, mcp_server.access_token) + + # Make JSON-RPC request to MCP server + tools = None + try: + async for data in call_mcp_server(url, access_token, "tools/list"): + if "result" in data and "tools" in data["result"]: + tools = data["result"]["tools"] + break + except aiohttp.ClientError as exc: + logger.exception(f"Failed to connect to MCP server: {exc}") + raise ValueError( + f"**Connection failed**\n\n" + f"Server: `{url}`\n" + f"Could not connect to the MCP server: {str(exc)}" + ) + except Exception as exc: + logger.exception(f"Failed to list tools: {exc}") + raise ValueError( + f"**Error**\n\nServer: `{url}`\nFailed to list tools: {str(exc)}" + ) + + if tools is None: + raise ValueError( + f"**Unexpected response format**\n\n" + f"Server: `{url}`\n" + f"The server returned an unexpected response format." + ) + + if not tools: + return ( + f"🔧 **MCP Server Tools**\n\n" + f"Server: `{url}`\n\n" + f"No tools available on this server." + ) + + # Format tools list + tools_list = "\n".join( + f"• **{t.get('name', 'unknown')}**: {t.get('description', 'No description')}" + for t in tools + ) + + return ( + f"🔧 **MCP Server Tools**\n\n" + f"Server: `{url}`\n" + f"Found {len(tools)} tool(s):\n\n" + f"{tools_list}" + ) + + +async def run_mcp_server_command( + interaction: discord.Interaction, + action: Literal["list", "add", "delete", "connect", "tools"], + url: str | None, + name: str = "memory", +) -> None: + """Handle MCP server management commands.""" + if action not in ["list", "add", "delete", "connect", "tools"]: + await interaction.response.send_message("❌ Invalid action", ephemeral=True) + return + if action != "list" and not url: + await interaction.response.send_message( + "❌ URL is required for this action", ephemeral=True + ) + return + + try: + if action == "list" or not url: + result = await handle_mcp_list(interaction) + elif action == "add": + result = await handle_mcp_add(interaction, url, name) + elif action == "delete": + result = await handle_mcp_delete(interaction, url) + elif action == "connect": + result = await handle_mcp_connect(interaction, url) + elif action == "tools": + result = await handle_mcp_tools(interaction, url) + except Exception as exc: + result = f"❌ Error: {exc}" + await interaction.response.send_message(result, ephemeral=True) diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py index 75968be..fe3416b 100644 --- a/src/memory/discord/messages.py +++ b/src/memory/discord/messages.py @@ -141,8 +141,10 @@ def previous_messages( ) -> list[DiscordMessage]: messages = session.query(DiscordMessage) if user_id: + print(f"user_id: {user_id}") messages = messages.filter(DiscordMessage.recipient_id == user_id) if channel_id: + print(f"channel_id: {channel_id}") messages = messages.filter(DiscordMessage.channel_id == channel_id) return list( reversed( diff --git a/src/memory/discord/oauth.py b/src/memory/discord/oauth.py new file mode 100644 index 0000000..a876598 --- /dev/null +++ b/src/memory/discord/oauth.py @@ -0,0 +1,268 @@ +import hashlib +import logging +import secrets +from base64 import urlsafe_b64encode +from dataclasses import dataclass +from datetime import datetime, timedelta +from urllib.parse import urlencode, urljoin + +import aiohttp +from sqlalchemy.orm import Session, scoped_session +from memory.common import settings +from memory.common.db.connection import make_session +from memory.common.db.models.discord import DiscordMCPServer + +logger = logging.getLogger(__name__) + + +@dataclass +class OAuthEndpoints: + authorization_endpoint: str + registration_endpoint: str + token_endpoint: str + redirect_uri: str + + +def generate_pkce_pair() -> tuple[str, str]: + """Generate PKCE code verifier and challenge. + + Returns: + Tuple of (code_verifier, code_challenge) + """ + # Generate a random code verifier + code_verifier = ( + urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") + ) + + # Create code challenge using S256 method + challenge_bytes = hashlib.sha256(code_verifier.encode("utf-8")).digest() + code_challenge = urlsafe_b64encode(challenge_bytes).decode("utf-8").rstrip("=") + + return code_verifier, code_challenge + + +async def discover_oauth_metadata(server_url: str) -> dict | None: + """Discover OAuth metadata from an MCP server.""" + # Try the standard OAuth discovery endpoint + discovery_url = urljoin(server_url, "/.well-known/oauth-authorization-server") + + try: + async with aiohttp.ClientSession() as session: + async with session.get( + discovery_url, timeout=aiohttp.ClientTimeout(total=5) + ) as resp: + if resp.status == 200: + return await resp.json() + except Exception as exc: + logger.warning(f"Failed to discover OAuth metadata from {discovery_url}: {exc}") + + return None + + +async def get_endpoints(url: str) -> OAuthEndpoints: + # Discover OAuth endpoints from the target server + oauth_metadata = await discover_oauth_metadata(url) + + if not oauth_metadata: + raise ValueError( + "**Failed to connect to MCP server**\n\n" + f"Could not discover OAuth endpoints at `{url}`\n" + "Make sure the server is running and supports OAuth 2.0.", + ) + + authorization_endpoint = oauth_metadata.get("authorization_endpoint") + registration_endpoint = oauth_metadata.get("registration_endpoint") + token_endpoint = oauth_metadata.get("token_endpoint") + + if not authorization_endpoint: + raise ValueError( + "**Invalid OAuth configuration**\n\n" + f"Server `{url}` did not provide an authorization endpoint.", + ) + + if not registration_endpoint: + raise ValueError( + "**Invalid OAuth configuration**\n\n" + f"Server `{url}` does not support dynamic client registration.", + ) + + if not token_endpoint: + raise ValueError( + "**Invalid OAuth configuration**\n\n" + f"Server `{url}` does not provide a token endpoint.", + ) + + logger.info(f"Authorization endpoint: {authorization_endpoint}") + logger.info(f"Registration endpoint: {registration_endpoint}") + + return OAuthEndpoints( + authorization_endpoint=authorization_endpoint, + registration_endpoint=registration_endpoint, + token_endpoint=token_endpoint, + redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord", + ) + + +async def register_oauth_client( + endpoints: OAuthEndpoints, + url: str, + client_name: str, +) -> None: + """Register OAuth client and store client_id in the mcp_server object.""" + client_metadata = { + "client_name": client_name, + "redirect_uris": [endpoints.redirect_uri], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scope": "read write", + "token_endpoint_auth_method": "none", + } + + logger.error(f"Registration metadata: {client_metadata}") + try: + async with aiohttp.ClientSession() as session: + async with session.post( + endpoints.registration_endpoint, + json=client_metadata, + timeout=aiohttp.ClientTimeout(total=5), + ) as resp: + logger.error( + f"Registration response: {resp.status} {await resp.text()}" + ) + resp.raise_for_status() + client_info = await resp.json() + except Exception as exc: + raise ValueError( + f"Failed to register OAuth client at {endpoints.registration_endpoint}: {exc}" + ) + + if not client_info or "client_id" not in client_info: + raise ValueError( + "**Failed to register OAuth client**\n\n" + f"Could not register with the MCP server at `{url}`\n" + f"Check the server logs for more details.", + ) + + client_id = client_info["client_id"] + + logger.info(f"Registered OAuth client: {client_id}") + return client_id + + +async def issue_challenge( + mcp_server: DiscordMCPServer, + endpoints: OAuthEndpoints, +) -> str: + """Generate OAuth challenge and store state in mcp_server object.""" + code_verifier, code_challenge = generate_pkce_pair() + state = secrets.token_urlsafe(32) + + # Store in mcp_server object + mcp_server.state = state # type: ignore + mcp_server.code_verifier = code_verifier # type: ignore + + logger.info( + f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: " + f"state={state[:20]}..., verifier={code_verifier[:20]}..." + ) + + # Build authorization URL pointing to the target server + auth_params = { + "client_id": mcp_server.client_id, + "redirect_uri": endpoints.redirect_uri, + "response_type": "code", + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "scope": "read write", + } + + return f"{endpoints.authorization_endpoint}?{urlencode(auth_params)}" + + +async def complete_oauth_flow( + session: Session | scoped_session, code: str, state: str +) -> tuple[int, str]: + """Complete OAuth flow by exchanging code for token. + + Args: + code: Authorization code from OAuth callback + state: State parameter from OAuth callback + + Returns: + Tuple of (status_code, html_message) for the callback response + """ + try: + mcp_server = ( + session.query(DiscordMCPServer) + .filter(DiscordMCPServer.state == state) + .first() + ) + + if not mcp_server: + logger.error(f"Invalid or expired state: {state[:20]}...") + return 400, "Invalid or expired OAuth state" + + logger.info( + f"Found MCP server config: user={mcp_server.discord_bot_user_id}, " + f"url={mcp_server.mcp_server_url}" + ) + + # Get OAuth endpoints + try: + endpoints = await get_endpoints(str(mcp_server.mcp_server_url)) + except Exception as exc: + return 500, f"Failed to get OAuth endpoints: {str(exc)}" + + # Exchange authorization code for access token + token_data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": endpoints.redirect_uri, + "client_id": mcp_server.client_id, + "code_verifier": mcp_server.code_verifier, + } + + async with aiohttp.ClientSession() as http_session: + async with http_session.post( + endpoints.token_endpoint, + data=token_data, + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + if resp.status != 200: + error_text = await resp.text() + logger.error(f"Token exchange failed: {resp.status} - {error_text}") + return 500, f"Token exchange failed: {error_text}" + + tokens = await resp.json() + + access_token = tokens.get("access_token") + refresh_token = tokens.get("refresh_token") + expires_in = tokens.get("expires_in", 3600) + + if not access_token: + return 500, "Token response did not include access_token" + + logger.info(f"Successfully obtained access token: {access_token[:20]}...") + + # Store tokens and clear temporary OAuth state + mcp_server.access_token = access_token # type: ignore + mcp_server.refresh_token = refresh_token # type: ignore + mcp_server.token_expires_at = datetime.now() + timedelta(seconds=expires_in) # type: ignore + + # Clear temporary OAuth flow data + mcp_server.state = None # type: ignore + mcp_server.code_verifier = None # type: ignore + + session.commit() + + logger.info( + f"Stored tokens for user {mcp_server.discord_bot_user_id}, " + f"server {mcp_server.mcp_server_url}" + ) + + return 200, "✅ Authorization successful! You can now use this MCP server." + + except Exception as exc: + logger.exception(f"Failed to complete OAuth flow: {exc}") + return 500, f"Failed to complete OAuth flow: {str(exc)}"