mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
list discord command
This commit is contained in:
parent
8893018af1
commit
b568222e88
@ -51,21 +51,40 @@ class MessageProcessor:
|
||||
),
|
||||
)
|
||||
|
||||
def as_xml(self) -> str:
|
||||
return (
|
||||
textwrap.dedent("""
|
||||
<{type}>
|
||||
<name>{name}</name>
|
||||
<summary>{summary}</summary>
|
||||
</{type}>
|
||||
""")
|
||||
.format(
|
||||
type=self.__class__.__tablename__[8:], # type: ignore
|
||||
name=getattr(self, "name", None) or getattr(self, "username", None),
|
||||
summary=self.summary,
|
||||
)
|
||||
.strip()
|
||||
)
|
||||
@property
|
||||
def entity_type(self) -> str:
|
||||
return self.__class__.__tablename__[8:-1] # type: ignore
|
||||
|
||||
def to_xml(self, fields: list[str]) -> str:
|
||||
def indent(key: str, text: str) -> str:
|
||||
res = textwrap.dedent("""
|
||||
<{key}>
|
||||
{text}
|
||||
</{key}>
|
||||
""").format(key=key, text=textwrap.indent(text, " "))
|
||||
return res.strip()
|
||||
|
||||
vals = []
|
||||
if "name" in fields:
|
||||
vals.append(indent("name", self.name))
|
||||
if "system_prompt" in fields:
|
||||
vals.append(indent("system_prompt", self.system_prompt or ""))
|
||||
if "summary" in fields:
|
||||
vals.append(indent("summary", self.summary or ""))
|
||||
if "mcp_servers" in fields:
|
||||
servers = [s.as_xml() for s in self.mcp_servers]
|
||||
vals.append(indent("mcp_servers", "\n".join(servers)))
|
||||
|
||||
return indent(self.entity_type, "\n".join(vals)) # type: ignore
|
||||
|
||||
def xml_prompt(self) -> str:
|
||||
return self.to_xml(["name", "system_prompt"]) if self.system_prompt else ""
|
||||
|
||||
def xml_summary(self) -> str:
|
||||
return self.to_xml(["name", "summary"])
|
||||
|
||||
def xml_mcp_servers(self) -> str:
|
||||
return self.to_xml(["mcp_servers"])
|
||||
|
||||
|
||||
class DiscordServer(Base, MessageProcessor):
|
||||
@ -130,6 +149,10 @@ class DiscordUser(Base, MessageProcessor):
|
||||
|
||||
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.username
|
||||
|
||||
|
||||
class MCPServer(Base):
|
||||
"""MCP server configuration and OAuth state."""
|
||||
@ -164,6 +187,30 @@ class MCPServer(Base):
|
||||
|
||||
__table_args__ = (Index("mcp_state_idx", "state"),)
|
||||
|
||||
def as_xml(self) -> str:
|
||||
tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip()
|
||||
return textwrap.dedent("""
|
||||
<mcp_server>
|
||||
<name>
|
||||
{name}
|
||||
</name>
|
||||
<mcp_server_url>
|
||||
{mcp_server_url}
|
||||
</mcp_server_url>
|
||||
<client_id>
|
||||
{client_id}
|
||||
</client_id>
|
||||
<available_tools>
|
||||
{available_tools}
|
||||
</available_tools>
|
||||
</mcp_server>
|
||||
""").format(
|
||||
name=self.name,
|
||||
mcp_server_url=self.mcp_server_url,
|
||||
client_id=self.client_id,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
|
||||
class MCPServerAssignment(Base):
|
||||
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
|
||||
|
||||
@ -345,11 +345,12 @@ class DiscordMessage(SourceItem):
|
||||
|
||||
@property
|
||||
def system_prompt(self) -> str:
|
||||
return (
|
||||
(self.from_user and self.from_user.system_prompt)
|
||||
or (self.channel and self.channel.system_prompt)
|
||||
or (self.server and self.server.system_prompt)
|
||||
)
|
||||
prompts = [
|
||||
(self.from_user and self.from_user.system_prompt),
|
||||
(self.channel and self.channel.system_prompt),
|
||||
(self.server and self.server.system_prompt),
|
||||
]
|
||||
return "\n\n".join(p for p in prompts if p)
|
||||
|
||||
@property
|
||||
def chattiness_threshold(self) -> int:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Lightweight slash-command helpers for the Discord collector."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Literal
|
||||
@ -8,12 +9,34 @@ import discord
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
from memory.common.db.models import (
|
||||
DiscordChannel,
|
||||
DiscordServer,
|
||||
DiscordUser,
|
||||
MCPServer,
|
||||
MCPServerAssignment,
|
||||
)
|
||||
from memory.discord.mcp import run_mcp_server_command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ScopeLiteral = Literal["bot", "server", "channel", "user"]
|
||||
ScopeLiteral = Literal["bot", "me", "server", "channel", "user"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscordObjects:
|
||||
bot: DiscordUser
|
||||
server: DiscordServer | None
|
||||
channel: DiscordChannel | None
|
||||
user: DiscordUser | None
|
||||
|
||||
@property
|
||||
def items(self):
|
||||
items = [self.bot, self.server, self.channel, self.user]
|
||||
return [item for item in items if item is not None]
|
||||
|
||||
|
||||
ListHandler = Callable[[DiscordObjects], str]
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
@ -43,12 +66,322 @@ class CommandContext:
|
||||
CommandHandler = Callable[..., CommandResponse]
|
||||
|
||||
|
||||
async def respond(
|
||||
interaction: discord.Interaction, content: str, ephemeral: bool = True
|
||||
) -> None:
|
||||
"""Send a response to the interaction, as file if too large."""
|
||||
max_length = 1900
|
||||
if len(content) <= max_length:
|
||||
await interaction.response.send_message(content, ephemeral=ephemeral)
|
||||
return
|
||||
|
||||
file = discord.File(io.BytesIO(content.encode("utf-8")), filename="response.txt")
|
||||
await interaction.response.send_message(
|
||||
"Response too large, sending as file:", file=file, ephemeral=ephemeral
|
||||
)
|
||||
|
||||
|
||||
def with_object_context(
|
||||
bot: discord.Client,
|
||||
interaction: discord.Interaction,
|
||||
handler: ListHandler,
|
||||
user: discord.User | None,
|
||||
) -> str:
|
||||
"""Execute handler with Discord objects context."""
|
||||
server = interaction.guild
|
||||
channel = interaction.channel
|
||||
target_user = user or interaction.user
|
||||
with make_session() as session:
|
||||
objects = DiscordObjects(
|
||||
bot=ensure_user(session, bot.user),
|
||||
server=server and ensure_server(session, server),
|
||||
channel=channel and _ensure_channel(session, channel, server and server.id),
|
||||
user=ensure_user(session, target_user),
|
||||
)
|
||||
return handler(objects)
|
||||
|
||||
|
||||
def _create_scope_group(
|
||||
parent: discord.app_commands.Group,
|
||||
scope: ScopeLiteral,
|
||||
name: str,
|
||||
description: str,
|
||||
) -> discord.app_commands.Group:
|
||||
"""Create a command group for a scope (bot/me/server/channel).
|
||||
|
||||
Args:
|
||||
parent: Parent command group
|
||||
scope: Scope literal (bot, me, server, channel)
|
||||
name: Group name
|
||||
description: Group description
|
||||
"""
|
||||
group = discord.app_commands.Group(
|
||||
name=name, description=description, parent=parent
|
||||
)
|
||||
|
||||
@group.command(name="prompt", description=f"Manage {name}'s system prompt")
|
||||
@discord.app_commands.describe(prompt="The system prompt to set")
|
||||
async def prompt_cmd(interaction: discord.Interaction, prompt: str | None = None):
|
||||
await _run_interaction_command(
|
||||
interaction, scope=scope, handler=handle_prompt, prompt=prompt
|
||||
)
|
||||
|
||||
@group.command(name="chattiness", description=f"Show/set {name}'s chattiness")
|
||||
@discord.app_commands.describe(value="Optional new chattiness value (0-100)")
|
||||
async def chattiness_cmd(
|
||||
interaction: discord.Interaction, value: int | None = None
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction, scope=scope, handler=handle_chattiness, value=value
|
||||
)
|
||||
|
||||
# Ignore command
|
||||
@group.command(name="ignore", description=f"Toggle bot ignoring {name} messages")
|
||||
@discord.app_commands.describe(enabled="Whether to ignore messages")
|
||||
async def ignore_cmd(interaction: discord.Interaction, enabled: bool | None = None):
|
||||
await _run_interaction_command(
|
||||
interaction, scope=scope, handler=handle_ignore, ignore_enabled=enabled
|
||||
)
|
||||
|
||||
# Summary command
|
||||
@group.command(name="summary", description=f"Show {name}'s summary")
|
||||
async def summary_cmd(interaction: discord.Interaction):
|
||||
await _run_interaction_command(interaction, scope=scope, handler=handle_summary)
|
||||
|
||||
# MCP command
|
||||
@group.command(name="mcp", description=f"Manage {name}'s MCP servers")
|
||||
@discord.app_commands.describe(
|
||||
action="Action to perform",
|
||||
url="MCP server URL (required for add, delete, connect, tools)",
|
||||
)
|
||||
async def mcp_cmd(
|
||||
interaction: discord.Interaction,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||
url: str | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_mcp_servers,
|
||||
action=action,
|
||||
url=url and url.strip(),
|
||||
)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
def _create_user_scope_group(
|
||||
parent: discord.app_commands.Group,
|
||||
name: str,
|
||||
description: str,
|
||||
) -> discord.app_commands.Group:
|
||||
"""Create command group for user scope (requires user parameter).
|
||||
|
||||
Args:
|
||||
parent: Parent command group
|
||||
name: Group name
|
||||
description: Group description
|
||||
"""
|
||||
group = discord.app_commands.Group(
|
||||
name=name, description=description, parent=parent
|
||||
)
|
||||
scope = "user"
|
||||
|
||||
@group.command(name="prompt", description=f"Manage {name}'s system prompt")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user", prompt="The system prompt to set"
|
||||
)
|
||||
async def prompt_cmd(
|
||||
interaction: discord.Interaction, user: discord.User, prompt: str | None = None
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_prompt,
|
||||
target_user=user,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
@group.command(name="chattiness", description=f"Show/set {name}'s chattiness")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user", value="Optional new chattiness value (0-100)"
|
||||
)
|
||||
async def chattiness_cmd(
|
||||
interaction: discord.Interaction, user: discord.User, value: int | None = None
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_chattiness,
|
||||
target_user=user,
|
||||
value=value,
|
||||
)
|
||||
|
||||
# Ignore command
|
||||
@group.command(name="ignore", description=f"Toggle bot ignoring {name} messages")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user", enabled="Whether to ignore messages"
|
||||
)
|
||||
async def ignore_cmd(
|
||||
interaction: discord.Interaction,
|
||||
user: discord.User,
|
||||
enabled: bool | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_ignore,
|
||||
target_user=user,
|
||||
ignore_enabled=enabled,
|
||||
)
|
||||
|
||||
# Summary command
|
||||
@group.command(name="summary", description=f"Show {name}'s summary")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def summary_cmd(interaction: discord.Interaction, user: discord.User):
|
||||
await _run_interaction_command(
|
||||
interaction, scope=scope, handler=handle_summary, target_user=user
|
||||
)
|
||||
|
||||
# MCP command
|
||||
@group.command(name="mcp", description=f"Manage {name}'s MCP servers")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user",
|
||||
action="Action to perform",
|
||||
url="MCP server URL (required for add, delete, connect, tools)",
|
||||
)
|
||||
async def mcp_cmd(
|
||||
interaction: discord.Interaction,
|
||||
user: discord.User,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||
url: str | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_mcp_servers,
|
||||
target_user=user,
|
||||
action=action,
|
||||
url=url and url.strip(),
|
||||
)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
def create_list_group(
|
||||
bot: discord.Client, parent: discord.app_commands.Group
|
||||
) -> discord.app_commands.Group:
|
||||
"""Create command group for listing settings.
|
||||
|
||||
Args:
|
||||
parent: Parent command group
|
||||
"""
|
||||
group = discord.app_commands.Group(
|
||||
name="list", description="List settings", parent=parent
|
||||
)
|
||||
|
||||
@group.command(name="prompt", description="List full system prompt")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def prompt_cmd(
|
||||
interaction: discord.Interaction, user: discord.User | None = None
|
||||
):
|
||||
def handler(objects: DiscordObjects) -> str:
|
||||
prompts = [o.xml_prompt() for o in objects.items if o.system_prompt]
|
||||
return "\n\n".join(prompts)
|
||||
|
||||
res = with_object_context(bot, interaction, handler, user)
|
||||
await respond(interaction, res)
|
||||
|
||||
@group.command(name="chattiness", description="Show {name}'s chattiness")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def chattiness_cmd(
|
||||
interaction: discord.Interaction, user: discord.User | None = None
|
||||
):
|
||||
def handler(objects: DiscordObjects) -> str:
|
||||
values = [
|
||||
o.chattiness_threshold
|
||||
for o in objects.items
|
||||
if o.chattiness_threshold is not None
|
||||
]
|
||||
val = min(values) if values else 50
|
||||
if objects.user:
|
||||
return f"Total current chattiness for {objects.user.username}: {val}"
|
||||
return f"Total current chattiness: {val}"
|
||||
|
||||
res = with_object_context(bot, interaction, handler, user)
|
||||
await respond(interaction, res)
|
||||
|
||||
@group.command(
|
||||
name="ignore", description="Does this bot ignore messages for this user?"
|
||||
)
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def ignore_cmd(
|
||||
interaction: discord.Interaction,
|
||||
user: discord.User | None = None,
|
||||
):
|
||||
def handler(objects: DiscordObjects) -> str:
|
||||
should_ignore = any(o.ignore_messages for o in objects.items)
|
||||
if should_ignore:
|
||||
return f"The bot ignores messages for {objects.user}."
|
||||
return f"The bot does not ignore messages for {objects.user}."
|
||||
|
||||
res = with_object_context(bot, interaction, handler, user)
|
||||
await respond(interaction, res)
|
||||
|
||||
@group.command(name="summary", description="Show the full summary")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def summary_cmd(
|
||||
interaction: discord.Interaction, user: discord.User | None = None
|
||||
):
|
||||
def handler(objects: DiscordObjects) -> str:
|
||||
summaries = [o.xml_summary() for o in objects.items if o.summary]
|
||||
return "\n\n".join(summaries)
|
||||
|
||||
res = with_object_context(bot, interaction, handler, user)
|
||||
await respond(interaction, res)
|
||||
|
||||
@group.command(name="mcp", description="All used MCP servers")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def mcp_cmd(
|
||||
interaction: discord.Interaction, user: discord.User | None = None
|
||||
):
|
||||
logger.error(f"Listing MCP servers for {user}")
|
||||
ids = [
|
||||
interaction.guild_id,
|
||||
interaction.channel_id,
|
||||
(user or interaction.user).id,
|
||||
bot.user.id,
|
||||
]
|
||||
with make_session() as session:
|
||||
mcp_servers = (
|
||||
session.query(MCPServer)
|
||||
.filter(
|
||||
MCPServerAssignment.entity_id.in_(i for i in ids if i is not None)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers]
|
||||
res = "\n\n".join(mcp_servers)
|
||||
await respond(interaction, res)
|
||||
# def handler(objects: DiscordObjects) -> str:
|
||||
# servers = [s.as_xml() for obj in objects.items for s in obj.mcp_servers]
|
||||
# return "\n\n".join(servers) if servers else "No MCP servers configured."
|
||||
|
||||
# try:
|
||||
# res = with_object_context(bot, interaction, handler, user)
|
||||
# except Exception as exc:
|
||||
# logger.error(f"Error listing MCP servers: {exc}", exc_info=True)
|
||||
# return CommandResponse(content="Error listing MCP servers.")
|
||||
# await respond(interaction, res)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
def register_slash_commands(bot: discord.Client) -> None:
|
||||
"""Register the collector slash commands on the provided bot.
|
||||
|
||||
Args:
|
||||
bot: Discord bot client
|
||||
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
|
||||
"""
|
||||
|
||||
if getattr(bot, "_memory_commands_registered", False):
|
||||
@ -62,139 +395,21 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
tree = bot.tree
|
||||
name = bot.user and bot.user.name.replace("-", "_").lower()
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_show_prompt", description="Show the current system prompt"
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def show_prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_prompt,
|
||||
target_user=user,
|
||||
# Create main command group
|
||||
memory_group = discord.app_commands.Group(
|
||||
name=name or "memory", description=f"{name} bot configuration and management"
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_set_prompt",
|
||||
description="Set the system prompt for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify",
|
||||
prompt="The system prompt to set",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def set_prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
prompt: str,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_set_prompt,
|
||||
target_user=user,
|
||||
prompt=prompt,
|
||||
)
|
||||
# Create scope groups
|
||||
_create_scope_group(memory_group, "bot", "bot", "Bot-wide settings")
|
||||
_create_scope_group(memory_group, "me", "me", "Your personal settings")
|
||||
_create_scope_group(memory_group, "server", "server", "Server-wide settings")
|
||||
_create_scope_group(memory_group, "channel", "channel", "Channel-specific settings")
|
||||
_create_user_scope_group(memory_group, "user", "Manage other users' settings")
|
||||
create_list_group(bot, memory_group)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_chattiness",
|
||||
description="Show or update the chattiness for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
value="Optional new chattiness value between 0 and 100",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def chattiness_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
value: int | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_chattiness,
|
||||
target_user=user,
|
||||
value=value,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_ignore",
|
||||
description="Toggle whether the bot should ignore messages for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify",
|
||||
enabled="Optional flag. Leave empty to enable ignoring.",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def ignore_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
enabled: bool | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_ignore,
|
||||
target_user=user,
|
||||
ignore_enabled=enabled,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_show_summary",
|
||||
description="Show the stored summary for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def summary_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_summary,
|
||||
target_user=user,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_mcp_servers",
|
||||
description="Manage MCP servers for a scope",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify (server, channel, or user)",
|
||||
action="Action to perform",
|
||||
url="MCP server URL (required for add, delete, connect, tools)",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def mcp_servers_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||
url: str | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_mcp_servers,
|
||||
target_user=user,
|
||||
action=action,
|
||||
url=url and url.strip(),
|
||||
)
|
||||
# Register main group
|
||||
tree.add_command(memory_group)
|
||||
|
||||
|
||||
async def _run_interaction_command(
|
||||
@ -208,17 +423,16 @@ async def _run_interaction_command(
|
||||
"""Shared coroutine used by the registered slash commands."""
|
||||
try:
|
||||
with make_session() as session:
|
||||
context = _build_context(session, interaction, scope, target_user)
|
||||
# Get bot from interaction client if needed for bot scope
|
||||
bot = getattr(interaction, "client", None)
|
||||
context = _build_context(session, interaction, scope, target_user, bot)
|
||||
response = await handler(context, **handler_kwargs)
|
||||
session.commit()
|
||||
except CommandError as exc: # pragma: no cover - passthrough
|
||||
await interaction.response.send_message(str(exc), ephemeral=True)
|
||||
await respond(interaction, str(exc))
|
||||
return
|
||||
|
||||
await interaction.response.send_message(
|
||||
response.content,
|
||||
ephemeral=response.ephemeral,
|
||||
)
|
||||
await respond(interaction, response.content, response.ephemeral)
|
||||
|
||||
|
||||
def _build_context(
|
||||
@ -226,60 +440,55 @@ def _build_context(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
target_user: discord.User | None,
|
||||
bot: discord.Client | None = None,
|
||||
) -> CommandContext:
|
||||
actor = _ensure_user(session, interaction.user)
|
||||
actor = ensure_user(session, interaction.user)
|
||||
|
||||
if scope == "server":
|
||||
# Determine target and display name based on scope
|
||||
if scope == "bot":
|
||||
if not bot or not bot.user:
|
||||
raise CommandError("Bot user is not available.")
|
||||
target = ensure_user(session, bot.user)
|
||||
display_name = f"bot **{bot.user.name}**"
|
||||
|
||||
elif scope == "me":
|
||||
target = ensure_user(session, interaction.user)
|
||||
name = target.display_name or target.username
|
||||
display_name = f"you (**{name}**)"
|
||||
|
||||
elif scope == "server":
|
||||
if interaction.guild is None:
|
||||
raise CommandError("This command can only be used inside a server.")
|
||||
|
||||
target = _ensure_server(session, interaction.guild)
|
||||
target = ensure_server(session, interaction.guild)
|
||||
display_name = f"server **{target.name}**"
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
if scope == "channel":
|
||||
channel_obj = interaction.channel
|
||||
if channel_obj is None or not hasattr(channel_obj, "id"):
|
||||
elif scope == "channel":
|
||||
if interaction.channel is None or not hasattr(interaction.channel, "id"):
|
||||
raise CommandError("Unable to determine channel for this interaction.")
|
||||
|
||||
target = _ensure_channel(session, channel_obj, interaction.guild_id)
|
||||
target = _ensure_channel(session, interaction.channel, interaction.guild_id)
|
||||
display_name = f"channel **#{target.name}**"
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
if scope == "user":
|
||||
discord_user = target_user or interaction.user
|
||||
if discord_user is None:
|
||||
elif scope == "user":
|
||||
if target_user is None:
|
||||
raise CommandError("A target user is required for this command.")
|
||||
target = ensure_user(session, target_user)
|
||||
name = target.display_name or target.username
|
||||
display_name = f"user **{name}**"
|
||||
|
||||
target = _ensure_user(session, discord_user)
|
||||
display_name = target.display_name or target.username
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=f"user **{display_name}**",
|
||||
)
|
||||
|
||||
else:
|
||||
raise CommandError(f"Unsupported scope '{scope}'.")
|
||||
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
def _ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
|
||||
|
||||
def ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
|
||||
server = session.get(DiscordServer, guild.id)
|
||||
if server is None:
|
||||
server = DiscordServer(
|
||||
@ -330,7 +539,7 @@ def _ensure_channel(
|
||||
return channel_model
|
||||
|
||||
|
||||
def _ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
|
||||
def ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
|
||||
user = session.get(DiscordUser, discord_user.id)
|
||||
display_name = getattr(discord_user, "display_name", discord_user.name)
|
||||
if user is None:
|
||||
@ -364,32 +573,23 @@ def _resolve_channel_type(channel: discord.abc.Messageable) -> str:
|
||||
return getattr(getattr(channel, "type", None), "name", "unknown")
|
||||
|
||||
|
||||
def handle_prompt(context: CommandContext) -> CommandResponse:
|
||||
async def handle_prompt(
|
||||
context: CommandContext, *, prompt: str | None = None
|
||||
) -> CommandResponse:
|
||||
if prompt is not None:
|
||||
prompt = prompt or None
|
||||
setattr(context.target, "system_prompt", prompt)
|
||||
else:
|
||||
prompt = getattr(context.target, "system_prompt", None)
|
||||
|
||||
if prompt:
|
||||
return CommandResponse(
|
||||
content=f"Current prompt for {context.display_name}:\n\n{prompt}",
|
||||
)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"No prompt configured for {context.display_name}.",
|
||||
)
|
||||
content = f"Current prompt for {context.display_name}:\n\n{prompt}"
|
||||
else:
|
||||
content = f"No prompt configured for {context.display_name}."
|
||||
return CommandResponse(content=content)
|
||||
|
||||
|
||||
def handle_set_prompt(
|
||||
context: CommandContext,
|
||||
*,
|
||||
prompt: str,
|
||||
) -> CommandResponse:
|
||||
setattr(context.target, "system_prompt", prompt)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"Updated system prompt for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
def handle_chattiness(
|
||||
async def handle_chattiness(
|
||||
context: CommandContext,
|
||||
*,
|
||||
value: int | None,
|
||||
@ -419,7 +619,7 @@ def handle_chattiness(
|
||||
)
|
||||
|
||||
|
||||
def handle_ignore(
|
||||
async def handle_ignore(
|
||||
context: CommandContext,
|
||||
*,
|
||||
ignore_enabled: bool | None,
|
||||
@ -434,7 +634,7 @@ def handle_ignore(
|
||||
)
|
||||
|
||||
|
||||
def handle_summary(context: CommandContext) -> CommandResponse:
|
||||
async def handle_summary(context: CommandContext) -> CommandResponse:
|
||||
summary = getattr(context.target, "summary", None)
|
||||
|
||||
if summary:
|
||||
@ -454,20 +654,22 @@ async def handle_mcp_servers(
|
||||
url: str | None,
|
||||
) -> CommandResponse:
|
||||
"""Handle MCP server commands for a specific scope."""
|
||||
entity_type_map = {
|
||||
# Map scope to database entity type
|
||||
entity_type = {
|
||||
"bot": "DiscordUser",
|
||||
"me": "DiscordUser",
|
||||
"user": "DiscordUser",
|
||||
"server": "DiscordServer",
|
||||
"channel": "DiscordChannel",
|
||||
"user": "DiscordUser",
|
||||
}
|
||||
entity_type = entity_type_map[context.scope]
|
||||
entity_id = context.target.id
|
||||
}[context.scope]
|
||||
|
||||
bot_user = getattr(getattr(context.interaction, "client", None), "user", None)
|
||||
|
||||
try:
|
||||
res = await run_mcp_server_command(
|
||||
context.interaction.user, action, url, entity_type, entity_id
|
||||
bot_user, action, url, entity_type, context.target.id
|
||||
)
|
||||
return CommandResponse(content=res)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Error running MCP server command: {traceback.format_exc()}")
|
||||
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
|
||||
raise CommandError(f"Error: {exc}") from exc
|
||||
|
||||
@ -210,6 +210,11 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
for server in discord_message.recipient_user.mcp_servers
|
||||
]
|
||||
|
||||
system_prompt = discord_message.system_prompt or ""
|
||||
system_prompt += comm_channel_prompt(
|
||||
session, discord_message.recipient_user, discord_message.channel
|
||||
)
|
||||
|
||||
try:
|
||||
response = call_llm(
|
||||
session,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user