Compare commits

...

4 Commits

Author SHA1 Message Date
2944a0bce1 properly handle mcp redirects 2025-11-03 00:00:02 +00:00
Daniel O'Connell
0d9f8beec3 handle mcp servers in discord 2025-11-02 23:49:50 +01:00
Daniel O'Connell
64bb926eba mcp servers for discord bots 2025-11-02 23:49:44 +01:00
Daniel O'Connell
6250586d1f prompt from bot user 2025-11-02 23:49:35 +01:00
20 changed files with 1073 additions and 82 deletions

View File

@ -0,0 +1,67 @@
"""discord mcp servers
Revision ID: 9b887449ea92
Revises: 1954477b25f4
Create Date: 2025-11-02 22:04:26.259323
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "9b887449ea92"
down_revision: Union[str, None] = "1954477b25f4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"discord_mcp_servers",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
sa.Column("mcp_server_url", sa.Text(), nullable=False),
sa.Column("client_id", sa.Text(), nullable=False),
sa.Column("state", sa.Text(), nullable=True),
sa.Column("code_verifier", sa.Text(), nullable=True),
sa.Column("access_token", sa.Text(), nullable=True),
sa.Column("refresh_token", sa.Text(), nullable=True),
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(
["discord_bot_user_id"],
["discord_users.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("state"),
)
op.create_index(
"discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
)
op.create_index(
"discord_mcp_user_url_idx",
"discord_mcp_servers",
["discord_bot_user_id", "mcp_server_url"],
unique=False,
)
def downgrade() -> None:
op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
op.drop_table("discord_mcp_servers")

View File

@ -14,6 +14,7 @@ from memory.common.db.models import (
BookSection,
Chunk,
Comic,
DiscordMCPServer,
DiscordMessage,
EmailAccount,
EmailAttachment,
@ -166,6 +167,37 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage):
column_sortable_list = ["sent_at"]
class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
column_list = [
"id",
"mcp_server_url",
"client_id",
"discord_bot_user_id",
"state",
"code_verifier",
"access_token",
"refresh_token",
"token_expires_at",
"created_at",
"updated_at",
]
column_searchable_list = [
"mcp_server_url",
"client_id",
"state",
"id",
"discord_bot_user_id",
]
column_sortable_list = [
"created_at",
"updated_at",
"mcp_server_url",
"client_id",
"state",
"id",
]
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
column_list = [
"id",
@ -328,4 +360,5 @@ def setup_admin(admin: Admin):
admin.add_view(DiscordUserAdmin)
admin.add_view(DiscordServerAdmin)
admin.add_view(DiscordChannelAdmin)
admin.add_view(DiscordMCPServerAdmin)
admin.add_view(ScheduledLLMCallAdmin)

View File

@ -1,13 +1,21 @@
from datetime import datetime, timedelta, timezone
import logging
from datetime import datetime, timedelta, timezone
from fastapi import HTTPException, Depends, Request, Response, APIRouter
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sqlalchemy.orm import Session as DBSession
from sqlalchemy.orm import scoped_session
from starlette.middleware.base import BaseHTTPMiddleware
from memory.common import settings
from sqlalchemy.orm import Session as DBSession, scoped_session
from memory.common import settings
from memory.common.db.connection import get_session, make_session
from memory.common.db.models.users import User, HumanUser, BotUser, UserSession
from memory.common.db.models import (
BotUser,
DiscordMCPServer,
HumanUser,
User,
UserSession,
)
from memory.common.oauth import complete_oauth_flow
logger = logging.getLogger(__name__)
@ -136,6 +144,58 @@ def get_me(user: User = Depends(get_current_user)):
return user.serialize()
@router.get("/callback/discord")
async def oauth_callback_discord(request: Request):
"""Get current user info"""
code = request.query_params.get("code")
state = request.query_params.get("state")
error = request.query_params.get("error")
logger.info(
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
)
message, title, close, status_code = "", "", "", 200
if error:
logger.error(f"OAuth error: {error}")
message = f"Error: {error}"
title = "❌ Authorization Failed"
status_code = 400
elif not code or not state:
message = "Missing authorization code or state parameter."
title = "❌ Invalid Request"
status_code = 400
else:
# Complete the OAuth flow (exchange code for token)
with make_session() as session:
mcp_server = (
session.query(DiscordMCPServer)
.filter(DiscordMCPServer.state == state)
.first()
)
status_code, message = await complete_oauth_flow(mcp_server, code, state)
session.commit()
if 200 <= status_code < 300:
title = "✅ Authorization Successful!"
close = "You can close this window and return to the MCP server."
else:
title = "❌ Authorization Failed"
return Response(
content=f"""
<html>
<body>
<h1>{title}</h1>
<p>{message}</p>
<p>{close}</p>
</body>
</html>
""",
status_code=status_code,
)
class AuthenticationMiddleware(BaseHTTPMiddleware):
"""Middleware to require authentication for all endpoints except whitelisted ones."""

View File

@ -34,6 +34,7 @@ from memory.common.db.models.discord import (
DiscordServer,
DiscordChannel,
DiscordUser,
DiscordMCPServer,
)
from memory.common.db.models.observations import (
ObservationContradiction,
@ -106,6 +107,7 @@ __all__ = [
"DiscordServer",
"DiscordChannel",
"DiscordUser",
"DiscordMCPServer",
# Users
"User",
"HumanUser",

View File

@ -127,5 +127,44 @@ class DiscordUser(Base, MessageProcessor):
updated_at = Column(DateTime(timezone=True), server_default=func.now())
system_user = relationship("User", back_populates="discord_users")
mcp_servers = relationship(
"DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
)
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
class DiscordMCPServer(Base):
"""MCP server configuration and OAuth state for Discord users."""
__tablename__ = "discord_mcp_servers"
id = Column(Integer, primary_key=True)
discord_bot_user_id = Column(
BigInteger, ForeignKey("discord_users.id"), nullable=False
)
# MCP server info
mcp_server_url = Column(Text, nullable=False)
client_id = Column(Text, nullable=False)
# OAuth flow state (temporary, cleared after token exchange)
state = Column(Text, nullable=True, unique=True)
code_verifier = Column(Text, nullable=True)
# OAuth tokens (set after successful authorization)
access_token = Column(Text, nullable=True)
refresh_token = Column(Text, nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
discord_user = relationship("DiscordUser", back_populates="mcp_servers")
__table_args__ = (
Index("discord_mcp_state_idx", "state"),
Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"),
)

View File

@ -363,7 +363,19 @@ class DiscordMessage(SourceItem):
@property
def title(self) -> str:
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
return textwrap.dedent("""
<message>
<id>{message_id}</id>
<from>{from_user}</from>
<sent_at>{sent_at}</sent_at>
<content>{content}</content>
</message>
""").format(
message_id=self.message_id,
from_user=self.from_user.username,
sent_at=self.sent_at.isoformat()[:19],
content=self.content,
)
def as_content(self) -> dict[str, Any]:
"""Return message content ready for LLM (text + images from disk)."""

View File

@ -1,22 +1,23 @@
import hashlib
import secrets
from typing import cast
import uuid
from sqlalchemy.orm import Session
from memory.common.db.models.base import Base
from typing import cast
from sqlalchemy import (
ARRAY,
Boolean,
CheckConstraint,
Column,
Integer,
String,
DateTime,
ForeignKey,
Boolean,
ARRAY,
Integer,
Numeric,
CheckConstraint,
String,
)
from sqlalchemy.orm import Session, relationship
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
def hash_password(password: str) -> str:

View File

@ -94,6 +94,30 @@ def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
return False
def add_reaction(bot_id: int, channel: int | str, message_id: int, emoji: str) -> bool:
"""Add a reaction to a message in a channel"""
try:
response = requests.post(
f"{get_api_url()}/add_reaction",
json={
"bot_id": bot_id,
"channel": channel,
"message_id": message_id,
"emoji": emoji,
},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(
f"Failed to add reaction {emoji} to message {message_id} in channel {channel}: {e}"
)
return False
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
"""Send a message to a channel by name or ID (ID supports threads)"""
try:

View File

@ -9,6 +9,7 @@ import anthropic
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
MCPServer,
LLMSettings,
Message,
MessageRole,
@ -103,6 +104,7 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None,
tools: list[ToolDefinition] | None,
mcp_servers: list[MCPServer] | None,
settings: LLMSettings,
) -> dict[str, Any]:
"""Build common request kwargs for API calls."""
@ -113,7 +115,9 @@ class AnthropicProvider(BaseLLMProvider):
"messages": anthropic_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
"extra_headers": {
"anthropic-beta": "web-fetch-2025-09-10,mcp-client-2025-04-04"
},
}
# Only include top_p if explicitly set
@ -129,6 +133,25 @@ class AnthropicProvider(BaseLLMProvider):
if tools:
kwargs["tools"] = self._convert_tools(tools)
if mcp_servers:
def format_server(server: MCPServer) -> dict[str, Any]:
conf: dict[str, Any] = {
"type": "url",
"url": server.url,
"name": server.name,
"authorization_token": server.token,
}
if server.allowed_tools:
conf["tool_configuration"] = {
"allowed_tools": server.allowed_tools,
}
return conf
kwargs["extra_body"] = {
"mcp_servers": [format_server(server) for server in mcp_servers]
}
# Enable extended thinking if requested and model supports it
if self.enable_thinking and self._supports_thinking():
thinking_budget = min(10000, settings.max_tokens - 1024)
@ -312,11 +335,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try:
response = self.client.messages.create(**kwargs)
@ -332,11 +358,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""Generate a streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try:
with self.client.messages.stream(**kwargs) as stream:
@ -358,11 +387,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try:
response = await self.async_client.messages.create(**kwargs)
@ -378,11 +410,14 @@ class AnthropicProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, mcp_servers, settings
)
try:
async with self.async_client.messages.stream(**kwargs) as stream:

View File

@ -11,7 +11,7 @@ from typing import Any, AsyncIterator, Iterator, Literal, Union, cast
from PIL import Image
from memory.common import settings
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
from memory.common.llms.tools import MCPServer, ToolCall, ToolDefinition, ToolResult
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
logger = logging.getLogger(__name__)
@ -434,6 +434,7 @@ class BaseLLMProvider(ABC):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""
@ -456,6 +457,7 @@ class BaseLLMProvider(ABC):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""
@ -478,6 +480,7 @@ class BaseLLMProvider(ABC):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""
@ -500,6 +503,7 @@ class BaseLLMProvider(ABC):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""
@ -520,6 +524,7 @@ class BaseLLMProvider(ABC):
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
@ -551,6 +556,7 @@ class BaseLLMProvider(ABC):
messages=messages,
system_prompt=system_prompt,
tools=list(tools.values()),
mcp_servers=mcp_servers,
settings=settings,
):
if event.type == "text":
@ -583,7 +589,12 @@ class BaseLLMProvider(ABC):
# Recursively continue the conversation with reduced iterations
yield from self.stream_with_tools(
messages, tools, settings, system_prompt, max_iterations - 1
messages,
tools,
mcp_servers,
settings,
system_prompt,
max_iterations - 1,
)
return # Exit after recursive call completes
@ -598,6 +609,7 @@ class BaseLLMProvider(ABC):
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
@ -606,6 +618,7 @@ class BaseLLMProvider(ABC):
for event in self.stream_with_tools(
messages=messages,
tools=tools,
mcp_servers=mcp_servers,
settings=settings,
system_prompt=system_prompt,
max_iterations=max_iterations,

View File

@ -9,6 +9,7 @@ import openai
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
MCPServer,
LLMSettings,
Message,
StreamEvent,
@ -175,6 +176,7 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None,
tools: list[ToolDefinition] | None,
mcp_servers: list[MCPServer] | None,
settings: LLMSettings,
stream: bool = False,
) -> dict[str, Any]:
@ -185,6 +187,7 @@ class OpenAIProvider(BaseLLMProvider):
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools
mcp_servers: Optional list of MCP servers
settings: LLM settings
stream: Whether to enable streaming
@ -333,12 +336,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=False
messages, system_prompt, tools, mcp_servers, settings, stream=False
)
try:
@ -361,12 +365,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""Generate a streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=True
messages, system_prompt, tools, mcp_servers, settings, stream=True
)
if kwargs.get("stream"):
@ -393,12 +398,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=False
messages, system_prompt, tools, mcp_servers, settings, stream=False
)
try:
@ -413,12 +419,13 @@ class OpenAIProvider(BaseLLMProvider):
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
mcp_servers: list[MCPServer] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(
messages, system_prompt, tools, settings, stream=True
messages, system_prompt, tools, mcp_servers, settings, stream=True
)
try:

View File

@ -23,6 +23,16 @@ class ToolResult(TypedDict):
output: str
@dataclass
class MCPServer:
"""An MCP server."""
name: str
url: str
token: str
allowed_tools: list[str] | None = None
@dataclass
class ToolDefinition:
"""Definition of a tool that can be called by the LLM."""

View File

@ -17,6 +17,7 @@ from memory.common.db.models import (
BotUser,
)
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
from memory.common.discord import add_reaction
UpdateSummaryType = Literal["server", "channel", "user"]
@ -209,6 +210,50 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
)
def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition:
bot_id = cast(int, bot.id)
channel_id = channel and channel.id
def handler(input: ToolInput) -> str:
if not isinstance(input, dict):
raise ValueError("Input must be a dictionary")
try:
emoji = input.get("emoji")
except ValueError:
raise ValueError("Emoji is required")
if not emoji:
raise ValueError("Emoji is required")
try:
message_id = int(input.get("message_id") or "no id")
except ValueError:
raise ValueError("Message ID is required")
success = add_reaction(bot_id, channel_id, message_id, emoji)
if not success:
return "Failed to add reaction"
return "Reaction added"
return ToolDefinition(
name="add_reaction",
description="Add a reaction to a message in a channel",
input_schema={
"type": "object",
"properties": {
"message_id": {
"type": "number",
"description": "The ID of the message to add the reaction to",
},
"emoji": {
"type": "string",
"description": "The emoji to add to the message",
},
},
},
function=handler,
)
def make_discord_tools(
bot: BotUser,
author: DiscordUser | None,
@ -227,5 +272,6 @@ def make_discord_tools(
if channel and channel.server:
tools += [
make_summary_tool("server", cast(BigInteger, channel.server_id)),
make_add_reaction_tool(bot, channel),
]
return {tool.name: tool for tool in tools}

258
src/memory/common/oauth.py Normal file
View File

@ -0,0 +1,258 @@
import hashlib
import logging
import secrets
from base64 import urlsafe_b64encode
from dataclasses import dataclass
from datetime import datetime, timedelta
from urllib.parse import urlencode, urljoin
import aiohttp
from memory.common import settings
from memory.common.db.models.discord import DiscordMCPServer
logger = logging.getLogger(__name__)
@dataclass
class OAuthEndpoints:
authorization_endpoint: str
registration_endpoint: str
token_endpoint: str
redirect_uri: str
def generate_pkce_pair() -> tuple[str, str]:
"""Generate PKCE code verifier and challenge.
Returns:
Tuple of (code_verifier, code_challenge)
"""
# Generate a random code verifier
code_verifier = (
urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
)
# Create code challenge using S256 method
challenge_bytes = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = urlsafe_b64encode(challenge_bytes).decode("utf-8").rstrip("=")
return code_verifier, code_challenge
async def discover_oauth_metadata(server_url: str) -> dict | None:
"""Discover OAuth metadata from an MCP server."""
# Try the standard OAuth discovery endpoint
discovery_url = urljoin(server_url, "/.well-known/oauth-authorization-server")
try:
async with aiohttp.ClientSession() as session:
async with session.get(
discovery_url, timeout=aiohttp.ClientTimeout(total=5)
) as resp:
if resp.status == 200:
return await resp.json()
except Exception as exc:
logger.warning(f"Failed to discover OAuth metadata from {discovery_url}: {exc}")
return None
async def get_endpoints(url: str) -> OAuthEndpoints:
# Discover OAuth endpoints from the target server
oauth_metadata = await discover_oauth_metadata(url)
if not oauth_metadata:
raise ValueError(
"**Failed to connect to MCP server**\n\n"
f"Could not discover OAuth endpoints at `{url}`\n"
"Make sure the server is running and supports OAuth 2.0.",
)
authorization_endpoint = oauth_metadata.get("authorization_endpoint")
registration_endpoint = oauth_metadata.get("registration_endpoint")
token_endpoint = oauth_metadata.get("token_endpoint")
if not authorization_endpoint:
raise ValueError(
"**Invalid OAuth configuration**\n\n"
f"Server `{url}` did not provide an authorization endpoint.",
)
if not registration_endpoint:
raise ValueError(
"**Invalid OAuth configuration**\n\n"
f"Server `{url}` does not support dynamic client registration.",
)
if not token_endpoint:
raise ValueError(
"**Invalid OAuth configuration**\n\n"
f"Server `{url}` does not provide a token endpoint.",
)
logger.info(f"Authorization endpoint: {authorization_endpoint}")
logger.info(f"Registration endpoint: {registration_endpoint}")
return OAuthEndpoints(
authorization_endpoint=authorization_endpoint,
registration_endpoint=registration_endpoint,
token_endpoint=token_endpoint,
redirect_uri=f"{settings.SERVER_URL}/auth/callback/discord",
)
async def register_oauth_client(
endpoints: OAuthEndpoints,
url: str,
client_name: str,
) -> None:
"""Register OAuth client and store client_id in the mcp_server object."""
client_metadata = {
"client_name": client_name,
"redirect_uris": [endpoints.redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"scope": "read write",
"token_endpoint_auth_method": "none",
}
logger.error(f"Registration metadata: {client_metadata}")
try:
async with aiohttp.ClientSession() as session:
async with session.post(
endpoints.registration_endpoint,
json=client_metadata,
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
logger.error(
f"Registration response: {resp.status} {await resp.text()}"
)
resp.raise_for_status()
client_info = await resp.json()
except Exception as exc:
raise ValueError(
f"Failed to register OAuth client at {endpoints.registration_endpoint}: {exc}"
)
if not client_info or "client_id" not in client_info:
raise ValueError(
"**Failed to register OAuth client**\n\n"
f"Could not register with the MCP server at `{url}`\n"
f"Check the server logs for more details.",
)
client_id = client_info["client_id"]
logger.info(f"Registered OAuth client: {client_id}")
return client_id
async def issue_challenge(
mcp_server: DiscordMCPServer,
endpoints: OAuthEndpoints,
) -> str:
"""Generate OAuth challenge and store state in mcp_server object."""
code_verifier, code_challenge = generate_pkce_pair()
state = secrets.token_urlsafe(32)
# Store in mcp_server object
mcp_server.state = state # type: ignore
mcp_server.code_verifier = code_verifier # type: ignore
logger.info(
f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: "
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
)
# Build authorization URL pointing to the target server
auth_params = {
"client_id": mcp_server.client_id,
"redirect_uri": endpoints.redirect_uri,
"response_type": "code",
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"scope": "read write",
}
return f"{endpoints.authorization_endpoint}?{urlencode(auth_params)}"
async def complete_oauth_flow(
mcp_server: DiscordMCPServer, code: str, state: str
) -> tuple[int, str]:
"""Complete OAuth flow by exchanging code for token.
Args:
code: Authorization code from OAuth callback
state: State parameter from OAuth callback
Returns:
Tuple of (status_code, html_message) for the callback response
"""
try:
if not mcp_server:
logger.error(f"Invalid or expired state: {state[:20]}...")
return 400, "Invalid or expired OAuth state"
logger.info(
f"Found MCP server config: user={mcp_server.discord_bot_user_id}, "
f"url={mcp_server.mcp_server_url}"
)
# Get OAuth endpoints
try:
endpoints = await get_endpoints(str(mcp_server.mcp_server_url))
except Exception as exc:
return 500, f"Failed to get OAuth endpoints: {str(exc)}"
# Exchange authorization code for access token
token_data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": endpoints.redirect_uri,
"client_id": mcp_server.client_id,
"code_verifier": mcp_server.code_verifier,
}
async with aiohttp.ClientSession() as http_session:
async with http_session.post(
endpoints.token_endpoint,
data=token_data,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Token exchange failed: {resp.status} - {error_text}")
return 500, f"Token exchange failed: {error_text}"
tokens = await resp.json()
access_token = tokens.get("access_token")
refresh_token = tokens.get("refresh_token")
expires_in = tokens.get("expires_in", 3600)
if not access_token:
return 500, "Token response did not include access_token"
logger.info(f"Successfully obtained access token: {access_token[:20]}...")
# Store tokens and clear temporary OAuth state
mcp_server.access_token = access_token # type: ignore
mcp_server.refresh_token = refresh_token # type: ignore
mcp_server.token_expires_at = datetime.now() + timedelta(seconds=expires_in) # type: ignore
# Clear temporary OAuth flow data
mcp_server.state = None # type: ignore
mcp_server.code_verifier = None # type: ignore
logger.info(
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
f"server {mcp_server.mcp_server_url}"
)
return 200, "✅ Authorization successful! You can now use this MCP server."
except Exception as exc:
logger.exception(f"Failed to complete OAuth flow: {exc}")
return 500, f"Failed to complete OAuth flow: {str(exc)}"

View File

@ -12,12 +12,14 @@ from contextlib import asynccontextmanager
from typing import cast
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models.users import DiscordBotUser
from memory.common.db.models import DiscordMCPServer, DiscordBotUser
from memory.common.oauth import complete_oauth_flow
from memory.discord.collector import MessageCollector
logger = logging.getLogger(__name__)
@ -45,6 +47,13 @@ class TypingChannelRequest(BaseModel):
channel: int | str # Channel name or ID (ID supports threads)
class AddReactionRequest(BaseModel):
bot_id: int
channel: int | str # Channel name or ID (ID supports threads)
message_id: int
emoji: str
class Collector:
collector: MessageCollector
collector_task: asyncio.Task
@ -53,6 +62,7 @@ class Collector:
bot_name: str
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}")
self.collector = collector
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
self.bot_id = cast(int, bot.id)
@ -72,7 +82,9 @@ async def lifespan(app: FastAPI):
bots = session.query(DiscordBotUser).all()
app.bots = {bot.id: make_collector(bot) for bot in bots}
logger.info(f"Discord collectors started for {len(app.bots)} bots")
logger.error(
f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}"
)
yield
@ -216,6 +228,36 @@ async def health_check():
}
@app.post("/add_reaction")
async def add_reaction_endpoint(request: AddReactionRequest):
"""Add a reaction to a message via the collector's Discord client"""
collector = app.bots.get(request.bot_id)
if not collector:
raise HTTPException(status_code=404, detail="Bot not found")
try:
success = await collector.collector.add_reaction(
request.channel, request.message_id, request.emoji
)
except Exception as e:
logger.error(f"Failed to add reaction: {e}")
raise HTTPException(status_code=500, detail=str(e))
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to add reaction to message {request.message_id}",
)
return {
"success": True,
"channel": request.channel,
"message_id": request.message_id,
"emoji": request.emoji,
"message": f"Added reaction {request.emoji} to message {request.message_id}",
}
@app.post("/refresh_metadata")
async def refresh_metadata():
"""Refresh Discord server/channel/user metadata from Discord API"""

View File

@ -229,11 +229,20 @@ class MessageCollector(commands.Bot):
intents=intents,
help_command=None, # Disable default help
)
logger.info(f"Initialized collector for {self.user}")
async def setup_hook(self):
"""Register slash commands when the bot is ready."""
register_slash_commands(self, name=self.user.name)
if not self.user:
logger.error(f"Failed to get user name for {self.user}")
return
try:
register_slash_commands(self)
except Exception as e:
logger.error(f"Failed to register slash commands for {self.user.name}: {e}")
logger.error(f"Registered slash commands for {self.user.name}")
async def on_ready(self):
"""Called when bot connects to Discord"""
@ -313,8 +322,6 @@ class MessageCollector(commands.Bot):
async def refresh_metadata(self) -> dict[str, int]:
"""Refresh server and channel metadata from Discord and update database"""
print("🔄 Refreshing Discord metadata...")
servers_updated = 0
channels_updated = 0
users_updated = 0
@ -454,25 +461,33 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False
async def _get_channel(
self, channel_identifier: int | str, check_notifications: bool = True
):
"""Get channel by ID or name with standard checks"""
if check_notifications and not settings.DISCORD_NOTIFICATIONS_ENABLED:
logger.debug("Discord notifications disabled")
return None
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
if not channel:
logger.error(f"Channel {channel_identifier} not found")
return channel
async def send_to_channel(
self, channel_identifier: int | str, message: str
) -> bool:
"""Send a message to a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False
try:
# Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
channel = await self._get_channel(channel_identifier)
if not channel:
logger.error(f"Channel {channel_identifier} not found")
return False
# Post-process mentions to convert usernames to IDs
with make_session() as session:
processed_message = process_mentions(session, message)
@ -486,18 +501,9 @@ class MessageCollector(commands.Bot):
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False
try:
# Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
channel = await self._get_channel(channel_identifier)
if not channel:
logger.error(f"Channel {channel_identifier} not found")
return False
async with channel.typing():
@ -509,3 +515,21 @@ class MessageCollector(commands.Bot):
f"Failed to trigger typing for channel {channel_identifier}: {e}"
)
return False
async def add_reaction(
self, channel_identifier: int | str, message_id: int, emoji: str
) -> bool:
"""Add a reaction to a message in a channel"""
try:
channel = await self._get_channel(channel_identifier)
if not channel:
return False
message = await channel.fetch_message(message_id)
await message.add_reaction(emoji)
logger.info(f"Added reaction {emoji} to message {message_id}")
return True
except Exception as e:
logger.error(f"Failed to add reaction: {e}")
return False

View File

@ -1,7 +1,7 @@
"""Lightweight slash-command helpers for the Discord collector."""
from __future__ import annotations
from calendar import c
import logging
from dataclasses import dataclass
from typing import Callable, Literal
@ -10,6 +10,9 @@ from sqlalchemy.orm import Session
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
from memory.discord.mcp import run_mcp_server_command
logger = logging.getLogger(__name__)
ScopeLiteral = Literal["server", "channel", "user"]
@ -41,7 +44,7 @@ class CommandContext:
CommandHandler = Callable[..., CommandResponse]
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
def register_slash_commands(bot: discord.Client) -> None:
"""Register the collector slash commands on the provided bot.
Args:
@ -58,6 +61,7 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
raise RuntimeError("Bot instance does not support app commands")
tree = bot.tree
name = bot.user and bot.user.name.replace("-", "_").lower()
@tree.command(
name=f"{name}_show_prompt", description="Show the current system prompt"
@ -167,6 +171,21 @@ def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
target_user=user,
)
@tree.command(
name=f"{name}_mcp_servers",
description="Manage MCP servers for your account",
)
@discord.app_commands.describe(
action="Action to perform",
url="MCP server URL (required for add, delete, connect, tools)",
)
async def mcp_servers_command(
interaction: discord.Interaction,
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
url: str | None = None,
) -> None:
await run_mcp_server_command(interaction, bot.user, action, url and url.strip())
async def _run_interaction_command(
interaction: discord.Interaction,
@ -177,17 +196,10 @@ async def _run_interaction_command(
**handler_kwargs,
) -> None:
"""Shared coroutine used by the registered slash commands."""
try:
with make_session() as session:
response = run_command(
session,
interaction,
scope,
handler=handler,
target_user=target_user,
**handler_kwargs,
)
context = _build_context(session, interaction, scope, target_user)
response = handler(context, **handler_kwargs)
session.commit()
except CommandError as exc: # pragma: no cover - passthrough
await interaction.response.send_message(str(exc), ephemeral=True)
@ -199,21 +211,6 @@ async def _run_interaction_command(
)
def run_command(
session: Session,
interaction: discord.Interaction,
scope: ScopeLiteral,
*,
handler: CommandHandler,
target_user: discord.User | None = None,
**handler_kwargs,
) -> CommandResponse:
"""Create a :class:`CommandContext` and execute the handler."""
context = _build_context(session, interaction, scope, target_user)
return handler(context, **handler_kwargs)
def _build_context(
session: Session,
interaction: discord.Interaction,

301
src/memory/discord/mcp.py Normal file
View File

@ -0,0 +1,301 @@
"""Lightweight slash-command helpers for the Discord collector."""
import json
import logging
import time
from typing import Any, AsyncGenerator, Literal, cast
import aiohttp
import discord
from sqlalchemy.orm import Session, scoped_session
from memory.common.db.connection import make_session
from memory.common.db.models.discord import DiscordMCPServer
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
logger = logging.getLogger(__name__)
def find_mcp_server(
session: Session | scoped_session, user_id: int, url: str
) -> DiscordMCPServer | None:
return (
session.query(DiscordMCPServer)
.filter(
DiscordMCPServer.discord_bot_user_id == user_id,
DiscordMCPServer.mcp_server_url == url,
)
.first()
)
async def call_mcp_server(
url: str, access_token: str, method: str, params: dict = {}
) -> AsyncGenerator[Any, None]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Authorization": f"Bearer {access_token}",
}
payload = {
"jsonrpc": "2.0",
"id": int(time.time() * 1000),
"method": method,
"params": params,
}
async with aiohttp.ClientSession() as http_session:
async with http_session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Tools list failed: {resp.status} - {error_text}")
raise ValueError(
f"Failed to call MCP server: {resp.status} - {error_text}"
)
# Parse SSE stream
async for line in resp.content:
line_str = line.decode("utf-8").strip()
# SSE format: "data: {json}"
if line_str.startswith("data: "):
json_str = line_str[6:] # Remove "data: " prefix
try:
yield json.loads(json_str)
except json.JSONDecodeError:
continue # Skip invalid JSON lines
async def handle_mcp_list(interaction: discord.Interaction) -> str:
"""List all MCP servers for the user."""
with make_session() as session:
servers = (
session.query(DiscordMCPServer)
.filter(
DiscordMCPServer.discord_bot_user_id == interaction.user.id,
)
.all()
)
if not servers:
return (
"📋 **Your MCP Servers**\n\n"
"You don't have any MCP servers configured yet.\n"
"Use `/memory_mcp_servers add <url>` to add one."
)
def format_server(server: DiscordMCPServer) -> str:
con = "🟢" if cast(str | None, server.access_token) else "🔴"
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
server_list = "\n".join(format_server(s) for s in servers)
return f"📋 **Your MCP Servers**\n\n{server_list}"
async def handle_mcp_add(
interaction: discord.Interaction,
bot_user: discord.User | None,
url: str,
) -> str:
"""Add a new MCP server via OAuth."""
if not bot_user:
raise ValueError("Bot user is required")
with make_session() as session:
if find_mcp_server(session, bot_user.id, url):
return (
f"**MCP Server Already Exists**\n\n"
f"You already have an MCP server configured at `{url}`.\n"
f"Use `/memory_mcp_servers connect {url}` to reconnect."
)
endpoints = await get_endpoints(url)
client_id = await register_oauth_client(
endpoints,
url,
f"Discord Bot - {bot_user.name} ({interaction.user.name})",
)
mcp_server = DiscordMCPServer(
discord_bot_user_id=bot_user.id,
mcp_server_url=url,
client_id=client_id,
)
session.add(mcp_server)
session.flush()
auth_url = await issue_challenge(mcp_server, endpoints)
session.commit()
logger.info(
f"Created MCP server record: id={mcp_server.id}, "
f"user={interaction.user.id}, url={url}"
)
return (
f"🔐 **Add MCP Server**\n\n"
f"Server: `{url}`\n"
f"Click the link below to authorize:\n{auth_url}\n\n"
f"⚠️ Keep this link private!\n"
f"💡 You'll be redirected to login and grant access to the MCP server."
)
async def handle_mcp_delete(bot_user: discord.User, url: str) -> str:
"""Delete an MCP server."""
with make_session() as session:
mcp_server = find_mcp_server(session, bot_user.id, url)
if not mcp_server:
return (
f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n"
)
session.delete(mcp_server)
session.commit()
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
"""Reconnect to an existing MCP server (redo OAuth)."""
with make_session() as session:
mcp_server = find_mcp_server(session, bot_user.id, url)
if not mcp_server:
raise ValueError(
f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n"
f"Use `/memory_mcp_servers add {url}` to add it first."
)
if not mcp_server:
raise ValueError(
f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n"
f"Use `/memory_mcp_servers add {url}` to add it first."
)
endpoints = await get_endpoints(url)
auth_url = await issue_challenge(mcp_server, endpoints)
session.commit()
logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}")
return (
f"🔄 **Reconnect to MCP Server**\n\n"
f"Server: `{url}`\n"
f"Click the link below to reauthorize:\n{auth_url}\n\n"
f"⚠️ Keep this link private!\n"
f"💡 You'll be redirected to login and grant access to the MCP server again."
)
async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
"""List tools available on an MCP server."""
with make_session() as session:
mcp_server = find_mcp_server(session, bot_user.id, url)
if not mcp_server:
raise ValueError(
f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n"
f"Use `/memory_mcp_servers add {url}` to add it first."
)
if not cast(str | None, mcp_server.access_token):
raise ValueError(
f"**Not Authorized**\n\n"
f"You haven't authorized access to `{url}` yet.\n"
f"Use `/memory_mcp_servers connect {url}` to authorize."
)
access_token = cast(str, mcp_server.access_token)
# Make JSON-RPC request to MCP server
tools = None
try:
async for data in call_mcp_server(url, access_token, "tools/list"):
if "result" in data and "tools" in data["result"]:
tools = data["result"]["tools"]
break
except aiohttp.ClientError as exc:
logger.exception(f"Failed to connect to MCP server: {exc}")
raise ValueError(
f"**Connection failed**\n\n"
f"Server: `{url}`\n"
f"Could not connect to the MCP server: {str(exc)}"
)
except Exception as exc:
logger.exception(f"Failed to list tools: {exc}")
raise ValueError(
f"**Error**\n\nServer: `{url}`\nFailed to list tools: {str(exc)}"
)
if tools is None:
raise ValueError(
f"**Unexpected response format**\n\n"
f"Server: `{url}`\n"
f"The server returned an unexpected response format."
)
if not tools:
return (
f"🔧 **MCP Server Tools**\n\n"
f"Server: `{url}`\n\n"
f"No tools available on this server."
)
# Format tools list
tools_list = "\n".join(
f"• **{t.get('name', 'unknown')}**: {t.get('description', 'No description')}"
for t in tools
)
return (
f"🔧 **MCP Server Tools**\n\n"
f"Server: `{url}`\n"
f"Found {len(tools)} tool(s):\n\n"
f"{tools_list}"
)
async def run_mcp_server_command(
interaction: discord.Interaction,
bot_user: discord.User | None,
action: Literal["list", "add", "delete", "connect", "tools"],
url: str | None,
) -> None:
"""Handle MCP server management commands."""
if action not in ["list", "add", "delete", "connect", "tools"]:
await interaction.response.send_message("❌ Invalid action", ephemeral=True)
return
if action != "list" and not url:
await interaction.response.send_message(
"❌ URL is required for this action", ephemeral=True
)
return
if not bot_user:
await interaction.response.send_message(
"❌ Bot user is required", ephemeral=True
)
return
try:
if action == "list" or not url:
result = await handle_mcp_list(interaction)
elif action == "add":
result = await handle_mcp_add(interaction, bot_user, url)
elif action == "delete":
result = await handle_mcp_delete(bot_user, url)
elif action == "connect":
result = await handle_mcp_connect(bot_user, url)
elif action == "tools":
result = await handle_mcp_tools(bot_user, url)
except Exception as exc:
result = f"❌ Error: {exc}"
await interaction.response.send_message(result, ephemeral=True)

View File

@ -13,8 +13,8 @@ from memory.common.db.models import (
DiscordUser,
ScheduledLLMCall,
)
from memory.common.db.models.users import BotUser
from memory.common.llms.base import create_provider
from memory.common.llms.tools import MCPServer
logger = logging.getLogger(__name__)
@ -141,8 +141,10 @@ def previous_messages(
) -> list[DiscordMessage]:
messages = session.query(DiscordMessage)
if user_id:
print(f"user_id: {user_id}")
messages = messages.filter(DiscordMessage.recipient_id == user_id)
if channel_id:
print(f"channel_id: {channel_id}")
messages = messages.filter(DiscordMessage.channel_id == channel_id)
return list(
reversed(
@ -214,6 +216,7 @@ def call_llm(
system_prompt: str = "",
messages: list[str | dict[str, Any]] = [],
allowed_tools: Collection[str] | None = None,
mcp_servers: list[MCPServer] | None = None,
num_previous_messages: int = 10,
) -> str | None:
"""
@ -266,7 +269,8 @@ def call_llm(
return provider.run_with_tools(
messages=provider.as_messages(message_content),
tools=tools,
system_prompt=system_prompt,
system_prompt=bot_user.system_prompt + "\n\n" + system_prompt,
mcp_servers=mcp_servers,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response

View File

@ -10,6 +10,7 @@ import textwrap
from datetime import datetime
from typing import Any, cast
from memory.common.llms.tools import MCPServer
import requests
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy.orm import Session, scoped_session
@ -189,6 +190,20 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id,
}
mcp_servers = None
if (
discord_message.recipient_user
and discord_message.recipient_user.mcp_servers
):
mcp_servers = [
MCPServer(
name=server.mcp_server_url,
url=server.mcp_server_url,
token=server.access_token,
)
for server in discord_message.recipient_user.mcp_servers
]
try:
response = call_llm(
session,
@ -196,6 +211,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
from_user=discord_message.from_user,
channel=discord_message.channel,
model=settings.DISCORD_MODEL,
mcp_servers=mcp_servers,
system_prompt=discord_message.system_prompt,
)
except Exception: