mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
mcp servers for discord bots
This commit is contained in:
parent
6250586d1f
commit
64bb926eba
@ -0,0 +1,67 @@
|
||||
"""discord mcp servers
|
||||
|
||||
Revision ID: 9b887449ea92
|
||||
Revises: 1954477b25f4
|
||||
Create Date: 2025-11-02 22:04:26.259323
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9b887449ea92"
|
||||
down_revision: Union[str, None] = "1954477b25f4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"discord_mcp_servers",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
||||
sa.Column("client_id", sa.Text(), nullable=False),
|
||||
sa.Column("state", sa.Text(), nullable=True),
|
||||
sa.Column("code_verifier", sa.Text(), nullable=True),
|
||||
sa.Column("access_token", sa.Text(), nullable=True),
|
||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["discord_bot_user_id"],
|
||||
["discord_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("state"),
|
||||
)
|
||||
op.create_index(
|
||||
"discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"discord_mcp_user_url_idx",
|
||||
"discord_mcp_servers",
|
||||
["discord_bot_user_id", "mcp_server_url"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
|
||||
op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
|
||||
op.drop_table("discord_mcp_servers")
|
||||
@ -34,6 +34,7 @@ from memory.common.db.models.discord import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
DiscordMCPServer,
|
||||
)
|
||||
from memory.common.db.models.observations import (
|
||||
ObservationContradiction,
|
||||
@ -106,6 +107,7 @@ __all__ = [
|
||||
"DiscordServer",
|
||||
"DiscordChannel",
|
||||
"DiscordUser",
|
||||
"DiscordMCPServer",
|
||||
# Users
|
||||
"User",
|
||||
"HumanUser",
|
||||
|
||||
@ -127,5 +127,44 @@ class DiscordUser(Base, MessageProcessor):
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
system_user = relationship("User", back_populates="discord_users")
|
||||
mcp_servers = relationship(
|
||||
"DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||
|
||||
|
||||
class DiscordMCPServer(Base):
|
||||
"""MCP server configuration and OAuth state for Discord users."""
|
||||
|
||||
__tablename__ = "discord_mcp_servers"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
discord_bot_user_id = Column(
|
||||
BigInteger, ForeignKey("discord_users.id"), nullable=False
|
||||
)
|
||||
|
||||
# MCP server info
|
||||
mcp_server_url = Column(Text, nullable=False)
|
||||
client_id = Column(Text, nullable=False)
|
||||
|
||||
# OAuth flow state (temporary, cleared after token exchange)
|
||||
state = Column(Text, nullable=True, unique=True)
|
||||
code_verifier = Column(Text, nullable=True)
|
||||
|
||||
# OAuth tokens (set after successful authorization)
|
||||
access_token = Column(Text, nullable=True)
|
||||
refresh_token = Column(Text, nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
discord_user = relationship("DiscordUser", back_populates="mcp_servers")
|
||||
|
||||
__table_args__ = (
|
||||
Index("discord_mcp_state_idx", "state"),
|
||||
Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"),
|
||||
)
|
||||
|
||||
@ -363,7 +363,19 @@ class DiscordMessage(SourceItem):
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
||||
return textwrap.dedent("""
|
||||
<message>
|
||||
<id>{message_id}</id>
|
||||
<from>{from_user}</from>
|
||||
<sent_at>{sent_at}</sent_at>
|
||||
<content>{content}</content>
|
||||
</message>
|
||||
""").format(
|
||||
message_id=self.message_id,
|
||||
from_user=self.from_user.username,
|
||||
sent_at=self.sent_at.isoformat()[:19],
|
||||
content=self.content,
|
||||
)
|
||||
|
||||
def as_content(self) -> dict[str, Any]:
|
||||
"""Return message content ready for LLM (text + images from disk)."""
|
||||
|
||||
@ -94,6 +94,30 @@ def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def add_reaction(bot_id: int, channel: int | str, message_id: int, emoji: str) -> bool:
|
||||
"""Add a reaction to a message in a channel"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/add_reaction",
|
||||
json={
|
||||
"bot_id": bot_id,
|
||||
"channel": channel,
|
||||
"message_id": message_id,
|
||||
"emoji": emoji,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"Failed to add reaction {emoji} to message {message_id} in channel {channel}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
|
||||
"""Send a message to a channel by name or ID (ID supports threads)"""
|
||||
try:
|
||||
|
||||
@ -17,6 +17,7 @@ from memory.common.db.models import (
|
||||
BotUser,
|
||||
)
|
||||
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
|
||||
from memory.common.discord import add_reaction
|
||||
|
||||
|
||||
UpdateSummaryType = Literal["server", "channel", "user"]
|
||||
@ -209,6 +210,50 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
|
||||
)
|
||||
|
||||
|
||||
def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition:
|
||||
bot_id = cast(int, bot.id)
|
||||
channel_id = channel and channel.id
|
||||
|
||||
def handler(input: ToolInput) -> str:
|
||||
if not isinstance(input, dict):
|
||||
raise ValueError("Input must be a dictionary")
|
||||
try:
|
||||
emoji = input.get("emoji")
|
||||
except ValueError:
|
||||
raise ValueError("Emoji is required")
|
||||
if not emoji:
|
||||
raise ValueError("Emoji is required")
|
||||
|
||||
try:
|
||||
message_id = int(input.get("message_id") or "no id")
|
||||
except ValueError:
|
||||
raise ValueError("Message ID is required")
|
||||
|
||||
success = add_reaction(bot_id, channel_id, message_id, emoji)
|
||||
if not success:
|
||||
return "Failed to add reaction"
|
||||
return "Reaction added"
|
||||
|
||||
return ToolDefinition(
|
||||
name="add_reaction",
|
||||
description="Add a reaction to a message in a channel",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message_id": {
|
||||
"type": "number",
|
||||
"description": "The ID of the message to add the reaction to",
|
||||
},
|
||||
"emoji": {
|
||||
"type": "string",
|
||||
"description": "The emoji to add to the message",
|
||||
},
|
||||
},
|
||||
},
|
||||
function=handler,
|
||||
)
|
||||
|
||||
|
||||
def make_discord_tools(
|
||||
bot: BotUser,
|
||||
author: DiscordUser | None,
|
||||
@ -227,5 +272,6 @@ def make_discord_tools(
|
||||
if channel and channel.server:
|
||||
tools += [
|
||||
make_summary_tool("server", cast(BigInteger, channel.server_id)),
|
||||
make_add_reaction_tool(bot, channel),
|
||||
]
|
||||
return {tool.name: tool for tool in tools}
|
||||
|
||||
@ -12,13 +12,15 @@ from contextlib import asynccontextmanager
|
||||
from typing import cast
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.users import DiscordBotUser
|
||||
from memory.discord.collector import MessageCollector
|
||||
from memory.discord.oauth import complete_oauth_flow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -45,6 +47,13 @@ class TypingChannelRequest(BaseModel):
|
||||
channel: int | str # Channel name or ID (ID supports threads)
|
||||
|
||||
|
||||
class AddReactionRequest(BaseModel):
|
||||
bot_id: int
|
||||
channel: int | str # Channel name or ID (ID supports threads)
|
||||
message_id: int
|
||||
emoji: str
|
||||
|
||||
|
||||
class Collector:
|
||||
collector: MessageCollector
|
||||
collector_task: asyncio.Task
|
||||
@ -53,6 +62,7 @@ class Collector:
|
||||
bot_name: str
|
||||
|
||||
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
||||
logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}")
|
||||
self.collector = collector
|
||||
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
||||
self.bot_id = cast(int, bot.id)
|
||||
@ -72,7 +82,9 @@ async def lifespan(app: FastAPI):
|
||||
bots = session.query(DiscordBotUser).all()
|
||||
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
||||
|
||||
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
||||
logger.error(
|
||||
f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
@ -216,6 +228,36 @@ async def health_check():
|
||||
}
|
||||
|
||||
|
||||
@app.post("/add_reaction")
|
||||
async def add_reaction_endpoint(request: AddReactionRequest):
|
||||
"""Add a reaction to a message via the collector's Discord client"""
|
||||
collector = app.bots.get(request.bot_id)
|
||||
if not collector:
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
|
||||
try:
|
||||
success = await collector.collector.add_reaction(
|
||||
request.channel, request.message_id, request.emoji
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add reaction: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to add reaction to message {request.message_id}",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"channel": request.channel,
|
||||
"message_id": request.message_id,
|
||||
"emoji": request.emoji,
|
||||
"message": f"Added reaction {request.emoji} to message {request.message_id}",
|
||||
}
|
||||
|
||||
|
||||
@app.post("/refresh_metadata")
|
||||
async def refresh_metadata():
|
||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||
@ -237,6 +279,51 @@ async def refresh_metadata():
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/oauth/callback/discord", response_class=HTMLResponse)
|
||||
async def oauth_callback(request: Request):
|
||||
"""Handle OAuth callback from MCP server after user authorization."""
|
||||
code = request.query_params.get("code")
|
||||
state = request.query_params.get("state")
|
||||
error = request.query_params.get("error")
|
||||
|
||||
logger.info(
|
||||
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
|
||||
)
|
||||
|
||||
message, title, close, status_code = "", "", "", 200
|
||||
if error:
|
||||
logger.error(f"OAuth error: {error}")
|
||||
message = f"Error: {error}"
|
||||
title = "❌ Authorization Failed"
|
||||
status_code = 400
|
||||
elif not code or not state:
|
||||
message = "Missing authorization code or state parameter."
|
||||
title = "❌ Invalid Request"
|
||||
status_code = 400
|
||||
else:
|
||||
# Complete the OAuth flow (exchange code for token)
|
||||
with make_session() as session:
|
||||
status_code, message = await complete_oauth_flow(session, code, state)
|
||||
if 200 <= status_code < 300:
|
||||
title = "✅ Authorization Successful!"
|
||||
close = "You can close this window and return to the MCP server."
|
||||
else:
|
||||
title = "❌ Authorization Failed"
|
||||
|
||||
return HTMLResponse(
|
||||
content=f"""
|
||||
<html>
|
||||
<body>
|
||||
<h1>{title}</h1>
|
||||
<p>{message}</p>
|
||||
<p>{close}</p>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
|
||||
"""Run the Discord API server"""
|
||||
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||
|
||||
@ -229,11 +229,21 @@ class MessageCollector(commands.Bot):
|
||||
intents=intents,
|
||||
help_command=None, # Disable default help
|
||||
)
|
||||
logger.info(f"Initialized collector for {self.user}")
|
||||
|
||||
async def setup_hook(self):
|
||||
"""Register slash commands when the bot is ready."""
|
||||
|
||||
register_slash_commands(self, name=self.user.name)
|
||||
if not (name := self.user.name):
|
||||
logger.error(f"Failed to get user name for {self.user}")
|
||||
return
|
||||
|
||||
name = name.replace("-", "_").lower()
|
||||
try:
|
||||
register_slash_commands(self, name=name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
|
||||
logger.error(f"Registered slash commands for {self.user.name}")
|
||||
|
||||
async def on_ready(self):
|
||||
"""Called when bot connects to Discord"""
|
||||
@ -313,8 +323,6 @@ class MessageCollector(commands.Bot):
|
||||
|
||||
async def refresh_metadata(self) -> dict[str, int]:
|
||||
"""Refresh server and channel metadata from Discord and update database"""
|
||||
print("🔄 Refreshing Discord metadata...")
|
||||
|
||||
servers_updated = 0
|
||||
channels_updated = 0
|
||||
users_updated = 0
|
||||
@ -454,25 +462,33 @@ class MessageCollector(commands.Bot):
|
||||
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
||||
return False
|
||||
|
||||
async def _get_channel(
|
||||
self, channel_identifier: int | str, check_notifications: bool = True
|
||||
):
|
||||
"""Get channel by ID or name with standard checks"""
|
||||
if check_notifications and not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
logger.debug("Discord notifications disabled")
|
||||
return None
|
||||
|
||||
if isinstance(channel_identifier, int):
|
||||
channel = self.get_channel(channel_identifier)
|
||||
else:
|
||||
channel = await self.get_channel_by_name(channel_identifier)
|
||||
|
||||
if not channel:
|
||||
logger.error(f"Channel {channel_identifier} not found")
|
||||
|
||||
return channel
|
||||
|
||||
async def send_to_channel(
|
||||
self, channel_identifier: int | str, message: str
|
||||
) -> bool:
|
||||
"""Send a message to a channel by name or ID (supports threads)"""
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get channel by ID or name
|
||||
if isinstance(channel_identifier, int):
|
||||
channel = self.get_channel(channel_identifier)
|
||||
else:
|
||||
channel = await self.get_channel_by_name(channel_identifier)
|
||||
|
||||
channel = await self._get_channel(channel_identifier)
|
||||
if not channel:
|
||||
logger.error(f"Channel {channel_identifier} not found")
|
||||
return False
|
||||
|
||||
# Post-process mentions to convert usernames to IDs
|
||||
with make_session() as session:
|
||||
processed_message = process_mentions(session, message)
|
||||
|
||||
@ -486,18 +502,9 @@ class MessageCollector(commands.Bot):
|
||||
|
||||
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
|
||||
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get channel by ID or name
|
||||
if isinstance(channel_identifier, int):
|
||||
channel = self.get_channel(channel_identifier)
|
||||
else:
|
||||
channel = await self.get_channel_by_name(channel_identifier)
|
||||
|
||||
channel = await self._get_channel(channel_identifier)
|
||||
if not channel:
|
||||
logger.error(f"Channel {channel_identifier} not found")
|
||||
return False
|
||||
|
||||
async with channel.typing():
|
||||
@ -509,3 +516,21 @@ class MessageCollector(commands.Bot):
|
||||
f"Failed to trigger typing for channel {channel_identifier}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def add_reaction(
|
||||
self, channel_identifier: int | str, message_id: int, emoji: str
|
||||
) -> bool:
|
||||
"""Add a reaction to a message in a channel"""
|
||||
try:
|
||||
channel = await self._get_channel(channel_identifier)
|
||||
if not channel:
|
||||
return False
|
||||
|
||||
message = await channel.fetch_message(message_id)
|
||||
await message.add_reaction(emoji)
|
||||
logger.info(f"Added reaction {emoji} to message {message_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add reaction: {e}")
|
||||
return False
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Lightweight slash-command helpers for the Discord collector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from calendar import c
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Literal
|
||||
|
||||
@ -10,6 +10,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
from memory.discord.mcp import run_mcp_server_command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ScopeLiteral = Literal["server", "channel", "user"]
|
||||
|
||||
@ -50,6 +53,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||
"""
|
||||
|
||||
if getattr(bot, "_memory_commands_registered", False):
|
||||
logger.error(f"Slash commands already registered for {name}")
|
||||
return
|
||||
|
||||
setattr(bot, "_memory_commands_registered", True)
|
||||
@ -167,6 +171,21 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||
target_user=user,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_mcp_servers",
|
||||
description="Manage MCP servers for your account",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
action="Action to perform",
|
||||
url="MCP server URL (required for add, delete, connect, tools)",
|
||||
)
|
||||
async def mcp_servers_command(
|
||||
interaction: discord.Interaction,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||
url: str | None = None,
|
||||
) -> None:
|
||||
await run_mcp_server_command(interaction, action, url and url.strip(), name)
|
||||
|
||||
|
||||
async def _run_interaction_command(
|
||||
interaction: discord.Interaction,
|
||||
@ -177,17 +196,10 @@ async def _run_interaction_command(
|
||||
**handler_kwargs,
|
||||
) -> None:
|
||||
"""Shared coroutine used by the registered slash commands."""
|
||||
|
||||
try:
|
||||
with make_session() as session:
|
||||
response = run_command(
|
||||
session,
|
||||
interaction,
|
||||
scope,
|
||||
handler=handler,
|
||||
target_user=target_user,
|
||||
**handler_kwargs,
|
||||
)
|
||||
context = _build_context(session, interaction, scope, target_user)
|
||||
response = handler(context, **handler_kwargs)
|
||||
session.commit()
|
||||
except CommandError as exc: # pragma: no cover - passthrough
|
||||
await interaction.response.send_message(str(exc), ephemeral=True)
|
||||
@ -199,21 +211,6 @@ async def _run_interaction_command(
|
||||
)
|
||||
|
||||
|
||||
def run_command(
|
||||
session: Session,
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
*,
|
||||
handler: CommandHandler,
|
||||
target_user: discord.User | None = None,
|
||||
**handler_kwargs,
|
||||
) -> CommandResponse:
|
||||
"""Create a :class:`CommandContext` and execute the handler."""
|
||||
|
||||
context = _build_context(session, interaction, scope, target_user)
|
||||
return handler(context, **handler_kwargs)
|
||||
|
||||
|
||||
def _build_context(
|
||||
session: Session,
|
||||
interaction: discord.Interaction,
|
||||
|
||||
294
src/memory/discord/mcp.py
Normal file
294
src/memory/discord/mcp.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""Lightweight slash-command helpers for the Discord collector."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Literal, cast
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.discord import DiscordMCPServer
|
||||
from memory.discord.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_mcp_server(
|
||||
session: Session | scoped_session, user_id: int, url: str
|
||||
) -> DiscordMCPServer | None:
|
||||
return (
|
||||
session.query(DiscordMCPServer)
|
||||
.filter(
|
||||
DiscordMCPServer.discord_bot_user_id == user_id,
|
||||
DiscordMCPServer.mcp_server_url == url,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
async def call_mcp_server(
|
||||
url: str, access_token: str, method: str, params: dict = {}
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": int(time.time() * 1000),
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"Tools list failed: {resp.status} - {error_text}")
|
||||
raise ValueError(
|
||||
f"Failed to call MCP server: {resp.status} - {error_text}"
|
||||
)
|
||||
|
||||
# Parse SSE stream
|
||||
async for line in resp.content:
|
||||
line_str = line.decode("utf-8").strip()
|
||||
|
||||
# SSE format: "data: {json}"
|
||||
if line_str.startswith("data: "):
|
||||
json_str = line_str[6:] # Remove "data: " prefix
|
||||
try:
|
||||
yield json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
continue # Skip invalid JSON lines
|
||||
|
||||
|
||||
async def handle_mcp_list(interaction: discord.Interaction) -> str:
|
||||
"""List all MCP servers for the user."""
|
||||
with make_session() as session:
|
||||
servers = (
|
||||
session.query(DiscordMCPServer)
|
||||
.filter(
|
||||
DiscordMCPServer.discord_bot_user_id == interaction.user.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not servers:
|
||||
return (
|
||||
"📋 **Your MCP Servers**\n\n"
|
||||
"You don't have any MCP servers configured yet.\n"
|
||||
"Use `/memory_mcp_servers add <url>` to add one."
|
||||
)
|
||||
|
||||
def format_server(server: DiscordMCPServer) -> str:
|
||||
con = "🟢" if cast(str | None, server.access_token) else "🔴"
|
||||
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
|
||||
|
||||
server_list = "\n".join(format_server(s) for s in servers)
|
||||
|
||||
return f"📋 **Your MCP Servers**\n\n{server_list}"
|
||||
|
||||
|
||||
async def handle_mcp_add(
|
||||
interaction: discord.Interaction, url: str, name: str = "memory"
|
||||
) -> str:
|
||||
"""Add a new MCP server via OAuth."""
|
||||
with make_session() as session:
|
||||
if find_mcp_server(session, interaction.user.id, url):
|
||||
return (
|
||||
f"**MCP Server Already Exists**\n\n"
|
||||
f"You already have an MCP server configured at `{url}`.\n"
|
||||
f"Use `/memory_mcp_servers connect {url}` to reconnect."
|
||||
)
|
||||
|
||||
endpoints = await get_endpoints(url)
|
||||
client_id = await register_oauth_client(
|
||||
endpoints,
|
||||
url,
|
||||
f"Discord Bot - {name} ({interaction.user.name})",
|
||||
)
|
||||
mcp_server = DiscordMCPServer(
|
||||
discord_bot_user_id=interaction.user.id,
|
||||
mcp_server_url=url,
|
||||
client_id=client_id,
|
||||
)
|
||||
session.add(mcp_server)
|
||||
session.flush()
|
||||
|
||||
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Created MCP server record: id={mcp_server.id}, "
|
||||
f"user={interaction.user.id}, url={url}"
|
||||
)
|
||||
|
||||
return (
|
||||
f"🔐 **Add MCP Server**\n\n"
|
||||
f"Server: `{url}`\n"
|
||||
f"Click the link below to authorize:\n{auth_url}\n\n"
|
||||
f"⚠️ Keep this link private!\n"
|
||||
f"💡 You'll be redirected to login and grant access to the MCP server."
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
|
||||
"""Delete an MCP server."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
||||
if not mcp_server:
|
||||
return (
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
f"You don't have an MCP server configured at `{url}`.\n"
|
||||
)
|
||||
session.delete(mcp_server)
|
||||
session.commit()
|
||||
|
||||
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
||||
|
||||
|
||||
async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
|
||||
"""Reconnect to an existing MCP server (redo OAuth)."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
f"You don't have an MCP server configured at `{url}`.\n"
|
||||
f"Use `/memory_mcp_servers add {url}` to add it first."
|
||||
)
|
||||
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
f"You don't have an MCP server configured at `{url}`.\n"
|
||||
f"Use `/memory_mcp_servers add {url}` to add it first."
|
||||
)
|
||||
|
||||
endpoints = await get_endpoints(url)
|
||||
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}"
|
||||
)
|
||||
|
||||
return (
|
||||
f"🔄 **Reconnect to MCP Server**\n\n"
|
||||
f"Server: `{url}`\n"
|
||||
f"Click the link below to reauthorize:\n{auth_url}\n\n"
|
||||
f"⚠️ Keep this link private!\n"
|
||||
f"💡 You'll be redirected to login and grant access to the MCP server again."
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
|
||||
"""List tools available on an MCP server."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
||||
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
f"You don't have an MCP server configured at `{url}`.\n"
|
||||
f"Use `/memory_mcp_servers add {url}` to add it first."
|
||||
)
|
||||
|
||||
if not cast(str | None, mcp_server.access_token):
|
||||
raise ValueError(
|
||||
f"**Not Authorized**\n\n"
|
||||
f"You haven't authorized access to `{url}` yet.\n"
|
||||
f"Use `/memory_mcp_servers connect {url}` to authorize."
|
||||
)
|
||||
|
||||
access_token = cast(str, mcp_server.access_token)
|
||||
|
||||
# Make JSON-RPC request to MCP server
|
||||
tools = None
|
||||
try:
|
||||
async for data in call_mcp_server(url, access_token, "tools/list"):
|
||||
if "result" in data and "tools" in data["result"]:
|
||||
tools = data["result"]["tools"]
|
||||
break
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.exception(f"Failed to connect to MCP server: {exc}")
|
||||
raise ValueError(
|
||||
f"**Connection failed**\n\n"
|
||||
f"Server: `{url}`\n"
|
||||
f"Could not connect to the MCP server: {str(exc)}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(f"Failed to list tools: {exc}")
|
||||
raise ValueError(
|
||||
f"**Error**\n\nServer: `{url}`\nFailed to list tools: {str(exc)}"
|
||||
)
|
||||
|
||||
if tools is None:
|
||||
raise ValueError(
|
||||
f"**Unexpected response format**\n\n"
|
||||
f"Server: `{url}`\n"
|
||||
f"The server returned an unexpected response format."
|
||||
)
|
||||
|
||||
if not tools:
|
||||
return (
|
||||
f"🔧 **MCP Server Tools**\n\n"
|
||||
f"Server: `{url}`\n\n"
|
||||
f"No tools available on this server."
|
||||
)
|
||||
|
||||
# Format tools list
|
||||
tools_list = "\n".join(
|
||||
f"• **{t.get('name', 'unknown')}**: {t.get('description', 'No description')}"
|
||||
for t in tools
|
||||
)
|
||||
|
||||
return (
|
||||
f"🔧 **MCP Server Tools**\n\n"
|
||||
f"Server: `{url}`\n"
|
||||
f"Found {len(tools)} tool(s):\n\n"
|
||||
f"{tools_list}"
|
||||
)
|
||||
|
||||
|
||||
async def run_mcp_server_command(
|
||||
interaction: discord.Interaction,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||
url: str | None,
|
||||
name: str = "memory",
|
||||
) -> None:
|
||||
"""Handle MCP server management commands."""
|
||||
if action not in ["list", "add", "delete", "connect", "tools"]:
|
||||
await interaction.response.send_message("❌ Invalid action", ephemeral=True)
|
||||
return
|
||||
if action != "list" and not url:
|
||||
await interaction.response.send_message(
|
||||
"❌ URL is required for this action", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if action == "list" or not url:
|
||||
result = await handle_mcp_list(interaction)
|
||||
elif action == "add":
|
||||
result = await handle_mcp_add(interaction, url, name)
|
||||
elif action == "delete":
|
||||
result = await handle_mcp_delete(interaction, url)
|
||||
elif action == "connect":
|
||||
result = await handle_mcp_connect(interaction, url)
|
||||
elif action == "tools":
|
||||
result = await handle_mcp_tools(interaction, url)
|
||||
except Exception as exc:
|
||||
result = f"❌ Error: {exc}"
|
||||
await interaction.response.send_message(result, ephemeral=True)
|
||||
@ -141,8 +141,10 @@ def previous_messages(
|
||||
) -> list[DiscordMessage]:
|
||||
messages = session.query(DiscordMessage)
|
||||
if user_id:
|
||||
print(f"user_id: {user_id}")
|
||||
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
||||
if channel_id:
|
||||
print(f"channel_id: {channel_id}")
|
||||
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
||||
return list(
|
||||
reversed(
|
||||
|
||||
268
src/memory/discord/oauth.py
Normal file
268
src/memory/discord/oauth.py
Normal file
@ -0,0 +1,268 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from base64 import urlsafe_b64encode
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from urllib.parse import urlencode, urljoin
|
||||
|
||||
import aiohttp
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.discord import DiscordMCPServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthEndpoints:
|
||||
authorization_endpoint: str
|
||||
registration_endpoint: str
|
||||
token_endpoint: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
"""Generate PKCE code verifier and challenge.
|
||||
|
||||
Returns:
|
||||
Tuple of (code_verifier, code_challenge)
|
||||
"""
|
||||
# Generate a random code verifier
|
||||
code_verifier = (
|
||||
urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
|
||||
)
|
||||
|
||||
# Create code challenge using S256 method
|
||||
challenge_bytes = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
code_challenge = urlsafe_b64encode(challenge_bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
async def discover_oauth_metadata(server_url: str) -> dict | None:
|
||||
"""Discover OAuth metadata from an MCP server."""
|
||||
# Try the standard OAuth discovery endpoint
|
||||
discovery_url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
discovery_url, timeout=aiohttp.ClientTimeout(total=5)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.json()
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to discover OAuth metadata from {discovery_url}: {exc}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_endpoints(url: str) -> OAuthEndpoints:
|
||||
# Discover OAuth endpoints from the target server
|
||||
oauth_metadata = await discover_oauth_metadata(url)
|
||||
|
||||
if not oauth_metadata:
|
||||
raise ValueError(
|
||||
"**Failed to connect to MCP server**\n\n"
|
||||
f"Could not discover OAuth endpoints at `{url}`\n"
|
||||
"Make sure the server is running and supports OAuth 2.0.",
|
||||
)
|
||||
|
||||
authorization_endpoint = oauth_metadata.get("authorization_endpoint")
|
||||
registration_endpoint = oauth_metadata.get("registration_endpoint")
|
||||
token_endpoint = oauth_metadata.get("token_endpoint")
|
||||
|
||||
if not authorization_endpoint:
|
||||
raise ValueError(
|
||||
"**Invalid OAuth configuration**\n\n"
|
||||
f"Server `{url}` did not provide an authorization endpoint.",
|
||||
)
|
||||
|
||||
if not registration_endpoint:
|
||||
raise ValueError(
|
||||
"**Invalid OAuth configuration**\n\n"
|
||||
f"Server `{url}` does not support dynamic client registration.",
|
||||
)
|
||||
|
||||
if not token_endpoint:
|
||||
raise ValueError(
|
||||
"**Invalid OAuth configuration**\n\n"
|
||||
f"Server `{url}` does not provide a token endpoint.",
|
||||
)
|
||||
|
||||
logger.info(f"Authorization endpoint: {authorization_endpoint}")
|
||||
logger.info(f"Registration endpoint: {registration_endpoint}")
|
||||
|
||||
return OAuthEndpoints(
|
||||
authorization_endpoint=authorization_endpoint,
|
||||
registration_endpoint=registration_endpoint,
|
||||
token_endpoint=token_endpoint,
|
||||
redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord",
|
||||
)
|
||||
|
||||
|
||||
async def register_oauth_client(
|
||||
endpoints: OAuthEndpoints,
|
||||
url: str,
|
||||
client_name: str,
|
||||
) -> None:
|
||||
"""Register OAuth client and store client_id in the mcp_server object."""
|
||||
client_metadata = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [endpoints.redirect_uri],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"scope": "read write",
|
||||
"token_endpoint_auth_method": "none",
|
||||
}
|
||||
|
||||
logger.error(f"Registration metadata: {client_metadata}")
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
endpoints.registration_endpoint,
|
||||
json=client_metadata,
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as resp:
|
||||
logger.error(
|
||||
f"Registration response: {resp.status} {await resp.text()}"
|
||||
)
|
||||
resp.raise_for_status()
|
||||
client_info = await resp.json()
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Failed to register OAuth client at {endpoints.registration_endpoint}: {exc}"
|
||||
)
|
||||
|
||||
if not client_info or "client_id" not in client_info:
|
||||
raise ValueError(
|
||||
"**Failed to register OAuth client**\n\n"
|
||||
f"Could not register with the MCP server at `{url}`\n"
|
||||
f"Check the server logs for more details.",
|
||||
)
|
||||
|
||||
client_id = client_info["client_id"]
|
||||
|
||||
logger.info(f"Registered OAuth client: {client_id}")
|
||||
return client_id
|
||||
|
||||
|
||||
async def issue_challenge(
|
||||
mcp_server: DiscordMCPServer,
|
||||
endpoints: OAuthEndpoints,
|
||||
) -> str:
|
||||
"""Generate OAuth challenge and store state in mcp_server object."""
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Store in mcp_server object
|
||||
mcp_server.state = state # type: ignore
|
||||
mcp_server.code_verifier = code_verifier # type: ignore
|
||||
|
||||
logger.info(
|
||||
f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: "
|
||||
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
|
||||
)
|
||||
|
||||
# Build authorization URL pointing to the target server
|
||||
auth_params = {
|
||||
"client_id": mcp_server.client_id,
|
||||
"redirect_uri": endpoints.redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"scope": "read write",
|
||||
}
|
||||
|
||||
return f"{endpoints.authorization_endpoint}?{urlencode(auth_params)}"
|
||||
|
||||
|
||||
async def complete_oauth_flow(
|
||||
session: Session | scoped_session, code: str, state: str
|
||||
) -> tuple[int, str]:
|
||||
"""Complete OAuth flow by exchanging code for token.
|
||||
|
||||
Args:
|
||||
code: Authorization code from OAuth callback
|
||||
state: State parameter from OAuth callback
|
||||
|
||||
Returns:
|
||||
Tuple of (status_code, html_message) for the callback response
|
||||
"""
|
||||
try:
|
||||
mcp_server = (
|
||||
session.query(DiscordMCPServer)
|
||||
.filter(DiscordMCPServer.state == state)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not mcp_server:
|
||||
logger.error(f"Invalid or expired state: {state[:20]}...")
|
||||
return 400, "Invalid or expired OAuth state"
|
||||
|
||||
logger.info(
|
||||
f"Found MCP server config: user={mcp_server.discord_bot_user_id}, "
|
||||
f"url={mcp_server.mcp_server_url}"
|
||||
)
|
||||
|
||||
# Get OAuth endpoints
|
||||
try:
|
||||
endpoints = await get_endpoints(str(mcp_server.mcp_server_url))
|
||||
except Exception as exc:
|
||||
return 500, f"Failed to get OAuth endpoints: {str(exc)}"
|
||||
|
||||
# Exchange authorization code for access token
|
||||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": endpoints.redirect_uri,
|
||||
"client_id": mcp_server.client_id,
|
||||
"code_verifier": mcp_server.code_verifier,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
endpoints.token_endpoint,
|
||||
data=token_data,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"Token exchange failed: {resp.status} - {error_text}")
|
||||
return 500, f"Token exchange failed: {error_text}"
|
||||
|
||||
tokens = await resp.json()
|
||||
|
||||
access_token = tokens.get("access_token")
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
expires_in = tokens.get("expires_in", 3600)
|
||||
|
||||
if not access_token:
|
||||
return 500, "Token response did not include access_token"
|
||||
|
||||
logger.info(f"Successfully obtained access token: {access_token[:20]}...")
|
||||
|
||||
# Store tokens and clear temporary OAuth state
|
||||
mcp_server.access_token = access_token # type: ignore
|
||||
mcp_server.refresh_token = refresh_token # type: ignore
|
||||
mcp_server.token_expires_at = datetime.now() + timedelta(seconds=expires_in) # type: ignore
|
||||
|
||||
# Clear temporary OAuth flow data
|
||||
mcp_server.state = None # type: ignore
|
||||
mcp_server.code_verifier = None # type: ignore
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
|
||||
f"server {mcp_server.mcp_server_url}"
|
||||
)
|
||||
|
||||
return 200, "✅ Authorization successful! You can now use this MCP server."
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"Failed to complete OAuth flow: {exc}")
|
||||
return 500, f"Failed to complete OAuth flow: {str(exc)}"
|
||||
Loading…
x
Reference in New Issue
Block a user