From 8af07f0dac029e53c0579447ce9092478b694f2f Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sat, 1 Nov 2025 18:04:38 +0100 Subject: [PATCH] add slash commands for discord --- src/memory/discord/collector.py | 11 + src/memory/discord/commands.py | 393 +++++++++++++++++++ tests/memory/discord_tests/test_collector.py | 2 + tests/memory/discord_tests/test_commands.py | 172 ++++++++ 4 files changed, 578 insertions(+) create mode 100644 src/memory/discord/commands.py create mode 100644 tests/memory/discord_tests/test_commands.py diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py index c0bdc0c..ab0803c 100644 --- a/src/memory/discord/collector.py +++ b/src/memory/discord/collector.py @@ -18,6 +18,7 @@ from memory.common.db.models import ( DiscordChannel, DiscordUser, ) +from memory.discord.commands import register_slash_commands from memory.workers.tasks.discord import add_discord_message, edit_discord_message logger = logging.getLogger(__name__) @@ -199,6 +200,11 @@ class MessageCollector(commands.Bot): help_command=None, # Disable default help ) + async def setup_hook(self): + """Register slash commands when the bot is ready.""" + + register_slash_commands(self) + async def on_ready(self): """Called when bot connects to Discord""" logger.info(f"Discord collector connected as {self.user}") @@ -207,6 +213,11 @@ class MessageCollector(commands.Bot): # Sync server and channel metadata await self.sync_servers_and_channels() + try: + await self.tree.sync() + except Exception as exc: # pragma: no cover - defensive + logger.error("Failed to sync slash commands: %s", exc) + logger.info("Discord message collector ready") async def on_message(self, message: discord.Message): diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py new file mode 100644 index 0000000..4697a3e --- /dev/null +++ b/src/memory/discord/commands.py @@ -0,0 +1,393 @@ +"""Lightweight slash-command helpers for the Discord collector.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Literal + +import discord +from sqlalchemy.orm import Session + +from memory.common.db.connection import make_session +from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser + +ScopeLiteral = Literal["server", "channel", "user"] + + +class CommandError(Exception): + """Raised when a user-facing error occurs while handling a command.""" + + +@dataclass(slots=True) +class CommandResponse: + """Value object returned by handlers.""" + + content: str + ephemeral: bool = True + + +@dataclass(slots=True) +class CommandContext: + """All information a handler needs to fulfil a command.""" + + session: Session + interaction: discord.Interaction + actor: DiscordUser + scope: ScopeLiteral + target: DiscordServer | DiscordChannel | DiscordUser + display_name: str + + +CommandHandler = Callable[..., CommandResponse] + + +def register_slash_commands(bot: discord.Client) -> None: + """Register the collector slash commands on the provided bot.""" + + if getattr(bot, "_memory_commands_registered", False): + return + + setattr(bot, "_memory_commands_registered", True) + + if not hasattr(bot, "tree"): + raise RuntimeError("Bot instance does not support app commands") + + tree = bot.tree + + @tree.command(name="memory_prompt", description="Show the current system prompt") + @discord.app_commands.describe( + scope="Which configuration to inspect", + user="Target user when the scope is 'user'", + ) + async def prompt_command( + interaction: discord.Interaction, + scope: ScopeLiteral, + user: discord.User | None = None, + ) -> None: + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_prompt, + target_user=user, + ) + + @tree.command( + name="memory_chattiness", + description="Show or update the chattiness threshold for the target", + ) + @discord.app_commands.describe( + scope="Which configuration to inspect", + value="Optional new threshold value between 0 and 100", + user="Target user when the scope is 'user'", + ) + async def chattiness_command( + interaction: discord.Interaction, + scope: ScopeLiteral, + value: int | None = None, + user: discord.User | None = None, + ) -> None: + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_chattiness, + target_user=user, + value=value, + ) + + @tree.command( + name="memory_ignore", + description="Toggle whether the bot should ignore messages for the target", + ) + @discord.app_commands.describe( + scope="Which configuration to modify", + enabled="Optional flag. Leave empty to enable ignoring.", + user="Target user when the scope is 'user'", + ) + async def ignore_command( + interaction: discord.Interaction, + scope: ScopeLiteral, + enabled: bool | None = None, + user: discord.User | None = None, + ) -> None: + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_ignore, + target_user=user, + ignore_enabled=enabled, + ) + + @tree.command(name="memory_summary", description="Show the stored summary for the target") + @discord.app_commands.describe( + scope="Which configuration to inspect", + user="Target user when the scope is 'user'", + ) + async def summary_command( + interaction: discord.Interaction, + scope: ScopeLiteral, + user: discord.User | None = None, + ) -> None: + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_summary, + target_user=user, + ) + + +async def _run_interaction_command( + interaction: discord.Interaction, + *, + scope: ScopeLiteral, + handler: CommandHandler, + target_user: discord.User | None = None, + **handler_kwargs, +) -> None: + """Shared coroutine used by the registered slash commands.""" + + try: + with make_session() as session: + response = run_command( + session, + interaction, + scope, + handler=handler, + target_user=target_user, + **handler_kwargs, + ) + session.commit() + except CommandError as exc: # pragma: no cover - passthrough + await interaction.response.send_message(str(exc), ephemeral=True) + return + + await interaction.response.send_message( + response.content, + ephemeral=response.ephemeral, + ) + + +def run_command( + session: Session, + interaction: discord.Interaction, + scope: ScopeLiteral, + *, + handler: CommandHandler, + target_user: discord.User | None = None, + **handler_kwargs, +) -> CommandResponse: + """Create a :class:`CommandContext` and execute the handler.""" + + context = _build_context(session, interaction, scope, target_user) + return handler(context, **handler_kwargs) + + +def _build_context( + session: Session, + interaction: discord.Interaction, + scope: ScopeLiteral, + target_user: discord.User | None, +) -> CommandContext: + actor = _ensure_user(session, interaction.user) + + if scope == "server": + if interaction.guild is None: + raise CommandError("This command can only be used inside a server.") + + target = _ensure_server(session, interaction.guild) + display_name = f"server **{target.name}**" + return CommandContext( + session=session, + interaction=interaction, + actor=actor, + scope=scope, + target=target, + display_name=display_name, + ) + + if scope == "channel": + channel_obj = interaction.channel + if channel_obj is None or not hasattr(channel_obj, "id"): + raise CommandError("Unable to determine channel for this interaction.") + + target = _ensure_channel(session, channel_obj, interaction.guild_id) + display_name = f"channel **#{target.name}**" + return CommandContext( + session=session, + interaction=interaction, + actor=actor, + scope=scope, + target=target, + display_name=display_name, + ) + + if scope == "user": + discord_user = target_user or interaction.user + if discord_user is None: + raise CommandError("A target user is required for this command.") + + target = _ensure_user(session, discord_user) + display_name = target.display_name or target.username + return CommandContext( + session=session, + interaction=interaction, + actor=actor, + scope=scope, + target=target, + display_name=f"user **{display_name}**", + ) + + raise CommandError(f"Unsupported scope '{scope}'.") + + +def _ensure_server(session: Session, guild: discord.Guild) -> DiscordServer: + server = session.get(DiscordServer, guild.id) + if server is None: + server = DiscordServer( + id=guild.id, + name=guild.name or f"Server {guild.id}", + description=getattr(guild, "description", None), + member_count=getattr(guild, "member_count", None), + ) + session.add(server) + session.flush() + else: + if guild.name and server.name != guild.name: + server.name = guild.name + description = getattr(guild, "description", None) + if description and server.description != description: + server.description = description + member_count = getattr(guild, "member_count", None) + if member_count is not None: + server.member_count = member_count + + return server + + +def _ensure_channel( + session: Session, + channel: discord.abc.Messageable, + guild_id: int | None, +) -> DiscordChannel: + channel_id = getattr(channel, "id", None) + if channel_id is None: + raise CommandError("Channel is missing an identifier.") + + channel_model = session.get(DiscordChannel, channel_id) + if channel_model is None: + channel_model = DiscordChannel( + id=channel_id, + server_id=guild_id, + name=getattr(channel, "name", f"Channel {channel_id}"), + channel_type=_resolve_channel_type(channel), + ) + session.add(channel_model) + session.flush() + else: + name = getattr(channel, "name", None) + if name and channel_model.name != name: + channel_model.name = name + + return channel_model + + +def _ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser: + user = session.get(DiscordUser, discord_user.id) + display_name = getattr(discord_user, "display_name", discord_user.name) + if user is None: + user = DiscordUser( + id=discord_user.id, + username=discord_user.name, + display_name=display_name, + ) + session.add(user) + session.flush() + else: + if user.username != discord_user.name: + user.username = discord_user.name + if display_name and user.display_name != display_name: + user.display_name = display_name + + return user + + +def _resolve_channel_type(channel: discord.abc.Messageable) -> str: + if isinstance(channel, discord.DMChannel): + return "dm" + if isinstance(channel, discord.GroupChannel): + return "group_dm" + if isinstance(channel, discord.Thread): + return "thread" + if isinstance(channel, discord.VoiceChannel): + return "voice" + if isinstance(channel, discord.TextChannel): + return "text" + return getattr(getattr(channel, "type", None), "name", "unknown") + + +def handle_prompt(context: CommandContext) -> CommandResponse: + prompt = getattr(context.target, "system_prompt", None) + + if prompt: + return CommandResponse( + content=f"Current prompt for {context.display_name}:\n\n{prompt}", + ) + + return CommandResponse( + content=f"No prompt configured for {context.display_name}.", + ) + + +def handle_chattiness( + context: CommandContext, + *, + value: int | None, +) -> CommandResponse: + model = context.target + + if value is None: + return CommandResponse( + content=( + f"Chattiness threshold for {context.display_name}: " + f"{getattr(model, 'chattiness_threshold', 'not set')}" + ) + ) + + if not 0 <= value <= 100: + raise CommandError("Chattiness threshold must be between 0 and 100.") + + setattr(model, "chattiness_threshold", value) + + return CommandResponse( + content=( + f"Updated chattiness threshold for {context.display_name} " + f"to {value}." + ) + ) + + +def handle_ignore( + context: CommandContext, + *, + ignore_enabled: bool | None, +) -> CommandResponse: + model = context.target + new_value = True if ignore_enabled is None else ignore_enabled + setattr(model, "ignore_messages", new_value) + + verb = "now ignoring" if new_value else "no longer ignoring" + return CommandResponse( + content=f"The bot is {verb} messages for {context.display_name}.", + ) + + +def handle_summary(context: CommandContext) -> CommandResponse: + summary = getattr(context.target, "summary", None) + + if summary: + return CommandResponse( + content=f"Summary for {context.display_name}:\n\n{summary}", + ) + + return CommandResponse( + content=f"No summary stored for {context.display_name}.", + ) diff --git a/tests/memory/discord_tests/test_collector.py b/tests/memory/discord_tests/test_collector.py index 63f7aef..0e37a1f 100644 --- a/tests/memory/discord_tests/test_collector.py +++ b/tests/memory/discord_tests/test_collector.py @@ -493,10 +493,12 @@ async def test_on_ready(): collector.user.name = "TestBot" collector.guilds = [Mock(), Mock()] collector.sync_servers_and_channels = AsyncMock() + collector.tree.sync = AsyncMock() await collector.on_ready() collector.sync_servers_and_channels.assert_called_once() + collector.tree.sync.assert_awaited() @pytest.mark.asyncio diff --git a/tests/memory/discord_tests/test_commands.py b/tests/memory/discord_tests/test_commands.py new file mode 100644 index 0000000..6daf4f9 --- /dev/null +++ b/tests/memory/discord_tests/test_commands.py @@ -0,0 +1,172 @@ +import pytest +from unittest.mock import MagicMock + +import discord + +from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser +from memory.discord.commands import ( + CommandError, + CommandResponse, + run_command, + handle_prompt, + handle_chattiness, + handle_ignore, + handle_summary, +) + + +class DummyInteraction: + """Lightweight stand-in for :class:`discord.Interaction` used in tests.""" + + def __init__( + self, + *, + guild: discord.Guild | None, + channel: discord.abc.Messageable | None, + user: discord.abc.User, + ) -> None: + self.guild = guild + self.channel = channel + self.user = user + self.guild_id = getattr(guild, "id", None) + self.channel_id = getattr(channel, "id", None) + + +@pytest.fixture +def guild() -> discord.Guild: + guild = MagicMock(spec=discord.Guild) + guild.id = 123 + guild.name = "Test Guild" + guild.description = "Guild description" + guild.member_count = 42 + return guild + + +@pytest.fixture +def text_channel(guild: discord.Guild) -> discord.TextChannel: + channel = MagicMock(spec=discord.TextChannel) + channel.id = 456 + channel.name = "general" + channel.guild = guild + channel.type = discord.ChannelType.text + return channel + + +@pytest.fixture +def discord_user() -> discord.User: + user = MagicMock(spec=discord.User) + user.id = 789 + user.name = "command-user" + user.display_name = "Commander" + return user + + +@pytest.fixture +def interaction(guild, text_channel, discord_user) -> DummyInteraction: + return DummyInteraction(guild=guild, channel=text_channel, user=discord_user) + + +def test_handle_command_prompt_server(db_session, guild, interaction): + server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful") + db_session.add(server) + db_session.commit() + + response = run_command( + db_session, + interaction, + scope="server", + handler=handle_prompt, + ) + + assert isinstance(response, CommandResponse) + assert "Be helpful" in response.content + + +def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel): + response = run_command( + db_session, + interaction, + scope="channel", + handler=handle_prompt, + ) + + assert "No prompt" in response.content + channel = db_session.get(DiscordChannel, text_channel.id) + assert channel is not None + assert channel.name == text_channel.name + + +def test_handle_command_chattiness_show(db_session, interaction, guild): + server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73) + db_session.add(server) + db_session.commit() + + response = run_command( + db_session, + interaction, + scope="server", + handler=handle_chattiness, + ) + + assert str(server.chattiness_threshold) in response.content + + +def test_handle_command_chattiness_update(db_session, interaction): + user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15) + db_session.add(user_model) + db_session.commit() + + response = run_command( + db_session, + interaction, + scope="user", + handler=handle_chattiness, + value=80, + ) + + db_session.flush() + + assert "Updated" in response.content + assert user_model.chattiness_threshold == 80 + + +def test_handle_command_chattiness_invalid_value(db_session, interaction): + with pytest.raises(CommandError): + run_command( + db_session, + interaction, + scope="user", + handler=handle_chattiness, + value=150, + ) + + +def test_handle_command_ignore_toggle(db_session, interaction, guild): + channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id) + db_session.add(channel) + db_session.commit() + + response = run_command( + db_session, + interaction, + scope="channel", + handler=handle_ignore, + ignore_enabled=True, + ) + + db_session.flush() + + assert "no longer" not in response.content + assert channel.ignore_messages is True + + +def test_handle_command_summary_missing(db_session, interaction): + response = run_command( + db_session, + interaction, + scope="user", + handler=handle_summary, + ) + + assert "No summary" in response.content +