From 8893018af1aeb40b7ca284fa8d03a31ba1d73722 Mon Sep 17 00:00:00 2001 From: mruwnik Date: Mon, 3 Nov 2025 16:41:26 +0000 Subject: [PATCH] multiple mcp servers --- AGENTS.md | 3 +- .../20251102_220426_discord_mcp_servers.py | 67 -------- .../versions/20251103_154126_mcp_servers.py | 103 +++++++++++ src/memory/api/admin.py | 9 +- src/memory/api/auth.py | 6 +- src/memory/common/db/models/__init__.py | 6 +- src/memory/common/db/models/discord.py | 53 ++++-- src/memory/common/discord.py | 1 - src/memory/common/extract.py | 1 - src/memory/common/oauth.py | 14 +- src/memory/discord/api.py | 6 +- src/memory/discord/commands.py | 46 ++++- src/memory/discord/mcp.py | 162 +++++++++++------- src/memory/discord/messages.py | 6 - src/memory/workers/tasks/discord.py | 6 + 15 files changed, 308 insertions(+), 181 deletions(-) delete mode 100644 db/migrations/versions/20251102_220426_discord_mcp_servers.py create mode 100644 db/migrations/versions/20251103_154126_mcp_servers.py diff --git a/AGENTS.md b/AGENTS.md index 00c6668..2df0bbf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,5 +1,6 @@ # Agent Guidance - Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary. -- Treat LLM model identifiers as `/` strings throughout the codebase. - Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved. +- Tests should be written with @pytest.mark.parametrize where applicable and should avoid test classes +- Make sure linting errors get fixed diff --git a/db/migrations/versions/20251102_220426_discord_mcp_servers.py b/db/migrations/versions/20251102_220426_discord_mcp_servers.py deleted file mode 100644 index deb35d5..0000000 --- a/db/migrations/versions/20251102_220426_discord_mcp_servers.py +++ /dev/null @@ -1,67 +0,0 @@ -"""discord mcp servers - -Revision ID: 9b887449ea92 -Revises: 1954477b25f4 -Create Date: 2025-11-02 22:04:26.259323 - -""" - -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = "9b887449ea92" -down_revision: Union[str, None] = "1954477b25f4" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - op.create_table( - "discord_mcp_servers", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False), - sa.Column("mcp_server_url", sa.Text(), nullable=False), - sa.Column("client_id", sa.Text(), nullable=False), - sa.Column("state", sa.Text(), nullable=True), - sa.Column("code_verifier", sa.Text(), nullable=True), - sa.Column("access_token", sa.Text(), nullable=True), - sa.Column("refresh_token", sa.Text(), nullable=True), - sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.text("now()"), - nullable=True, - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - server_default=sa.text("now()"), - nullable=True, - ), - sa.ForeignKeyConstraint( - ["discord_bot_user_id"], - ["discord_users.id"], - ), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("state"), - ) - op.create_index( - "discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False - ) - op.create_index( - "discord_mcp_user_url_idx", - "discord_mcp_servers", - ["discord_bot_user_id", "mcp_server_url"], - unique=False, - ) - - -def downgrade() -> None: - op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers") - op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers") - op.drop_table("discord_mcp_servers") diff --git a/db/migrations/versions/20251103_154126_mcp_servers.py b/db/migrations/versions/20251103_154126_mcp_servers.py new file mode 100644 index 0000000..dd009c3 --- /dev/null +++ b/db/migrations/versions/20251103_154126_mcp_servers.py @@ -0,0 +1,103 @@ +"""mcp servers + +Revision ID: 89861d5f1102 +Revises: 1954477b25f4 +Create Date: 2025-11-03 15:41:26.254854 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "89861d5f1102" +down_revision: Union[str, None] = "1954477b25f4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "mcp_servers", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("mcp_server_url", sa.Text(), nullable=False), + sa.Column("client_id", sa.Text(), nullable=False), + sa.Column( + "available_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False + ), + sa.Column("state", sa.Text(), nullable=True), + sa.Column("code_verifier", sa.Text(), nullable=True), + sa.Column("access_token", sa.Text(), nullable=True), + sa.Column("refresh_token", sa.Text(), nullable=True), + sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("state"), + ) + op.create_index("mcp_state_idx", "mcp_servers", ["state"], unique=False) + op.create_table( + "mcp_server_assignments", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("mcp_server_id", sa.Integer(), nullable=False), + sa.Column("entity_type", sa.Text(), nullable=False), + sa.Column("entity_id", sa.BigInteger(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["mcp_server_id"], + ["mcp_servers.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "mcp_assignment_entity_idx", + "mcp_server_assignments", + ["entity_type", "entity_id"], + unique=False, + ) + op.create_index( + "mcp_assignment_server_idx", + "mcp_server_assignments", + ["mcp_server_id"], + unique=False, + ) + op.create_index( + "mcp_assignment_unique_idx", + "mcp_server_assignments", + ["mcp_server_id", "entity_type", "entity_id"], + unique=True, + ) + + +def downgrade() -> None: + op.drop_index("mcp_assignment_unique_idx", table_name="mcp_server_assignments") + op.drop_index("mcp_assignment_server_idx", table_name="mcp_server_assignments") + op.drop_index("mcp_assignment_entity_idx", table_name="mcp_server_assignments") + op.drop_table("mcp_server_assignments") + op.drop_index("mcp_state_idx", table_name="mcp_servers") + op.drop_table("mcp_servers") diff --git a/src/memory/api/admin.py b/src/memory/api/admin.py index 6e76412..b7c7cf2 100644 --- a/src/memory/api/admin.py +++ b/src/memory/api/admin.py @@ -14,7 +14,7 @@ from memory.common.db.models import ( BookSection, Chunk, Comic, - DiscordMCPServer, + MCPServer, DiscordMessage, EmailAccount, EmailAttachment, @@ -167,17 +167,17 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage): column_sortable_list = ["sent_at"] -class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer): +class MCPServerAdmin(ModelView, model=MCPServer): column_list = [ "id", "mcp_server_url", "client_id", - "discord_bot_user_id", "state", "code_verifier", "access_token", "refresh_token", "token_expires_at", + "available_tools", "created_at", "updated_at", ] @@ -186,7 +186,6 @@ class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer): "client_id", "state", "id", - "discord_bot_user_id", ] column_sortable_list = [ "created_at", @@ -360,5 +359,5 @@ def setup_admin(admin: Admin): admin.add_view(DiscordUserAdmin) admin.add_view(DiscordServerAdmin) admin.add_view(DiscordChannelAdmin) - admin.add_view(DiscordMCPServerAdmin) + admin.add_view(MCPServerAdmin) admin.add_view(ScheduledLLMCallAdmin) diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py index adafcfc..0567762 100644 --- a/src/memory/api/auth.py +++ b/src/memory/api/auth.py @@ -10,7 +10,7 @@ from memory.common import settings from memory.common.db.connection import get_session, make_session from memory.common.db.models import ( BotUser, - DiscordMCPServer, + MCPServer, HumanUser, User, UserSession, @@ -169,9 +169,7 @@ async def oauth_callback_discord(request: Request): # Complete the OAuth flow (exchange code for token) with make_session() as session: mcp_server = ( - session.query(DiscordMCPServer) - .filter(DiscordMCPServer.state == state) - .first() + session.query(MCPServer).filter(MCPServer.state == state).first() ) status_code, message = await complete_oauth_flow(mcp_server, code, state) session.commit() diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py index 1198d8c..c2e262b 100644 --- a/src/memory/common/db/models/__init__.py +++ b/src/memory/common/db/models/__init__.py @@ -34,7 +34,8 @@ from memory.common.db.models.discord import ( DiscordServer, DiscordChannel, DiscordUser, - DiscordMCPServer, + MCPServer, + MCPServerAssignment, ) from memory.common.db.models.observations import ( ObservationContradiction, @@ -107,7 +108,8 @@ __all__ = [ "DiscordServer", "DiscordChannel", "DiscordUser", - "DiscordMCPServer", + "MCPServer", + "MCPServerAssignment", # Users "User", "HumanUser", diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py index 638deec..ff4a405 100644 --- a/src/memory/common/db/models/discord.py +++ b/src/memory/common/db/models/discord.py @@ -127,26 +127,22 @@ class DiscordUser(Base, MessageProcessor): updated_at = Column(DateTime(timezone=True), server_default=func.now()) system_user = relationship("User", back_populates="discord_users") - mcp_servers = relationship( - "DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan" - ) __table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),) -class DiscordMCPServer(Base): - """MCP server configuration and OAuth state for Discord users.""" +class MCPServer(Base): + """MCP server configuration and OAuth state.""" - __tablename__ = "discord_mcp_servers" + __tablename__ = "mcp_servers" id = Column(Integer, primary_key=True) - discord_bot_user_id = Column( - BigInteger, ForeignKey("discord_users.id"), nullable=False - ) # MCP server info + name = Column(Text, nullable=False) mcp_server_url = Column(Text, nullable=False) client_id = Column(Text, nullable=False) + available_tools = Column(ARRAY(Text), nullable=False, server_default="{}") # OAuth flow state (temporary, cleared after token exchange) state = Column(Text, nullable=True, unique=True) @@ -162,9 +158,42 @@ class DiscordMCPServer(Base): updated_at = Column(DateTime(timezone=True), server_default=func.now()) # Relationships - discord_user = relationship("DiscordUser", back_populates="mcp_servers") + assignments = relationship( + "MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan" + ) + + __table_args__ = (Index("mcp_state_idx", "state"),) + + +class MCPServerAssignment(Base): + """Assignment of MCP servers to entities (users, channels, servers, etc.).""" + + __tablename__ = "mcp_server_assignments" + + id = Column(Integer, primary_key=True) + mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False) + + # Polymorphic entity reference + entity_type = Column( + Text, nullable=False + ) # "User", "DiscordUser", "DiscordServer", "DiscordChannel" + entity_id = Column(BigInteger, nullable=False) + + # Timestamps + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + # Relationships + mcp_server = relationship("MCPServer", back_populates="assignments") __table_args__ = ( - Index("discord_mcp_state_idx", "state"), - Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"), + Index("mcp_assignment_entity_idx", "entity_type", "entity_id"), + Index("mcp_assignment_server_idx", "mcp_server_id"), + Index( + "mcp_assignment_unique_idx", + "mcp_server_id", + "entity_type", + "entity_id", + unique=True, + ), ) diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py index a324276..567d31e 100644 --- a/src/memory/common/discord.py +++ b/src/memory/common/discord.py @@ -69,7 +69,6 @@ def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool: ) response.raise_for_status() result = response.json() - print("Result", result) return result.get("success", False) except requests.RequestException as e: diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index 92be227..157b96a 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -55,7 +55,6 @@ CUSTOM_EXTENSIONS = { def get_mime_type(path: pathlib.Path) -> str: mime_type, _ = mimetypes.guess_type(str(path)) if mime_type: - print(f"mime_type: {mime_type}") return mime_type ext = path.suffix.lower() return CUSTOM_EXTENSIONS.get(ext, "application/octet-stream") diff --git a/src/memory/common/oauth.py b/src/memory/common/oauth.py index 117b468..82bd111 100644 --- a/src/memory/common/oauth.py +++ b/src/memory/common/oauth.py @@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin import aiohttp from memory.common import settings -from memory.common.db.models.discord import DiscordMCPServer +from memory.common.db.models.discord import MCPServer logger = logging.getLogger(__name__) @@ -148,7 +148,7 @@ async def register_oauth_client( async def issue_challenge( - mcp_server: DiscordMCPServer, + mcp_server: MCPServer, endpoints: OAuthEndpoints, ) -> str: """Generate OAuth challenge and store state in mcp_server object.""" @@ -160,7 +160,7 @@ async def issue_challenge( mcp_server.code_verifier = code_verifier # type: ignore logger.info( - f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: " + f"Generated OAuth state for MCP server {mcp_server.mcp_server_url}: " f"state={state[:20]}..., verifier={code_verifier[:20]}..." ) @@ -179,7 +179,7 @@ async def issue_challenge( async def complete_oauth_flow( - mcp_server: DiscordMCPServer, code: str, state: str + mcp_server: MCPServer, code: str, state: str ) -> tuple[int, str]: """Complete OAuth flow by exchanging code for token. @@ -196,7 +196,7 @@ async def complete_oauth_flow( return 400, "Invalid or expired OAuth state" logger.info( - f"Found MCP server config: user={mcp_server.discord_bot_user_id}, " + f"Found MCP server config: id={mcp_server.id}, " f"url={mcp_server.mcp_server_url}" ) @@ -247,8 +247,8 @@ async def complete_oauth_flow( mcp_server.code_verifier = None # type: ignore logger.info( - f"Stored tokens for user {mcp_server.discord_bot_user_id}, " - f"server {mcp_server.mcp_server_url}" + f"Stored tokens for MCP server id={mcp_server.id}, " + f"url={mcp_server.mcp_server_url}" ) return 200, "✅ Authorization successful! You can now use this MCP server." diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py index 7aa7560..93f8d27 100644 --- a/src/memory/discord/api.py +++ b/src/memory/discord/api.py @@ -12,14 +12,12 @@ from contextlib import asynccontextmanager from typing import cast import uvicorn -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse +from fastapi import FastAPI, HTTPException from pydantic import BaseModel from memory.common import settings from memory.common.db.connection import make_session -from memory.common.db.models import DiscordMCPServer, DiscordBotUser -from memory.common.oauth import complete_oauth_flow +from memory.common.db.models import DiscordBotUser from memory.discord.collector import MessageCollector logger = logging.getLogger(__name__) diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py index 69830ab..fbc0300 100644 --- a/src/memory/discord/commands.py +++ b/src/memory/discord/commands.py @@ -1,6 +1,5 @@ """Lightweight slash-command helpers for the Discord collector.""" -from calendar import c import logging from dataclasses import dataclass from typing import Callable, Literal @@ -14,7 +13,7 @@ from memory.discord.mcp import run_mcp_server_command logger = logging.getLogger(__name__) -ScopeLiteral = Literal["server", "channel", "user"] +ScopeLiteral = Literal["bot", "server", "channel", "user"] class CommandError(Exception): @@ -173,18 +172,29 @@ def register_slash_commands(bot: discord.Client) -> None: @tree.command( name=f"{name}_mcp_servers", - description="Manage MCP servers for your account", + description="Manage MCP servers for a scope", ) @discord.app_commands.describe( + scope="Which configuration to modify (server, channel, or user)", action="Action to perform", url="MCP server URL (required for add, delete, connect, tools)", + user="Target user when the scope is 'user'", ) async def mcp_servers_command( interaction: discord.Interaction, + scope: ScopeLiteral, action: Literal["list", "add", "delete", "connect", "tools"] = "list", url: str | None = None, + user: discord.User | None = None, ) -> None: - await run_mcp_server_command(interaction, bot.user, action, url and url.strip()) + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_mcp_servers, + target_user=user, + action=action, + url=url and url.strip(), + ) async def _run_interaction_command( @@ -199,7 +209,7 @@ async def _run_interaction_command( try: with make_session() as session: context = _build_context(session, interaction, scope, target_user) - response = handler(context, **handler_kwargs) + response = await handler(context, **handler_kwargs) session.commit() except CommandError as exc: # pragma: no cover - passthrough await interaction.response.send_message(str(exc), ephemeral=True) @@ -435,3 +445,29 @@ def handle_summary(context: CommandContext) -> CommandResponse: return CommandResponse( content=f"No summary stored for {context.display_name}.", ) + + +async def handle_mcp_servers( + context: CommandContext, + *, + action: Literal["list", "add", "delete", "connect", "tools"], + url: str | None, +) -> CommandResponse: + """Handle MCP server commands for a specific scope.""" + entity_type_map = { + "server": "DiscordServer", + "channel": "DiscordChannel", + "user": "DiscordUser", + } + entity_type = entity_type_map[context.scope] + entity_id = context.target.id + try: + res = await run_mcp_server_command( + context.interaction.user, action, url, entity_type, entity_id + ) + return CommandResponse(content=res) + except Exception as exc: + import traceback + + logger.error(f"Error running MCP server command: {traceback.format_exc()}") + raise CommandError(f"Error: {exc}") from exc diff --git a/src/memory/discord/mcp.py b/src/memory/discord/mcp.py index 4507f20..a0f42aa 100644 --- a/src/memory/discord/mcp.py +++ b/src/memory/discord/mcp.py @@ -10,23 +10,27 @@ import discord from sqlalchemy.orm import Session, scoped_session from memory.common.db.connection import make_session -from memory.common.db.models.discord import DiscordMCPServer +from memory.common.db.models.discord import MCPServer, MCPServerAssignment from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client logger = logging.getLogger(__name__) def find_mcp_server( - session: Session | scoped_session, user_id: int, url: str -) -> DiscordMCPServer | None: - return ( - session.query(DiscordMCPServer) + session: Session | scoped_session, entity_type: str, entity_id: int, url: str +) -> MCPServer | None: + """Find an MCP server assigned to an entity.""" + assignment = ( + session.query(MCPServerAssignment) + .join(MCPServer) .filter( - DiscordMCPServer.discord_bot_user_id == user_id, - DiscordMCPServer.mcp_server_url == url, + MCPServerAssignment.entity_type == entity_type, + MCPServerAssignment.entity_id == entity_id, + MCPServer.mcp_server_url == url, ) .first() ) + return assignment and assignment.mcp_server async def call_mcp_server( @@ -72,35 +76,39 @@ async def call_mcp_server( continue # Skip invalid JSON lines -async def handle_mcp_list(interaction: discord.Interaction) -> str: +async def handle_mcp_list(entity_type: str, entity_id: int) -> str: """List all MCP servers for the user.""" with make_session() as session: - servers = ( - session.query(DiscordMCPServer) + assignments = ( + session.query(MCPServerAssignment) + .join(MCPServer) .filter( - DiscordMCPServer.discord_bot_user_id == interaction.user.id, + MCPServerAssignment.entity_type == entity_type, + MCPServerAssignment.entity_id == entity_id, ) .all() ) - if not servers: + if not assignments: return ( "📋 **Your MCP Servers**\n\n" "You don't have any MCP servers configured yet.\n" "Use `/memory_mcp_servers add ` to add one." ) - def format_server(server: DiscordMCPServer) -> str: + def format_server(assignment: MCPServerAssignment) -> str: + server = assignment.mcp_server con = "🟢" if cast(str | None, server.access_token) else "🔴" return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`" - server_list = "\n".join(format_server(s) for s in servers) + server_list = "\n".join(format_server(a) for a in assignments) return f"📋 **Your MCP Servers**\n\n{server_list}" async def handle_mcp_add( - interaction: discord.Interaction, + entity_type: str, + entity_id: int, bot_user: discord.User | None, url: str, ) -> str: @@ -108,7 +116,7 @@ async def handle_mcp_add( if not bot_user: raise ValueError("Bot user is required") with make_session() as session: - if find_mcp_server(session, bot_user.id, url): + if find_mcp_server(session, entity_type, entity_id, url): return ( f"**MCP Server Already Exists**\n\n" f"You already have an MCP server configured at `{url}`.\n" @@ -116,25 +124,32 @@ async def handle_mcp_add( ) endpoints = await get_endpoints(url) - client_id = await register_oauth_client( - endpoints, - url, - f"Discord Bot - {bot_user.name} ({interaction.user.name})", - ) - mcp_server = DiscordMCPServer( - discord_bot_user_id=bot_user.id, + name = f"Discord Bot - {bot_user.name} ({entity_type} {entity_id})" + client_id = await register_oauth_client(endpoints, url, name) + + # Create MCP server + mcp_server = MCPServer( mcp_server_url=url, client_id=client_id, + name=name, ) session.add(mcp_server) session.flush() + assignment = MCPServerAssignment( + mcp_server_id=mcp_server.id, + entity_type=entity_type, + entity_id=entity_id, + ) + session.add(assignment) + session.flush() + auth_url = await issue_challenge(mcp_server, endpoints) session.commit() logger.info( f"Created MCP server record: id={mcp_server.id}, " - f"user={interaction.user.id}, url={url}" + f"{entity_type}={entity_id}, url={url}" ) return ( @@ -146,32 +161,54 @@ async def handle_mcp_add( ) -async def handle_mcp_delete(bot_user: discord.User, url: str) -> str: - """Delete an MCP server.""" +async def handle_mcp_delete(entity_type: str, entity_id: int, url: str) -> str: + """Delete an MCP server assignment.""" with make_session() as session: - mcp_server = find_mcp_server(session, bot_user.id, url) - if not mcp_server: + # Find the assignment + assignment = ( + session.query(MCPServerAssignment) + .join(MCPServer) + .filter( + MCPServerAssignment.entity_type == entity_type, + MCPServerAssignment.entity_id == entity_id, + MCPServer.mcp_server_url == url, + ) + .first() + ) + + if not assignment: return ( f"**MCP Server Not Found**\n\n" f"You don't have an MCP server configured at `{url}`.\n" ) - session.delete(mcp_server) + + # Delete the assignment (server will cascade delete if no other assignments exist) + session.delete(assignment) + + # Check if server has other assignments + mcp_server = assignment.mcp_server + other_assignments = ( + session.query(MCPServerAssignment) + .filter( + MCPServerAssignment.mcp_server_id == mcp_server.id, + MCPServerAssignment.id != assignment.id, + ) + .count() + ) + + # If no other assignments, delete the server too + if other_assignments == 0: + session.delete(mcp_server) + session.commit() return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed." -async def handle_mcp_connect(bot_user: discord.User, url: str) -> str: +async def handle_mcp_connect(entity_type: str, entity_id: int, url: str) -> str: """Reconnect to an existing MCP server (redo OAuth).""" with make_session() as session: - mcp_server = find_mcp_server(session, bot_user.id, url) - if not mcp_server: - raise ValueError( - f"**MCP Server Not Found**\n\n" - f"You don't have an MCP server configured at `{url}`.\n" - f"Use `/memory_mcp_servers add {url}` to add it first." - ) - + mcp_server = find_mcp_server(session, entity_type, entity_id, url) if not mcp_server: raise ValueError( f"**MCP Server Not Found**\n\n" @@ -184,7 +221,9 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str: session.commit() - logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}") + logger.info( + f"Regenerated OAuth challenge for {entity_type}={entity_id}, url={url}" + ) return ( f"🔄 **Reconnect to MCP Server**\n\n" @@ -195,10 +234,10 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str: ) -async def handle_mcp_tools(bot_user: discord.User, url: str) -> str: +async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str: """List tools available on an MCP server.""" with make_session() as session: - mcp_server = find_mcp_server(session, bot_user.id, url) + mcp_server = find_mcp_server(session, entity_type, entity_id, url) if not mcp_server: raise ValueError( @@ -265,37 +304,28 @@ async def handle_mcp_tools(bot_user: discord.User, url: str) -> str: async def run_mcp_server_command( - interaction: discord.Interaction, bot_user: discord.User | None, action: Literal["list", "add", "delete", "connect", "tools"], url: str | None, + entity_type: str, + entity_id: int, ) -> None: """Handle MCP server management commands.""" if action not in ["list", "add", "delete", "connect", "tools"]: - await interaction.response.send_message("❌ Invalid action", ephemeral=True) - return + raise ValueError(f"Invalid action: {action}") if action != "list" and not url: - await interaction.response.send_message( - "❌ URL is required for this action", ephemeral=True - ) - return + raise ValueError("URL is required for this action") if not bot_user: - await interaction.response.send_message( - "❌ Bot user is required", ephemeral=True - ) - return + raise ValueError("Bot user is required") - try: - if action == "list" or not url: - result = await handle_mcp_list(interaction) - elif action == "add": - result = await handle_mcp_add(interaction, bot_user, url) - elif action == "delete": - result = await handle_mcp_delete(bot_user, url) - elif action == "connect": - result = await handle_mcp_connect(bot_user, url) - elif action == "tools": - result = await handle_mcp_tools(bot_user, url) - except Exception as exc: - result = f"❌ Error: {exc}" - await interaction.response.send_message(result, ephemeral=True) + if action == "list" or not url: + return await handle_mcp_list(entity_type, entity_id) + elif action == "add": + return await handle_mcp_add(entity_type, entity_id, bot_user, url) + elif action == "delete": + return await handle_mcp_delete(entity_type, entity_id, url) + elif action == "connect": + return await handle_mcp_connect(entity_type, entity_id, url) + elif action == "tools": + return await handle_mcp_tools(entity_type, entity_id, url) + raise ValueError(f"Invalid action: {action}") diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py index 9d6c810..d815dd3 100644 --- a/src/memory/discord/messages.py +++ b/src/memory/discord/messages.py @@ -113,8 +113,6 @@ def upsert_scheduled_message( .first() ) naive_scheduled_time = scheduled_time.replace(tzinfo=None) - print(f"naive_scheduled_time: {naive_scheduled_time}") - print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}") if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time: prev_call.status = "cancelled" # type: ignore @@ -141,10 +139,8 @@ def previous_messages( ) -> list[DiscordMessage]: messages = session.query(DiscordMessage) if user_id: - print(f"user_id: {user_id}") messages = messages.filter(DiscordMessage.recipient_id == user_id) if channel_id: - print(f"channel_id: {channel_id}") messages = messages.filter(DiscordMessage.channel_id == channel_id) return list( reversed( @@ -294,10 +290,8 @@ def send_discord_response( True if sent successfully """ if channel_id is not None: - logger.info(f"Sending message to channel {channel_id}") return discord.send_to_channel(bot_id, channel_id, response) elif user_identifier is not None: - logger.info(f"Sending DM to {user_identifier}") return discord.send_dm(bot_id, user_identifier, response) else: logger.error("Neither channel_id nor user_identifier provided") diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index d5c47ec..3662926 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -91,6 +91,10 @@ def should_process(message: DiscordMessage) -> bool: ): return False + if f"<@{message.recipient_user.id}>" in message.content: + logger.info("Direct mention of the bot, processing message") + return True + if message.from_user == message.recipient_user: logger.info("Skipping message because from_user == recipient_user") return False @@ -132,6 +136,8 @@ def should_process(message: DiscordMessage) -> bool: if not (res := re.search(r"(.*)", response)): return False try: + logger.info(f"chattiness_threshold: {message.chattiness_threshold}") + logger.info(f"haiku desire: {res.group(1)}") if int(res.group(1)) < 100 - message.chattiness_threshold: return False except ValueError: