mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
add slash commands for discord
This commit is contained in:
parent
c296f3b533
commit
8af07f0dac
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
)
|
||||
from memory.discord.commands import register_slash_commands
|
||||
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -199,6 +200,11 @@ class MessageCollector(commands.Bot):
|
||||
help_command=None, # Disable default help
|
||||
)
|
||||
|
||||
async def setup_hook(self):
|
||||
"""Register slash commands when the bot is ready."""
|
||||
|
||||
register_slash_commands(self)
|
||||
|
||||
async def on_ready(self):
|
||||
"""Called when bot connects to Discord"""
|
||||
logger.info(f"Discord collector connected as {self.user}")
|
||||
@ -207,6 +213,11 @@ class MessageCollector(commands.Bot):
|
||||
# Sync server and channel metadata
|
||||
await self.sync_servers_and_channels()
|
||||
|
||||
try:
|
||||
await self.tree.sync()
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.error("Failed to sync slash commands: %s", exc)
|
||||
|
||||
logger.info("Discord message collector ready")
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
|
||||
393
src/memory/discord/commands.py
Normal file
393
src/memory/discord/commands.py
Normal file
@ -0,0 +1,393 @@
|
||||
"""Lightweight slash-command helpers for the Discord collector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Literal
|
||||
|
||||
import discord
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
|
||||
ScopeLiteral = Literal["server", "channel", "user"]
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
"""Raised when a user-facing error occurs while handling a command."""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CommandResponse:
|
||||
"""Value object returned by handlers."""
|
||||
|
||||
content: str
|
||||
ephemeral: bool = True
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CommandContext:
|
||||
"""All information a handler needs to fulfil a command."""
|
||||
|
||||
session: Session
|
||||
interaction: discord.Interaction
|
||||
actor: DiscordUser
|
||||
scope: ScopeLiteral
|
||||
target: DiscordServer | DiscordChannel | DiscordUser
|
||||
display_name: str
|
||||
|
||||
|
||||
CommandHandler = Callable[..., CommandResponse]
|
||||
|
||||
|
||||
def register_slash_commands(bot: discord.Client) -> None:
|
||||
"""Register the collector slash commands on the provided bot."""
|
||||
|
||||
if getattr(bot, "_memory_commands_registered", False):
|
||||
return
|
||||
|
||||
setattr(bot, "_memory_commands_registered", True)
|
||||
|
||||
if not hasattr(bot, "tree"):
|
||||
raise RuntimeError("Bot instance does not support app commands")
|
||||
|
||||
tree = bot.tree
|
||||
|
||||
@tree.command(name="memory_prompt", description="Show the current system prompt")
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_prompt,
|
||||
target_user=user,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name="memory_chattiness",
|
||||
description="Show or update the chattiness threshold for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
value="Optional new threshold value between 0 and 100",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def chattiness_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
value: int | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_chattiness,
|
||||
target_user=user,
|
||||
value=value,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name="memory_ignore",
|
||||
description="Toggle whether the bot should ignore messages for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify",
|
||||
enabled="Optional flag. Leave empty to enable ignoring.",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def ignore_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
enabled: bool | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_ignore,
|
||||
target_user=user,
|
||||
ignore_enabled=enabled,
|
||||
)
|
||||
|
||||
@tree.command(name="memory_summary", description="Show the stored summary for the target")
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def summary_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_summary,
|
||||
target_user=user,
|
||||
)
|
||||
|
||||
|
||||
async def _run_interaction_command(
|
||||
interaction: discord.Interaction,
|
||||
*,
|
||||
scope: ScopeLiteral,
|
||||
handler: CommandHandler,
|
||||
target_user: discord.User | None = None,
|
||||
**handler_kwargs,
|
||||
) -> None:
|
||||
"""Shared coroutine used by the registered slash commands."""
|
||||
|
||||
try:
|
||||
with make_session() as session:
|
||||
response = run_command(
|
||||
session,
|
||||
interaction,
|
||||
scope,
|
||||
handler=handler,
|
||||
target_user=target_user,
|
||||
**handler_kwargs,
|
||||
)
|
||||
session.commit()
|
||||
except CommandError as exc: # pragma: no cover - passthrough
|
||||
await interaction.response.send_message(str(exc), ephemeral=True)
|
||||
return
|
||||
|
||||
await interaction.response.send_message(
|
||||
response.content,
|
||||
ephemeral=response.ephemeral,
|
||||
)
|
||||
|
||||
|
||||
def run_command(
|
||||
session: Session,
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
*,
|
||||
handler: CommandHandler,
|
||||
target_user: discord.User | None = None,
|
||||
**handler_kwargs,
|
||||
) -> CommandResponse:
|
||||
"""Create a :class:`CommandContext` and execute the handler."""
|
||||
|
||||
context = _build_context(session, interaction, scope, target_user)
|
||||
return handler(context, **handler_kwargs)
|
||||
|
||||
|
||||
def _build_context(
|
||||
session: Session,
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
target_user: discord.User | None,
|
||||
) -> CommandContext:
|
||||
actor = _ensure_user(session, interaction.user)
|
||||
|
||||
if scope == "server":
|
||||
if interaction.guild is None:
|
||||
raise CommandError("This command can only be used inside a server.")
|
||||
|
||||
target = _ensure_server(session, interaction.guild)
|
||||
display_name = f"server **{target.name}**"
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
if scope == "channel":
|
||||
channel_obj = interaction.channel
|
||||
if channel_obj is None or not hasattr(channel_obj, "id"):
|
||||
raise CommandError("Unable to determine channel for this interaction.")
|
||||
|
||||
target = _ensure_channel(session, channel_obj, interaction.guild_id)
|
||||
display_name = f"channel **#{target.name}**"
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
if scope == "user":
|
||||
discord_user = target_user or interaction.user
|
||||
if discord_user is None:
|
||||
raise CommandError("A target user is required for this command.")
|
||||
|
||||
target = _ensure_user(session, discord_user)
|
||||
display_name = target.display_name or target.username
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=f"user **{display_name}**",
|
||||
)
|
||||
|
||||
raise CommandError(f"Unsupported scope '{scope}'.")
|
||||
|
||||
|
||||
def _ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
|
||||
server = session.get(DiscordServer, guild.id)
|
||||
if server is None:
|
||||
server = DiscordServer(
|
||||
id=guild.id,
|
||||
name=guild.name or f"Server {guild.id}",
|
||||
description=getattr(guild, "description", None),
|
||||
member_count=getattr(guild, "member_count", None),
|
||||
)
|
||||
session.add(server)
|
||||
session.flush()
|
||||
else:
|
||||
if guild.name and server.name != guild.name:
|
||||
server.name = guild.name
|
||||
description = getattr(guild, "description", None)
|
||||
if description and server.description != description:
|
||||
server.description = description
|
||||
member_count = getattr(guild, "member_count", None)
|
||||
if member_count is not None:
|
||||
server.member_count = member_count
|
||||
|
||||
return server
|
||||
|
||||
|
||||
def _ensure_channel(
|
||||
session: Session,
|
||||
channel: discord.abc.Messageable,
|
||||
guild_id: int | None,
|
||||
) -> DiscordChannel:
|
||||
channel_id = getattr(channel, "id", None)
|
||||
if channel_id is None:
|
||||
raise CommandError("Channel is missing an identifier.")
|
||||
|
||||
channel_model = session.get(DiscordChannel, channel_id)
|
||||
if channel_model is None:
|
||||
channel_model = DiscordChannel(
|
||||
id=channel_id,
|
||||
server_id=guild_id,
|
||||
name=getattr(channel, "name", f"Channel {channel_id}"),
|
||||
channel_type=_resolve_channel_type(channel),
|
||||
)
|
||||
session.add(channel_model)
|
||||
session.flush()
|
||||
else:
|
||||
name = getattr(channel, "name", None)
|
||||
if name and channel_model.name != name:
|
||||
channel_model.name = name
|
||||
|
||||
return channel_model
|
||||
|
||||
|
||||
def _ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
|
||||
user = session.get(DiscordUser, discord_user.id)
|
||||
display_name = getattr(discord_user, "display_name", discord_user.name)
|
||||
if user is None:
|
||||
user = DiscordUser(
|
||||
id=discord_user.id,
|
||||
username=discord_user.name,
|
||||
display_name=display_name,
|
||||
)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
else:
|
||||
if user.username != discord_user.name:
|
||||
user.username = discord_user.name
|
||||
if display_name and user.display_name != display_name:
|
||||
user.display_name = display_name
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def _resolve_channel_type(channel: discord.abc.Messageable) -> str:
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
return "dm"
|
||||
if isinstance(channel, discord.GroupChannel):
|
||||
return "group_dm"
|
||||
if isinstance(channel, discord.Thread):
|
||||
return "thread"
|
||||
if isinstance(channel, discord.VoiceChannel):
|
||||
return "voice"
|
||||
if isinstance(channel, discord.TextChannel):
|
||||
return "text"
|
||||
return getattr(getattr(channel, "type", None), "name", "unknown")
|
||||
|
||||
|
||||
def handle_prompt(context: CommandContext) -> CommandResponse:
|
||||
prompt = getattr(context.target, "system_prompt", None)
|
||||
|
||||
if prompt:
|
||||
return CommandResponse(
|
||||
content=f"Current prompt for {context.display_name}:\n\n{prompt}",
|
||||
)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"No prompt configured for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
def handle_chattiness(
|
||||
context: CommandContext,
|
||||
*,
|
||||
value: int | None,
|
||||
) -> CommandResponse:
|
||||
model = context.target
|
||||
|
||||
if value is None:
|
||||
return CommandResponse(
|
||||
content=(
|
||||
f"Chattiness threshold for {context.display_name}: "
|
||||
f"{getattr(model, 'chattiness_threshold', 'not set')}"
|
||||
)
|
||||
)
|
||||
|
||||
if not 0 <= value <= 100:
|
||||
raise CommandError("Chattiness threshold must be between 0 and 100.")
|
||||
|
||||
setattr(model, "chattiness_threshold", value)
|
||||
|
||||
return CommandResponse(
|
||||
content=(
|
||||
f"Updated chattiness threshold for {context.display_name} "
|
||||
f"to {value}."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def handle_ignore(
|
||||
context: CommandContext,
|
||||
*,
|
||||
ignore_enabled: bool | None,
|
||||
) -> CommandResponse:
|
||||
model = context.target
|
||||
new_value = True if ignore_enabled is None else ignore_enabled
|
||||
setattr(model, "ignore_messages", new_value)
|
||||
|
||||
verb = "now ignoring" if new_value else "no longer ignoring"
|
||||
return CommandResponse(
|
||||
content=f"The bot is {verb} messages for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
def handle_summary(context: CommandContext) -> CommandResponse:
|
||||
summary = getattr(context.target, "summary", None)
|
||||
|
||||
if summary:
|
||||
return CommandResponse(
|
||||
content=f"Summary for {context.display_name}:\n\n{summary}",
|
||||
)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"No summary stored for {context.display_name}.",
|
||||
)
|
||||
@ -493,10 +493,12 @@ async def test_on_ready():
|
||||
collector.user.name = "TestBot"
|
||||
collector.guilds = [Mock(), Mock()]
|
||||
collector.sync_servers_and_channels = AsyncMock()
|
||||
collector.tree.sync = AsyncMock()
|
||||
|
||||
await collector.on_ready()
|
||||
|
||||
collector.sync_servers_and_channels.assert_called_once()
|
||||
collector.tree.sync.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
172
tests/memory/discord_tests/test_commands.py
Normal file
172
tests/memory/discord_tests/test_commands.py
Normal file
@ -0,0 +1,172 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import discord
|
||||
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
from memory.discord.commands import (
|
||||
CommandError,
|
||||
CommandResponse,
|
||||
run_command,
|
||||
handle_prompt,
|
||||
handle_chattiness,
|
||||
handle_ignore,
|
||||
handle_summary,
|
||||
)
|
||||
|
||||
|
||||
class DummyInteraction:
|
||||
"""Lightweight stand-in for :class:`discord.Interaction` used in tests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
guild: discord.Guild | None,
|
||||
channel: discord.abc.Messageable | None,
|
||||
user: discord.abc.User,
|
||||
) -> None:
|
||||
self.guild = guild
|
||||
self.channel = channel
|
||||
self.user = user
|
||||
self.guild_id = getattr(guild, "id", None)
|
||||
self.channel_id = getattr(channel, "id", None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def guild() -> discord.Guild:
|
||||
guild = MagicMock(spec=discord.Guild)
|
||||
guild.id = 123
|
||||
guild.name = "Test Guild"
|
||||
guild.description = "Guild description"
|
||||
guild.member_count = 42
|
||||
return guild
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_channel(guild: discord.Guild) -> discord.TextChannel:
|
||||
channel = MagicMock(spec=discord.TextChannel)
|
||||
channel.id = 456
|
||||
channel.name = "general"
|
||||
channel.guild = guild
|
||||
channel.type = discord.ChannelType.text
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discord_user() -> discord.User:
|
||||
user = MagicMock(spec=discord.User)
|
||||
user.id = 789
|
||||
user.name = "command-user"
|
||||
user.display_name = "Commander"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def interaction(guild, text_channel, discord_user) -> DummyInteraction:
|
||||
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
|
||||
|
||||
|
||||
def test_handle_command_prompt_server(db_session, guild, interaction):
|
||||
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="server",
|
||||
handler=handle_prompt,
|
||||
)
|
||||
|
||||
assert isinstance(response, CommandResponse)
|
||||
assert "Be helpful" in response.content
|
||||
|
||||
|
||||
def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel):
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="channel",
|
||||
handler=handle_prompt,
|
||||
)
|
||||
|
||||
assert "No prompt" in response.content
|
||||
channel = db_session.get(DiscordChannel, text_channel.id)
|
||||
assert channel is not None
|
||||
assert channel.name == text_channel.name
|
||||
|
||||
|
||||
def test_handle_command_chattiness_show(db_session, interaction, guild):
|
||||
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="server",
|
||||
handler=handle_chattiness,
|
||||
)
|
||||
|
||||
assert str(server.chattiness_threshold) in response.content
|
||||
|
||||
|
||||
def test_handle_command_chattiness_update(db_session, interaction):
|
||||
user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15)
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
handler=handle_chattiness,
|
||||
value=80,
|
||||
)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert user_model.chattiness_threshold == 80
|
||||
|
||||
|
||||
def test_handle_command_chattiness_invalid_value(db_session, interaction):
|
||||
with pytest.raises(CommandError):
|
||||
run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
handler=handle_chattiness,
|
||||
value=150,
|
||||
)
|
||||
|
||||
|
||||
def test_handle_command_ignore_toggle(db_session, interaction, guild):
|
||||
channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="channel",
|
||||
handler=handle_ignore,
|
||||
ignore_enabled=True,
|
||||
)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
assert "no longer" not in response.content
|
||||
assert channel.ignore_messages is True
|
||||
|
||||
|
||||
def test_handle_command_summary_missing(db_session, interaction):
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
handler=handle_summary,
|
||||
)
|
||||
|
||||
assert "No summary" in response.content
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user