multiple mcp servers

This commit is contained in:
mruwnik 2025-11-03 16:41:26 +00:00
parent 2d3dc06fdf
commit 8893018af1
15 changed files with 308 additions and 181 deletions

View File

@ -1,5 +1,6 @@
# Agent Guidance # Agent Guidance
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary. - Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
- Treat LLM model identifiers as `<provider>/<model_name>` strings throughout the codebase.
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved. - Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
- Tests should be written with @pytest.mark.parametrize where applicable and should avoid test classes
- Make sure linting errors get fixed

View File

@ -1,67 +0,0 @@
"""discord mcp servers
Revision ID: 9b887449ea92
Revises: 1954477b25f4
Create Date: 2025-11-02 22:04:26.259323
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "9b887449ea92"
down_revision: Union[str, None] = "1954477b25f4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"discord_mcp_servers",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
sa.Column("mcp_server_url", sa.Text(), nullable=False),
sa.Column("client_id", sa.Text(), nullable=False),
sa.Column("state", sa.Text(), nullable=True),
sa.Column("code_verifier", sa.Text(), nullable=True),
sa.Column("access_token", sa.Text(), nullable=True),
sa.Column("refresh_token", sa.Text(), nullable=True),
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(
["discord_bot_user_id"],
["discord_users.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("state"),
)
op.create_index(
"discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
)
op.create_index(
"discord_mcp_user_url_idx",
"discord_mcp_servers",
["discord_bot_user_id", "mcp_server_url"],
unique=False,
)
def downgrade() -> None:
op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
op.drop_table("discord_mcp_servers")

View File

@ -0,0 +1,103 @@
"""mcp servers
Revision ID: 89861d5f1102
Revises: 1954477b25f4
Create Date: 2025-11-03 15:41:26.254854
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "89861d5f1102"
down_revision: Union[str, None] = "1954477b25f4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"mcp_servers",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("mcp_server_url", sa.Text(), nullable=False),
sa.Column("client_id", sa.Text(), nullable=False),
sa.Column(
"available_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
),
sa.Column("state", sa.Text(), nullable=True),
sa.Column("code_verifier", sa.Text(), nullable=True),
sa.Column("access_token", sa.Text(), nullable=True),
sa.Column("refresh_token", sa.Text(), nullable=True),
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("state"),
)
op.create_index("mcp_state_idx", "mcp_servers", ["state"], unique=False)
op.create_table(
"mcp_server_assignments",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("mcp_server_id", sa.Integer(), nullable=False),
sa.Column("entity_type", sa.Text(), nullable=False),
sa.Column("entity_id", sa.BigInteger(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(
["mcp_server_id"],
["mcp_servers.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"mcp_assignment_entity_idx",
"mcp_server_assignments",
["entity_type", "entity_id"],
unique=False,
)
op.create_index(
"mcp_assignment_server_idx",
"mcp_server_assignments",
["mcp_server_id"],
unique=False,
)
op.create_index(
"mcp_assignment_unique_idx",
"mcp_server_assignments",
["mcp_server_id", "entity_type", "entity_id"],
unique=True,
)
def downgrade() -> None:
op.drop_index("mcp_assignment_unique_idx", table_name="mcp_server_assignments")
op.drop_index("mcp_assignment_server_idx", table_name="mcp_server_assignments")
op.drop_index("mcp_assignment_entity_idx", table_name="mcp_server_assignments")
op.drop_table("mcp_server_assignments")
op.drop_index("mcp_state_idx", table_name="mcp_servers")
op.drop_table("mcp_servers")

View File

@ -14,7 +14,7 @@ from memory.common.db.models import (
BookSection, BookSection,
Chunk, Chunk,
Comic, Comic,
DiscordMCPServer, MCPServer,
DiscordMessage, DiscordMessage,
EmailAccount, EmailAccount,
EmailAttachment, EmailAttachment,
@ -167,17 +167,17 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage):
column_sortable_list = ["sent_at"] column_sortable_list = ["sent_at"]
class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer): class MCPServerAdmin(ModelView, model=MCPServer):
column_list = [ column_list = [
"id", "id",
"mcp_server_url", "mcp_server_url",
"client_id", "client_id",
"discord_bot_user_id",
"state", "state",
"code_verifier", "code_verifier",
"access_token", "access_token",
"refresh_token", "refresh_token",
"token_expires_at", "token_expires_at",
"available_tools",
"created_at", "created_at",
"updated_at", "updated_at",
] ]
@ -186,7 +186,6 @@ class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
"client_id", "client_id",
"state", "state",
"id", "id",
"discord_bot_user_id",
] ]
column_sortable_list = [ column_sortable_list = [
"created_at", "created_at",
@ -360,5 +359,5 @@ def setup_admin(admin: Admin):
admin.add_view(DiscordUserAdmin) admin.add_view(DiscordUserAdmin)
admin.add_view(DiscordServerAdmin) admin.add_view(DiscordServerAdmin)
admin.add_view(DiscordChannelAdmin) admin.add_view(DiscordChannelAdmin)
admin.add_view(DiscordMCPServerAdmin) admin.add_view(MCPServerAdmin)
admin.add_view(ScheduledLLMCallAdmin) admin.add_view(ScheduledLLMCallAdmin)

View File

@ -10,7 +10,7 @@ from memory.common import settings
from memory.common.db.connection import get_session, make_session from memory.common.db.connection import get_session, make_session
from memory.common.db.models import ( from memory.common.db.models import (
BotUser, BotUser,
DiscordMCPServer, MCPServer,
HumanUser, HumanUser,
User, User,
UserSession, UserSession,
@ -169,9 +169,7 @@ async def oauth_callback_discord(request: Request):
# Complete the OAuth flow (exchange code for token) # Complete the OAuth flow (exchange code for token)
with make_session() as session: with make_session() as session:
mcp_server = ( mcp_server = (
session.query(DiscordMCPServer) session.query(MCPServer).filter(MCPServer.state == state).first()
.filter(DiscordMCPServer.state == state)
.first()
) )
status_code, message = await complete_oauth_flow(mcp_server, code, state) status_code, message = await complete_oauth_flow(mcp_server, code, state)
session.commit() session.commit()

View File

@ -34,7 +34,8 @@ from memory.common.db.models.discord import (
DiscordServer, DiscordServer,
DiscordChannel, DiscordChannel,
DiscordUser, DiscordUser,
DiscordMCPServer, MCPServer,
MCPServerAssignment,
) )
from memory.common.db.models.observations import ( from memory.common.db.models.observations import (
ObservationContradiction, ObservationContradiction,
@ -107,7 +108,8 @@ __all__ = [
"DiscordServer", "DiscordServer",
"DiscordChannel", "DiscordChannel",
"DiscordUser", "DiscordUser",
"DiscordMCPServer", "MCPServer",
"MCPServerAssignment",
# Users # Users
"User", "User",
"HumanUser", "HumanUser",

View File

@ -127,26 +127,22 @@ class DiscordUser(Base, MessageProcessor):
updated_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now())
system_user = relationship("User", back_populates="discord_users") system_user = relationship("User", back_populates="discord_users")
mcp_servers = relationship(
"DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
)
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),) __table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
class DiscordMCPServer(Base): class MCPServer(Base):
"""MCP server configuration and OAuth state for Discord users.""" """MCP server configuration and OAuth state."""
__tablename__ = "discord_mcp_servers" __tablename__ = "mcp_servers"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
discord_bot_user_id = Column(
BigInteger, ForeignKey("discord_users.id"), nullable=False
)
# MCP server info # MCP server info
name = Column(Text, nullable=False)
mcp_server_url = Column(Text, nullable=False) mcp_server_url = Column(Text, nullable=False)
client_id = Column(Text, nullable=False) client_id = Column(Text, nullable=False)
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
# OAuth flow state (temporary, cleared after token exchange) # OAuth flow state (temporary, cleared after token exchange)
state = Column(Text, nullable=True, unique=True) state = Column(Text, nullable=True, unique=True)
@ -162,9 +158,42 @@ class DiscordMCPServer(Base):
updated_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships # Relationships
discord_user = relationship("DiscordUser", back_populates="mcp_servers") assignments = relationship(
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
)
__table_args__ = (Index("mcp_state_idx", "state"),)
class MCPServerAssignment(Base):
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
__tablename__ = "mcp_server_assignments"
id = Column(Integer, primary_key=True)
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
# Polymorphic entity reference
entity_type = Column(
Text, nullable=False
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
entity_id = Column(BigInteger, nullable=False)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
mcp_server = relationship("MCPServer", back_populates="assignments")
__table_args__ = ( __table_args__ = (
Index("discord_mcp_state_idx", "state"), Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"), Index("mcp_assignment_server_idx", "mcp_server_id"),
Index(
"mcp_assignment_unique_idx",
"mcp_server_id",
"entity_type",
"entity_id",
unique=True,
),
) )

View File

@ -69,7 +69,6 @@ def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
print("Result", result)
return result.get("success", False) return result.get("success", False)
except requests.RequestException as e: except requests.RequestException as e:

View File

@ -55,7 +55,6 @@ CUSTOM_EXTENSIONS = {
def get_mime_type(path: pathlib.Path) -> str: def get_mime_type(path: pathlib.Path) -> str:
mime_type, _ = mimetypes.guess_type(str(path)) mime_type, _ = mimetypes.guess_type(str(path))
if mime_type: if mime_type:
print(f"mime_type: {mime_type}")
return mime_type return mime_type
ext = path.suffix.lower() ext = path.suffix.lower()
return CUSTOM_EXTENSIONS.get(ext, "application/octet-stream") return CUSTOM_EXTENSIONS.get(ext, "application/octet-stream")

View File

@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
import aiohttp import aiohttp
from memory.common import settings from memory.common import settings
from memory.common.db.models.discord import DiscordMCPServer from memory.common.db.models.discord import MCPServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -148,7 +148,7 @@ async def register_oauth_client(
async def issue_challenge( async def issue_challenge(
mcp_server: DiscordMCPServer, mcp_server: MCPServer,
endpoints: OAuthEndpoints, endpoints: OAuthEndpoints,
) -> str: ) -> str:
"""Generate OAuth challenge and store state in mcp_server object.""" """Generate OAuth challenge and store state in mcp_server object."""
@ -160,7 +160,7 @@ async def issue_challenge(
mcp_server.code_verifier = code_verifier # type: ignore mcp_server.code_verifier = code_verifier # type: ignore
logger.info( logger.info(
f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: " f"Generated OAuth state for MCP server {mcp_server.mcp_server_url}: "
f"state={state[:20]}..., verifier={code_verifier[:20]}..." f"state={state[:20]}..., verifier={code_verifier[:20]}..."
) )
@ -179,7 +179,7 @@ async def issue_challenge(
async def complete_oauth_flow( async def complete_oauth_flow(
mcp_server: DiscordMCPServer, code: str, state: str mcp_server: MCPServer, code: str, state: str
) -> tuple[int, str]: ) -> tuple[int, str]:
"""Complete OAuth flow by exchanging code for token. """Complete OAuth flow by exchanging code for token.
@ -196,7 +196,7 @@ async def complete_oauth_flow(
return 400, "Invalid or expired OAuth state" return 400, "Invalid or expired OAuth state"
logger.info( logger.info(
f"Found MCP server config: user={mcp_server.discord_bot_user_id}, " f"Found MCP server config: id={mcp_server.id}, "
f"url={mcp_server.mcp_server_url}" f"url={mcp_server.mcp_server_url}"
) )
@ -247,8 +247,8 @@ async def complete_oauth_flow(
mcp_server.code_verifier = None # type: ignore mcp_server.code_verifier = None # type: ignore
logger.info( logger.info(
f"Stored tokens for user {mcp_server.discord_bot_user_id}, " f"Stored tokens for MCP server id={mcp_server.id}, "
f"server {mcp_server.mcp_server_url}" f"url={mcp_server.mcp_server_url}"
) )
return 200, "✅ Authorization successful! You can now use this MCP server." return 200, "✅ Authorization successful! You can now use this MCP server."

View File

@ -12,14 +12,12 @@ from contextlib import asynccontextmanager
from typing import cast from typing import cast
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel from pydantic import BaseModel
from memory.common import settings from memory.common import settings
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import DiscordMCPServer, DiscordBotUser from memory.common.db.models import DiscordBotUser
from memory.common.oauth import complete_oauth_flow
from memory.discord.collector import MessageCollector from memory.discord.collector import MessageCollector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,6 +1,5 @@
"""Lightweight slash-command helpers for the Discord collector.""" """Lightweight slash-command helpers for the Discord collector."""
from calendar import c
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Literal from typing import Callable, Literal
@ -14,7 +13,7 @@ from memory.discord.mcp import run_mcp_server_command
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ScopeLiteral = Literal["server", "channel", "user"] ScopeLiteral = Literal["bot", "server", "channel", "user"]
class CommandError(Exception): class CommandError(Exception):
@ -173,18 +172,29 @@ def register_slash_commands(bot: discord.Client) -> None:
@tree.command( @tree.command(
name=f"{name}_mcp_servers", name=f"{name}_mcp_servers",
description="Manage MCP servers for your account", description="Manage MCP servers for a scope",
) )
@discord.app_commands.describe( @discord.app_commands.describe(
scope="Which configuration to modify (server, channel, or user)",
action="Action to perform", action="Action to perform",
url="MCP server URL (required for add, delete, connect, tools)", url="MCP server URL (required for add, delete, connect, tools)",
user="Target user when the scope is 'user'",
) )
async def mcp_servers_command( async def mcp_servers_command(
interaction: discord.Interaction, interaction: discord.Interaction,
scope: ScopeLiteral,
action: Literal["list", "add", "delete", "connect", "tools"] = "list", action: Literal["list", "add", "delete", "connect", "tools"] = "list",
url: str | None = None, url: str | None = None,
user: discord.User | None = None,
) -> None: ) -> None:
await run_mcp_server_command(interaction, bot.user, action, url and url.strip()) await _run_interaction_command(
interaction,
scope=scope,
handler=handle_mcp_servers,
target_user=user,
action=action,
url=url and url.strip(),
)
async def _run_interaction_command( async def _run_interaction_command(
@ -199,7 +209,7 @@ async def _run_interaction_command(
try: try:
with make_session() as session: with make_session() as session:
context = _build_context(session, interaction, scope, target_user) context = _build_context(session, interaction, scope, target_user)
response = handler(context, **handler_kwargs) response = await handler(context, **handler_kwargs)
session.commit() session.commit()
except CommandError as exc: # pragma: no cover - passthrough except CommandError as exc: # pragma: no cover - passthrough
await interaction.response.send_message(str(exc), ephemeral=True) await interaction.response.send_message(str(exc), ephemeral=True)
@ -435,3 +445,29 @@ def handle_summary(context: CommandContext) -> CommandResponse:
return CommandResponse( return CommandResponse(
content=f"No summary stored for {context.display_name}.", content=f"No summary stored for {context.display_name}.",
) )
async def handle_mcp_servers(
context: CommandContext,
*,
action: Literal["list", "add", "delete", "connect", "tools"],
url: str | None,
) -> CommandResponse:
"""Handle MCP server commands for a specific scope."""
entity_type_map = {
"server": "DiscordServer",
"channel": "DiscordChannel",
"user": "DiscordUser",
}
entity_type = entity_type_map[context.scope]
entity_id = context.target.id
try:
res = await run_mcp_server_command(
context.interaction.user, action, url, entity_type, entity_id
)
return CommandResponse(content=res)
except Exception as exc:
import traceback
logger.error(f"Error running MCP server command: {traceback.format_exc()}")
raise CommandError(f"Error: {exc}") from exc

View File

@ -10,23 +10,27 @@ import discord
from sqlalchemy.orm import Session, scoped_session from sqlalchemy.orm import Session, scoped_session
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models.discord import DiscordMCPServer from memory.common.db.models.discord import MCPServer, MCPServerAssignment
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def find_mcp_server( def find_mcp_server(
session: Session | scoped_session, user_id: int, url: str session: Session | scoped_session, entity_type: str, entity_id: int, url: str
) -> DiscordMCPServer | None: ) -> MCPServer | None:
return ( """Find an MCP server assigned to an entity."""
session.query(DiscordMCPServer) assignment = (
session.query(MCPServerAssignment)
.join(MCPServer)
.filter( .filter(
DiscordMCPServer.discord_bot_user_id == user_id, MCPServerAssignment.entity_type == entity_type,
DiscordMCPServer.mcp_server_url == url, MCPServerAssignment.entity_id == entity_id,
MCPServer.mcp_server_url == url,
) )
.first() .first()
) )
return assignment and assignment.mcp_server
async def call_mcp_server( async def call_mcp_server(
@ -72,35 +76,39 @@ async def call_mcp_server(
continue # Skip invalid JSON lines continue # Skip invalid JSON lines
async def handle_mcp_list(interaction: discord.Interaction) -> str: async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
"""List all MCP servers for the user.""" """List all MCP servers for the user."""
with make_session() as session: with make_session() as session:
servers = ( assignments = (
session.query(DiscordMCPServer) session.query(MCPServerAssignment)
.join(MCPServer)
.filter( .filter(
DiscordMCPServer.discord_bot_user_id == interaction.user.id, MCPServerAssignment.entity_type == entity_type,
MCPServerAssignment.entity_id == entity_id,
) )
.all() .all()
) )
if not servers: if not assignments:
return ( return (
"📋 **Your MCP Servers**\n\n" "📋 **Your MCP Servers**\n\n"
"You don't have any MCP servers configured yet.\n" "You don't have any MCP servers configured yet.\n"
"Use `/memory_mcp_servers add <url>` to add one." "Use `/memory_mcp_servers add <url>` to add one."
) )
def format_server(server: DiscordMCPServer) -> str: def format_server(assignment: MCPServerAssignment) -> str:
server = assignment.mcp_server
con = "🟢" if cast(str | None, server.access_token) else "🔴" con = "🟢" if cast(str | None, server.access_token) else "🔴"
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`" return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
server_list = "\n".join(format_server(s) for s in servers) server_list = "\n".join(format_server(a) for a in assignments)
return f"📋 **Your MCP Servers**\n\n{server_list}" return f"📋 **Your MCP Servers**\n\n{server_list}"
async def handle_mcp_add( async def handle_mcp_add(
interaction: discord.Interaction, entity_type: str,
entity_id: int,
bot_user: discord.User | None, bot_user: discord.User | None,
url: str, url: str,
) -> str: ) -> str:
@ -108,7 +116,7 @@ async def handle_mcp_add(
if not bot_user: if not bot_user:
raise ValueError("Bot user is required") raise ValueError("Bot user is required")
with make_session() as session: with make_session() as session:
if find_mcp_server(session, bot_user.id, url): if find_mcp_server(session, entity_type, entity_id, url):
return ( return (
f"**MCP Server Already Exists**\n\n" f"**MCP Server Already Exists**\n\n"
f"You already have an MCP server configured at `{url}`.\n" f"You already have an MCP server configured at `{url}`.\n"
@ -116,25 +124,32 @@ async def handle_mcp_add(
) )
endpoints = await get_endpoints(url) endpoints = await get_endpoints(url)
client_id = await register_oauth_client( name = f"Discord Bot - {bot_user.name} ({entity_type} {entity_id})"
endpoints, client_id = await register_oauth_client(endpoints, url, name)
url,
f"Discord Bot - {bot_user.name} ({interaction.user.name})", # Create MCP server
) mcp_server = MCPServer(
mcp_server = DiscordMCPServer(
discord_bot_user_id=bot_user.id,
mcp_server_url=url, mcp_server_url=url,
client_id=client_id, client_id=client_id,
name=name,
) )
session.add(mcp_server) session.add(mcp_server)
session.flush() session.flush()
assignment = MCPServerAssignment(
mcp_server_id=mcp_server.id,
entity_type=entity_type,
entity_id=entity_id,
)
session.add(assignment)
session.flush()
auth_url = await issue_challenge(mcp_server, endpoints) auth_url = await issue_challenge(mcp_server, endpoints)
session.commit() session.commit()
logger.info( logger.info(
f"Created MCP server record: id={mcp_server.id}, " f"Created MCP server record: id={mcp_server.id}, "
f"user={interaction.user.id}, url={url}" f"{entity_type}={entity_id}, url={url}"
) )
return ( return (
@ -146,32 +161,54 @@ async def handle_mcp_add(
) )
async def handle_mcp_delete(bot_user: discord.User, url: str) -> str: async def handle_mcp_delete(entity_type: str, entity_id: int, url: str) -> str:
"""Delete an MCP server.""" """Delete an MCP server assignment."""
with make_session() as session: with make_session() as session:
mcp_server = find_mcp_server(session, bot_user.id, url) # Find the assignment
if not mcp_server: assignment = (
session.query(MCPServerAssignment)
.join(MCPServer)
.filter(
MCPServerAssignment.entity_type == entity_type,
MCPServerAssignment.entity_id == entity_id,
MCPServer.mcp_server_url == url,
)
.first()
)
if not assignment:
return ( return (
f"**MCP Server Not Found**\n\n" f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n" f"You don't have an MCP server configured at `{url}`.\n"
) )
session.delete(mcp_server)
# Delete the assignment (server will cascade delete if no other assignments exist)
session.delete(assignment)
# Check if server has other assignments
mcp_server = assignment.mcp_server
other_assignments = (
session.query(MCPServerAssignment)
.filter(
MCPServerAssignment.mcp_server_id == mcp_server.id,
MCPServerAssignment.id != assignment.id,
)
.count()
)
# If no other assignments, delete the server too
if other_assignments == 0:
session.delete(mcp_server)
session.commit() session.commit()
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed." return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
async def handle_mcp_connect(bot_user: discord.User, url: str) -> str: async def handle_mcp_connect(entity_type: str, entity_id: int, url: str) -> str:
"""Reconnect to an existing MCP server (redo OAuth).""" """Reconnect to an existing MCP server (redo OAuth)."""
with make_session() as session: with make_session() as session:
mcp_server = find_mcp_server(session, bot_user.id, url) mcp_server = find_mcp_server(session, entity_type, entity_id, url)
if not mcp_server:
raise ValueError(
f"**MCP Server Not Found**\n\n"
f"You don't have an MCP server configured at `{url}`.\n"
f"Use `/memory_mcp_servers add {url}` to add it first."
)
if not mcp_server: if not mcp_server:
raise ValueError( raise ValueError(
f"**MCP Server Not Found**\n\n" f"**MCP Server Not Found**\n\n"
@ -184,7 +221,9 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
session.commit() session.commit()
logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}") logger.info(
f"Regenerated OAuth challenge for {entity_type}={entity_id}, url={url}"
)
return ( return (
f"🔄 **Reconnect to MCP Server**\n\n" f"🔄 **Reconnect to MCP Server**\n\n"
@ -195,10 +234,10 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
) )
async def handle_mcp_tools(bot_user: discord.User, url: str) -> str: async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
"""List tools available on an MCP server.""" """List tools available on an MCP server."""
with make_session() as session: with make_session() as session:
mcp_server = find_mcp_server(session, bot_user.id, url) mcp_server = find_mcp_server(session, entity_type, entity_id, url)
if not mcp_server: if not mcp_server:
raise ValueError( raise ValueError(
@ -265,37 +304,28 @@ async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
async def run_mcp_server_command( async def run_mcp_server_command(
interaction: discord.Interaction,
bot_user: discord.User | None, bot_user: discord.User | None,
action: Literal["list", "add", "delete", "connect", "tools"], action: Literal["list", "add", "delete", "connect", "tools"],
url: str | None, url: str | None,
entity_type: str,
entity_id: int,
) -> None: ) -> None:
"""Handle MCP server management commands.""" """Handle MCP server management commands."""
if action not in ["list", "add", "delete", "connect", "tools"]: if action not in ["list", "add", "delete", "connect", "tools"]:
await interaction.response.send_message("❌ Invalid action", ephemeral=True) raise ValueError(f"Invalid action: {action}")
return
if action != "list" and not url: if action != "list" and not url:
await interaction.response.send_message( raise ValueError("URL is required for this action")
"❌ URL is required for this action", ephemeral=True
)
return
if not bot_user: if not bot_user:
await interaction.response.send_message( raise ValueError("Bot user is required")
"❌ Bot user is required", ephemeral=True
)
return
try: if action == "list" or not url:
if action == "list" or not url: return await handle_mcp_list(entity_type, entity_id)
result = await handle_mcp_list(interaction) elif action == "add":
elif action == "add": return await handle_mcp_add(entity_type, entity_id, bot_user, url)
result = await handle_mcp_add(interaction, bot_user, url) elif action == "delete":
elif action == "delete": return await handle_mcp_delete(entity_type, entity_id, url)
result = await handle_mcp_delete(bot_user, url) elif action == "connect":
elif action == "connect": return await handle_mcp_connect(entity_type, entity_id, url)
result = await handle_mcp_connect(bot_user, url) elif action == "tools":
elif action == "tools": return await handle_mcp_tools(entity_type, entity_id, url)
result = await handle_mcp_tools(bot_user, url) raise ValueError(f"Invalid action: {action}")
except Exception as exc:
result = f"❌ Error: {exc}"
await interaction.response.send_message(result, ephemeral=True)

View File

@ -113,8 +113,6 @@ def upsert_scheduled_message(
.first() .first()
) )
naive_scheduled_time = scheduled_time.replace(tzinfo=None) naive_scheduled_time = scheduled_time.replace(tzinfo=None)
print(f"naive_scheduled_time: {naive_scheduled_time}")
print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}")
if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time: if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
prev_call.status = "cancelled" # type: ignore prev_call.status = "cancelled" # type: ignore
@ -141,10 +139,8 @@ def previous_messages(
) -> list[DiscordMessage]: ) -> list[DiscordMessage]:
messages = session.query(DiscordMessage) messages = session.query(DiscordMessage)
if user_id: if user_id:
print(f"user_id: {user_id}")
messages = messages.filter(DiscordMessage.recipient_id == user_id) messages = messages.filter(DiscordMessage.recipient_id == user_id)
if channel_id: if channel_id:
print(f"channel_id: {channel_id}")
messages = messages.filter(DiscordMessage.channel_id == channel_id) messages = messages.filter(DiscordMessage.channel_id == channel_id)
return list( return list(
reversed( reversed(
@ -294,10 +290,8 @@ def send_discord_response(
True if sent successfully True if sent successfully
""" """
if channel_id is not None: if channel_id is not None:
logger.info(f"Sending message to channel {channel_id}")
return discord.send_to_channel(bot_id, channel_id, response) return discord.send_to_channel(bot_id, channel_id, response)
elif user_identifier is not None: elif user_identifier is not None:
logger.info(f"Sending DM to {user_identifier}")
return discord.send_dm(bot_id, user_identifier, response) return discord.send_dm(bot_id, user_identifier, response)
else: else:
logger.error("Neither channel_id nor user_identifier provided") logger.error("Neither channel_id nor user_identifier provided")

View File

@ -91,6 +91,10 @@ def should_process(message: DiscordMessage) -> bool:
): ):
return False return False
if f"<@{message.recipient_user.id}>" in message.content:
logger.info("Direct mention of the bot, processing message")
return True
if message.from_user == message.recipient_user: if message.from_user == message.recipient_user:
logger.info("Skipping message because from_user == recipient_user") logger.info("Skipping message because from_user == recipient_user")
return False return False
@ -132,6 +136,8 @@ def should_process(message: DiscordMessage) -> bool:
if not (res := re.search(r"<number>(.*)</number>", response)): if not (res := re.search(r"<number>(.*)</number>", response)):
return False return False
try: try:
logger.info(f"chattiness_threshold: {message.chattiness_threshold}")
logger.info(f"haiku desire: {res.group(1)}")
if int(res.group(1)) < 100 - message.chattiness_threshold: if int(res.group(1)) < 100 - message.chattiness_threshold:
return False return False
except ValueError: except ValueError: