diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py index e850059..c2d92fa 100644 --- a/src/memory/common/discord.py +++ b/src/memory/common/discord.py @@ -5,9 +5,10 @@ Simple HTTP client that communicates with the Discord collector's API server. """ import logging -import requests from typing import Any +import requests + from memory.common import settings logger = logging.getLogger(__name__) @@ -37,6 +38,23 @@ def send_dm(bot_id: int, user_identifier: str, message: str) -> bool: return False +def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool: + """Trigger typing indicator for a DM via the Discord collector API""" + try: + response = requests.post( + f"{get_api_url()}/typing/dm", + json={"bot_id": bot_id, "user": user_identifier}, + timeout=10, + ) + response.raise_for_status() + result = response.json() + return result.get("success", False) + + except requests.RequestException as e: + logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}") + return False + + def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool: """Send a DM via the Discord collector API""" try: @@ -59,6 +77,23 @@ def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool: return False +def trigger_typing_channel(bot_id: int, channel_name: str) -> bool: + """Trigger typing indicator for a channel via the Discord collector API""" + try: + response = requests.post( + f"{get_api_url()}/typing/channel", + json={"bot_id": bot_id, "channel_name": channel_name}, + timeout=10, + ) + response.raise_for_status() + result = response.json() + return result.get("success", False) + + except requests.RequestException as e: + logger.error(f"Failed to trigger typing for channel {channel_name}: {e}") + return False + + def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool: """Send a message to a channel via the Discord collector API""" try: diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py index a5e21b6..71528b8 100644 --- a/src/memory/discord/api.py +++ b/src/memory/discord/api.py @@ -35,6 +35,16 @@ class SendChannelRequest(BaseModel): message: str +class TypingDMRequest(BaseModel): + bot_id: int + user: int | str + + +class TypingChannelRequest(BaseModel): + bot_id: int + channel_name: str + + class Collector: collector: MessageCollector collector_task: asyncio.Task @@ -109,6 +119,32 @@ async def send_dm_endpoint(request: SendDMRequest): } +@app.post("/typing/dm") +async def trigger_dm_typing(request: TypingDMRequest): + """Trigger a typing indicator for a DM via the collector""" + collector = app.bots.get(request.bot_id) + if not collector: + raise HTTPException(status_code=404, detail="Bot not found") + + try: + success = await collector.collector.trigger_typing_dm(request.user) + except Exception as e: + logger.error(f"Failed to trigger DM typing: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + if not success: + raise HTTPException( + status_code=400, + detail=f"Failed to trigger typing for {request.user}", + ) + + return { + "success": True, + "user": request.user, + "message": f"Typing triggered for {request.user}", + } + + @app.post("/send_channel") async def send_channel_endpoint(request: SendChannelRequest): """Send a message to a channel via the collector's Discord client""" @@ -120,23 +156,48 @@ async def send_channel_endpoint(request: SendChannelRequest): success = await collector.collector.send_to_channel( request.channel_name, request.message ) - - if success: - return { - "success": True, - "message": f"Message sent to channel {request.channel_name}", - "channel": request.channel_name, - } - else: - raise HTTPException( - status_code=400, - detail=f"Failed to send message to channel {request.channel_name}", - ) - except Exception as e: logger.error(f"Failed to send channel message: {e}") raise HTTPException(status_code=500, detail=str(e)) + if success: + return { + "success": True, + "message": f"Message sent to channel {request.channel_name}", + "channel": request.channel_name, + } + + raise HTTPException( + status_code=400, + detail=f"Failed to send message to channel {request.channel_name}", + ) + + +@app.post("/typing/channel") +async def trigger_channel_typing(request: TypingChannelRequest): + """Trigger a typing indicator for a channel via the collector""" + collector = app.bots.get(request.bot_id) + if not collector: + raise HTTPException(status_code=404, detail="Bot not found") + + try: + success = await collector.collector.trigger_typing_channel(request.channel_name) + except Exception as e: + logger.error(f"Failed to trigger channel typing: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + if not success: + raise HTTPException( + status_code=400, + detail=f"Failed to trigger typing for channel {request.channel_name}", + ) + + return { + "success": True, + "channel": request.channel_name, + "message": f"Typing triggered for channel {request.channel_name}", + } + @app.get("/health") async def health_check(): diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py index a898458..d37df9d 100644 --- a/src/memory/discord/collector.py +++ b/src/memory/discord/collector.py @@ -381,6 +381,26 @@ class MessageCollector(commands.Bot): logger.error(f"Failed to send DM to {user_identifier}: {e}") return False + async def trigger_typing_dm(self, user_identifier: int | str) -> bool: + """Trigger typing indicator in a DM""" + try: + user = await self.get_user(user_identifier) + if not user: + logger.error(f"User {user_identifier} not found") + return False + + channel = user.dm_channel or await user.create_dm() + if not channel: + logger.error(f"DM channel not available for {user_identifier}") + return False + + await channel.trigger_typing() + return True + + except Exception as e: + logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}") + return False + async def send_to_channel(self, channel_name: str, message: str) -> bool: """Send a message to a channel by name across all guilds""" if not settings.DISCORD_NOTIFICATIONS_ENABLED: @@ -400,6 +420,24 @@ class MessageCollector(commands.Bot): logger.error(f"Failed to send message to channel {channel_name}: {e}") return False + async def trigger_typing_channel(self, channel_name: str) -> bool: + """Trigger typing indicator in a channel""" + if not settings.DISCORD_NOTIFICATIONS_ENABLED: + return False + + try: + channel = await self.get_channel_by_name(channel_name) + if not channel: + logger.error(f"Channel {channel_name} not found") + return False + + await channel.trigger_typing() + return True + + except Exception as e: + logger.error(f"Failed to trigger typing for channel {channel_name}: {e}") + return False + async def run_collector(): """Run the Discord message collector""" diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index f1b9f9a..ab6b5f6 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -170,14 +170,6 @@ def process_discord_message(message_id: int) -> dict[str, Any]: "message_id": message_id, } - response = call_llm(session, discord_message, settings.DISCORD_MODEL) - - if not response: - return { - "status": "processed", - "message_id": message_id, - } - bot_id = _resolve_bot_id(discord_message) if not bot_id: logger.warning( @@ -189,6 +181,26 @@ def process_discord_message(message_id: int) -> dict[str, Any]: "message_id": message_id, } + if discord_message.channel and discord_message.channel.server: + discord.trigger_typing_channel( + bot_id, discord_message.channel.name + ) + else: + discord.trigger_typing_dm(bot_id, discord_message.from_id) + + response: str | None = None + + try: + response = call_llm(session, discord_message, settings.DISCORD_MODEL) + except Exception: + logger.exception("Failed to generate Discord response") + + if not response: + return { + "status": "processed", + "message_id": message_id, + } + if discord_message.channel.server: discord.send_to_channel(bot_id, discord_message.channel.name, response) else: