From 0d9f8beec32726be458627b0a3170b12d61902be Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 2 Nov 2025 23:45:49 +0100 Subject: [PATCH] handle mcp servers in discord --- src/memory/common/llms/anthropic_provider.py | 48 ++++++++++++++++++-- src/memory/common/llms/base.py | 17 ++++++- src/memory/common/llms/openai_provider.py | 15 ++++-- src/memory/common/llms/tools/__init__.py | 10 ++++ src/memory/discord/collector.py | 5 +- src/memory/discord/commands.py | 6 +-- src/memory/discord/mcp.py | 43 ++++++++++-------- src/memory/discord/oauth.py | 2 +- 8 files changed, 110 insertions(+), 36 deletions(-) diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py index a55a997..a74a7e8 100644 --- a/src/memory/common/llms/anthropic_provider.py +++ b/src/memory/common/llms/anthropic_provider.py @@ -9,6 +9,7 @@ import anthropic from memory.common.llms.base import ( BaseLLMProvider, ImageContent, + MCPServer, LLMSettings, Message, MessageRole, @@ -103,6 +104,7 @@ class AnthropicProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None, tools: list[ToolDefinition] | None, + mcp_servers: list[MCPServer] | None, settings: LLMSettings, ) -> dict[str, Any]: """Build common request kwargs for API calls.""" @@ -113,7 +115,9 @@ class AnthropicProvider(BaseLLMProvider): "messages": anthropic_messages, "temperature": settings.temperature, "max_tokens": settings.max_tokens, - "extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"}, + "extra_headers": { + "anthropic-beta": "web-fetch-2025-09-10,mcp-client-2025-04-04" + }, } # Only include top_p if explicitly set @@ -129,6 +133,25 @@ class AnthropicProvider(BaseLLMProvider): if tools: kwargs["tools"] = self._convert_tools(tools) + if mcp_servers: + + def format_server(server: MCPServer) -> dict[str, Any]: + conf: dict[str, Any] = { + "type": "url", + "url": server.url, + "name": server.name, + "authorization_token": server.token, + } + if server.allowed_tools: + conf["tool_configuration"] = { + "allowed_tools": server.allowed_tools, + } + return conf + + kwargs["extra_body"] = { + "mcp_servers": [format_server(server) for server in mcp_servers] + } + # Enable extended thinking if requested and model supports it if self.enable_thinking and self._supports_thinking(): thinking_budget = min(10000, settings.max_tokens - 1024) @@ -141,6 +164,9 @@ class AnthropicProvider(BaseLLMProvider): kwargs["temperature"] = 1.0 kwargs.pop("top_p", None) + for k, v in kwargs.items(): + print(f"{k}: {v}") + return kwargs def _handle_stream_event( @@ -312,11 +338,14 @@ class AnthropicProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> str: """Generate a non-streaming response.""" settings = settings or LLMSettings() - kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) + kwargs = self._build_request_kwargs( + messages, system_prompt, tools, mcp_servers, settings + ) try: response = self.client.messages.create(**kwargs) @@ -332,11 +361,14 @@ class AnthropicProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> Iterator[StreamEvent]: """Generate a streaming response.""" settings = settings or LLMSettings() - kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) + kwargs = self._build_request_kwargs( + messages, system_prompt, tools, mcp_servers, settings + ) try: with self.client.messages.stream(**kwargs) as stream: @@ -358,11 +390,14 @@ class AnthropicProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> str: """Generate a non-streaming response asynchronously.""" settings = settings or LLMSettings() - kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) + kwargs = self._build_request_kwargs( + messages, system_prompt, tools, mcp_servers, settings + ) try: response = await self.async_client.messages.create(**kwargs) @@ -378,11 +413,14 @@ class AnthropicProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> AsyncIterator[StreamEvent]: """Generate a streaming response asynchronously.""" settings = settings or LLMSettings() - kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) + kwargs = self._build_request_kwargs( + messages, system_prompt, tools, mcp_servers, settings + ) try: async with self.async_client.messages.stream(**kwargs) as stream: diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py index e950287..ac31185 100644 --- a/src/memory/common/llms/base.py +++ b/src/memory/common/llms/base.py @@ -11,7 +11,7 @@ from typing import Any, AsyncIterator, Iterator, Literal, Union, cast from PIL import Image from memory.common import settings -from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult +from memory.common.llms.tools import MCPServer, ToolCall, ToolDefinition, ToolResult from memory.common.llms.usage import UsageTracker, RedisUsageTracker logger = logging.getLogger(__name__) @@ -434,6 +434,7 @@ class BaseLLMProvider(ABC): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> str: """ @@ -456,6 +457,7 @@ class BaseLLMProvider(ABC): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> Iterator[StreamEvent]: """ @@ -478,6 +480,7 @@ class BaseLLMProvider(ABC): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> str: """ @@ -500,6 +503,7 @@ class BaseLLMProvider(ABC): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> AsyncIterator[StreamEvent]: """ @@ -520,6 +524,7 @@ class BaseLLMProvider(ABC): self, messages: list[Message], tools: dict[str, ToolDefinition], + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, system_prompt: str | None = None, max_iterations: int = 10, @@ -551,6 +556,7 @@ class BaseLLMProvider(ABC): messages=messages, system_prompt=system_prompt, tools=list(tools.values()), + mcp_servers=mcp_servers, settings=settings, ): if event.type == "text": @@ -583,7 +589,12 @@ class BaseLLMProvider(ABC): # Recursively continue the conversation with reduced iterations yield from self.stream_with_tools( - messages, tools, settings, system_prompt, max_iterations - 1 + messages, + tools, + mcp_servers, + settings, + system_prompt, + max_iterations - 1, ) return # Exit after recursive call completes @@ -598,6 +609,7 @@ class BaseLLMProvider(ABC): self, messages: list[Message], tools: dict[str, ToolDefinition], + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, system_prompt: str | None = None, max_iterations: int = 10, @@ -606,6 +618,7 @@ class BaseLLMProvider(ABC): for event in self.stream_with_tools( messages=messages, tools=tools, + mcp_servers=mcp_servers, settings=settings, system_prompt=system_prompt, max_iterations=max_iterations, diff --git a/src/memory/common/llms/openai_provider.py b/src/memory/common/llms/openai_provider.py index 8a342cd..e85d1df 100644 --- a/src/memory/common/llms/openai_provider.py +++ b/src/memory/common/llms/openai_provider.py @@ -9,6 +9,7 @@ import openai from memory.common.llms.base import ( BaseLLMProvider, ImageContent, + MCPServer, LLMSettings, Message, StreamEvent, @@ -175,6 +176,7 @@ class OpenAIProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None, tools: list[ToolDefinition] | None, + mcp_servers: list[MCPServer] | None, settings: LLMSettings, stream: bool = False, ) -> dict[str, Any]: @@ -185,6 +187,7 @@ class OpenAIProvider(BaseLLMProvider): messages: Conversation history system_prompt: Optional system prompt tools: Optional list of tools + mcp_servers: Optional list of MCP servers settings: LLM settings stream: Whether to enable streaming @@ -333,12 +336,13 @@ class OpenAIProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> str: """Generate a non-streaming response.""" settings = settings or LLMSettings() kwargs = self._build_request_kwargs( - messages, system_prompt, tools, settings, stream=False + messages, system_prompt, tools, mcp_servers, settings, stream=False ) try: @@ -361,12 +365,13 @@ class OpenAIProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> Iterator[StreamEvent]: """Generate a streaming response.""" settings = settings or LLMSettings() kwargs = self._build_request_kwargs( - messages, system_prompt, tools, settings, stream=True + messages, system_prompt, tools, mcp_servers, settings, stream=True ) if kwargs.get("stream"): @@ -393,12 +398,13 @@ class OpenAIProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> str: """Generate a non-streaming response asynchronously.""" settings = settings or LLMSettings() kwargs = self._build_request_kwargs( - messages, system_prompt, tools, settings, stream=False + messages, system_prompt, tools, mcp_servers, settings, stream=False ) try: @@ -413,12 +419,13 @@ class OpenAIProvider(BaseLLMProvider): messages: list[Message], system_prompt: str | None = None, tools: list[ToolDefinition] | None = None, + mcp_servers: list[MCPServer] | None = None, settings: LLMSettings | None = None, ) -> AsyncIterator[StreamEvent]: """Generate a streaming response asynchronously.""" settings = settings or LLMSettings() kwargs = self._build_request_kwargs( - messages, system_prompt, tools, settings, stream=True + messages, system_prompt, tools, mcp_servers, settings, stream=True ) try: diff --git a/src/memory/common/llms/tools/__init__.py b/src/memory/common/llms/tools/__init__.py index f716e20..a04bd8a 100644 --- a/src/memory/common/llms/tools/__init__.py +++ b/src/memory/common/llms/tools/__init__.py @@ -23,6 +23,16 @@ class ToolResult(TypedDict): output: str +@dataclass +class MCPServer: + """An MCP server.""" + + name: str + url: str + token: str + allowed_tools: list[str] | None = None + + @dataclass class ToolDefinition: """Definition of a tool that can be called by the LLM.""" diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py index 852a892..fb03436 100644 --- a/src/memory/discord/collector.py +++ b/src/memory/discord/collector.py @@ -234,13 +234,12 @@ class MessageCollector(commands.Bot): async def setup_hook(self): """Register slash commands when the bot is ready.""" - if not (name := self.user.name): + if not self.user: logger.error(f"Failed to get user name for {self.user}") return - name = name.replace("-", "_").lower() try: - register_slash_commands(self, name=name) + register_slash_commands(self) except Exception as e: logger.error(f"Failed to register slash commands for {self.user.name}: {e}") logger.error(f"Registered slash commands for {self.user.name}") diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py index 82c15ce..69830ab 100644 --- a/src/memory/discord/commands.py +++ b/src/memory/discord/commands.py @@ -44,7 +44,7 @@ class CommandContext: CommandHandler = Callable[..., CommandResponse] -def register_slash_commands(bot: discord.Client, name: str = "memory") -> None: +def register_slash_commands(bot: discord.Client) -> None: """Register the collector slash commands on the provided bot. Args: @@ -53,7 +53,6 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None: """ if getattr(bot, "_memory_commands_registered", False): - logger.error(f"Slash commands already registered for {name}") return setattr(bot, "_memory_commands_registered", True) @@ -62,6 +61,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None: raise RuntimeError("Bot instance does not support app commands") tree = bot.tree + name = bot.user and bot.user.name.replace("-", "_").lower() @tree.command( name=f"{name}_show_prompt", description="Show the current system prompt" @@ -184,7 +184,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None: action: Literal["list", "add", "delete", "connect", "tools"] = "list", url: str | None = None, ) -> None: - await run_mcp_server_command(interaction, action, url and url.strip(), name) + await run_mcp_server_command(interaction, bot.user, action, url and url.strip()) async def _run_interaction_command( diff --git a/src/memory/discord/mcp.py b/src/memory/discord/mcp.py index e2f7e9d..accfb6e 100644 --- a/src/memory/discord/mcp.py +++ b/src/memory/discord/mcp.py @@ -100,11 +100,15 @@ async def handle_mcp_list(interaction: discord.Interaction) -> str: async def handle_mcp_add( - interaction: discord.Interaction, url: str, name: str = "memory" + interaction: discord.Interaction, + bot_user: discord.User | None, + url: str, ) -> str: """Add a new MCP server via OAuth.""" + if not bot_user: + raise ValueError("Bot user is required") with make_session() as session: - if find_mcp_server(session, interaction.user.id, url): + if find_mcp_server(session, bot_user.id, url): return ( f"**MCP Server Already Exists**\n\n" f"You already have an MCP server configured at `{url}`.\n" @@ -115,10 +119,10 @@ async def handle_mcp_add( client_id = await register_oauth_client( endpoints, url, - f"Discord Bot - {name} ({interaction.user.name})", + f"Discord Bot - {bot_user.name} ({interaction.user.name})", ) mcp_server = DiscordMCPServer( - discord_bot_user_id=interaction.user.id, + discord_bot_user_id=bot_user.id, mcp_server_url=url, client_id=client_id, ) @@ -142,10 +146,10 @@ async def handle_mcp_add( ) -async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str: +async def handle_mcp_delete(bot_user: discord.User, url: str) -> str: """Delete an MCP server.""" with make_session() as session: - mcp_server = find_mcp_server(session, interaction.user.id, url) + mcp_server = find_mcp_server(session, bot_user.id, url) if not mcp_server: return ( f"**MCP Server Not Found**\n\n" @@ -157,10 +161,10 @@ async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str: return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed." -async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str: +async def handle_mcp_connect(bot_user: discord.User, url: str) -> str: """Reconnect to an existing MCP server (redo OAuth).""" with make_session() as session: - mcp_server = find_mcp_server(session, interaction.user.id, url) + mcp_server = find_mcp_server(session, bot_user.id, url) if not mcp_server: raise ValueError( f"**MCP Server Not Found**\n\n" @@ -180,9 +184,7 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str: session.commit() - logger.info( - f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}" - ) + logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}") return ( f"🔄 **Reconnect to MCP Server**\n\n" @@ -193,10 +195,10 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str: ) -async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str: +async def handle_mcp_tools(bot_user: discord.User, url: str) -> str: """List tools available on an MCP server.""" with make_session() as session: - mcp_server = find_mcp_server(session, interaction.user.id, url) + mcp_server = find_mcp_server(session, bot_user.id, url) if not mcp_server: raise ValueError( @@ -264,9 +266,9 @@ async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str: async def run_mcp_server_command( interaction: discord.Interaction, + bot_user: discord.User | None, action: Literal["list", "add", "delete", "connect", "tools"], url: str | None, - name: str = "memory", ) -> None: """Handle MCP server management commands.""" if action not in ["list", "add", "delete", "connect", "tools"]: @@ -277,18 +279,23 @@ async def run_mcp_server_command( "❌ URL is required for this action", ephemeral=True ) return + if not bot_user: + await interaction.response.send_message( + "❌ Bot user is required", ephemeral=True + ) + return try: if action == "list" or not url: result = await handle_mcp_list(interaction) elif action == "add": - result = await handle_mcp_add(interaction, url, name) + result = await handle_mcp_add(interaction, bot_user, url) elif action == "delete": - result = await handle_mcp_delete(interaction, url) + result = await handle_mcp_delete(bot_user, url) elif action == "connect": - result = await handle_mcp_connect(interaction, url) + result = await handle_mcp_connect(bot_user, url) elif action == "tools": - result = await handle_mcp_tools(interaction, url) + result = await handle_mcp_tools(bot_user, url) except Exception as exc: result = f"❌ Error: {exc}" await interaction.response.send_message(result, ephemeral=True) diff --git a/src/memory/discord/oauth.py b/src/memory/discord/oauth.py index a876598..ce65c7c 100644 --- a/src/memory/discord/oauth.py +++ b/src/memory/discord/oauth.py @@ -99,7 +99,7 @@ async def get_endpoints(url: str) -> OAuthEndpoints: authorization_endpoint=authorization_endpoint, registration_endpoint=registration_endpoint, token_endpoint=token_endpoint, - redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord", + redirect_uri=f"{settings.SERVER_URL}/oauth/callback/discord", )