mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
handle discord threads
This commit is contained in:
parent
6bd7df8ee3
commit
69192f834a
@ -55,14 +55,14 @@ def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
||||
"""Send a DM via the Discord collector API"""
|
||||
def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
|
||||
"""Send message to a channel by name or ID (ID supports threads)"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_channel",
|
||||
json={
|
||||
"bot_id": bot_id,
|
||||
"channel_name": channel_name,
|
||||
"channel": channel,
|
||||
"message": message,
|
||||
},
|
||||
timeout=10,
|
||||
@ -73,16 +73,16 @@ def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to send to channel {channel_name}: {e}")
|
||||
logger.error(f"Failed to send to channel {channel}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
|
||||
"""Trigger typing indicator for a channel via the Discord collector API"""
|
||||
def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
|
||||
"""Trigger typing indicator for a channel by name or ID (ID supports threads)"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/typing/channel",
|
||||
json={"bot_id": bot_id, "channel_name": channel_name},
|
||||
json={"bot_id": bot_id, "channel": channel},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@ -90,18 +90,18 @@ def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
||||
logger.error(f"Failed to trigger typing for channel {channel}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
||||
"""Send a message to a channel via the Discord collector API"""
|
||||
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
|
||||
"""Send a message to a channel by name or ID (ID supports threads)"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_channel",
|
||||
json={
|
||||
"bot_id": bot_id,
|
||||
"channel_name": channel_name,
|
||||
"channel": channel,
|
||||
"message": message,
|
||||
},
|
||||
timeout=10,
|
||||
@ -111,7 +111,7 @@ def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||
logger.error(f"Failed to send message to channel {channel}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ class SendDMRequest(BaseModel):
|
||||
|
||||
class SendChannelRequest(BaseModel):
|
||||
bot_id: int
|
||||
channel_name: str # Channel name (e.g., "memory-errors")
|
||||
channel: int | str # Channel name or ID (ID supports threads)
|
||||
message: str
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ class TypingDMRequest(BaseModel):
|
||||
|
||||
class TypingChannelRequest(BaseModel):
|
||||
bot_id: int
|
||||
channel_name: str
|
||||
channel: int | str # Channel name or ID (ID supports threads)
|
||||
|
||||
|
||||
class Collector:
|
||||
@ -154,7 +154,7 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
||||
|
||||
try:
|
||||
success = await collector.collector.send_to_channel(
|
||||
request.channel_name, request.message
|
||||
request.channel, request.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send channel message: {e}")
|
||||
@ -163,13 +163,13 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Message sent to channel {request.channel_name}",
|
||||
"channel": request.channel_name,
|
||||
"message": f"Message sent to channel {request.channel}",
|
||||
"channel": request.channel,
|
||||
}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to send message to channel {request.channel_name}",
|
||||
detail=f"Failed to send message to channel {request.channel}",
|
||||
)
|
||||
|
||||
|
||||
@ -181,7 +181,7 @@ async def trigger_channel_typing(request: TypingChannelRequest):
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
|
||||
try:
|
||||
success = await collector.collector.trigger_typing_channel(request.channel_name)
|
||||
success = await collector.collector.trigger_typing_channel(request.channel)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger channel typing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@ -189,13 +189,13 @@ async def trigger_channel_typing(request: TypingChannelRequest):
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to trigger typing for channel {request.channel_name}",
|
||||
detail=f"Failed to trigger typing for channel {request.channel}",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"channel": request.channel_name,
|
||||
"message": f"Typing triggered for channel {request.channel_name}",
|
||||
"channel": request.channel,
|
||||
"message": f"Typing triggered for channel {request.channel}",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -207,6 +207,10 @@ def sync_guild_metadata(guild: discord.Guild) -> None:
|
||||
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
||||
create_or_update_channel(session, channel)
|
||||
|
||||
# Sync threads
|
||||
for thread in guild.threads:
|
||||
create_or_update_channel(session, thread)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
@ -266,6 +270,11 @@ class MessageCollector(commands.Bot):
|
||||
if att.content_type and att.content_type.startswith("image/")
|
||||
]
|
||||
|
||||
# Determine message metadata (type, reply, thread)
|
||||
message_type, reply_to_id, thread_id = determine_message_metadata(
|
||||
message
|
||||
)
|
||||
|
||||
# Queue the message for processing
|
||||
add_discord_message.delay(
|
||||
message_id=message.id,
|
||||
@ -275,9 +284,9 @@ class MessageCollector(commands.Bot):
|
||||
server_id=message.guild.id if message.guild else None,
|
||||
content=message.content or "",
|
||||
sent_at=message.created_at.isoformat(),
|
||||
message_reference_id=message.reference.message_id
|
||||
if message.reference
|
||||
else None,
|
||||
message_reference_id=reply_to_id,
|
||||
message_type=message_type,
|
||||
thread_id=thread_id,
|
||||
image_urls=image_urls,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -322,6 +331,11 @@ class MessageCollector(commands.Bot):
|
||||
create_or_update_channel(session, channel)
|
||||
channels_updated += 1
|
||||
|
||||
# Refresh all threads in this server
|
||||
for thread in guild.threads:
|
||||
create_or_update_channel(session, thread)
|
||||
channels_updated += 1
|
||||
|
||||
# Refresh all members in this server (if members intent is enabled)
|
||||
if self.intents.members:
|
||||
for member in guild.members:
|
||||
@ -440,15 +454,22 @@ class MessageCollector(commands.Bot):
|
||||
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
||||
return False
|
||||
|
||||
async def send_to_channel(self, channel_name: str, message: str) -> bool:
|
||||
"""Send a message to a channel by name across all guilds"""
|
||||
async def send_to_channel(
|
||||
self, channel_identifier: int | str, message: str
|
||||
) -> bool:
|
||||
"""Send a message to a channel by name or ID (supports threads)"""
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
channel = await self.get_channel_by_name(channel_name)
|
||||
# Get channel by ID or name
|
||||
if isinstance(channel_identifier, int):
|
||||
channel = self.get_channel(channel_identifier)
|
||||
else:
|
||||
channel = await self.get_channel_by_name(channel_identifier)
|
||||
|
||||
if not channel:
|
||||
logger.error(f"Channel {channel_name} not found")
|
||||
logger.error(f"Channel {channel_identifier} not found")
|
||||
return False
|
||||
|
||||
# Post-process mentions to convert usernames to IDs
|
||||
@ -456,22 +477,27 @@ class MessageCollector(commands.Bot):
|
||||
processed_message = process_mentions(session, message)
|
||||
|
||||
await channel.send(processed_message)
|
||||
logger.info(f"Sent message to channel {channel_name}")
|
||||
logger.info(f"Sent message to channel {channel_identifier}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||
logger.error(f"Failed to send message to channel {channel_identifier}: {e}")
|
||||
return False
|
||||
|
||||
async def trigger_typing_channel(self, channel_name: str) -> bool:
|
||||
"""Trigger typing indicator in a channel"""
|
||||
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
|
||||
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
channel = await self.get_channel_by_name(channel_name)
|
||||
# Get channel by ID or name
|
||||
if isinstance(channel_identifier, int):
|
||||
channel = self.get_channel(channel_identifier)
|
||||
else:
|
||||
channel = await self.get_channel_by_name(channel_identifier)
|
||||
|
||||
if not channel:
|
||||
logger.error(f"Channel {channel_name} not found")
|
||||
logger.error(f"Channel {channel_identifier} not found")
|
||||
return False
|
||||
|
||||
async with channel.typing():
|
||||
@ -479,5 +505,7 @@ class MessageCollector(commands.Bot):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
||||
logger.error(
|
||||
f"Failed to trigger typing for channel {channel_identifier}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
@ -180,7 +180,7 @@ def should_process(message: DiscordMessage) -> bool:
|
||||
return False
|
||||
|
||||
if message.channel and message.channel.server:
|
||||
discord.trigger_typing_channel(bot_id, message.channel.name)
|
||||
discord.trigger_typing_channel(bot_id, cast(int, message.channel_id))
|
||||
else:
|
||||
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
|
||||
return True
|
||||
@ -242,7 +242,9 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
if discord_message.channel.server:
|
||||
discord.send_to_channel(bot_id, discord_message.channel.name, response)
|
||||
discord.send_to_channel(
|
||||
bot_id, cast(int, discord_message.channel_id), response
|
||||
)
|
||||
else:
|
||||
discord.send_dm(bot_id, discord_message.from_user.username, response)
|
||||
|
||||
@ -263,6 +265,8 @@ def add_discord_message(
|
||||
server_id: int | None = None,
|
||||
recipient_id: int | None = None,
|
||||
message_reference_id: int | None = None,
|
||||
message_type: str = "default",
|
||||
thread_id: int | None = None,
|
||||
image_urls: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@ -291,8 +295,9 @@ def add_discord_message(
|
||||
from_id=author_id,
|
||||
recipient_id=recipient_id,
|
||||
message_id=message_id,
|
||||
message_type="reply" if message_reference_id else "default",
|
||||
message_type=message_type,
|
||||
reply_to_message_id=message_reference_id,
|
||||
thread_id=thread_id,
|
||||
images=saved_image_paths or None,
|
||||
)
|
||||
existing_msg = check_content_exists(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user