mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 08:14:05 +01:00
handle discord threads
This commit is contained in:
parent
6bd7df8ee3
commit
69192f834a
@ -55,14 +55,14 @@ def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
|
||||||
"""Send a DM via the Discord collector API"""
|
"""Send message to a channel by name or ID (ID supports threads)"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_channel",
|
f"{get_api_url()}/send_channel",
|
||||||
json={
|
json={
|
||||||
"bot_id": bot_id,
|
"bot_id": bot_id,
|
||||||
"channel_name": channel_name,
|
"channel": channel,
|
||||||
"message": message,
|
"message": message,
|
||||||
},
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
@ -73,16 +73,16 @@ def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
|||||||
return result.get("success", False)
|
return result.get("success", False)
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Failed to send to channel {channel_name}: {e}")
|
logger.error(f"Failed to send to channel {channel}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
|
def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
|
||||||
"""Trigger typing indicator for a channel via the Discord collector API"""
|
"""Trigger typing indicator for a channel by name or ID (ID supports threads)"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/typing/channel",
|
f"{get_api_url()}/typing/channel",
|
||||||
json={"bot_id": bot_id, "channel_name": channel_name},
|
json={"bot_id": bot_id, "channel": channel},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -90,18 +90,18 @@ def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
|
|||||||
return result.get("success", False)
|
return result.get("success", False)
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
logger.error(f"Failed to trigger typing for channel {channel}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
|
||||||
"""Send a message to a channel via the Discord collector API"""
|
"""Send a message to a channel by name or ID (ID supports threads)"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_channel",
|
f"{get_api_url()}/send_channel",
|
||||||
json={
|
json={
|
||||||
"bot_id": bot_id,
|
"bot_id": bot_id,
|
||||||
"channel_name": channel_name,
|
"channel": channel,
|
||||||
"message": message,
|
"message": message,
|
||||||
},
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
@ -111,7 +111,7 @@ def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
|||||||
return result.get("success", False)
|
return result.get("success", False)
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
logger.error(f"Failed to send message to channel {channel}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class SendDMRequest(BaseModel):
|
|||||||
|
|
||||||
class SendChannelRequest(BaseModel):
|
class SendChannelRequest(BaseModel):
|
||||||
bot_id: int
|
bot_id: int
|
||||||
channel_name: str # Channel name (e.g., "memory-errors")
|
channel: int | str # Channel name or ID (ID supports threads)
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ class TypingDMRequest(BaseModel):
|
|||||||
|
|
||||||
class TypingChannelRequest(BaseModel):
|
class TypingChannelRequest(BaseModel):
|
||||||
bot_id: int
|
bot_id: int
|
||||||
channel_name: str
|
channel: int | str # Channel name or ID (ID supports threads)
|
||||||
|
|
||||||
|
|
||||||
class Collector:
|
class Collector:
|
||||||
@ -154,7 +154,7 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
success = await collector.collector.send_to_channel(
|
success = await collector.collector.send_to_channel(
|
||||||
request.channel_name, request.message
|
request.channel, request.message
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send channel message: {e}")
|
logger.error(f"Failed to send channel message: {e}")
|
||||||
@ -163,13 +163,13 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
|||||||
if success:
|
if success:
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": f"Message sent to channel {request.channel_name}",
|
"message": f"Message sent to channel {request.channel}",
|
||||||
"channel": request.channel_name,
|
"channel": request.channel,
|
||||||
}
|
}
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Failed to send message to channel {request.channel_name}",
|
detail=f"Failed to send message to channel {request.channel}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -181,7 +181,7 @@ async def trigger_channel_typing(request: TypingChannelRequest):
|
|||||||
raise HTTPException(status_code=404, detail="Bot not found")
|
raise HTTPException(status_code=404, detail="Bot not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success = await collector.collector.trigger_typing_channel(request.channel_name)
|
success = await collector.collector.trigger_typing_channel(request.channel)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to trigger channel typing: {e}")
|
logger.error(f"Failed to trigger channel typing: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@ -189,13 +189,13 @@ async def trigger_channel_typing(request: TypingChannelRequest):
|
|||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Failed to trigger typing for channel {request.channel_name}",
|
detail=f"Failed to trigger typing for channel {request.channel}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"channel": request.channel_name,
|
"channel": request.channel,
|
||||||
"message": f"Typing triggered for channel {request.channel_name}",
|
"message": f"Typing triggered for channel {request.channel}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -207,6 +207,10 @@ def sync_guild_metadata(guild: discord.Guild) -> None:
|
|||||||
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
||||||
create_or_update_channel(session, channel)
|
create_or_update_channel(session, channel)
|
||||||
|
|
||||||
|
# Sync threads
|
||||||
|
for thread in guild.threads:
|
||||||
|
create_or_update_channel(session, thread)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
@ -266,6 +270,11 @@ class MessageCollector(commands.Bot):
|
|||||||
if att.content_type and att.content_type.startswith("image/")
|
if att.content_type and att.content_type.startswith("image/")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Determine message metadata (type, reply, thread)
|
||||||
|
message_type, reply_to_id, thread_id = determine_message_metadata(
|
||||||
|
message
|
||||||
|
)
|
||||||
|
|
||||||
# Queue the message for processing
|
# Queue the message for processing
|
||||||
add_discord_message.delay(
|
add_discord_message.delay(
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
@ -275,9 +284,9 @@ class MessageCollector(commands.Bot):
|
|||||||
server_id=message.guild.id if message.guild else None,
|
server_id=message.guild.id if message.guild else None,
|
||||||
content=message.content or "",
|
content=message.content or "",
|
||||||
sent_at=message.created_at.isoformat(),
|
sent_at=message.created_at.isoformat(),
|
||||||
message_reference_id=message.reference.message_id
|
message_reference_id=reply_to_id,
|
||||||
if message.reference
|
message_type=message_type,
|
||||||
else None,
|
thread_id=thread_id,
|
||||||
image_urls=image_urls,
|
image_urls=image_urls,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -322,6 +331,11 @@ class MessageCollector(commands.Bot):
|
|||||||
create_or_update_channel(session, channel)
|
create_or_update_channel(session, channel)
|
||||||
channels_updated += 1
|
channels_updated += 1
|
||||||
|
|
||||||
|
# Refresh all threads in this server
|
||||||
|
for thread in guild.threads:
|
||||||
|
create_or_update_channel(session, thread)
|
||||||
|
channels_updated += 1
|
||||||
|
|
||||||
# Refresh all members in this server (if members intent is enabled)
|
# Refresh all members in this server (if members intent is enabled)
|
||||||
if self.intents.members:
|
if self.intents.members:
|
||||||
for member in guild.members:
|
for member in guild.members:
|
||||||
@ -440,15 +454,22 @@ class MessageCollector(commands.Bot):
|
|||||||
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def send_to_channel(self, channel_name: str, message: str) -> bool:
|
async def send_to_channel(
|
||||||
"""Send a message to a channel by name across all guilds"""
|
self, channel_identifier: int | str, message: str
|
||||||
|
) -> bool:
|
||||||
|
"""Send a message to a channel by name or ID (supports threads)"""
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
channel = await self.get_channel_by_name(channel_name)
|
# Get channel by ID or name
|
||||||
|
if isinstance(channel_identifier, int):
|
||||||
|
channel = self.get_channel(channel_identifier)
|
||||||
|
else:
|
||||||
|
channel = await self.get_channel_by_name(channel_identifier)
|
||||||
|
|
||||||
if not channel:
|
if not channel:
|
||||||
logger.error(f"Channel {channel_name} not found")
|
logger.error(f"Channel {channel_identifier} not found")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Post-process mentions to convert usernames to IDs
|
# Post-process mentions to convert usernames to IDs
|
||||||
@ -456,22 +477,27 @@ class MessageCollector(commands.Bot):
|
|||||||
processed_message = process_mentions(session, message)
|
processed_message = process_mentions(session, message)
|
||||||
|
|
||||||
await channel.send(processed_message)
|
await channel.send(processed_message)
|
||||||
logger.info(f"Sent message to channel {channel_name}")
|
logger.info(f"Sent message to channel {channel_identifier}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
logger.error(f"Failed to send message to channel {channel_identifier}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def trigger_typing_channel(self, channel_name: str) -> bool:
|
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
|
||||||
"""Trigger typing indicator in a channel"""
|
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
channel = await self.get_channel_by_name(channel_name)
|
# Get channel by ID or name
|
||||||
|
if isinstance(channel_identifier, int):
|
||||||
|
channel = self.get_channel(channel_identifier)
|
||||||
|
else:
|
||||||
|
channel = await self.get_channel_by_name(channel_identifier)
|
||||||
|
|
||||||
if not channel:
|
if not channel:
|
||||||
logger.error(f"Channel {channel_name} not found")
|
logger.error(f"Channel {channel_identifier} not found")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async with channel.typing():
|
async with channel.typing():
|
||||||
@ -479,5 +505,7 @@ class MessageCollector(commands.Bot):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
logger.error(
|
||||||
|
f"Failed to trigger typing for channel {channel_identifier}: {e}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -180,7 +180,7 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if message.channel and message.channel.server:
|
if message.channel and message.channel.server:
|
||||||
discord.trigger_typing_channel(bot_id, message.channel.name)
|
discord.trigger_typing_channel(bot_id, cast(int, message.channel_id))
|
||||||
else:
|
else:
|
||||||
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
|
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
|
||||||
return True
|
return True
|
||||||
@ -242,7 +242,9 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if discord_message.channel.server:
|
if discord_message.channel.server:
|
||||||
discord.send_to_channel(bot_id, discord_message.channel.name, response)
|
discord.send_to_channel(
|
||||||
|
bot_id, cast(int, discord_message.channel_id), response
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
discord.send_dm(bot_id, discord_message.from_user.username, response)
|
discord.send_dm(bot_id, discord_message.from_user.username, response)
|
||||||
|
|
||||||
@ -263,6 +265,8 @@ def add_discord_message(
|
|||||||
server_id: int | None = None,
|
server_id: int | None = None,
|
||||||
recipient_id: int | None = None,
|
recipient_id: int | None = None,
|
||||||
message_reference_id: int | None = None,
|
message_reference_id: int | None = None,
|
||||||
|
message_type: str = "default",
|
||||||
|
thread_id: int | None = None,
|
||||||
image_urls: list[str] | None = None,
|
image_urls: list[str] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -291,8 +295,9 @@ def add_discord_message(
|
|||||||
from_id=author_id,
|
from_id=author_id,
|
||||||
recipient_id=recipient_id,
|
recipient_id=recipient_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
message_type="reply" if message_reference_id else "default",
|
message_type=message_type,
|
||||||
reply_to_message_id=message_reference_id,
|
reply_to_message_id=message_reference_id,
|
||||||
|
thread_id=thread_id,
|
||||||
images=saved_image_paths or None,
|
images=saved_image_paths or None,
|
||||||
)
|
)
|
||||||
existing_msg = check_content_exists(
|
existing_msg = check_content_exists(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user