mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
mcp servers for discord bots
This commit is contained in:
parent
6250586d1f
commit
64bb926eba
@ -0,0 +1,67 @@
|
|||||||
|
"""discord mcp servers
|
||||||
|
|
||||||
|
Revision ID: 9b887449ea92
|
||||||
|
Revises: 1954477b25f4
|
||||||
|
Create Date: 2025-11-02 22:04:26.259323
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "9b887449ea92"
|
||||||
|
down_revision: Union[str, None] = "1954477b25f4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"discord_mcp_servers",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
||||||
|
sa.Column("client_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("state", sa.Text(), nullable=True),
|
||||||
|
sa.Column("code_verifier", sa.Text(), nullable=True),
|
||||||
|
sa.Column("access_token", sa.Text(), nullable=True),
|
||||||
|
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||||
|
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["discord_bot_user_id"],
|
||||||
|
["discord_users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("state"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_mcp_user_url_idx",
|
||||||
|
"discord_mcp_servers",
|
||||||
|
["discord_bot_user_id", "mcp_server_url"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
|
||||||
|
op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
|
||||||
|
op.drop_table("discord_mcp_servers")
|
||||||
@ -34,6 +34,7 @@ from memory.common.db.models.discord import (
|
|||||||
DiscordServer,
|
DiscordServer,
|
||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
|
DiscordMCPServer,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.observations import (
|
from memory.common.db.models.observations import (
|
||||||
ObservationContradiction,
|
ObservationContradiction,
|
||||||
@ -106,6 +107,7 @@ __all__ = [
|
|||||||
"DiscordServer",
|
"DiscordServer",
|
||||||
"DiscordChannel",
|
"DiscordChannel",
|
||||||
"DiscordUser",
|
"DiscordUser",
|
||||||
|
"DiscordMCPServer",
|
||||||
# Users
|
# Users
|
||||||
"User",
|
"User",
|
||||||
"HumanUser",
|
"HumanUser",
|
||||||
|
|||||||
@ -127,5 +127,44 @@ class DiscordUser(Base, MessageProcessor):
|
|||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
system_user = relationship("User", back_populates="discord_users")
|
system_user = relationship("User", back_populates="discord_users")
|
||||||
|
mcp_servers = relationship(
|
||||||
|
"DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordMCPServer(Base):
|
||||||
|
"""MCP server configuration and OAuth state for Discord users."""
|
||||||
|
|
||||||
|
__tablename__ = "discord_mcp_servers"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
discord_bot_user_id = Column(
|
||||||
|
BigInteger, ForeignKey("discord_users.id"), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# MCP server info
|
||||||
|
mcp_server_url = Column(Text, nullable=False)
|
||||||
|
client_id = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
# OAuth flow state (temporary, cleared after token exchange)
|
||||||
|
state = Column(Text, nullable=True, unique=True)
|
||||||
|
code_verifier = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# OAuth tokens (set after successful authorization)
|
||||||
|
access_token = Column(Text, nullable=True)
|
||||||
|
refresh_token = Column(Text, nullable=True)
|
||||||
|
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
discord_user = relationship("DiscordUser", back_populates="mcp_servers")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("discord_mcp_state_idx", "state"),
|
||||||
|
Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"),
|
||||||
|
)
|
||||||
|
|||||||
@ -363,7 +363,19 @@ class DiscordMessage(SourceItem):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def title(self) -> str:
|
def title(self) -> str:
|
||||||
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
return textwrap.dedent("""
|
||||||
|
<message>
|
||||||
|
<id>{message_id}</id>
|
||||||
|
<from>{from_user}</from>
|
||||||
|
<sent_at>{sent_at}</sent_at>
|
||||||
|
<content>{content}</content>
|
||||||
|
</message>
|
||||||
|
""").format(
|
||||||
|
message_id=self.message_id,
|
||||||
|
from_user=self.from_user.username,
|
||||||
|
sent_at=self.sent_at.isoformat()[:19],
|
||||||
|
content=self.content,
|
||||||
|
)
|
||||||
|
|
||||||
def as_content(self) -> dict[str, Any]:
|
def as_content(self) -> dict[str, Any]:
|
||||||
"""Return message content ready for LLM (text + images from disk)."""
|
"""Return message content ready for LLM (text + images from disk)."""
|
||||||
|
|||||||
@ -94,6 +94,30 @@ def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def add_reaction(bot_id: int, channel: int | str, message_id: int, emoji: str) -> bool:
|
||||||
|
"""Add a reaction to a message in a channel"""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{get_api_url()}/add_reaction",
|
||||||
|
json={
|
||||||
|
"bot_id": bot_id,
|
||||||
|
"channel": channel,
|
||||||
|
"message_id": message_id,
|
||||||
|
"emoji": emoji,
|
||||||
|
},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("success", False)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to add reaction {emoji} to message {message_id} in channel {channel}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
|
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
|
||||||
"""Send a message to a channel by name or ID (ID supports threads)"""
|
"""Send a message to a channel by name or ID (ID supports threads)"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from memory.common.db.models import (
|
|||||||
BotUser,
|
BotUser,
|
||||||
)
|
)
|
||||||
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
|
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
|
||||||
|
from memory.common.discord import add_reaction
|
||||||
|
|
||||||
|
|
||||||
UpdateSummaryType = Literal["server", "channel", "user"]
|
UpdateSummaryType = Literal["server", "channel", "user"]
|
||||||
@ -209,6 +210,50 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition:
|
||||||
|
bot_id = cast(int, bot.id)
|
||||||
|
channel_id = channel and channel.id
|
||||||
|
|
||||||
|
def handler(input: ToolInput) -> str:
|
||||||
|
if not isinstance(input, dict):
|
||||||
|
raise ValueError("Input must be a dictionary")
|
||||||
|
try:
|
||||||
|
emoji = input.get("emoji")
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Emoji is required")
|
||||||
|
if not emoji:
|
||||||
|
raise ValueError("Emoji is required")
|
||||||
|
|
||||||
|
try:
|
||||||
|
message_id = int(input.get("message_id") or "no id")
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Message ID is required")
|
||||||
|
|
||||||
|
success = add_reaction(bot_id, channel_id, message_id, emoji)
|
||||||
|
if not success:
|
||||||
|
return "Failed to add reaction"
|
||||||
|
return "Reaction added"
|
||||||
|
|
||||||
|
return ToolDefinition(
|
||||||
|
name="add_reaction",
|
||||||
|
description="Add a reaction to a message in a channel",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message_id": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The ID of the message to add the reaction to",
|
||||||
|
},
|
||||||
|
"emoji": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The emoji to add to the message",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
function=handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_discord_tools(
|
def make_discord_tools(
|
||||||
bot: BotUser,
|
bot: BotUser,
|
||||||
author: DiscordUser | None,
|
author: DiscordUser | None,
|
||||||
@ -227,5 +272,6 @@ def make_discord_tools(
|
|||||||
if channel and channel.server:
|
if channel and channel.server:
|
||||||
tools += [
|
tools += [
|
||||||
make_summary_tool("server", cast(BigInteger, channel.server_id)),
|
make_summary_tool("server", cast(BigInteger, channel.server_id)),
|
||||||
|
make_add_reaction_tool(bot, channel),
|
||||||
]
|
]
|
||||||
return {tool.name: tool for tool in tools}
|
return {tool.name: tool for tool in tools}
|
||||||
|
|||||||
@ -12,13 +12,15 @@ from contextlib import asynccontextmanager
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models.users import DiscordBotUser
|
from memory.common.db.models.users import DiscordBotUser
|
||||||
from memory.discord.collector import MessageCollector
|
from memory.discord.collector import MessageCollector
|
||||||
|
from memory.discord.oauth import complete_oauth_flow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -45,6 +47,13 @@ class TypingChannelRequest(BaseModel):
|
|||||||
channel: int | str # Channel name or ID (ID supports threads)
|
channel: int | str # Channel name or ID (ID supports threads)
|
||||||
|
|
||||||
|
|
||||||
|
class AddReactionRequest(BaseModel):
|
||||||
|
bot_id: int
|
||||||
|
channel: int | str # Channel name or ID (ID supports threads)
|
||||||
|
message_id: int
|
||||||
|
emoji: str
|
||||||
|
|
||||||
|
|
||||||
class Collector:
|
class Collector:
|
||||||
collector: MessageCollector
|
collector: MessageCollector
|
||||||
collector_task: asyncio.Task
|
collector_task: asyncio.Task
|
||||||
@ -53,6 +62,7 @@ class Collector:
|
|||||||
bot_name: str
|
bot_name: str
|
||||||
|
|
||||||
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
||||||
|
logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}")
|
||||||
self.collector = collector
|
self.collector = collector
|
||||||
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
||||||
self.bot_id = cast(int, bot.id)
|
self.bot_id = cast(int, bot.id)
|
||||||
@ -72,7 +82,9 @@ async def lifespan(app: FastAPI):
|
|||||||
bots = session.query(DiscordBotUser).all()
|
bots = session.query(DiscordBotUser).all()
|
||||||
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
||||||
|
|
||||||
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
logger.error(
|
||||||
|
f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -216,6 +228,36 @@ async def health_check():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/add_reaction")
|
||||||
|
async def add_reaction_endpoint(request: AddReactionRequest):
|
||||||
|
"""Add a reaction to a message via the collector's Discord client"""
|
||||||
|
collector = app.bots.get(request.bot_id)
|
||||||
|
if not collector:
|
||||||
|
raise HTTPException(status_code=404, detail="Bot not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await collector.collector.add_reaction(
|
||||||
|
request.channel, request.message_id, request.emoji
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to add reaction: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to add reaction to message {request.message_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"channel": request.channel,
|
||||||
|
"message_id": request.message_id,
|
||||||
|
"emoji": request.emoji,
|
||||||
|
"message": f"Added reaction {request.emoji} to message {request.message_id}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/refresh_metadata")
|
@app.post("/refresh_metadata")
|
||||||
async def refresh_metadata():
|
async def refresh_metadata():
|
||||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||||
@ -237,6 +279,51 @@ async def refresh_metadata():
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/oauth/callback/discord", response_class=HTMLResponse)
|
||||||
|
async def oauth_callback(request: Request):
|
||||||
|
"""Handle OAuth callback from MCP server after user authorization."""
|
||||||
|
code = request.query_params.get("code")
|
||||||
|
state = request.query_params.get("state")
|
||||||
|
error = request.query_params.get("error")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
message, title, close, status_code = "", "", "", 200
|
||||||
|
if error:
|
||||||
|
logger.error(f"OAuth error: {error}")
|
||||||
|
message = f"Error: {error}"
|
||||||
|
title = "❌ Authorization Failed"
|
||||||
|
status_code = 400
|
||||||
|
elif not code or not state:
|
||||||
|
message = "Missing authorization code or state parameter."
|
||||||
|
title = "❌ Invalid Request"
|
||||||
|
status_code = 400
|
||||||
|
else:
|
||||||
|
# Complete the OAuth flow (exchange code for token)
|
||||||
|
with make_session() as session:
|
||||||
|
status_code, message = await complete_oauth_flow(session, code, state)
|
||||||
|
if 200 <= status_code < 300:
|
||||||
|
title = "✅ Authorization Successful!"
|
||||||
|
close = "You can close this window and return to the MCP server."
|
||||||
|
else:
|
||||||
|
title = "❌ Authorization Failed"
|
||||||
|
|
||||||
|
return HTMLResponse(
|
||||||
|
content=f"""
|
||||||
|
<html>
|
||||||
|
<body>
|
||||||
|
<h1>{title}</h1>
|
||||||
|
<p>{message}</p>
|
||||||
|
<p>{close}</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""",
|
||||||
|
status_code=status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
|
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
|
||||||
"""Run the Discord API server"""
|
"""Run the Discord API server"""
|
||||||
uvicorn.run(app, host=host, port=port, log_level="debug")
|
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||||
|
|||||||
@ -229,11 +229,21 @@ class MessageCollector(commands.Bot):
|
|||||||
intents=intents,
|
intents=intents,
|
||||||
help_command=None, # Disable default help
|
help_command=None, # Disable default help
|
||||||
)
|
)
|
||||||
|
logger.info(f"Initialized collector for {self.user}")
|
||||||
|
|
||||||
async def setup_hook(self):
|
async def setup_hook(self):
|
||||||
"""Register slash commands when the bot is ready."""
|
"""Register slash commands when the bot is ready."""
|
||||||
|
|
||||||
register_slash_commands(self, name=self.user.name)
|
if not (name := self.user.name):
|
||||||
|
logger.error(f"Failed to get user name for {self.user}")
|
||||||
|
return
|
||||||
|
|
||||||
|
name = name.replace("-", "_").lower()
|
||||||
|
try:
|
||||||
|
register_slash_commands(self, name=name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
|
||||||
|
logger.error(f"Registered slash commands for {self.user.name}")
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
"""Called when bot connects to Discord"""
|
"""Called when bot connects to Discord"""
|
||||||
@ -313,8 +323,6 @@ class MessageCollector(commands.Bot):
|
|||||||
|
|
||||||
async def refresh_metadata(self) -> dict[str, int]:
|
async def refresh_metadata(self) -> dict[str, int]:
|
||||||
"""Refresh server and channel metadata from Discord and update database"""
|
"""Refresh server and channel metadata from Discord and update database"""
|
||||||
print("🔄 Refreshing Discord metadata...")
|
|
||||||
|
|
||||||
servers_updated = 0
|
servers_updated = 0
|
||||||
channels_updated = 0
|
channels_updated = 0
|
||||||
users_updated = 0
|
users_updated = 0
|
||||||
@ -454,25 +462,33 @@ class MessageCollector(commands.Bot):
|
|||||||
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _get_channel(
|
||||||
|
self, channel_identifier: int | str, check_notifications: bool = True
|
||||||
|
):
|
||||||
|
"""Get channel by ID or name with standard checks"""
|
||||||
|
if check_notifications and not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
|
logger.debug("Discord notifications disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(channel_identifier, int):
|
||||||
|
channel = self.get_channel(channel_identifier)
|
||||||
|
else:
|
||||||
|
channel = await self.get_channel_by_name(channel_identifier)
|
||||||
|
|
||||||
|
if not channel:
|
||||||
|
logger.error(f"Channel {channel_identifier} not found")
|
||||||
|
|
||||||
|
return channel
|
||||||
|
|
||||||
async def send_to_channel(
|
async def send_to_channel(
|
||||||
self, channel_identifier: int | str, message: str
|
self, channel_identifier: int | str, message: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Send a message to a channel by name or ID (supports threads)"""
|
"""Send a message to a channel by name or ID (supports threads)"""
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get channel by ID or name
|
channel = await self._get_channel(channel_identifier)
|
||||||
if isinstance(channel_identifier, int):
|
|
||||||
channel = self.get_channel(channel_identifier)
|
|
||||||
else:
|
|
||||||
channel = await self.get_channel_by_name(channel_identifier)
|
|
||||||
|
|
||||||
if not channel:
|
if not channel:
|
||||||
logger.error(f"Channel {channel_identifier} not found")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Post-process mentions to convert usernames to IDs
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
processed_message = process_mentions(session, message)
|
processed_message = process_mentions(session, message)
|
||||||
|
|
||||||
@ -486,18 +502,9 @@ class MessageCollector(commands.Bot):
|
|||||||
|
|
||||||
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
|
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
|
||||||
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
|
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get channel by ID or name
|
channel = await self._get_channel(channel_identifier)
|
||||||
if isinstance(channel_identifier, int):
|
|
||||||
channel = self.get_channel(channel_identifier)
|
|
||||||
else:
|
|
||||||
channel = await self.get_channel_by_name(channel_identifier)
|
|
||||||
|
|
||||||
if not channel:
|
if not channel:
|
||||||
logger.error(f"Channel {channel_identifier} not found")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async with channel.typing():
|
async with channel.typing():
|
||||||
@ -509,3 +516,21 @@ class MessageCollector(commands.Bot):
|
|||||||
f"Failed to trigger typing for channel {channel_identifier}: {e}"
|
f"Failed to trigger typing for channel {channel_identifier}: {e}"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def add_reaction(
|
||||||
|
self, channel_identifier: int | str, message_id: int, emoji: str
|
||||||
|
) -> bool:
|
||||||
|
"""Add a reaction to a message in a channel"""
|
||||||
|
try:
|
||||||
|
channel = await self._get_channel(channel_identifier)
|
||||||
|
if not channel:
|
||||||
|
return False
|
||||||
|
|
||||||
|
message = await channel.fetch_message(message_id)
|
||||||
|
await message.add_reaction(emoji)
|
||||||
|
logger.info(f"Added reaction {emoji} to message {message_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to add reaction: {e}")
|
||||||
|
return False
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Lightweight slash-command helpers for the Discord collector."""
|
"""Lightweight slash-command helpers for the Discord collector."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from calendar import c
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Literal
|
from typing import Callable, Literal
|
||||||
|
|
||||||
@ -10,6 +10,9 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||||
|
from memory.discord.mcp import run_mcp_server_command
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ScopeLiteral = Literal["server", "channel", "user"]
|
ScopeLiteral = Literal["server", "channel", "user"]
|
||||||
|
|
||||||
@ -50,6 +53,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if getattr(bot, "_memory_commands_registered", False):
|
if getattr(bot, "_memory_commands_registered", False):
|
||||||
|
logger.error(f"Slash commands already registered for {name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
setattr(bot, "_memory_commands_registered", True)
|
setattr(bot, "_memory_commands_registered", True)
|
||||||
@ -167,6 +171,21 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
|||||||
target_user=user,
|
target_user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@tree.command(
|
||||||
|
name=f"{name}_mcp_servers",
|
||||||
|
description="Manage MCP servers for your account",
|
||||||
|
)
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
action="Action to perform",
|
||||||
|
url="MCP server URL (required for add, delete, connect, tools)",
|
||||||
|
)
|
||||||
|
async def mcp_servers_command(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||||
|
url: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
await run_mcp_server_command(interaction, action, url and url.strip(), name)
|
||||||
|
|
||||||
|
|
||||||
async def _run_interaction_command(
|
async def _run_interaction_command(
|
||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
@ -177,17 +196,10 @@ async def _run_interaction_command(
|
|||||||
**handler_kwargs,
|
**handler_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Shared coroutine used by the registered slash commands."""
|
"""Shared coroutine used by the registered slash commands."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
response = run_command(
|
context = _build_context(session, interaction, scope, target_user)
|
||||||
session,
|
response = handler(context, **handler_kwargs)
|
||||||
interaction,
|
|
||||||
scope,
|
|
||||||
handler=handler,
|
|
||||||
target_user=target_user,
|
|
||||||
**handler_kwargs,
|
|
||||||
)
|
|
||||||
session.commit()
|
session.commit()
|
||||||
except CommandError as exc: # pragma: no cover - passthrough
|
except CommandError as exc: # pragma: no cover - passthrough
|
||||||
await interaction.response.send_message(str(exc), ephemeral=True)
|
await interaction.response.send_message(str(exc), ephemeral=True)
|
||||||
@ -199,21 +211,6 @@ async def _run_interaction_command(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
|
||||||
session: Session,
|
|
||||||
interaction: discord.Interaction,
|
|
||||||
scope: ScopeLiteral,
|
|
||||||
*,
|
|
||||||
handler: CommandHandler,
|
|
||||||
target_user: discord.User | None = None,
|
|
||||||
**handler_kwargs,
|
|
||||||
) -> CommandResponse:
|
|
||||||
"""Create a :class:`CommandContext` and execute the handler."""
|
|
||||||
|
|
||||||
context = _build_context(session, interaction, scope, target_user)
|
|
||||||
return handler(context, **handler_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_context(
|
def _build_context(
|
||||||
session: Session,
|
session: Session,
|
||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
|
|||||||
294
src/memory/discord/mcp.py
Normal file
294
src/memory/discord/mcp.py
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
"""Lightweight slash-command helpers for the Discord collector."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, AsyncGenerator, Literal, cast
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import discord
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models.discord import DiscordMCPServer
|
||||||
|
from memory.discord.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def find_mcp_server(
|
||||||
|
session: Session | scoped_session, user_id: int, url: str
|
||||||
|
) -> DiscordMCPServer | None:
|
||||||
|
return (
|
||||||
|
session.query(DiscordMCPServer)
|
||||||
|
.filter(
|
||||||
|
DiscordMCPServer.discord_bot_user_id == user_id,
|
||||||
|
DiscordMCPServer.mcp_server_url == url,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def call_mcp_server(
|
||||||
|
url: str, access_token: str, method: str, params: dict = {}
|
||||||
|
) -> AsyncGenerator[Any, None]:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": int(time.time() * 1000),
|
||||||
|
"method": method,
|
||||||
|
"params": params,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as http_session:
|
||||||
|
async with http_session.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=10),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
logger.error(f"Tools list failed: {resp.status} - {error_text}")
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to call MCP server: {resp.status} - {error_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse SSE stream
|
||||||
|
async for line in resp.content:
|
||||||
|
line_str = line.decode("utf-8").strip()
|
||||||
|
|
||||||
|
# SSE format: "data: {json}"
|
||||||
|
if line_str.startswith("data: "):
|
||||||
|
json_str = line_str[6:] # Remove "data: " prefix
|
||||||
|
try:
|
||||||
|
yield json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue # Skip invalid JSON lines
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_list(interaction: discord.Interaction) -> str:
|
||||||
|
"""List all MCP servers for the user."""
|
||||||
|
with make_session() as session:
|
||||||
|
servers = (
|
||||||
|
session.query(DiscordMCPServer)
|
||||||
|
.filter(
|
||||||
|
DiscordMCPServer.discord_bot_user_id == interaction.user.id,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not servers:
|
||||||
|
return (
|
||||||
|
"📋 **Your MCP Servers**\n\n"
|
||||||
|
"You don't have any MCP servers configured yet.\n"
|
||||||
|
"Use `/memory_mcp_servers add <url>` to add one."
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_server(server: DiscordMCPServer) -> str:
|
||||||
|
con = "🟢" if cast(str | None, server.access_token) else "🔴"
|
||||||
|
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
|
||||||
|
|
||||||
|
server_list = "\n".join(format_server(s) for s in servers)
|
||||||
|
|
||||||
|
return f"📋 **Your MCP Servers**\n\n{server_list}"
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_add(
|
||||||
|
interaction: discord.Interaction, url: str, name: str = "memory"
|
||||||
|
) -> str:
|
||||||
|
"""Add a new MCP server via OAuth."""
|
||||||
|
with make_session() as session:
|
||||||
|
if find_mcp_server(session, interaction.user.id, url):
|
||||||
|
return (
|
||||||
|
f"**MCP Server Already Exists**\n\n"
|
||||||
|
f"You already have an MCP server configured at `{url}`.\n"
|
||||||
|
f"Use `/memory_mcp_servers connect {url}` to reconnect."
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoints = await get_endpoints(url)
|
||||||
|
client_id = await register_oauth_client(
|
||||||
|
endpoints,
|
||||||
|
url,
|
||||||
|
f"Discord Bot - {name} ({interaction.user.name})",
|
||||||
|
)
|
||||||
|
mcp_server = DiscordMCPServer(
|
||||||
|
discord_bot_user_id=interaction.user.id,
|
||||||
|
mcp_server_url=url,
|
||||||
|
client_id=client_id,
|
||||||
|
)
|
||||||
|
session.add(mcp_server)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created MCP server record: id={mcp_server.id}, "
|
||||||
|
f"user={interaction.user.id}, url={url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"🔐 **Add MCP Server**\n\n"
|
||||||
|
f"Server: `{url}`\n"
|
||||||
|
f"Click the link below to authorize:\n{auth_url}\n\n"
|
||||||
|
f"⚠️ Keep this link private!\n"
|
||||||
|
f"💡 You'll be redirected to login and grant access to the MCP server."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_delete(interaction: discord.Interaction, url: str) -> str:
|
||||||
|
"""Delete an MCP server."""
|
||||||
|
with make_session() as session:
|
||||||
|
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
||||||
|
if not mcp_server:
|
||||||
|
return (
|
||||||
|
f"**MCP Server Not Found**\n\n"
|
||||||
|
f"You don't have an MCP server configured at `{url}`.\n"
|
||||||
|
)
|
||||||
|
session.delete(mcp_server)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_connect(interaction: discord.Interaction, url: str) -> str:
|
||||||
|
"""Reconnect to an existing MCP server (redo OAuth)."""
|
||||||
|
with make_session() as session:
|
||||||
|
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
||||||
|
if not mcp_server:
|
||||||
|
raise ValueError(
|
||||||
|
f"**MCP Server Not Found**\n\n"
|
||||||
|
f"You don't have an MCP server configured at `{url}`.\n"
|
||||||
|
f"Use `/memory_mcp_servers add {url}` to add it first."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not mcp_server:
|
||||||
|
raise ValueError(
|
||||||
|
f"**MCP Server Not Found**\n\n"
|
||||||
|
f"You don't have an MCP server configured at `{url}`.\n"
|
||||||
|
f"Use `/memory_mcp_servers add {url}` to add it first."
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoints = await get_endpoints(url)
|
||||||
|
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Regenerated OAuth challenge for user={interaction.user.id}, url={url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"🔄 **Reconnect to MCP Server**\n\n"
|
||||||
|
f"Server: `{url}`\n"
|
||||||
|
f"Click the link below to reauthorize:\n{auth_url}\n\n"
|
||||||
|
f"⚠️ Keep this link private!\n"
|
||||||
|
f"💡 You'll be redirected to login and grant access to the MCP server again."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_tools(interaction: discord.Interaction, url: str) -> str:
|
||||||
|
"""List tools available on an MCP server."""
|
||||||
|
with make_session() as session:
|
||||||
|
mcp_server = find_mcp_server(session, interaction.user.id, url)
|
||||||
|
|
||||||
|
if not mcp_server:
|
||||||
|
raise ValueError(
|
||||||
|
f"**MCP Server Not Found**\n\n"
|
||||||
|
f"You don't have an MCP server configured at `{url}`.\n"
|
||||||
|
f"Use `/memory_mcp_servers add {url}` to add it first."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not cast(str | None, mcp_server.access_token):
|
||||||
|
raise ValueError(
|
||||||
|
f"**Not Authorized**\n\n"
|
||||||
|
f"You haven't authorized access to `{url}` yet.\n"
|
||||||
|
f"Use `/memory_mcp_servers connect {url}` to authorize."
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token = cast(str, mcp_server.access_token)
|
||||||
|
|
||||||
|
# Make JSON-RPC request to MCP server
|
||||||
|
tools = None
|
||||||
|
try:
|
||||||
|
async for data in call_mcp_server(url, access_token, "tools/list"):
|
||||||
|
if "result" in data and "tools" in data["result"]:
|
||||||
|
tools = data["result"]["tools"]
|
||||||
|
break
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
logger.exception(f"Failed to connect to MCP server: {exc}")
|
||||||
|
raise ValueError(
|
||||||
|
f"**Connection failed**\n\n"
|
||||||
|
f"Server: `{url}`\n"
|
||||||
|
f"Could not connect to the MCP server: {str(exc)}"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(f"Failed to list tools: {exc}")
|
||||||
|
raise ValueError(
|
||||||
|
f"**Error**\n\nServer: `{url}`\nFailed to list tools: {str(exc)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if tools is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"**Unexpected response format**\n\n"
|
||||||
|
f"Server: `{url}`\n"
|
||||||
|
f"The server returned an unexpected response format."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tools:
|
||||||
|
return (
|
||||||
|
f"🔧 **MCP Server Tools**\n\n"
|
||||||
|
f"Server: `{url}`\n\n"
|
||||||
|
f"No tools available on this server."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format tools list
|
||||||
|
tools_list = "\n".join(
|
||||||
|
f"• **{t.get('name', 'unknown')}**: {t.get('description', 'No description')}"
|
||||||
|
for t in tools
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"🔧 **MCP Server Tools**\n\n"
|
||||||
|
f"Server: `{url}`\n"
|
||||||
|
f"Found {len(tools)} tool(s):\n\n"
|
||||||
|
f"{tools_list}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_mcp_server_command(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||||
|
url: str | None,
|
||||||
|
name: str = "memory",
|
||||||
|
) -> None:
|
||||||
|
"""Handle MCP server management commands."""
|
||||||
|
if action not in ["list", "add", "delete", "connect", "tools"]:
|
||||||
|
await interaction.response.send_message("❌ Invalid action", ephemeral=True)
|
||||||
|
return
|
||||||
|
if action != "list" and not url:
|
||||||
|
await interaction.response.send_message(
|
||||||
|
"❌ URL is required for this action", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if action == "list" or not url:
|
||||||
|
result = await handle_mcp_list(interaction)
|
||||||
|
elif action == "add":
|
||||||
|
result = await handle_mcp_add(interaction, url, name)
|
||||||
|
elif action == "delete":
|
||||||
|
result = await handle_mcp_delete(interaction, url)
|
||||||
|
elif action == "connect":
|
||||||
|
result = await handle_mcp_connect(interaction, url)
|
||||||
|
elif action == "tools":
|
||||||
|
result = await handle_mcp_tools(interaction, url)
|
||||||
|
except Exception as exc:
|
||||||
|
result = f"❌ Error: {exc}"
|
||||||
|
await interaction.response.send_message(result, ephemeral=True)
|
||||||
@ -141,8 +141,10 @@ def previous_messages(
|
|||||||
) -> list[DiscordMessage]:
|
) -> list[DiscordMessage]:
|
||||||
messages = session.query(DiscordMessage)
|
messages = session.query(DiscordMessage)
|
||||||
if user_id:
|
if user_id:
|
||||||
|
print(f"user_id: {user_id}")
|
||||||
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
||||||
if channel_id:
|
if channel_id:
|
||||||
|
print(f"channel_id: {channel_id}")
|
||||||
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
||||||
return list(
|
return list(
|
||||||
reversed(
|
reversed(
|
||||||
|
|||||||
268
src/memory/discord/oauth.py
Normal file
268
src/memory/discord/oauth.py
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from base64 import urlsafe_b64encode
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from urllib.parse import urlencode, urljoin
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models.discord import DiscordMCPServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthEndpoints:
|
||||||
|
authorization_endpoint: str
|
||||||
|
registration_endpoint: str
|
||||||
|
token_endpoint: str
|
||||||
|
redirect_uri: str
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pkce_pair() -> tuple[str, str]:
|
||||||
|
"""Generate PKCE code verifier and challenge.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (code_verifier, code_challenge)
|
||||||
|
"""
|
||||||
|
# Generate a random code verifier
|
||||||
|
code_verifier = (
|
||||||
|
urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create code challenge using S256 method
|
||||||
|
challenge_bytes = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||||
|
code_challenge = urlsafe_b64encode(challenge_bytes).decode("utf-8").rstrip("=")
|
||||||
|
|
||||||
|
return code_verifier, code_challenge
|
||||||
|
|
||||||
|
|
||||||
|
async def discover_oauth_metadata(server_url: str) -> dict | None:
|
||||||
|
"""Discover OAuth metadata from an MCP server."""
|
||||||
|
# Try the standard OAuth discovery endpoint
|
||||||
|
discovery_url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
discovery_url, timeout=aiohttp.ClientTimeout(total=5)
|
||||||
|
) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
return await resp.json()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to discover OAuth metadata from {discovery_url}: {exc}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_endpoints(url: str) -> OAuthEndpoints:
|
||||||
|
# Discover OAuth endpoints from the target server
|
||||||
|
oauth_metadata = await discover_oauth_metadata(url)
|
||||||
|
|
||||||
|
if not oauth_metadata:
|
||||||
|
raise ValueError(
|
||||||
|
"**Failed to connect to MCP server**\n\n"
|
||||||
|
f"Could not discover OAuth endpoints at `{url}`\n"
|
||||||
|
"Make sure the server is running and supports OAuth 2.0.",
|
||||||
|
)
|
||||||
|
|
||||||
|
authorization_endpoint = oauth_metadata.get("authorization_endpoint")
|
||||||
|
registration_endpoint = oauth_metadata.get("registration_endpoint")
|
||||||
|
token_endpoint = oauth_metadata.get("token_endpoint")
|
||||||
|
|
||||||
|
if not authorization_endpoint:
|
||||||
|
raise ValueError(
|
||||||
|
"**Invalid OAuth configuration**\n\n"
|
||||||
|
f"Server `{url}` did not provide an authorization endpoint.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not registration_endpoint:
|
||||||
|
raise ValueError(
|
||||||
|
"**Invalid OAuth configuration**\n\n"
|
||||||
|
f"Server `{url}` does not support dynamic client registration.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not token_endpoint:
|
||||||
|
raise ValueError(
|
||||||
|
"**Invalid OAuth configuration**\n\n"
|
||||||
|
f"Server `{url}` does not provide a token endpoint.",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Authorization endpoint: {authorization_endpoint}")
|
||||||
|
logger.info(f"Registration endpoint: {registration_endpoint}")
|
||||||
|
|
||||||
|
return OAuthEndpoints(
|
||||||
|
authorization_endpoint=authorization_endpoint,
|
||||||
|
registration_endpoint=registration_endpoint,
|
||||||
|
token_endpoint=token_endpoint,
|
||||||
|
redirect_uri=f"http://{settings.DISCORD_COLLECTOR_SERVER_URL}:{settings.DISCORD_COLLECTOR_PORT}/oauth/callback/discord",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def register_oauth_client(
|
||||||
|
endpoints: OAuthEndpoints,
|
||||||
|
url: str,
|
||||||
|
client_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Register OAuth client and store client_id in the mcp_server object."""
|
||||||
|
client_metadata = {
|
||||||
|
"client_name": client_name,
|
||||||
|
"redirect_uris": [endpoints.redirect_uri],
|
||||||
|
"grant_types": ["authorization_code", "refresh_token"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"scope": "read write",
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.error(f"Registration metadata: {client_metadata}")
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
endpoints.registration_endpoint,
|
||||||
|
json=client_metadata,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=5),
|
||||||
|
) as resp:
|
||||||
|
logger.error(
|
||||||
|
f"Registration response: {resp.status} {await resp.text()}"
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
client_info = await resp.json()
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to register OAuth client at {endpoints.registration_endpoint}: {exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not client_info or "client_id" not in client_info:
|
||||||
|
raise ValueError(
|
||||||
|
"**Failed to register OAuth client**\n\n"
|
||||||
|
f"Could not register with the MCP server at `{url}`\n"
|
||||||
|
f"Check the server logs for more details.",
|
||||||
|
)
|
||||||
|
|
||||||
|
client_id = client_info["client_id"]
|
||||||
|
|
||||||
|
logger.info(f"Registered OAuth client: {client_id}")
|
||||||
|
return client_id
|
||||||
|
|
||||||
|
|
||||||
|
async def issue_challenge(
|
||||||
|
mcp_server: DiscordMCPServer,
|
||||||
|
endpoints: OAuthEndpoints,
|
||||||
|
) -> str:
|
||||||
|
"""Generate OAuth challenge and store state in mcp_server object."""
|
||||||
|
code_verifier, code_challenge = generate_pkce_pair()
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
# Store in mcp_server object
|
||||||
|
mcp_server.state = state # type: ignore
|
||||||
|
mcp_server.code_verifier = code_verifier # type: ignore
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: "
|
||||||
|
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build authorization URL pointing to the target server
|
||||||
|
auth_params = {
|
||||||
|
"client_id": mcp_server.client_id,
|
||||||
|
"redirect_uri": endpoints.redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": code_challenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
"scope": "read write",
|
||||||
|
}
|
||||||
|
|
||||||
|
return f"{endpoints.authorization_endpoint}?{urlencode(auth_params)}"
|
||||||
|
|
||||||
|
|
||||||
|
async def complete_oauth_flow(
|
||||||
|
session: Session | scoped_session, code: str, state: str
|
||||||
|
) -> tuple[int, str]:
|
||||||
|
"""Complete OAuth flow by exchanging code for token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from OAuth callback
|
||||||
|
state: State parameter from OAuth callback
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (status_code, html_message) for the callback response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
mcp_server = (
|
||||||
|
session.query(DiscordMCPServer)
|
||||||
|
.filter(DiscordMCPServer.state == state)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not mcp_server:
|
||||||
|
logger.error(f"Invalid or expired state: {state[:20]}...")
|
||||||
|
return 400, "Invalid or expired OAuth state"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found MCP server config: user={mcp_server.discord_bot_user_id}, "
|
||||||
|
f"url={mcp_server.mcp_server_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get OAuth endpoints
|
||||||
|
try:
|
||||||
|
endpoints = await get_endpoints(str(mcp_server.mcp_server_url))
|
||||||
|
except Exception as exc:
|
||||||
|
return 500, f"Failed to get OAuth endpoints: {str(exc)}"
|
||||||
|
|
||||||
|
# Exchange authorization code for access token
|
||||||
|
token_data = {
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": endpoints.redirect_uri,
|
||||||
|
"client_id": mcp_server.client_id,
|
||||||
|
"code_verifier": mcp_server.code_verifier,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as http_session:
|
||||||
|
async with http_session.post(
|
||||||
|
endpoints.token_endpoint,
|
||||||
|
data=token_data,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=10),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
logger.error(f"Token exchange failed: {resp.status} - {error_text}")
|
||||||
|
return 500, f"Token exchange failed: {error_text}"
|
||||||
|
|
||||||
|
tokens = await resp.json()
|
||||||
|
|
||||||
|
access_token = tokens.get("access_token")
|
||||||
|
refresh_token = tokens.get("refresh_token")
|
||||||
|
expires_in = tokens.get("expires_in", 3600)
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
return 500, "Token response did not include access_token"
|
||||||
|
|
||||||
|
logger.info(f"Successfully obtained access token: {access_token[:20]}...")
|
||||||
|
|
||||||
|
# Store tokens and clear temporary OAuth state
|
||||||
|
mcp_server.access_token = access_token # type: ignore
|
||||||
|
mcp_server.refresh_token = refresh_token # type: ignore
|
||||||
|
mcp_server.token_expires_at = datetime.now() + timedelta(seconds=expires_in) # type: ignore
|
||||||
|
|
||||||
|
# Clear temporary OAuth flow data
|
||||||
|
mcp_server.state = None # type: ignore
|
||||||
|
mcp_server.code_verifier = None # type: ignore
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
|
||||||
|
f"server {mcp_server.mcp_server_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, "✅ Authorization successful! You can now use this MCP server."
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(f"Failed to complete OAuth flow: {exc}")
|
||||||
|
return 500, f"Failed to complete OAuth flow: {str(exc)}"
|
||||||
Loading…
x
Reference in New Issue
Block a user