mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
handle mcp servers in discord
This commit is contained in:
parent
64bb926eba
commit
0d9f8beec3
@ -9,6 +9,7 @@ import anthropic
|
||||
from memory.common.llms.base import (
|
||||
BaseLLMProvider,
|
||||
ImageContent,
|
||||
MCPServer,
|
||||
LLMSettings,
|
||||
Message,
|
||||
MessageRole,
|
||||
@ -103,6 +104,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None,
|
||||
tools: list[ToolDefinition] | None,
|
||||
mcp_servers: list[MCPServer] | None,
|
||||
settings: LLMSettings,
|
||||
) -> dict[str, Any]:
|
||||
"""Build common request kwargs for API calls."""
|
||||
@ -113,7 +115,9 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
"messages": anthropic_messages,
|
||||
"temperature": settings.temperature,
|
||||
"max_tokens": settings.max_tokens,
|
||||
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
|
||||
"extra_headers": {
|
||||
"anthropic-beta": "web-fetch-2025-09-10,mcp-client-2025-04-04"
|
||||
},
|
||||
}
|
||||
|
||||
# Only include top_p if explicitly set
|
||||
@ -129,6 +133,25 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
if tools:
|
||||
kwargs["tools"] = self._convert_tools(tools)
|
||||
|
||||
if mcp_servers:
|
||||
|
||||
def format_server(server: MCPServer) -> dict[str, Any]:
|
||||
conf: dict[str, Any] = {
|
||||
"type": "url",
|
||||
"url": server.url,
|
||||
"name": server.name,
|
||||
"authorization_token": server.token,
|
||||
}
|
||||
if server.allowed_tools:
|
||||
conf["tool_configuration"] = {
|
||||
"allowed_tools": server.allowed_tools,
|
||||
}
|
||||
return conf
|
||||
|
||||
kwargs["extra_body"] = {
|
||||
"mcp_servers": [format_server(server) for server in mcp_servers]
|
||||
}
|
||||
|
||||
# Enable extended thinking if requested and model supports it
|
||||
if self.enable_thinking and self._supports_thinking():
|
||||
thinking_budget = min(10000, settings.max_tokens - 1024)
|
||||
@ -141,6 +164,9 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
kwargs["temperature"] = 1.0
|
||||
kwargs.pop("top_p", None)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
print(f"{k}: {v}")
|
||||
|
||||
return kwargs
|
||||
|
||||
def _handle_stream_event(
|
||||
@ -312,11 +338,14 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, mcp_servers, settings
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.messages.create(**kwargs)
|
||||
@ -332,11 +361,14 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Generate a streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, mcp_servers, settings
|
||||
)
|
||||
|
||||
try:
|
||||
with self.client.messages.stream(**kwargs) as stream:
|
||||
@ -358,11 +390,14 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, mcp_servers, settings
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.async_client.messages.create(**kwargs)
|
||||
@ -378,11 +413,14 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Generate a streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, mcp_servers, settings
|
||||
)
|
||||
|
||||
try:
|
||||
async with self.async_client.messages.stream(**kwargs) as stream:
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Any, AsyncIterator, Iterator, Literal, Union, cast
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
||||
from memory.common.llms.tools import MCPServer, ToolCall, ToolDefinition, ToolResult
|
||||
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -434,6 +434,7 @@ class BaseLLMProvider(ABC):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
@ -456,6 +457,7 @@ class BaseLLMProvider(ABC):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""
|
||||
@ -478,6 +480,7 @@ class BaseLLMProvider(ABC):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
@ -500,6 +503,7 @@ class BaseLLMProvider(ABC):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""
|
||||
@ -520,6 +524,7 @@ class BaseLLMProvider(ABC):
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: dict[str, ToolDefinition],
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
@ -551,6 +556,7 @@ class BaseLLMProvider(ABC):
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
tools=list(tools.values()),
|
||||
mcp_servers=mcp_servers,
|
||||
settings=settings,
|
||||
):
|
||||
if event.type == "text":
|
||||
@ -583,7 +589,12 @@ class BaseLLMProvider(ABC):
|
||||
|
||||
# Recursively continue the conversation with reduced iterations
|
||||
yield from self.stream_with_tools(
|
||||
messages, tools, settings, system_prompt, max_iterations - 1
|
||||
messages,
|
||||
tools,
|
||||
mcp_servers,
|
||||
settings,
|
||||
system_prompt,
|
||||
max_iterations - 1,
|
||||
)
|
||||
return # Exit after recursive call completes
|
||||
|
||||
@ -598,6 +609,7 @@ class BaseLLMProvider(ABC):
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: dict[str, ToolDefinition],
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
system_prompt: str | None = None,
|
||||
max_iterations: int = 10,
|
||||
@ -606,6 +618,7 @@ class BaseLLMProvider(ABC):
|
||||
for event in self.stream_with_tools(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
mcp_servers=mcp_servers,
|
||||
settings=settings,
|
||||
system_prompt=system_prompt,
|
||||
max_iterations=max_iterations,
|
||||
|
||||
@ -9,6 +9,7 @@ import openai
|
||||
from memory.common.llms.base import (
|
||||
BaseLLMProvider,
|
||||
ImageContent,
|
||||
MCPServer,
|
||||
LLMSettings,
|
||||
Message,
|
||||
StreamEvent,
|
||||
@ -175,6 +176,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None,
|
||||
tools: list[ToolDefinition] | None,
|
||||
mcp_servers: list[MCPServer] | None,
|
||||
settings: LLMSettings,
|
||||
stream: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
@ -185,6 +187,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
messages: Conversation history
|
||||
system_prompt: Optional system prompt
|
||||
tools: Optional list of tools
|
||||
mcp_servers: Optional list of MCP servers
|
||||
settings: LLM settings
|
||||
stream: Whether to enable streaming
|
||||
|
||||
@ -333,12 +336,13 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, settings, stream=False
|
||||
messages, system_prompt, tools, mcp_servers, settings, stream=False
|
||||
)
|
||||
|
||||
try:
|
||||
@ -361,12 +365,13 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Generate a streaming response."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, settings, stream=True
|
||||
messages, system_prompt, tools, mcp_servers, settings, stream=True
|
||||
)
|
||||
|
||||
if kwargs.get("stream"):
|
||||
@ -393,12 +398,13 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> str:
|
||||
"""Generate a non-streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, settings, stream=False
|
||||
messages, system_prompt, tools, mcp_servers, settings, stream=False
|
||||
)
|
||||
|
||||
try:
|
||||
@ -413,12 +419,13 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
messages: list[Message],
|
||||
system_prompt: str | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
settings: LLMSettings | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Generate a streaming response asynchronously."""
|
||||
settings = settings or LLMSettings()
|
||||
kwargs = self._build_request_kwargs(
|
||||
messages, system_prompt, tools, settings, stream=True
|
||||
messages, system_prompt, tools, mcp_servers, settings, stream=True
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -23,6 +23,16 @@ class ToolResult(TypedDict):
|
||||
output: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPServer:
|
||||
"""An MCP server."""
|
||||
|
||||
name: str
|
||||
url: str
|
||||
token: str
|
||||
allowed_tools: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""Definition of a tool that can be called by the LLM."""
|
||||
|
||||
@ -234,13 +234,12 @@ class MessageCollector(commands.Bot):
|
||||
async def setup_hook(self):
|
||||
"""Register slash commands when the bot is ready."""
|
||||
|
||||
if not (name := self.user.name):
|
||||
if not self.user:
|
||||
logger.error(f"Failed to get user name for {self.user}")
|
||||
return
|
||||
|
||||
name = name.replace("-", "_").lower()
|
||||
try:
|
||||
register_slash_commands(self, name=name)
|
||||
register_slash_commands(self)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
|
||||
logger.error(f"Registered slash commands for {self.user.name}")
|
||||
|
||||
@ -44,7 +44,7 @@ class CommandContext:
|
||||
CommandHandler = Callable[..., CommandResponse]
|
||||
|
||||
|
||||
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||
def register_slash_commands(bot: discord.Client) -> None:
|
||||
"""Register the collector slash commands on the provided bot.
|
||||
|
||||
Args:
|
||||
@ -53,7 +53,6 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||
"""
|
||||
|
||||
if getattr(bot, "_memory_commands_registered", False):
|
||||
logger.error(f"Slash commands already registered for {name}")
|
||||
return
|
||||
|
||||
setattr(bot, "_memory_commands_registered", True)
|
||||
@ -62,6 +61,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||
raise RuntimeError("Bot instance does not support app commands")
|
||||
|
||||
tree = bot.tree
|
||||
name = bot.user and bot.user.name.replace("-", "_").lower()
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_show_prompt", description="Show the current system prompt"
|
||||
@ -184,7 +184,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||
url: str | None = None,
|
||||
) -> None:
|
||||
await run_mcp_server_command(interaction, action, url and url.strip(), name)
|
||||
await run_mcp_server_command(interaction, bot.user, action, url and url.strip())
|
||||
|
||||
|
||||
async def _run_interaction_command(
|
||||
|
||||
@ -100,11 +100,15 @@ async def handle_mcp_list(interaction: discord.Interaction) -> str:
|
||||
|
||||
|
||||
async def handle_mcp_add(
|
||||
interaction: discord.Interaction, url: str, name: str = "memory"
|
||||
interaction: discord.Interaction,
|
||||
bot_user: discord.User | None,
|
||||
url: str,
|
||||
) -> str:
|
||||
"""Add a new MCP server via OAuth."""
|
||||
if not bot_user:
|
||||
raise ValueError("Bot user is required")
|
||||
with make_session() as session:
|
||||
if find_mcp_server(session, interaction.user.id, url):
|
||||
if find_mcp_server(session, bot_user.id, url):
|
||||
return (
|
||||
f"**MCP Server Already Exists**\n\n"
|
||||
f"You already have an MCP server configured at `{url}`.\n"
|
||||
@ -115,10 +119,10 @@ async def handle_mcp_add(
|
||||
client_id = await register_oauth_client(
|
||||
endpoints,
|
||||
url,
|
||||
f"Discord Bot - {name} ({interaction.user.name})",
|
||||
f"Discord Bot - {bot_user.name} ({interaction.user.name})",
|
||||
)
|
||||
mcp_server = DiscordMCPServer(
|
||||
discord_bot_user_id=interaction.user.id,
|
||||
discord_bot_user_id=bot_user.id,
|
||||
mcp_server_url=url,
|
||||
client_id=client_id,
|
||||
)
|
||||
@ -142,10 +146,10 @@ async def handle_mcp_add(
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
|
||||
async def handle_mcp_delete(bot_user: discord.User, url: str) -> str:
|
||||
"""Delete an MCP server."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||
if not mcp_server:
|
||||
return (
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
@ -157,10 +161,10 @@ async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
|
||||
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
||||
|
||||
|
||||
async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
|
||||
async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
||||
"""Reconnect to an existing MCP server (redo OAuth)."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
@ -180,9 +184,7 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}"
|
||||
)
|
||||
logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}")
|
||||
|
||||
return (
|
||||
f"🔄 **Reconnect to MCP Server**\n\n"
|
||||
@ -193,10 +195,10 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
|
||||
async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
|
||||
"""List tools available on an MCP server."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
@ -264,9 +266,9 @@ async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
|
||||
|
||||
async def run_mcp_server_command(
|
||||
interaction: discord.Interaction,
|
||||
bot_user: discord.User | None,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||
url: str | None,
|
||||
name: str = "memory",
|
||||
) -> None:
|
||||
"""Handle MCP server management commands."""
|
||||
if action not in ["list", "add", "delete", "connect", "tools"]:
|
||||
@ -277,18 +279,23 @@ async def run_mcp_server_command(
|
||||
"❌ URL is required for this action", ephemeral=True
|
||||
)
|
||||
return
|
||||
if not bot_user:
|
||||
await interaction.response.send_message(
|
||||
"❌ Bot user is required", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if action == "list" or not url:
|
||||
result = await handle_mcp_list(interaction)
|
||||
elif action == "add":
|
||||
result = await handle_mcp_add(interaction, url, name)
|
||||
result = await handle_mcp_add(interaction, bot_user, url)
|
||||
elif action == "delete":
|
||||
result = await handle_mcp_delete(interaction, url)
|
||||
result = await handle_mcp_delete(bot_user, url)
|
||||
elif action == "connect":
|
||||
result = await handle_mcp_connect(interaction, url)
|
||||
result = await handle_mcp_connect(bot_user, url)
|
||||
elif action == "tools":
|
||||
result = await handle_mcp_tools(interaction, url)
|
||||
result = await handle_mcp_tools(bot_user, url)
|
||||
except Exception as exc:
|
||||
result = f"❌ Error: {exc}"
|
||||
await interaction.response.send_message(result, ephemeral=True)
|
||||
|
||||
@ -99,7 +99,7 @@ async def get_endpoints(url: str) -> OAuthEndpoints:
|
||||
authorization_endpoint=authorization_endpoint,
|
||||
registration_endpoint=registration_endpoint,
|
||||
token_endpoint=token_endpoint,
|
||||
redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord",
|
||||
redirect_uri=f"{settings.SERVER_URL}/oauth/callback/discord",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user