handle discord threads

This commit is contained in:
mruwnik 2025-11-02 11:23:31 +00:00
parent 6bd7df8ee3
commit 69192f834a
4 changed files with 72 additions and 39 deletions

View File

@ -55,14 +55,14 @@ def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
return False return False
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool: def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
"""Send a DM via the Discord collector API""" """Send message to a channel by name or ID (ID supports threads)"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_channel", f"{get_api_url()}/send_channel",
json={ json={
"bot_id": bot_id, "bot_id": bot_id,
"channel_name": channel_name, "channel": channel,
"message": message, "message": message,
}, },
timeout=10, timeout=10,
@ -73,16 +73,16 @@ def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
return result.get("success", False) return result.get("success", False)
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Failed to send to channel {channel_name}: {e}") logger.error(f"Failed to send to channel {channel}: {e}")
return False return False
def trigger_typing_channel(bot_id: int, channel_name: str) -> bool: def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
"""Trigger typing indicator for a channel via the Discord collector API""" """Trigger typing indicator for a channel by name or ID (ID supports threads)"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/typing/channel", f"{get_api_url()}/typing/channel",
json={"bot_id": bot_id, "channel_name": channel_name}, json={"bot_id": bot_id, "channel": channel},
timeout=10, timeout=10,
) )
response.raise_for_status() response.raise_for_status()
@ -90,18 +90,18 @@ def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
return result.get("success", False) return result.get("success", False)
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}") logger.error(f"Failed to trigger typing for channel {channel}: {e}")
return False return False
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool: def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API""" """Send a message to a channel by name or ID (ID supports threads)"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_channel", f"{get_api_url()}/send_channel",
json={ json={
"bot_id": bot_id, "bot_id": bot_id,
"channel_name": channel_name, "channel": channel,
"message": message, "message": message,
}, },
timeout=10, timeout=10,
@ -111,7 +111,7 @@ def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
return result.get("success", False) return result.get("success", False)
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}") logger.error(f"Failed to send message to channel {channel}: {e}")
return False return False

View File

@ -31,7 +31,7 @@ class SendDMRequest(BaseModel):
class SendChannelRequest(BaseModel): class SendChannelRequest(BaseModel):
bot_id: int bot_id: int
channel_name: str # Channel name (e.g., "memory-errors") channel: int | str # Channel name or ID (ID supports threads)
message: str message: str
@ -42,7 +42,7 @@ class TypingDMRequest(BaseModel):
class TypingChannelRequest(BaseModel): class TypingChannelRequest(BaseModel):
bot_id: int bot_id: int
channel_name: str channel: int | str # Channel name or ID (ID supports threads)
class Collector: class Collector:
@ -154,7 +154,7 @@ async def send_channel_endpoint(request: SendChannelRequest):
try: try:
success = await collector.collector.send_to_channel( success = await collector.collector.send_to_channel(
request.channel_name, request.message request.channel, request.message
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send channel message: {e}") logger.error(f"Failed to send channel message: {e}")
@ -163,13 +163,13 @@ async def send_channel_endpoint(request: SendChannelRequest):
if success: if success:
return { return {
"success": True, "success": True,
"message": f"Message sent to channel {request.channel_name}", "message": f"Message sent to channel {request.channel}",
"channel": request.channel_name, "channel": request.channel,
} }
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Failed to send message to channel {request.channel_name}", detail=f"Failed to send message to channel {request.channel}",
) )
@ -181,7 +181,7 @@ async def trigger_channel_typing(request: TypingChannelRequest):
raise HTTPException(status_code=404, detail="Bot not found") raise HTTPException(status_code=404, detail="Bot not found")
try: try:
success = await collector.collector.trigger_typing_channel(request.channel_name) success = await collector.collector.trigger_typing_channel(request.channel)
except Exception as e: except Exception as e:
logger.error(f"Failed to trigger channel typing: {e}") logger.error(f"Failed to trigger channel typing: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -189,13 +189,13 @@ async def trigger_channel_typing(request: TypingChannelRequest):
if not success: if not success:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Failed to trigger typing for channel {request.channel_name}", detail=f"Failed to trigger typing for channel {request.channel}",
) )
return { return {
"success": True, "success": True,
"channel": request.channel_name, "channel": request.channel,
"message": f"Typing triggered for channel {request.channel_name}", "message": f"Typing triggered for channel {request.channel}",
} }

View File

@ -207,6 +207,10 @@ def sync_guild_metadata(guild: discord.Guild) -> None:
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)): if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
create_or_update_channel(session, channel) create_or_update_channel(session, channel)
# Sync threads
for thread in guild.threads:
create_or_update_channel(session, thread)
session.commit() session.commit()
@ -266,6 +270,11 @@ class MessageCollector(commands.Bot):
if att.content_type and att.content_type.startswith("image/") if att.content_type and att.content_type.startswith("image/")
] ]
# Determine message metadata (type, reply, thread)
message_type, reply_to_id, thread_id = determine_message_metadata(
message
)
# Queue the message for processing # Queue the message for processing
add_discord_message.delay( add_discord_message.delay(
message_id=message.id, message_id=message.id,
@ -275,9 +284,9 @@ class MessageCollector(commands.Bot):
server_id=message.guild.id if message.guild else None, server_id=message.guild.id if message.guild else None,
content=message.content or "", content=message.content or "",
sent_at=message.created_at.isoformat(), sent_at=message.created_at.isoformat(),
message_reference_id=message.reference.message_id message_reference_id=reply_to_id,
if message.reference message_type=message_type,
else None, thread_id=thread_id,
image_urls=image_urls, image_urls=image_urls,
) )
except Exception as e: except Exception as e:
@ -322,6 +331,11 @@ class MessageCollector(commands.Bot):
create_or_update_channel(session, channel) create_or_update_channel(session, channel)
channels_updated += 1 channels_updated += 1
# Refresh all threads in this server
for thread in guild.threads:
create_or_update_channel(session, thread)
channels_updated += 1
# Refresh all members in this server (if members intent is enabled) # Refresh all members in this server (if members intent is enabled)
if self.intents.members: if self.intents.members:
for member in guild.members: for member in guild.members:
@ -440,15 +454,22 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}") logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False return False
async def send_to_channel(self, channel_name: str, message: str) -> bool: async def send_to_channel(
"""Send a message to a channel by name across all guilds""" self, channel_identifier: int | str, message: str
) -> bool:
"""Send a message to a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED: if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False return False
try: try:
channel = await self.get_channel_by_name(channel_name) # Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
if not channel: if not channel:
logger.error(f"Channel {channel_name} not found") logger.error(f"Channel {channel_identifier} not found")
return False return False
# Post-process mentions to convert usernames to IDs # Post-process mentions to convert usernames to IDs
@ -456,22 +477,27 @@ class MessageCollector(commands.Bot):
processed_message = process_mentions(session, message) processed_message = process_mentions(session, message)
await channel.send(processed_message) await channel.send(processed_message)
logger.info(f"Sent message to channel {channel_name}") logger.info(f"Sent message to channel {channel_identifier}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}") logger.error(f"Failed to send message to channel {channel_identifier}: {e}")
return False return False
async def trigger_typing_channel(self, channel_name: str) -> bool: async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
"""Trigger typing indicator in a channel""" """Trigger typing indicator in a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED: if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False return False
try: try:
channel = await self.get_channel_by_name(channel_name) # Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
if not channel: if not channel:
logger.error(f"Channel {channel_name} not found") logger.error(f"Channel {channel_identifier} not found")
return False return False
async with channel.typing(): async with channel.typing():
@ -479,5 +505,7 @@ class MessageCollector(commands.Bot):
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}") logger.error(
f"Failed to trigger typing for channel {channel_identifier}: {e}"
)
return False return False

View File

@ -180,7 +180,7 @@ def should_process(message: DiscordMessage) -> bool:
return False return False
if message.channel and message.channel.server: if message.channel and message.channel.server:
discord.trigger_typing_channel(bot_id, message.channel.name) discord.trigger_typing_channel(bot_id, cast(int, message.channel_id))
else: else:
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id)) discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
return True return True
@ -242,7 +242,9 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
} }
if discord_message.channel.server: if discord_message.channel.server:
discord.send_to_channel(bot_id, discord_message.channel.name, response) discord.send_to_channel(
bot_id, cast(int, discord_message.channel_id), response
)
else: else:
discord.send_dm(bot_id, discord_message.from_user.username, response) discord.send_dm(bot_id, discord_message.from_user.username, response)
@ -263,6 +265,8 @@ def add_discord_message(
server_id: int | None = None, server_id: int | None = None,
recipient_id: int | None = None, recipient_id: int | None = None,
message_reference_id: int | None = None, message_reference_id: int | None = None,
message_type: str = "default",
thread_id: int | None = None,
image_urls: list[str] | None = None, image_urls: list[str] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
@ -291,8 +295,9 @@ def add_discord_message(
from_id=author_id, from_id=author_id,
recipient_id=recipient_id, recipient_id=recipient_id,
message_id=message_id, message_id=message_id,
message_type="reply" if message_reference_id else "default", message_type=message_type,
reply_to_message_id=message_reference_id, reply_to_message_id=message_reference_id,
thread_id=thread_id,
images=saved_image_paths or None, images=saved_image_paths or None,
) )
existing_msg = check_content_exists( existing_msg = check_content_exists(