add slash commands for discord

This commit is contained in:
Daniel O'Connell 2025-11-01 18:04:38 +01:00
parent c296f3b533
commit 8af07f0dac
4 changed files with 578 additions and 0 deletions

View File

@ -18,6 +18,7 @@ from memory.common.db.models import (
DiscordChannel,
DiscordUser,
)
from memory.discord.commands import register_slash_commands
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
logger = logging.getLogger(__name__)
@ -199,6 +200,11 @@ class MessageCollector(commands.Bot):
help_command=None, # Disable default help
)
async def setup_hook(self):
"""Register slash commands when the bot is ready."""
register_slash_commands(self)
async def on_ready(self):
"""Called when bot connects to Discord"""
logger.info(f"Discord collector connected as {self.user}")
@ -207,6 +213,11 @@ class MessageCollector(commands.Bot):
# Sync server and channel metadata
await self.sync_servers_and_channels()
try:
await self.tree.sync()
except Exception as exc: # pragma: no cover - defensive
logger.error("Failed to sync slash commands: %s", exc)
logger.info("Discord message collector ready")
async def on_message(self, message: discord.Message):

View File

@ -0,0 +1,393 @@
"""Lightweight slash-command helpers for the Discord collector."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Literal
import discord
from sqlalchemy.orm import Session
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
ScopeLiteral = Literal["server", "channel", "user"]
class CommandError(Exception):
"""Raised when a user-facing error occurs while handling a command."""
@dataclass(slots=True)
class CommandResponse:
"""Value object returned by handlers."""
content: str
ephemeral: bool = True
@dataclass(slots=True)
class CommandContext:
"""All information a handler needs to fulfil a command."""
session: Session
interaction: discord.Interaction
actor: DiscordUser
scope: ScopeLiteral
target: DiscordServer | DiscordChannel | DiscordUser
display_name: str
CommandHandler = Callable[..., CommandResponse]
def register_slash_commands(bot: discord.Client) -> None:
"""Register the collector slash commands on the provided bot."""
if getattr(bot, "_memory_commands_registered", False):
return
setattr(bot, "_memory_commands_registered", True)
if not hasattr(bot, "tree"):
raise RuntimeError("Bot instance does not support app commands")
tree = bot.tree
@tree.command(name="memory_prompt", description="Show the current system prompt")
@discord.app_commands.describe(
scope="Which configuration to inspect",
user="Target user when the scope is 'user'",
)
async def prompt_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_prompt,
target_user=user,
)
@tree.command(
name="memory_chattiness",
description="Show or update the chattiness threshold for the target",
)
@discord.app_commands.describe(
scope="Which configuration to inspect",
value="Optional new threshold value between 0 and 100",
user="Target user when the scope is 'user'",
)
async def chattiness_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
value: int | None = None,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_chattiness,
target_user=user,
value=value,
)
@tree.command(
name="memory_ignore",
description="Toggle whether the bot should ignore messages for the target",
)
@discord.app_commands.describe(
scope="Which configuration to modify",
enabled="Optional flag. Leave empty to enable ignoring.",
user="Target user when the scope is 'user'",
)
async def ignore_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
enabled: bool | None = None,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_ignore,
target_user=user,
ignore_enabled=enabled,
)
@tree.command(name="memory_summary", description="Show the stored summary for the target")
@discord.app_commands.describe(
scope="Which configuration to inspect",
user="Target user when the scope is 'user'",
)
async def summary_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_summary,
target_user=user,
)
async def _run_interaction_command(
interaction: discord.Interaction,
*,
scope: ScopeLiteral,
handler: CommandHandler,
target_user: discord.User | None = None,
**handler_kwargs,
) -> None:
"""Shared coroutine used by the registered slash commands."""
try:
with make_session() as session:
response = run_command(
session,
interaction,
scope,
handler=handler,
target_user=target_user,
**handler_kwargs,
)
session.commit()
except CommandError as exc: # pragma: no cover - passthrough
await interaction.response.send_message(str(exc), ephemeral=True)
return
await interaction.response.send_message(
response.content,
ephemeral=response.ephemeral,
)
def run_command(
session: Session,
interaction: discord.Interaction,
scope: ScopeLiteral,
*,
handler: CommandHandler,
target_user: discord.User | None = None,
**handler_kwargs,
) -> CommandResponse:
"""Create a :class:`CommandContext` and execute the handler."""
context = _build_context(session, interaction, scope, target_user)
return handler(context, **handler_kwargs)
def _build_context(
session: Session,
interaction: discord.Interaction,
scope: ScopeLiteral,
target_user: discord.User | None,
) -> CommandContext:
actor = _ensure_user(session, interaction.user)
if scope == "server":
if interaction.guild is None:
raise CommandError("This command can only be used inside a server.")
target = _ensure_server(session, interaction.guild)
display_name = f"server **{target.name}**"
return CommandContext(
session=session,
interaction=interaction,
actor=actor,
scope=scope,
target=target,
display_name=display_name,
)
if scope == "channel":
channel_obj = interaction.channel
if channel_obj is None or not hasattr(channel_obj, "id"):
raise CommandError("Unable to determine channel for this interaction.")
target = _ensure_channel(session, channel_obj, interaction.guild_id)
display_name = f"channel **#{target.name}**"
return CommandContext(
session=session,
interaction=interaction,
actor=actor,
scope=scope,
target=target,
display_name=display_name,
)
if scope == "user":
discord_user = target_user or interaction.user
if discord_user is None:
raise CommandError("A target user is required for this command.")
target = _ensure_user(session, discord_user)
display_name = target.display_name or target.username
return CommandContext(
session=session,
interaction=interaction,
actor=actor,
scope=scope,
target=target,
display_name=f"user **{display_name}**",
)
raise CommandError(f"Unsupported scope '{scope}'.")
def _ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
server = session.get(DiscordServer, guild.id)
if server is None:
server = DiscordServer(
id=guild.id,
name=guild.name or f"Server {guild.id}",
description=getattr(guild, "description", None),
member_count=getattr(guild, "member_count", None),
)
session.add(server)
session.flush()
else:
if guild.name and server.name != guild.name:
server.name = guild.name
description = getattr(guild, "description", None)
if description and server.description != description:
server.description = description
member_count = getattr(guild, "member_count", None)
if member_count is not None:
server.member_count = member_count
return server
def _ensure_channel(
session: Session,
channel: discord.abc.Messageable,
guild_id: int | None,
) -> DiscordChannel:
channel_id = getattr(channel, "id", None)
if channel_id is None:
raise CommandError("Channel is missing an identifier.")
channel_model = session.get(DiscordChannel, channel_id)
if channel_model is None:
channel_model = DiscordChannel(
id=channel_id,
server_id=guild_id,
name=getattr(channel, "name", f"Channel {channel_id}"),
channel_type=_resolve_channel_type(channel),
)
session.add(channel_model)
session.flush()
else:
name = getattr(channel, "name", None)
if name and channel_model.name != name:
channel_model.name = name
return channel_model
def _ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
user = session.get(DiscordUser, discord_user.id)
display_name = getattr(discord_user, "display_name", discord_user.name)
if user is None:
user = DiscordUser(
id=discord_user.id,
username=discord_user.name,
display_name=display_name,
)
session.add(user)
session.flush()
else:
if user.username != discord_user.name:
user.username = discord_user.name
if display_name and user.display_name != display_name:
user.display_name = display_name
return user
def _resolve_channel_type(channel: discord.abc.Messageable) -> str:
if isinstance(channel, discord.DMChannel):
return "dm"
if isinstance(channel, discord.GroupChannel):
return "group_dm"
if isinstance(channel, discord.Thread):
return "thread"
if isinstance(channel, discord.VoiceChannel):
return "voice"
if isinstance(channel, discord.TextChannel):
return "text"
return getattr(getattr(channel, "type", None), "name", "unknown")
def handle_prompt(context: CommandContext) -> CommandResponse:
prompt = getattr(context.target, "system_prompt", None)
if prompt:
return CommandResponse(
content=f"Current prompt for {context.display_name}:\n\n{prompt}",
)
return CommandResponse(
content=f"No prompt configured for {context.display_name}.",
)
def handle_chattiness(
context: CommandContext,
*,
value: int | None,
) -> CommandResponse:
model = context.target
if value is None:
return CommandResponse(
content=(
f"Chattiness threshold for {context.display_name}: "
f"{getattr(model, 'chattiness_threshold', 'not set')}"
)
)
if not 0 <= value <= 100:
raise CommandError("Chattiness threshold must be between 0 and 100.")
setattr(model, "chattiness_threshold", value)
return CommandResponse(
content=(
f"Updated chattiness threshold for {context.display_name} "
f"to {value}."
)
)
def handle_ignore(
context: CommandContext,
*,
ignore_enabled: bool | None,
) -> CommandResponse:
model = context.target
new_value = True if ignore_enabled is None else ignore_enabled
setattr(model, "ignore_messages", new_value)
verb = "now ignoring" if new_value else "no longer ignoring"
return CommandResponse(
content=f"The bot is {verb} messages for {context.display_name}.",
)
def handle_summary(context: CommandContext) -> CommandResponse:
summary = getattr(context.target, "summary", None)
if summary:
return CommandResponse(
content=f"Summary for {context.display_name}:\n\n{summary}",
)
return CommandResponse(
content=f"No summary stored for {context.display_name}.",
)

View File

@ -493,10 +493,12 @@ async def test_on_ready():
collector.user.name = "TestBot"
collector.guilds = [Mock(), Mock()]
collector.sync_servers_and_channels = AsyncMock()
collector.tree.sync = AsyncMock()
await collector.on_ready()
collector.sync_servers_and_channels.assert_called_once()
collector.tree.sync.assert_awaited()
@pytest.mark.asyncio

View File

@ -0,0 +1,172 @@
import pytest
from unittest.mock import MagicMock
import discord
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
from memory.discord.commands import (
CommandError,
CommandResponse,
run_command,
handle_prompt,
handle_chattiness,
handle_ignore,
handle_summary,
)
class DummyInteraction:
"""Lightweight stand-in for :class:`discord.Interaction` used in tests."""
def __init__(
self,
*,
guild: discord.Guild | None,
channel: discord.abc.Messageable | None,
user: discord.abc.User,
) -> None:
self.guild = guild
self.channel = channel
self.user = user
self.guild_id = getattr(guild, "id", None)
self.channel_id = getattr(channel, "id", None)
@pytest.fixture
def guild() -> discord.Guild:
guild = MagicMock(spec=discord.Guild)
guild.id = 123
guild.name = "Test Guild"
guild.description = "Guild description"
guild.member_count = 42
return guild
@pytest.fixture
def text_channel(guild: discord.Guild) -> discord.TextChannel:
channel = MagicMock(spec=discord.TextChannel)
channel.id = 456
channel.name = "general"
channel.guild = guild
channel.type = discord.ChannelType.text
return channel
@pytest.fixture
def discord_user() -> discord.User:
user = MagicMock(spec=discord.User)
user.id = 789
user.name = "command-user"
user.display_name = "Commander"
return user
@pytest.fixture
def interaction(guild, text_channel, discord_user) -> DummyInteraction:
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
def test_handle_command_prompt_server(db_session, guild, interaction):
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
db_session.add(server)
db_session.commit()
response = run_command(
db_session,
interaction,
scope="server",
handler=handle_prompt,
)
assert isinstance(response, CommandResponse)
assert "Be helpful" in response.content
def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel):
response = run_command(
db_session,
interaction,
scope="channel",
handler=handle_prompt,
)
assert "No prompt" in response.content
channel = db_session.get(DiscordChannel, text_channel.id)
assert channel is not None
assert channel.name == text_channel.name
def test_handle_command_chattiness_show(db_session, interaction, guild):
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
db_session.add(server)
db_session.commit()
response = run_command(
db_session,
interaction,
scope="server",
handler=handle_chattiness,
)
assert str(server.chattiness_threshold) in response.content
def test_handle_command_chattiness_update(db_session, interaction):
user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15)
db_session.add(user_model)
db_session.commit()
response = run_command(
db_session,
interaction,
scope="user",
handler=handle_chattiness,
value=80,
)
db_session.flush()
assert "Updated" in response.content
assert user_model.chattiness_threshold == 80
def test_handle_command_chattiness_invalid_value(db_session, interaction):
with pytest.raises(CommandError):
run_command(
db_session,
interaction,
scope="user",
handler=handle_chattiness,
value=150,
)
def test_handle_command_ignore_toggle(db_session, interaction, guild):
channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id)
db_session.add(channel)
db_session.commit()
response = run_command(
db_session,
interaction,
scope="channel",
handler=handle_ignore,
ignore_enabled=True,
)
db_session.flush()
assert "no longer" not in response.content
assert channel.ignore_messages is True
def test_handle_command_summary_missing(db_session, interaction):
response = run_command(
db_session,
interaction,
scope="user",
handler=handle_summary,
)
assert "No summary" in response.content