mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
handle mcp servers in discord
This commit is contained in:
parent
64bb926eba
commit
0d9f8beec3
@ -9,6 +9,7 @@ import anthropic
|
|||||||
from memory.common.llms.base import (
|
from memory.common.llms.base import (
|
||||||
BaseLLMProvider,
|
BaseLLMProvider,
|
||||||
ImageContent,
|
ImageContent,
|
||||||
|
MCPServer,
|
||||||
LLMSettings,
|
LLMSettings,
|
||||||
Message,
|
Message,
|
||||||
MessageRole,
|
MessageRole,
|
||||||
@ -103,6 +104,7 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None,
|
system_prompt: str | None,
|
||||||
tools: list[ToolDefinition] | None,
|
tools: list[ToolDefinition] | None,
|
||||||
|
mcp_servers: list[MCPServer] | None,
|
||||||
settings: LLMSettings,
|
settings: LLMSettings,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Build common request kwargs for API calls."""
|
"""Build common request kwargs for API calls."""
|
||||||
@ -113,7 +115,9 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
"messages": anthropic_messages,
|
"messages": anthropic_messages,
|
||||||
"temperature": settings.temperature,
|
"temperature": settings.temperature,
|
||||||
"max_tokens": settings.max_tokens,
|
"max_tokens": settings.max_tokens,
|
||||||
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
|
"extra_headers": {
|
||||||
|
"anthropic-beta": "web-fetch-2025-09-10,mcp-client-2025-04-04"
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Only include top_p if explicitly set
|
# Only include top_p if explicitly set
|
||||||
@ -129,6 +133,25 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = self._convert_tools(tools)
|
kwargs["tools"] = self._convert_tools(tools)
|
||||||
|
|
||||||
|
if mcp_servers:
|
||||||
|
|
||||||
|
def format_server(server: MCPServer) -> dict[str, Any]:
|
||||||
|
conf: dict[str, Any] = {
|
||||||
|
"type": "url",
|
||||||
|
"url": server.url,
|
||||||
|
"name": server.name,
|
||||||
|
"authorization_token": server.token,
|
||||||
|
}
|
||||||
|
if server.allowed_tools:
|
||||||
|
conf["tool_configuration"] = {
|
||||||
|
"allowed_tools": server.allowed_tools,
|
||||||
|
}
|
||||||
|
return conf
|
||||||
|
|
||||||
|
kwargs["extra_body"] = {
|
||||||
|
"mcp_servers": [format_server(server) for server in mcp_servers]
|
||||||
|
}
|
||||||
|
|
||||||
# Enable extended thinking if requested and model supports it
|
# Enable extended thinking if requested and model supports it
|
||||||
if self.enable_thinking and self._supports_thinking():
|
if self.enable_thinking and self._supports_thinking():
|
||||||
thinking_budget = min(10000, settings.max_tokens - 1024)
|
thinking_budget = min(10000, settings.max_tokens - 1024)
|
||||||
@ -141,6 +164,9 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
kwargs["temperature"] = 1.0
|
kwargs["temperature"] = 1.0
|
||||||
kwargs.pop("top_p", None)
|
kwargs.pop("top_p", None)
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
print(f"{k}: {v}")
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _handle_stream_event(
|
def _handle_stream_event(
|
||||||
@ -312,11 +338,14 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a non-streaming response."""
|
"""Generate a non-streaming response."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
kwargs = self._build_request_kwargs(
|
||||||
|
messages, system_prompt, tools, mcp_servers, settings
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.messages.create(**kwargs)
|
response = self.client.messages.create(**kwargs)
|
||||||
@ -332,11 +361,14 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> Iterator[StreamEvent]:
|
) -> Iterator[StreamEvent]:
|
||||||
"""Generate a streaming response."""
|
"""Generate a streaming response."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
kwargs = self._build_request_kwargs(
|
||||||
|
messages, system_prompt, tools, mcp_servers, settings
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self.client.messages.stream(**kwargs) as stream:
|
with self.client.messages.stream(**kwargs) as stream:
|
||||||
@ -358,11 +390,14 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a non-streaming response asynchronously."""
|
"""Generate a non-streaming response asynchronously."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
kwargs = self._build_request_kwargs(
|
||||||
|
messages, system_prompt, tools, mcp_servers, settings
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.async_client.messages.create(**kwargs)
|
response = await self.async_client.messages.create(**kwargs)
|
||||||
@ -378,11 +413,14 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> AsyncIterator[StreamEvent]:
|
) -> AsyncIterator[StreamEvent]:
|
||||||
"""Generate a streaming response asynchronously."""
|
"""Generate a streaming response asynchronously."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
|
kwargs = self._build_request_kwargs(
|
||||||
|
messages, system_prompt, tools, mcp_servers, settings
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.async_client.messages.stream(**kwargs) as stream:
|
async with self.async_client.messages.stream(**kwargs) as stream:
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from typing import Any, AsyncIterator, Iterator, Literal, Union, cast
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
from memory.common.llms.tools import MCPServer, ToolCall, ToolDefinition, ToolResult
|
||||||
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
|
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -434,6 +434,7 @@ class BaseLLMProvider(ABC):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -456,6 +457,7 @@ class BaseLLMProvider(ABC):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> Iterator[StreamEvent]:
|
) -> Iterator[StreamEvent]:
|
||||||
"""
|
"""
|
||||||
@ -478,6 +480,7 @@ class BaseLLMProvider(ABC):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -500,6 +503,7 @@ class BaseLLMProvider(ABC):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> AsyncIterator[StreamEvent]:
|
) -> AsyncIterator[StreamEvent]:
|
||||||
"""
|
"""
|
||||||
@ -520,6 +524,7 @@ class BaseLLMProvider(ABC):
|
|||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
tools: dict[str, ToolDefinition],
|
tools: dict[str, ToolDefinition],
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
max_iterations: int = 10,
|
max_iterations: int = 10,
|
||||||
@ -551,6 +556,7 @@ class BaseLLMProvider(ABC):
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
tools=list(tools.values()),
|
tools=list(tools.values()),
|
||||||
|
mcp_servers=mcp_servers,
|
||||||
settings=settings,
|
settings=settings,
|
||||||
):
|
):
|
||||||
if event.type == "text":
|
if event.type == "text":
|
||||||
@ -583,7 +589,12 @@ class BaseLLMProvider(ABC):
|
|||||||
|
|
||||||
# Recursively continue the conversation with reduced iterations
|
# Recursively continue the conversation with reduced iterations
|
||||||
yield from self.stream_with_tools(
|
yield from self.stream_with_tools(
|
||||||
messages, tools, settings, system_prompt, max_iterations - 1
|
messages,
|
||||||
|
tools,
|
||||||
|
mcp_servers,
|
||||||
|
settings,
|
||||||
|
system_prompt,
|
||||||
|
max_iterations - 1,
|
||||||
)
|
)
|
||||||
return # Exit after recursive call completes
|
return # Exit after recursive call completes
|
||||||
|
|
||||||
@ -598,6 +609,7 @@ class BaseLLMProvider(ABC):
|
|||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
tools: dict[str, ToolDefinition],
|
tools: dict[str, ToolDefinition],
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
max_iterations: int = 10,
|
max_iterations: int = 10,
|
||||||
@ -606,6 +618,7 @@ class BaseLLMProvider(ABC):
|
|||||||
for event in self.stream_with_tools(
|
for event in self.stream_with_tools(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
mcp_servers=mcp_servers,
|
||||||
settings=settings,
|
settings=settings,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
max_iterations=max_iterations,
|
max_iterations=max_iterations,
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import openai
|
|||||||
from memory.common.llms.base import (
|
from memory.common.llms.base import (
|
||||||
BaseLLMProvider,
|
BaseLLMProvider,
|
||||||
ImageContent,
|
ImageContent,
|
||||||
|
MCPServer,
|
||||||
LLMSettings,
|
LLMSettings,
|
||||||
Message,
|
Message,
|
||||||
StreamEvent,
|
StreamEvent,
|
||||||
@ -175,6 +176,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None,
|
system_prompt: str | None,
|
||||||
tools: list[ToolDefinition] | None,
|
tools: list[ToolDefinition] | None,
|
||||||
|
mcp_servers: list[MCPServer] | None,
|
||||||
settings: LLMSettings,
|
settings: LLMSettings,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@ -185,6 +187,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
messages: Conversation history
|
messages: Conversation history
|
||||||
system_prompt: Optional system prompt
|
system_prompt: Optional system prompt
|
||||||
tools: Optional list of tools
|
tools: Optional list of tools
|
||||||
|
mcp_servers: Optional list of MCP servers
|
||||||
settings: LLM settings
|
settings: LLM settings
|
||||||
stream: Whether to enable streaming
|
stream: Whether to enable streaming
|
||||||
|
|
||||||
@ -333,12 +336,13 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a non-streaming response."""
|
"""Generate a non-streaming response."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(
|
kwargs = self._build_request_kwargs(
|
||||||
messages, system_prompt, tools, settings, stream=False
|
messages, system_prompt, tools, mcp_servers, settings, stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -361,12 +365,13 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> Iterator[StreamEvent]:
|
) -> Iterator[StreamEvent]:
|
||||||
"""Generate a streaming response."""
|
"""Generate a streaming response."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(
|
kwargs = self._build_request_kwargs(
|
||||||
messages, system_prompt, tools, settings, stream=True
|
messages, system_prompt, tools, mcp_servers, settings, stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
@ -393,12 +398,13 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a non-streaming response asynchronously."""
|
"""Generate a non-streaming response asynchronously."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(
|
kwargs = self._build_request_kwargs(
|
||||||
messages, system_prompt, tools, settings, stream=False
|
messages, system_prompt, tools, mcp_servers, settings, stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -413,12 +419,13 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
settings: LLMSettings | None = None,
|
settings: LLMSettings | None = None,
|
||||||
) -> AsyncIterator[StreamEvent]:
|
) -> AsyncIterator[StreamEvent]:
|
||||||
"""Generate a streaming response asynchronously."""
|
"""Generate a streaming response asynchronously."""
|
||||||
settings = settings or LLMSettings()
|
settings = settings or LLMSettings()
|
||||||
kwargs = self._build_request_kwargs(
|
kwargs = self._build_request_kwargs(
|
||||||
messages, system_prompt, tools, settings, stream=True
|
messages, system_prompt, tools, mcp_servers, settings, stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -23,6 +23,16 @@ class ToolResult(TypedDict):
|
|||||||
output: str
|
output: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPServer:
|
||||||
|
"""An MCP server."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
url: str
|
||||||
|
token: str
|
||||||
|
allowed_tools: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolDefinition:
|
class ToolDefinition:
|
||||||
"""Definition of a tool that can be called by the LLM."""
|
"""Definition of a tool that can be called by the LLM."""
|
||||||
|
|||||||
@ -234,13 +234,12 @@ class MessageCollector(commands.Bot):
|
|||||||
async def setup_hook(self):
|
async def setup_hook(self):
|
||||||
"""Register slash commands when the bot is ready."""
|
"""Register slash commands when the bot is ready."""
|
||||||
|
|
||||||
if not (name := self.user.name):
|
if not self.user:
|
||||||
logger.error(f"Failed to get user name for {self.user}")
|
logger.error(f"Failed to get user name for {self.user}")
|
||||||
return
|
return
|
||||||
|
|
||||||
name = name.replace("-", "_").lower()
|
|
||||||
try:
|
try:
|
||||||
register_slash_commands(self, name=name)
|
register_slash_commands(self)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
|
logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
|
||||||
logger.error(f"Registered slash commands for {self.user.name}")
|
logger.error(f"Registered slash commands for {self.user.name}")
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class CommandContext:
|
|||||||
CommandHandler = Callable[..., CommandResponse]
|
CommandHandler = Callable[..., CommandResponse]
|
||||||
|
|
||||||
|
|
||||||
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
def register_slash_commands(bot: discord.Client) -> None:
|
||||||
"""Register the collector slash commands on the provided bot.
|
"""Register the collector slash commands on the provided bot.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -53,7 +53,6 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if getattr(bot, "_memory_commands_registered", False):
|
if getattr(bot, "_memory_commands_registered", False):
|
||||||
logger.error(f"Slash commands already registered for {name}")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
setattr(bot, "_memory_commands_registered", True)
|
setattr(bot, "_memory_commands_registered", True)
|
||||||
@ -62,6 +61,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
|||||||
raise RuntimeError("Bot instance does not support app commands")
|
raise RuntimeError("Bot instance does not support app commands")
|
||||||
|
|
||||||
tree = bot.tree
|
tree = bot.tree
|
||||||
|
name = bot.user and bot.user.name.replace("-", "_").lower()
|
||||||
|
|
||||||
@tree.command(
|
@tree.command(
|
||||||
name=f"{name}_show_prompt", description="Show the current system prompt"
|
name=f"{name}_show_prompt", description="Show the current system prompt"
|
||||||
@ -184,7 +184,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
|||||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await run_mcp_server_command(interaction, action, url and url.strip(), name)
|
await run_mcp_server_command(interaction, bot.user, action, url and url.strip())
|
||||||
|
|
||||||
|
|
||||||
async def _run_interaction_command(
|
async def _run_interaction_command(
|
||||||
|
|||||||
@ -100,11 +100,15 @@ async def handle_mcp_list(interaction: discord.Interaction) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def handle_mcp_add(
|
async def handle_mcp_add(
|
||||||
interaction: discord.Interaction, url: str, name: str = "memory"
|
interaction: discord.Interaction,
|
||||||
|
bot_user: discord.User | None,
|
||||||
|
url: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Add a new MCP server via OAuth."""
|
"""Add a new MCP server via OAuth."""
|
||||||
|
if not bot_user:
|
||||||
|
raise ValueError("Bot user is required")
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
if find_mcp_server(session, interaction.user.id, url):
|
if find_mcp_server(session, bot_user.id, url):
|
||||||
return (
|
return (
|
||||||
f"**MCP Server Already Exists**\n\n"
|
f"**MCP Server Already Exists**\n\n"
|
||||||
f"You already have an MCP server configured at `{url}`.\n"
|
f"You already have an MCP server configured at `{url}`.\n"
|
||||||
@ -115,10 +119,10 @@ async def handle_mcp_add(
|
|||||||
client_id = await register_oauth_client(
|
client_id = await register_oauth_client(
|
||||||
endpoints,
|
endpoints,
|
||||||
url,
|
url,
|
||||||
f"Discord Bot - {name} ({interaction.user.name})",
|
f"Discord Bot - {bot_user.name} ({interaction.user.name})",
|
||||||
)
|
)
|
||||||
mcp_server = DiscordMCPServer(
|
mcp_server = DiscordMCPServer(
|
||||||
discord_bot_user_id=interaction.user.id,
|
discord_bot_user_id=bot_user.id,
|
||||||
mcp_server_url=url,
|
mcp_server_url=url,
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
)
|
)
|
||||||
@ -142,10 +146,10 @@ async def handle_mcp_add(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
|
async def handle_mcp_delete(bot_user: discord.User, url: str) -> str:
|
||||||
"""Delete an MCP server."""
|
"""Delete an MCP server."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
return (
|
return (
|
||||||
f"**MCP Server Not Found**\n\n"
|
f"**MCP Server Not Found**\n\n"
|
||||||
@ -157,10 +161,10 @@ async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
|
|||||||
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
|
async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
||||||
"""Reconnect to an existing MCP server (redo OAuth)."""
|
"""Reconnect to an existing MCP server (redo OAuth)."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"**MCP Server Not Found**\n\n"
|
f"**MCP Server Not Found**\n\n"
|
||||||
@ -180,9 +184,7 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}")
|
||||||
f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"🔄 **Reconnect to MCP Server**\n\n"
|
f"🔄 **Reconnect to MCP Server**\n\n"
|
||||||
@ -193,10 +195,10 @@ async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
|
async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
|
||||||
"""List tools available on an MCP server."""
|
"""List tools available on an MCP server."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||||
|
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -264,9 +266,9 @@ async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
|
|||||||
|
|
||||||
async def run_mcp_server_command(
|
async def run_mcp_server_command(
|
||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
|
bot_user: discord.User | None,
|
||||||
action: Literal["list", "add", "delete", "connect", "tools"],
|
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||||
url: str | None,
|
url: str | None,
|
||||||
name: str = "memory",
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle MCP server management commands."""
|
"""Handle MCP server management commands."""
|
||||||
if action not in ["list", "add", "delete", "connect", "tools"]:
|
if action not in ["list", "add", "delete", "connect", "tools"]:
|
||||||
@ -277,18 +279,23 @@ async def run_mcp_server_command(
|
|||||||
"❌ URL is required for this action", ephemeral=True
|
"❌ URL is required for this action", ephemeral=True
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
if not bot_user:
|
||||||
|
await interaction.response.send_message(
|
||||||
|
"❌ Bot user is required", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if action == "list" or not url:
|
if action == "list" or not url:
|
||||||
result = await handle_mcp_list(interaction)
|
result = await handle_mcp_list(interaction)
|
||||||
elif action == "add":
|
elif action == "add":
|
||||||
result = await handle_mcp_add(interaction, url, name)
|
result = await handle_mcp_add(interaction, bot_user, url)
|
||||||
elif action == "delete":
|
elif action == "delete":
|
||||||
result = await handle_mcp_delete(interaction, url)
|
result = await handle_mcp_delete(bot_user, url)
|
||||||
elif action == "connect":
|
elif action == "connect":
|
||||||
result = await handle_mcp_connect(interaction, url)
|
result = await handle_mcp_connect(bot_user, url)
|
||||||
elif action == "tools":
|
elif action == "tools":
|
||||||
result = await handle_mcp_tools(interaction, url)
|
result = await handle_mcp_tools(bot_user, url)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
result = f"❌ Error: {exc}"
|
result = f"❌ Error: {exc}"
|
||||||
await interaction.response.send_message(result, ephemeral=True)
|
await interaction.response.send_message(result, ephemeral=True)
|
||||||
|
|||||||
@ -99,7 +99,7 @@ async def get_endpoints(url: str) -> OAuthEndpoints:
|
|||||||
authorization_endpoint=authorization_endpoint,
|
authorization_endpoint=authorization_endpoint,
|
||||||
registration_endpoint=registration_endpoint,
|
registration_endpoint=registration_endpoint,
|
||||||
token_endpoint=token_endpoint,
|
token_endpoint=token_endpoint,
|
||||||
redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord",
|
redirect_uri=f"{settings.SERVER_URL}/oauth/callback/discord",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user