handle mcp servers in discord

This commit is contained in:
Daniel O'Connell 2025-11-02 23:45:49 +01:00
parent 64bb926eba
commit 0d9f8beec3
8 changed files with 110 additions and 36 deletions

View File

@ -9,6 +9,7 @@ import anthropic
from memory.common.llms.base import ( from memory.common.llms.base import (
BaseLLMProvider, BaseLLMProvider,
ImageContent, ImageContent,
MCPServer,
LLMSettings, LLMSettings,
Message, Message,
MessageRole, MessageRole,
@ -103,6 +104,7 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None, system_prompt: str | None,
tools: list[ToolDefinition] | None, tools: list[ToolDefinition] | None,
mcp_servers: list[MCPServer] | None,
settings: LLMSettings, settings: LLMSettings,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Build common request kwargs for API calls.""" """Build common request kwargs for API calls."""
@ -113,7 +115,9 @@ class AnthropicProvider(BaseLLMProvider):
"messages": anthropic_messages, "messages": anthropic_messages,
"temperature": settings.temperature, "temperature": settings.temperature,
"max_tokens": settings.max_tokens, "max_tokens": settings.max_tokens,
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"}, "extra_headers": {
"anthropic-beta": "web-fetch-2025-09-10,mcp-client-2025-04-04"
},
} }
# Only include top_p if explicitly set # Only include top_p if explicitly set
@ -129,6 +133,25 @@ class AnthropicProvider(BaseLLMProvider):
if tools: if tools:
kwargs["tools"] = self._convert_tools(tools) kwargs["tools"] = self._convert_tools(tools)
if mcp_servers:
def format_server(server: MCPServer) -> dict[str, Any]:
conf: dict[str, Any] = {
"type": "url",
"url": server.url,
"name": server.name,
"authorization_token": server.token,
}
if server.allowed_tools:
conf["tool_configuration"] = {
"allowed_tools": server.allowed_tools,
}
return conf
kwargs["extra_body"] = {
"mcp_servers": [format_server(server) for server in mcp_servers]
}
# Enable extended thinking if requested and model supports it # Enable extended thinking if requested and model supports it
if self.enable_thinking and self._supports_thinking(): if self.enable_thinking and self._supports_thinking():
thinking_budget = min(10000, settings.max_tokens - 1024) thinking_budget = min(10000, settings.max_tokens - 1024)
@ -141,6 +164,9 @@ class AnthropicProvider(BaseLLMProvider):
kwargs["temperature"] = 1.0 kwargs["temperature"] = 1.0
kwargs.pop("top_p", None) kwargs.pop("top_p", None)
for k, v in kwargs.items():
print(f"{k}: {v}")
return kwargs return kwargs
def _handle_stream_event( def _handle_stream_event(
@ -312,11 +338,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> str: ) -> str:
"""Generate a non-streaming response.""" """Generate a non-streaming response."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try: try:
response = self.client.messages.create(**kwargs) response = self.client.messages.create(**kwargs)
@ -332,11 +361,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]: ) -> Iterator[StreamEvent]:
"""Generate a streaming response.""" """Generate a streaming response."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try: try:
with self.client.messages.stream(**kwargs) as stream: with self.client.messages.stream(**kwargs) as stream:
@ -358,11 +390,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> str: ) -> str:
"""Generate a non-streaming response asynchronously.""" """Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try: try:
response = await self.async_client.messages.create(**kwargs) response = await self.async_client.messages.create(**kwargs)
@ -378,11 +413,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]: ) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously.""" """Generate a streaming response asynchronously."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings) kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try: try:
async with self.async_client.messages.stream(**kwargs) as stream: async with self.async_client.messages.stream(**kwargs) as stream:

View File

@ -11,7 +11,7 @@ from typing import Any, AsyncIterator, Iterator, Literal, Union, cast
from PIL import Image from PIL import Image
from memory.common import settings from memory.common import settings
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult from memory.common.llms.tools import MCPServer, ToolCall, ToolDefinition, ToolResult
from memory.common.llms.usage import UsageTracker, RedisUsageTracker from memory.common.llms.usage import UsageTracker, RedisUsageTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -434,6 +434,7 @@ class BaseLLMProvider(ABC):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> str: ) -> str:
""" """
@ -456,6 +457,7 @@ class BaseLLMProvider(ABC):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]: ) -> Iterator[StreamEvent]:
""" """
@ -478,6 +480,7 @@ class BaseLLMProvider(ABC):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> str: ) -> str:
""" """
@ -500,6 +503,7 @@ class BaseLLMProvider(ABC):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]: ) -> AsyncIterator[StreamEvent]:
""" """
@ -520,6 +524,7 @@ class BaseLLMProvider(ABC):
self, self,
messages: list[Message], messages: list[Message],
tools: dict[str, ToolDefinition], tools: dict[str, ToolDefinition],
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
max_iterations: int = 10, max_iterations: int = 10,
@ -551,6 +556,7 @@ class BaseLLMProvider(ABC):
messages=messages, messages=messages,
system_prompt=system_prompt, system_prompt=system_prompt,
tools=list(tools.values()), tools=list(tools.values()),
mcp_servers=mcp_servers,
settings=settings, settings=settings,
): ):
if event.type == "text": if event.type == "text":
@ -583,7 +589,12 @@ class BaseLLMProvider(ABC):
# Recursively continue the conversation with reduced iterations # Recursively continue the conversation with reduced iterations
yield from self.stream_with_tools( yield from self.stream_with_tools(
messages, tools, settings, system_prompt, max_iterations - 1 messages,
tools,
mcp_servers,
settings,
system_prompt,
max_iterations - 1,
) )
return # Exit after recursive call completes return # Exit after recursive call completes
@ -598,6 +609,7 @@ class BaseLLMProvider(ABC):
self, self,
messages: list[Message], messages: list[Message],
tools: dict[str, ToolDefinition], tools: dict[str, ToolDefinition],
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
max_iterations: int = 10, max_iterations: int = 10,
@ -606,6 +618,7 @@ class BaseLLMProvider(ABC):
for event in self.stream_with_tools( for event in self.stream_with_tools(
messages=messages, messages=messages,
tools=tools, tools=tools,
mcp_servers=mcp_servers,
settings=settings, settings=settings,
system_prompt=system_prompt, system_prompt=system_prompt,
max_iterations=max_iterations, max_iterations=max_iterations,

View File

@ -9,6 +9,7 @@ import openai
from memory.common.llms.base import ( from memory.common.llms.base import (
BaseLLMProvider, BaseLLMProvider,
ImageContent, ImageContent,
MCPServer,
LLMSettings, LLMSettings,
Message, Message,
StreamEvent, StreamEvent,
@ -175,6 +176,7 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None, system_prompt: str | None,
tools: list[ToolDefinition] | None, tools: list[ToolDefinition] | None,
mcp_servers: list[MCPServer] | None,
settings: LLMSettings, settings: LLMSettings,
stream: bool = False, stream: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
@ -185,6 +187,7 @@ class OpenAIProvider(BaseLLMProvider):
messages: Conversation history messages: Conversation history
system_prompt: Optional system prompt system_prompt: Optional system prompt
tools: Optional list of tools tools: Optional list of tools
mcp_servers: Optional list of MCP servers
settings: LLM settings settings: LLM settings
stream: Whether to enable streaming stream: Whether to enable streaming
@ -333,12 +336,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> str: ) -> str:
"""Generate a non-streaming response.""" """Generate a non-streaming response."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs( kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=False messages, system_prompt, tools, mcp_servers, settings, stream=False
) )
try: try:
@ -361,12 +365,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]: ) -> Iterator[StreamEvent]:
"""Generate a streaming response.""" """Generate a streaming response."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs( kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=True messages, system_prompt, tools, mcp_servers, settings, stream=True
) )
if kwargs.get("stream"): if kwargs.get("stream"):
@ -393,12 +398,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> str: ) -> str:
"""Generate a non-streaming response asynchronously.""" """Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs( kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=False messages, system_prompt, tools, mcp_servers, settings, stream=False
) )
try: try:
@ -413,12 +419,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message], messages: list[Message],
system_prompt: str | None = None, system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None, settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]: ) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously.""" """Generate a streaming response asynchronously."""
settings = settings or LLMSettings() settings = settings or LLMSettings()
kwargs = self._build_request_kwargs( kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=True messages, system_prompt, tools, mcp_servers, settings, stream=True
) )
try: try:

View File

@ -23,6 +23,16 @@ class ToolResult(TypedDict):
output: str output: str
@dataclass
class MCPServer:
"""An MCP server."""
name: str
url: str
token: str
allowed_tools: list[str] | None = None
@dataclass @dataclass
class ToolDefinition: class ToolDefinition:
"""Definition of a tool that can be called by the LLM.""" """Definition of a tool that can be called by the LLM."""

View File

@ -234,13 +234,12 @@ class MessageCollector(commands.Bot):
async def setup_hook(self): async def setup_hook(self):
"""Register slash commands when the bot is ready.""" """Register slash commands when the bot is ready."""
if not (name := self.user.name): if not self.user:
logger.error(f"Failed to get user name for {self.user}") logger.error(f"Failed to get user name for {self.user}")
return return
name = name.replace("-", "_").lower()
try: try:
register_slash_commands(self, name=name) register_slash_commands(self)
except Exception as e: except Exception as e:
logger.error(f"Failed to register slash commands for {self.user.name}: {e}") logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
logger.error(f"Registered slash commands for {self.user.name}") logger.error(f"Registered slash commands for {self.user.name}")

View File

@ -44,7 +44,7 @@ class CommandContext:
CommandHandler = Callable[..., CommandResponse] CommandHandler = Callable[..., CommandResponse]
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None: def register_slash_commands(bot: discord.Client) -> None:
"""Register the collector slash commands on the provided bot. """Register the collector slash commands on the provided bot.
Args: Args:
@ -53,7 +53,6 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
""" """
if getattr(bot, "_memory_commands_registered", False): if getattr(bot, "_memory_commands_registered", False):
logger.error(f"Slash commands already registered for {name}")
return return
setattr(bot, "_memory_commands_registered", True) setattr(bot, "_memory_commands_registered", True)
@ -62,6 +61,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
raise RuntimeError("Bot instance does not support app commands") raise RuntimeError("Bot instance does not support app commands")
tree = bot.tree tree = bot.tree
name = bot.user and bot.user.name.replace("-", "_").lower()
@tree.command( @tree.command(
name=f"{name}_show_prompt", description="Show the current system prompt" name=f"{name}_show_prompt", description="Show the current system prompt"
@ -184,7 +184,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
action: Literal["list", "add", "delete", "connect", "tools"] = "list", action: Literal["list", "add", "delete", "connect", "tools"] = "list",
url: str | None = None, url: str | None = None,
) -> None: ) -> None:
await run_mcp_server_command(interaction, action, url and url.strip(), name) await run_mcp_server_command(interaction, bot.user, action, url and url.strip())
async def _run_interaction_command( async def _run_interaction_command(

View File

@ -100,11 +100,15 @@ async def handle_mcp_list(interaction: discord.Interaction) -> str:
async def handle_mcp_add( async def handle_mcp_add(
interaction: discord.Interaction, url: str, name: str = "memory" interaction: discord.Interaction,
bot_user: discord.User | None,
url: str,
) -> str: ) -> str:
"""Add a new MCP server via OAuth.""" """Add a new MCP server via OAuth."""
if not bot_user:
raise ValueError("Bot user is required")
with make_session() as session: with make_session() as session:
if find_mcp_server(session, interaction.user.id, url): if find_mcp_server(session, bot_user.id, url):
return ( return (
f"**MCP Server Already Exists**\n\n" f"**MCP Server Already Exists**\n\n"
f"You already have an MCP server configured at `{url}`.\n" f"You already have an MCP server configured at `{url}`.\n"
@ -115,10 +119,10 @@ async def handle_mcp_add(
client_id = await register_oauth_client( client_id = await register_oauth_client(
endpoints, endpoints,
url, url,
f"Discord Bot - {name} ({interaction.user.name})", f"Discord Bot - {bot_user.name} ({interaction.user.name})",
) )
mcp_server = DiscordMCPServer( mcp_server = DiscordMCPServer(
discord_bot_user_id=interaction.user.id, discord_bot_user_id=bot_user.id,
mcp_server_url=url, mcp_server_url=url,
client_id=client_id, client_id=client_id,
) )
@ -142,10 +146,10 @@ async def handle_mcp_add(
) )
async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str: async def handle_mcp_delete(bot_user: discord.User, url: str) -> str:
"""Delete an MCP server.""" """Delete an MCP server."""
with make_session() as session: with make_session() as session:
mcp_server = find_mcp_server(session, interaction.user.id, url) mcp_server = find_mcp_server(session, bot_user.id, url)
if not mcp_server: if not mcp_server:
return ( return (
f"**MCP Server Not Found**\n\n" f"**MCP Server Not Found**\n\n"
@ -157,10 +161,10 @@ async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed." return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str: async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
"""Reconnect to an existing MCP server (redo OAuth).""" """Reconnect to an existing MCP server (redo OAuth)."""
with make_session() as session: with make_session() as session:
mcp_server = find_mcp_server(session, interaction.user.id, url) mcp_server = find_mcp_server(session, bot_user.id, url)
if not mcp_server: if not mcp_server:
raise ValueError( raise ValueError(
f"**MCP Server Not Found**\n\n" f"**MCP Server Not Found**\n\n"
@ -180,9 +184,7 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
session.commit() session.commit()
logger.info( logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}")
f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}"
)
return ( return (
f"🔄 **Reconnect to MCP Server**\n\n" f"🔄 **Reconnect to MCP Server**\n\n"
@ -193,10 +195,10 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
) )
async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str: async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
"""List tools available on an MCP server.""" """List tools available on an MCP server."""
with make_session() as session: with make_session() as session:
mcp_server = find_mcp_server(session, interaction.user.id, url) mcp_server = find_mcp_server(session, bot_user.id, url)
if not mcp_server: if not mcp_server:
raise ValueError( raise ValueError(
@ -264,9 +266,9 @@ async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
async def run_mcp_server_command( async def run_mcp_server_command(
interaction: discord.Interaction, interaction: discord.Interaction,
bot_user: discord.User | None,
action: Literal["list", "add", "delete", "connect", "tools"], action: Literal["list", "add", "delete", "connect", "tools"],
url: str | None, url: str | None,
name: str = "memory",
) -> None: ) -> None:
"""Handle MCP server management commands.""" """Handle MCP server management commands."""
if action not in ["list", "add", "delete", "connect", "tools"]: if action not in ["list", "add", "delete", "connect", "tools"]:
@ -277,18 +279,23 @@ async def run_mcp_server_command(
"❌ URL is required for this action", ephemeral=True "❌ URL is required for this action", ephemeral=True
) )
return return
if not bot_user:
await interaction.response.send_message(
"❌ Bot user is required", ephemeral=True
)
return
try: try:
if action == "list" or not url: if action == "list" or not url:
result = await handle_mcp_list(interaction) result = await handle_mcp_list(interaction)
elif action == "add": elif action == "add":
result = await handle_mcp_add(interaction, url, name) result = await handle_mcp_add(interaction, bot_user, url)
elif action == "delete": elif action == "delete":
result = await handle_mcp_delete(interaction, url) result = await handle_mcp_delete(bot_user, url)
elif action == "connect": elif action == "connect":
result = await handle_mcp_connect(interaction, url) result = await handle_mcp_connect(bot_user, url)
elif action == "tools": elif action == "tools":
result = await handle_mcp_tools(interaction, url) result = await handle_mcp_tools(bot_user, url)
except Exception as exc: except Exception as exc:
result = f"❌ Error: {exc}" result = f"❌ Error: {exc}"
await interaction.response.send_message(result, ephemeral=True) await interaction.response.send_message(result, ephemeral=True)

View File

@ -99,7 +99,7 @@ async def get_endpoints(url: str) -> OAuthEndpoints:
authorization_endpoint=authorization_endpoint, authorization_endpoint=authorization_endpoint,
registration_endpoint=registration_endpoint, registration_endpoint=registration_endpoint,
token_endpoint=token_endpoint, token_endpoint=token_endpoint,
redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord", redirect_uri=f"{settings.SERVER_URL}/oauth/callback/discord",
) )