handle mcp servers in discord

This commit is contained in:
Daniel O'Connell 2025-11-02 23:45:49 +01:00
parent 64bb926eba
commit 0d9f8beec3
8 changed files with 110 additions and 36 deletions

View File

@ -9,6 +9,7 @@ import anthropic
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
MCPServer,
LLMSettings,
Message,
MessageRole,
@ -103,6 +104,7 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None,
tools: list[ToolDefinition] | None,
mcp_servers: list[MCPServer] | None,
settings: LLMSettings,
) -> dict[str, Any]:
"""Build common request kwargs for API calls."""
@ -113,7 +115,9 @@ class AnthropicProvider(BaseLLMProvider):
"messages": anthropic_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
"extra_headers": {
"anthropic-beta": "web-fetch-2025-09-10,mcp-client-2025-04-04"
},
}
# Only include top_p if explicitly set
@ -129,6 +133,25 @@ class AnthropicProvider(BaseLLMProvider):
if tools:
kwargs["tools"] = self._convert_tools(tools)
if mcp_servers:
def format_server(server: MCPServer) -> dict[str, Any]:
conf: dict[str, Any] = {
"type": "url",
"url": server.url,
"name": server.name,
"authorization_token": server.token,
}
if server.allowed_tools:
conf["tool_configuration"] = {
"allowed_tools": server.allowed_tools,
}
return conf
kwargs["extra_body"] = {
"mcp_servers": [format_server(server) for server in mcp_servers]
}
# Enable extended thinking if requested and model supports it
if self.enable_thinking and self._supports_thinking():
thinking_budget = min(10000, settings.max_tokens - 1024)
@ -141,6 +164,9 @@ class AnthropicProvider(BaseLLMProvider):
kwargs["temperature"] = 1.0
kwargs.pop("top_p", None)
for k, v in kwargs.items():
print(f"{k}: {v}")
return kwargs
def _handle_stream_event(
@ -312,11 +338,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try:
response = self.client.messages.create(**kwargs)
@ -332,11 +361,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""Generate a streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try:
with self.client.messages.stream(**kwargs) as stream:
@ -358,11 +390,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try:
response = await self.async_client.messages.create(**kwargs)
@ -378,11 +413,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try:
async with self.async_client.messages.stream(**kwargs) as stream:

View File

@ -11,7 +11,7 @@ from typing import Any, AsyncIterator, Iterator, Literal, Union, cast
from PIL import Image
from memory.common import settings
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
from memory.common.llms.tools import MCPServer, ToolCall, ToolDefinition, ToolResult
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
logger = logging.getLogger(__name__)
@ -434,6 +434,7 @@ class BaseLLMProvider(ABC):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""
@ -456,6 +457,7 @@ class BaseLLMProvider(ABC):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""
@ -478,6 +480,7 @@ class BaseLLMProvider(ABC):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""
@ -500,6 +503,7 @@ class BaseLLMProvider(ABC):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""
@ -520,6 +524,7 @@ class BaseLLMProvider(ABC):
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
@ -551,6 +556,7 @@ class BaseLLMProvider(ABC):
messages=messages,
system_prompt=system_prompt,
tools=list(tools.values()),
mcp_servers=mcp_servers,
settings=settings,
):
if event.type == "text":
@ -583,7 +589,12 @@ class BaseLLMProvider(ABC):
# Recursively continue the conversation with reduced iterations
yield from self.stream_with_tools(
messages, tools, settings, system_prompt, max_iterations - 1
messages,
tools,
mcp_servers,
settings,
system_prompt,
max_iterations - 1,
)
return # Exit after recursive call completes
@ -598,6 +609,7 @@ class BaseLLMProvider(ABC):
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
@ -606,6 +618,7 @@ class BaseLLMProvider(ABC):
for event in self.stream_with_tools(
messages=messages,
tools=tools,
mcp_servers=mcp_servers,
settings=settings,
system_prompt=system_prompt,
max_iterations=max_iterations,

View File

@ -9,6 +9,7 @@ import openai
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
MCPServer,
LLMSettings,
Message,
StreamEvent,
@ -175,6 +176,7 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None,
tools: list[ToolDefinition] | None,
mcp_servers: list[MCPServer] | None,
settings: LLMSettings,
stream: bool = False,
) -> dict[str, Any]:
@ -185,6 +187,7 @@ class OpenAIProvider(BaseLLMProvider):
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools
mcp_servers: Optional list of MCP servers
settings: LLM settings
stream: Whether to enable streaming
@ -333,12 +336,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=False
messages, system_prompt, tools, mcp_servers, settings, stream=False
)
try:
@ -361,12 +365,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""Generate a streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=True
messages, system_prompt, tools, mcp_servers, settings, stream=True
)
if kwargs.get("stream"):
@ -393,12 +398,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=False
messages, system_prompt, tools, mcp_servers, settings, stream=False
)
try:
@ -413,12 +419,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=True
messages, system_prompt, tools, mcp_servers, settings, stream=True
)
try:

View File

@ -23,6 +23,16 @@ class ToolResult(TypedDict):
output: str
@dataclass
class MCPServer:
"""An MCP server."""
name: str
url: str
token: str
allowed_tools: list[str] | None = None
@dataclass
class ToolDefinition:
"""Definition of a tool that can be called by the LLM."""

View File

@ -234,13 +234,12 @@ class MessageCollector(commands.Bot):
async def setup_hook(self):
"""Register slash commands when the bot is ready."""
if not (name := self.user.name):
if not self.user:
logger.error(f"Failed to get user name for {self.user}")
return
name = name.replace("-", "_").lower()
try:
register_slash_commands(self, name=name)
register_slash_commands(self)
except Exception as e:
logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
logger.error(f"Registered slash commands for {self.user.name}")

View File

@ -44,7 +44,7 @@ class CommandContext:
CommandHandler = Callable[..., CommandResponse]
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
def register_slash_commands(bot: discord.Client) -> None:
"""Register the collector slash commands on the provided bot.
Args:
@ -53,7 +53,6 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
"""
if getattr(bot, "_memory_commands_registered", False):
logger.error(f"Slash commands already registered for {name}")
return
setattr(bot, "_memory_commands_registered", True)
@ -62,6 +61,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
raise RuntimeError("Bot instance does not support app commands")
tree = bot.tree
name = bot.user and bot.user.name.replace("-", "_").lower()
@tree.command(
name=f"{name}_show_prompt", description="Show the current system prompt"
@ -184,7 +184,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
url: str | None = None,
) -> None:
await run_mcp_server_command(interaction, action, url and url.strip(), name)
await run_mcp_server_command(interaction, bot.user, action, url and url.strip())
async def _run_interaction_command(

View File

@ -100,11 +100,15 @@ async def handle_mcp_list(interaction: discord.Interaction) -> str:
async def handle_mcp_add(
interaction: discord.Interaction, url: str, name: str = "memory"
interaction: discord.Interaction,
bot_user: discord.User | None,
url: str,
) -> str:
"""Add a new MCP server via OAuth."""
if not bot_user:
raise ValueError("Bot user is required")
with make_session() as session:
if find_mcp_server(session, interaction.user.id, url):
if find_mcp_server(session, bot_user.id, url):
return (
f"**MCP Server Already Exists**\n\n"
f"You already have an MCP server configured at `{url}`.\n"
@ -115,10 +119,10 @@ async def handle_mcp_add(
client_id = await register_oauth_client(
endpoints,
url,
f"Discord Bot - {name} ({interaction.user.name})",
f"Discord Bot - {bot_user.name} ({interaction.user.name})",
)
mcp_server = DiscordMCPServer(
discord_bot_user_id=interaction.user.id,
discord_bot_user_id=bot_user.id,
mcp_server_url=url,
client_id=client_id,
)
@ -142,10 +146,10 @@ async def handle_mcp_add(
)
async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
async def handle_mcp_delete(bot_user: discord.User, url: str) -> str:
"""Delete an MCP server."""
with make_session() as session:
mcp_server = find_mcp_server(session, interaction.user.id, url)
mcp_server = find_mcp_server(session, bot_user.id, url)
if not mcp_server:
return (
f"**MCP Server Not Found**\n\n"
@ -157,10 +161,10 @@ async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
"""Reconnect to an existing MCP server (redo OAuth)."""
with make_session() as session:
mcp_server = find_mcp_server(session, interaction.user.id, url)
mcp_server = find_mcp_server(session, bot_user.id, url)
if not mcp_server:
raise ValueError(
f"**MCP Server Not Found**\n\n"
@ -180,9 +184,7 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
session.commit()
logger.info(
f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}"
)
logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}")
return (
f"🔄 **Reconnect to MCP Server**\n\n"
@ -193,10 +195,10 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
)
async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
"""List tools available on an MCP server."""
with make_session() as session:
mcp_server = find_mcp_server(session, interaction.user.id, url)
mcp_server = find_mcp_server(session, bot_user.id, url)
if not mcp_server:
raise ValueError(
@ -264,9 +266,9 @@ async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
async def run_mcp_server_command(
interaction: discord.Interaction,
bot_user: discord.User | None,
action: Literal["list", "add", "delete", "connect", "tools"],
url: str | None,
name: str = "memory",
) -> None:
"""Handle MCP server management commands."""
if action not in ["list", "add", "delete", "connect", "tools"]:
@ -277,18 +279,23 @@ async def run_mcp_server_command(
"❌ URL is required for this action", ephemeral=True
)
return
if not bot_user:
await interaction.response.send_message(
"❌ Bot user is required", ephemeral=True
)
return
try:
if action == "list" or not url:
result = await handle_mcp_list(interaction)
elif action == "add":
result = await handle_mcp_add(interaction, url, name)
result = await handle_mcp_add(interaction, bot_user, url)
elif action == "delete":
result = await handle_mcp_delete(interaction, url)
result = await handle_mcp_delete(bot_user, url)
elif action == "connect":
result = await handle_mcp_connect(interaction, url)
result = await handle_mcp_connect(bot_user, url)
elif action == "tools":
result = await handle_mcp_tools(interaction, url)
result = await handle_mcp_tools(bot_user, url)
except Exception as exc:
result = f"❌ Error: {exc}"
await interaction.response.send_message(result, ephemeral=True)

View File

@ -99,7 +99,7 @@ async def get_endpoints(url: str) -> OAuthEndpoints:
authorization_endpoint=authorization_endpoint,
registration_endpoint=registration_endpoint,
token_endpoint=token_endpoint,
redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord",
redirect_uri=f"{settings.SERVER_URL}/oauth/callback/discord",
)