unify discord callers

This commit is contained in:
mruwnik 2025-11-02 14:46:43 +00:00
parent 69192f834a
commit 798b4779da
4 changed files with 215 additions and 119 deletions

View File

@ -5,6 +5,7 @@ Database models for the knowledge base system.
import pathlib import pathlib
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from collections.abc import Collection
from typing import Any, Annotated, Sequence, cast from typing import Any, Annotated, Sequence, cast
from PIL import Image from PIL import Image
@ -309,16 +310,16 @@ class DiscordMessage(SourceItem):
recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id]) recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
@property @property
def allowed_tools(self) -> list[str]: def allowed_tools(self) -> set[str]:
return ( return set(
(self.channel.allowed_tools if self.channel else []) (self.channel.allowed_tools if self.channel else [])
+ (self.from_user.allowed_tools if self.from_user else []) + (self.from_user.allowed_tools if self.from_user else [])
+ (self.server.allowed_tools if self.server else []) + (self.server.allowed_tools if self.server else [])
) )
@property @property
def disallowed_tools(self) -> list[str]: def disallowed_tools(self) -> set[str]:
return ( return set(
(self.channel.disallowed_tools if self.channel else []) (self.channel.disallowed_tools if self.channel else [])
+ (self.from_user.disallowed_tools if self.from_user else []) + (self.from_user.disallowed_tools if self.from_user else [])
+ (self.server.disallowed_tools if self.server else []) + (self.server.disallowed_tools if self.server else [])
@ -329,6 +330,11 @@ class DiscordMessage(SourceItem):
not self.allowed_tools or tool in self.allowed_tools not self.allowed_tools or tool in self.allowed_tools
) )
def filter_tools(self, tools: Collection[str] | None = None) -> set[str]:
if tools is None:
return self.allowed_tools - self.disallowed_tools
return set(tools) - self.disallowed_tools & self.allowed_tools
@property @property
def ignore_messages(self) -> bool: def ignore_messages(self) -> bool:
return ( return (
@ -359,7 +365,7 @@ class DiscordMessage(SourceItem):
def title(self) -> str: def title(self) -> str:
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}" return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
def as_content(self): def as_content(self) -> dict[str, Any]:
"""Return message content ready for LLM (text + images from disk).""" """Return message content ready for LLM (text + images from disk)."""
content = {"text": self.title, "images": []} content = {"text": self.title, "images": []}
for path in cast(list[str] | None, self.images) or []: for path in cast(list[str] | None, self.images) or []:

View File

@ -1,16 +1,20 @@
import logging import logging
import textwrap import textwrap
from collections.abc import Collection
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, cast from typing import Any, cast
from sqlalchemy.orm import Session, scoped_session from sqlalchemy.orm import Session, scoped_session
from memory.common import discord, settings
from memory.common.db.models import ( from memory.common.db.models import (
DiscordChannel, DiscordChannel,
DiscordMessage,
DiscordUser, DiscordUser,
ScheduledLLMCall, ScheduledLLMCall,
DiscordMessage,
) )
from memory.common.db.models.users import BotUser
from memory.common.llms.base import create_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -199,3 +203,94 @@ def comm_channel_prompt(
You will be given the last {max_messages} messages in the conversation. You will be given the last {max_messages} messages in the conversation.
Please react to them appropriately. You can return an empty response if you don't have anything to say. Please react to them appropriately. You can return an empty response if you don't have anything to say.
""").format(server_context=server_context, max_messages=max_messages) """).format(server_context=server_context, max_messages=max_messages)
def call_llm(
session: Session | scoped_session,
bot_user: BotUser,
from_user: DiscordUser | None,
channel: DiscordChannel | None,
model: str,
system_prompt: str = "",
messages: list[str | dict[str, Any]] = [],
allowed_tools: Collection[str] | None = None,
num_previous_messages: int = 10,
) -> str | None:
"""
Call LLM with Discord tools support.
Args:
session: Database session
bot_user: Bot user making the call
from_user: Discord user who initiated the interaction
channel: Discord channel (if any)
messages: List of message strings or dicts with text/images
model: LLM model to use
system_prompt: System prompt
allowed_tools: List of allowed tool names (None = all tools allowed)
Returns:
LLM response or None if failed
"""
provider = create_provider(model=model)
if provider.usage_tracker.is_rate_limited(model):
logger.error(
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
)
return None
user_id = None
if from_user and not channel:
user_id = from_user.id
prev_messages = previous_messages(
session,
user_id,
channel and channel.id,
max_messages=num_previous_messages,
)
from memory.common.llms.tools.discord import make_discord_tools
tools = make_discord_tools(bot_user, from_user, channel, model=model)
# Filter to allowed tools if specified
if allowed_tools is not None:
tools = {name: tool for name, tool in tools.items() if name in allowed_tools}
message_content = [m.as_content() for m in prev_messages] + messages
return provider.run_with_tools(
messages=provider.as_messages(message_content),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response
def send_discord_response(
bot_id: int,
response: str,
channel_id: int | None = None,
user_identifier: str | None = None,
) -> bool:
"""
Send a response to Discord channel or user.
Args:
bot_id: Bot user ID
response: Message to send
channel_id: Channel ID (for channel messages)
user_identifier: Username (for DMs)
Returns:
True if sent successfully
"""
if channel_id is not None:
logger.info(f"Sending message to channel {channel_id}")
return discord.send_to_channel(bot_id, channel_id, response)
elif user_identifier is not None:
logger.info(f"Sending DM to {user_identifier}")
return discord.send_dm(bot_id, user_identifier, response)
else:
logger.error("Neither channel_id nor user_identifier provided")
return False

View File

@ -23,9 +23,7 @@ from memory.common.celery_app import (
) )
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import DiscordMessage, DiscordUser from memory.common.db.models import DiscordMessage, DiscordUser
from memory.common.llms.base import create_provider from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
from memory.common.llms.tools.discord import make_discord_tools
from memory.discord.messages import comm_channel_prompt, previous_messages
from memory.workers.tasks.content_processing import ( from memory.workers.tasks.content_processing import (
check_content_exists, check_content_exists,
create_task_result, create_task_result,
@ -84,54 +82,6 @@ def get_prev(
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]] return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
def call_llm(
session,
message: DiscordMessage,
model: str,
msgs: list[str] = [],
allowed_tools: list[str] | None = None,
) -> str | None:
provider = create_provider(model=model)
if provider.usage_tracker.is_rate_limited(model):
logger.error(
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
)
return None
tools = make_discord_tools(
message.recipient_user.system_user,
message.from_user,
message.channel,
model=model,
)
tools = {
name: tool
for name, tool in tools.items()
if message.tool_allowed(name)
and (allowed_tools is None or name in allowed_tools)
}
system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel
)
messages = previous_messages(
session,
message.recipient_user and message.recipient_user.id,
message.channel and message.channel.id,
max_messages=10,
)
# Build message list: previous messages + current message + any extra text msgs
message_content = [m.as_content() for m in messages + [message]] + msgs
return provider.run_with_tools(
messages=provider.as_messages(message_content),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response
def should_process(message: DiscordMessage) -> bool: def should_process(message: DiscordMessage) -> bool:
if not ( if not (
settings.DISCORD_PROCESS_MESSAGES settings.DISCORD_PROCESS_MESSAGES
@ -155,16 +105,26 @@ def should_process(message: DiscordMessage) -> bool:
<reason>I want to continue the conversation because I think it's important.</reason> <reason>I want to continue the conversation because I think it's important.</reason>
</response> </response>
""") """)
response = call_llm(
session, system_prompt = message.system_prompt or ""
message, system_prompt += comm_channel_prompt(
settings.SUMMARIZER_MODEL, session, message.recipient_user, message.channel
[msg], )
allowed_tools = [ allowed_tools = [
"update_channel_summary", "update_channel_summary",
"update_user_summary", "update_user_summary",
"update_server_summary", "update_server_summary",
], ]
response = call_llm(
session,
bot_user=message.recipient_user.system_user,
from_user=message.from_user,
channel=message.channel,
model=settings.SUMMARIZER_MODEL,
system_prompt=system_prompt,
messages=[msg],
allowed_tools=message.filter_tools(allowed_tools),
) )
if not response: if not response:
return False return False
@ -230,23 +190,40 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
} }
try: try:
response = call_llm(session, discord_message, settings.DISCORD_MODEL) response = call_llm(
session,
bot_user=discord_message.recipient_user.system_user,
from_user=discord_message.from_user,
channel=discord_message.channel,
model=settings.DISCORD_MODEL,
system_prompt=discord_message.system_prompt,
)
except Exception: except Exception:
logger.exception("Failed to generate Discord response") logger.exception("Failed to generate Discord response")
return {
print("response:", response) "status": "error",
"error": "Failed to generate Discord response",
"message_id": message_id,
}
if not response: if not response:
return { return {
"status": "processed", "status": "no-response",
"message_id": message_id, "message_id": message_id,
} }
if discord_message.channel.server: res = send_discord_response(
discord.send_to_channel( bot_id=bot_id,
bot_id, cast(int, discord_message.channel_id), response response=response,
channel_id=discord_message.channel_id,
user_identifier=discord_message.from_user
and discord_message.from_user.username,
) )
else: if not res:
discord.send_dm(bot_id, discord_message.from_user.username, response) return {
"status": "error",
"error": "Failed to send Discord response",
"message_id": message_id,
}
return { return {
"status": "processed", "status": "processed",

View File

@ -2,41 +2,45 @@ import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import cast from typing import cast
from memory.common.db.connection import make_session from memory.common import settings
from memory.common.db.models import ScheduledLLMCall
from memory.common.celery_app import ( from memory.common.celery_app import (
app,
EXECUTE_SCHEDULED_CALL, EXECUTE_SCHEDULED_CALL,
RUN_SCHEDULED_CALLS, RUN_SCHEDULED_CALLS,
app,
) )
from memory.common import llms, discord from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
from memory.discord.messages import call_llm, send_discord_response
from memory.workers.tasks.content_processing import safe_task_execution from memory.workers.tasks.content_processing import safe_task_execution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
"""Call LLM with tools support for scheduled calls."""
if not scheduled_call.discord_user:
logger.warning("No discord_user for scheduled call - cannot execute")
return None
model = cast(str, scheduled_call.model or settings.DISCORD_MODEL)
system_prompt = cast(str, scheduled_call.system_prompt or "")
message = cast(str, scheduled_call.message)
allowed_tools_list = cast(list[str] | None, scheduled_call.allowed_tools)
return call_llm(
session=session,
bot_user=scheduled_call.user,
from_user=scheduled_call.discord_user,
channel=scheduled_call.discord_channel,
messages=[message],
model=model,
system_prompt=system_prompt,
allowed_tools=allowed_tools_list,
)
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str): def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
""" """Send the LLM response to Discord user or channel."""
Send the LLM response to the specified Discord user.
Args:
scheduled_call: The scheduled call object
response: The LLM response to send
"""
# Format the message with topic, model, and response
message_parts = []
if cast(str, scheduled_call.topic):
message_parts.append(f"**Topic:** {scheduled_call.topic}")
if cast(str, scheduled_call.model):
message_parts.append(f"**Model:** {scheduled_call.model}")
message_parts.append(f"**Response:** {response}")
message = "\n".join(message_parts)
# Discord has a 2000 character limit, so we may need to split the message
if len(message) > 1900: # Leave some buffer
message = message[:1900] + "\n\n... (response truncated)"
bot_id_value = scheduled_call.user_id bot_id_value = scheduled_call.user_id
if bot_id_value is None: if bot_id_value is None:
logger.warning( logger.warning(
@ -47,15 +51,15 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
bot_id = cast(int, bot_id_value) bot_id = cast(int, bot_id_value)
if discord_user := scheduled_call.discord_user: send_discord_response(
logger.info(f"Sending DM to {discord_user.username}: {message}") bot_id=bot_id,
discord.send_dm(bot_id, discord_user.username, message) response=response,
elif discord_channel := scheduled_call.discord_channel: channel_id=cast(int, scheduled_call.discord_channel.id)
logger.info(f"Broadcasting message to {discord_channel.name}: {message}") if scheduled_call.discord_channel
discord.broadcast_message(bot_id, discord_channel.name, message) else None,
else: user_identifier=scheduled_call.discord_user.username
logger.warning( if scheduled_call.discord_user
f"No Discord user or channel found for scheduled call {scheduled_call.id}" else None,
) )
@ -92,15 +96,29 @@ def execute_scheduled_call(self, scheduled_call_id: str):
logger.info(f"Calling LLM with model {scheduled_call.model}") logger.info(f"Calling LLM with model {scheduled_call.model}")
# Make the LLM call # Make the LLM call with tools support
if scheduled_call.model: try:
response = llms.summarize( response = _call_llm_for_scheduled(session, scheduled_call)
prompt=cast(str, scheduled_call.message), except Exception:
model=cast(str, scheduled_call.model), logger.exception("Failed to generate LLM response")
system_prompt=cast(str, scheduled_call.system_prompt), scheduled_call.status = "failed"
) scheduled_call.error_message = "LLM call failed"
else: session.commit()
response = cast(str, scheduled_call.message) return {
"success": False,
"error": "LLM call failed",
"scheduled_call_id": scheduled_call_id,
}
if not response:
scheduled_call.status = "failed"
scheduled_call.error_message = "No response from LLM"
session.commit()
return {
"success": False,
"error": "No response from LLM",
"scheduled_call_id": scheduled_call_id,
}
# Store the response # Store the response
scheduled_call.response = response scheduled_call.response = response