diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py index c2d92fa..6eab35b 100644 --- a/src/memory/common/discord.py +++ b/src/memory/common/discord.py @@ -55,14 +55,14 @@ def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool: return False -def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool: - """Send a DM via the Discord collector API""" +def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool: + """Send message to a channel by name or ID (ID supports threads)""" try: response = requests.post( f"{get_api_url()}/send_channel", json={ "bot_id": bot_id, - "channel_name": channel_name, + "channel": channel, "message": message, }, timeout=10, @@ -73,16 +73,16 @@ def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool: return result.get("success", False) except requests.RequestException as e: - logger.error(f"Failed to send to channel {channel_name}: {e}") + logger.error(f"Failed to send to channel {channel}: {e}") return False -def trigger_typing_channel(bot_id: int, channel_name: str) -> bool: - """Trigger typing indicator for a channel via the Discord collector API""" +def trigger_typing_channel(bot_id: int, channel: int | str) -> bool: + """Trigger typing indicator for a channel by name or ID (ID supports threads)""" try: response = requests.post( f"{get_api_url()}/typing/channel", - json={"bot_id": bot_id, "channel_name": channel_name}, + json={"bot_id": bot_id, "channel": channel}, timeout=10, ) response.raise_for_status() @@ -90,18 +90,18 @@ def trigger_typing_channel(bot_id: int, channel_name: str) -> bool: return result.get("success", False) except requests.RequestException as e: - logger.error(f"Failed to trigger typing for channel {channel_name}: {e}") + logger.error(f"Failed to trigger typing for channel {channel}: {e}") return False -def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool: - """Send a message to a channel via the Discord collector API""" +def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool: + """Send a message to a channel by name or ID (ID supports threads)""" try: response = requests.post( f"{get_api_url()}/send_channel", json={ "bot_id": bot_id, - "channel_name": channel_name, + "channel": channel, "message": message, }, timeout=10, @@ -111,7 +111,7 @@ def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool: return result.get("success", False) except requests.RequestException as e: - logger.error(f"Failed to send message to channel {channel_name}: {e}") + logger.error(f"Failed to send message to channel {channel}: {e}") return False diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py index 71528b8..bb43f0f 100644 --- a/src/memory/discord/api.py +++ b/src/memory/discord/api.py @@ -31,7 +31,7 @@ class SendDMRequest(BaseModel): class SendChannelRequest(BaseModel): bot_id: int - channel_name: str # Channel name (e.g., "memory-errors") + channel: int | str # Channel name or ID (ID supports threads) message: str @@ -42,7 +42,7 @@ class TypingDMRequest(BaseModel): class TypingChannelRequest(BaseModel): bot_id: int - channel_name: str + channel: int | str # Channel name or ID (ID supports threads) class Collector: @@ -154,7 +154,7 @@ async def send_channel_endpoint(request: SendChannelRequest): try: success = await collector.collector.send_to_channel( - request.channel_name, request.message + request.channel, request.message ) except Exception as e: logger.error(f"Failed to send channel message: {e}") @@ -163,13 +163,13 @@ async def send_channel_endpoint(request: SendChannelRequest): if success: return { "success": True, - "message": f"Message sent to channel {request.channel_name}", - "channel": request.channel_name, + "message": f"Message sent to channel {request.channel}", + "channel": request.channel, } raise HTTPException( status_code=400, - detail=f"Failed to send message to channel {request.channel_name}", + detail=f"Failed to send message to channel {request.channel}", ) @@ -181,7 +181,7 @@ async def trigger_channel_typing(request: TypingChannelRequest): raise HTTPException(status_code=404, detail="Bot not found") try: - success = await collector.collector.trigger_typing_channel(request.channel_name) + success = await collector.collector.trigger_typing_channel(request.channel) except Exception as e: logger.error(f"Failed to trigger channel typing: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -189,13 +189,13 @@ async def trigger_channel_typing(request: TypingChannelRequest): if not success: raise HTTPException( status_code=400, - detail=f"Failed to trigger typing for channel {request.channel_name}", + detail=f"Failed to trigger typing for channel {request.channel}", ) return { "success": True, - "channel": request.channel_name, - "message": f"Typing triggered for channel {request.channel_name}", + "channel": request.channel, + "message": f"Typing triggered for channel {request.channel}", } diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py index eb6da56..bb5b6c6 100644 --- a/src/memory/discord/collector.py +++ b/src/memory/discord/collector.py @@ -207,6 +207,10 @@ def sync_guild_metadata(guild: discord.Guild) -> None: if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)): create_or_update_channel(session, channel) + # Sync threads + for thread in guild.threads: + create_or_update_channel(session, thread) + session.commit() @@ -266,6 +270,11 @@ class MessageCollector(commands.Bot): if att.content_type and att.content_type.startswith("image/") ] + # Determine message metadata (type, reply, thread) + message_type, reply_to_id, thread_id = determine_message_metadata( + message + ) + # Queue the message for processing add_discord_message.delay( message_id=message.id, @@ -275,9 +284,9 @@ class MessageCollector(commands.Bot): server_id=message.guild.id if message.guild else None, content=message.content or "", sent_at=message.created_at.isoformat(), - message_reference_id=message.reference.message_id - if message.reference - else None, + message_reference_id=reply_to_id, + message_type=message_type, + thread_id=thread_id, image_urls=image_urls, ) except Exception as e: @@ -322,6 +331,11 @@ class MessageCollector(commands.Bot): create_or_update_channel(session, channel) channels_updated += 1 + # Refresh all threads in this server + for thread in guild.threads: + create_or_update_channel(session, thread) + channels_updated += 1 + # Refresh all members in this server (if members intent is enabled) if self.intents.members: for member in guild.members: @@ -440,15 +454,22 @@ class MessageCollector(commands.Bot): logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}") return False - async def send_to_channel(self, channel_name: str, message: str) -> bool: - """Send a message to a channel by name across all guilds""" + async def send_to_channel( + self, channel_identifier: int | str, message: str + ) -> bool: + """Send a message to a channel by name or ID (supports threads)""" if not settings.DISCORD_NOTIFICATIONS_ENABLED: return False try: - channel = await self.get_channel_by_name(channel_name) + # Get channel by ID or name + if isinstance(channel_identifier, int): + channel = self.get_channel(channel_identifier) + else: + channel = await self.get_channel_by_name(channel_identifier) + if not channel: - logger.error(f"Channel {channel_name} not found") + logger.error(f"Channel {channel_identifier} not found") return False # Post-process mentions to convert usernames to IDs @@ -456,22 +477,27 @@ class MessageCollector(commands.Bot): processed_message = process_mentions(session, message) await channel.send(processed_message) - logger.info(f"Sent message to channel {channel_name}") + logger.info(f"Sent message to channel {channel_identifier}") return True except Exception as e: - logger.error(f"Failed to send message to channel {channel_name}: {e}") + logger.error(f"Failed to send message to channel {channel_identifier}: {e}") return False - async def trigger_typing_channel(self, channel_name: str) -> bool: - """Trigger typing indicator in a channel""" + async def trigger_typing_channel(self, channel_identifier: int | str) -> bool: + """Trigger typing indicator in a channel by name or ID (supports threads)""" if not settings.DISCORD_NOTIFICATIONS_ENABLED: return False try: - channel = await self.get_channel_by_name(channel_name) + # Get channel by ID or name + if isinstance(channel_identifier, int): + channel = self.get_channel(channel_identifier) + else: + channel = await self.get_channel_by_name(channel_identifier) + if not channel: - logger.error(f"Channel {channel_name} not found") + logger.error(f"Channel {channel_identifier} not found") return False async with channel.typing(): @@ -479,5 +505,7 @@ class MessageCollector(commands.Bot): return True except Exception as e: - logger.error(f"Failed to trigger typing for channel {channel_name}: {e}") + logger.error( + f"Failed to trigger typing for channel {channel_identifier}: {e}" + ) return False diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index b5f79ac..a889eb8 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -180,7 +180,7 @@ def should_process(message: DiscordMessage) -> bool: return False if message.channel and message.channel.server: - discord.trigger_typing_channel(bot_id, message.channel.name) + discord.trigger_typing_channel(bot_id, cast(int, message.channel_id)) else: discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id)) return True @@ -242,7 +242,9 @@ def process_discord_message(message_id: int) -> dict[str, Any]: } if discord_message.channel.server: - discord.send_to_channel(bot_id, discord_message.channel.name, response) + discord.send_to_channel( + bot_id, cast(int, discord_message.channel_id), response + ) else: discord.send_dm(bot_id, discord_message.from_user.username, response) @@ -263,6 +265,8 @@ def add_discord_message( server_id: int | None = None, recipient_id: int | None = None, message_reference_id: int | None = None, + message_type: str = "default", + thread_id: int | None = None, image_urls: list[str] | None = None, ) -> dict[str, Any]: """ @@ -291,8 +295,9 @@ def add_discord_message( from_id=author_id, recipient_id=recipient_id, message_id=message_id, - message_type="reply" if message_reference_id else "default", + message_type=message_type, reply_to_message_id=message_reference_id, + thread_id=thread_id, images=saved_image_paths or None, ) existing_msg = check_content_exists(