diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index 2dc7fef..3c011a8 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -5,6 +5,7 @@ Database models for the knowledge base system. import pathlib import textwrap from datetime import datetime +from collections.abc import Collection from typing import Any, Annotated, Sequence, cast from PIL import Image @@ -309,16 +310,16 @@ class DiscordMessage(SourceItem): recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id]) @property - def allowed_tools(self) -> list[str]: - return ( + def allowed_tools(self) -> set[str]: + return set( (self.channel.allowed_tools if self.channel else []) + (self.from_user.allowed_tools if self.from_user else []) + (self.server.allowed_tools if self.server else []) ) @property - def disallowed_tools(self) -> list[str]: - return ( + def disallowed_tools(self) -> set[str]: + return set( (self.channel.disallowed_tools if self.channel else []) + (self.from_user.disallowed_tools if self.from_user else []) + (self.server.disallowed_tools if self.server else []) @@ -329,6 +330,11 @@ class DiscordMessage(SourceItem): not self.allowed_tools or tool in self.allowed_tools ) + def filter_tools(self, tools: Collection[str] | None = None) -> set[str]: + if tools is None: + return self.allowed_tools - self.disallowed_tools + return set(tools) - self.disallowed_tools & self.allowed_tools + @property def ignore_messages(self) -> bool: return ( @@ -359,7 +365,7 @@ class DiscordMessage(SourceItem): def title(self) -> str: return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}" - def as_content(self): + def as_content(self) -> dict[str, Any]: """Return message content ready for LLM (text + images from disk).""" content = {"text": self.title, "images": []} for path in cast(list[str] | None, self.images) or []: diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py index b6f4cfa..3f4d254 100644 --- a/src/memory/discord/messages.py +++ b/src/memory/discord/messages.py @@ -1,16 +1,20 @@ import logging import textwrap +from collections.abc import Collection from datetime import datetime, timezone from typing import Any, cast from sqlalchemy.orm import Session, scoped_session +from memory.common import discord, settings from memory.common.db.models import ( DiscordChannel, + DiscordMessage, DiscordUser, ScheduledLLMCall, - DiscordMessage, ) +from memory.common.db.models.users import BotUser +from memory.common.llms.base import create_provider logger = logging.getLogger(__name__) @@ -199,3 +203,94 @@ def comm_channel_prompt( You will be given the last {max_messages} messages in the conversation. Please react to them appropriately. You can return an empty response if you don't have anything to say. """).format(server_context=server_context, max_messages=max_messages) + + +def call_llm( + session: Session | scoped_session, + bot_user: BotUser, + from_user: DiscordUser | None, + channel: DiscordChannel | None, + model: str, + system_prompt: str = "", + messages: list[str | dict[str, Any]] = [], + allowed_tools: Collection[str] | None = None, + num_previous_messages: int = 10, +) -> str | None: + """ + Call LLM with Discord tools support. + + Args: + session: Database session + bot_user: Bot user making the call + from_user: Discord user who initiated the interaction + channel: Discord channel (if any) + messages: List of message strings or dicts with text/images + model: LLM model to use + system_prompt: System prompt + allowed_tools: List of allowed tool names (None = all tools allowed) + + Returns: + LLM response or None if failed + """ + provider = create_provider(model=model) + + if provider.usage_tracker.is_rate_limited(model): + logger.error( + f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}" + ) + return None + + user_id = None + if from_user and not channel: + user_id = from_user.id + prev_messages = previous_messages( + session, + user_id, + channel and channel.id, + max_messages=num_previous_messages, + ) + + from memory.common.llms.tools.discord import make_discord_tools + + tools = make_discord_tools(bot_user, from_user, channel, model=model) + + # Filter to allowed tools if specified + if allowed_tools is not None: + tools = {name: tool for name, tool in tools.items() if name in allowed_tools} + + message_content = [m.as_content() for m in prev_messages] + messages + return provider.run_with_tools( + messages=provider.as_messages(message_content), + tools=tools, + system_prompt=system_prompt, + max_iterations=settings.DISCORD_MAX_TOOL_CALLS, + ).response + + +def send_discord_response( + bot_id: int, + response: str, + channel_id: int | None = None, + user_identifier: str | None = None, +) -> bool: + """ + Send a response to Discord channel or user. + + Args: + bot_id: Bot user ID + response: Message to send + channel_id: Channel ID (for channel messages) + user_identifier: Username (for DMs) + + Returns: + True if sent successfully + """ + if channel_id is not None: + logger.info(f"Sending message to channel {channel_id}") + return discord.send_to_channel(bot_id, channel_id, response) + elif user_identifier is not None: + logger.info(f"Sending DM to {user_identifier}") + return discord.send_dm(bot_id, user_identifier, response) + else: + logger.error("Neither channel_id nor user_identifier provided") + return False diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index a889eb8..02ce8d4 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -23,9 +23,7 @@ from memory.common.celery_app import ( ) from memory.common.db.connection import make_session from memory.common.db.models import DiscordMessage, DiscordUser -from memory.common.llms.base import create_provider -from memory.common.llms.tools.discord import make_discord_tools -from memory.discord.messages import comm_channel_prompt, previous_messages +from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response from memory.workers.tasks.content_processing import ( check_content_exists, create_task_result, @@ -84,54 +82,6 @@ def get_prev( return [f"{msg.username}: {msg.content}" for msg in prev[::-1]] -def call_llm( - session, - message: DiscordMessage, - model: str, - msgs: list[str] = [], - allowed_tools: list[str] | None = None, -) -> str | None: - provider = create_provider(model=model) - if provider.usage_tracker.is_rate_limited(model): - logger.error( - f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}" - ) - return None - - tools = make_discord_tools( - message.recipient_user.system_user, - message.from_user, - message.channel, - model=model, - ) - tools = { - name: tool - for name, tool in tools.items() - if message.tool_allowed(name) - and (allowed_tools is None or name in allowed_tools) - } - system_prompt = message.system_prompt or "" - system_prompt += comm_channel_prompt( - session, message.recipient_user, message.channel - ) - messages = previous_messages( - session, - message.recipient_user and message.recipient_user.id, - message.channel and message.channel.id, - max_messages=10, - ) - - # Build message list: previous messages + current message + any extra text msgs - message_content = [m.as_content() for m in messages + [message]] + msgs - - return provider.run_with_tools( - messages=provider.as_messages(message_content), - tools=tools, - system_prompt=system_prompt, - max_iterations=settings.DISCORD_MAX_TOOL_CALLS, - ).response - - def should_process(message: DiscordMessage) -> bool: if not ( settings.DISCORD_PROCESS_MESSAGES @@ -155,16 +105,26 @@ def should_process(message: DiscordMessage) -> bool: I want to continue the conversation because I think it's important. """) + + system_prompt = message.system_prompt or "" + system_prompt += comm_channel_prompt( + session, message.recipient_user, message.channel + ) + allowed_tools = [ + "update_channel_summary", + "update_user_summary", + "update_server_summary", + ] + response = call_llm( session, - message, - settings.SUMMARIZER_MODEL, - [msg], - allowed_tools=[ - "update_channel_summary", - "update_user_summary", - "update_server_summary", - ], + bot_user=message.recipient_user.system_user, + from_user=message.from_user, + channel=message.channel, + model=settings.SUMMARIZER_MODEL, + system_prompt=system_prompt, + messages=[msg], + allowed_tools=message.filter_tools(allowed_tools), ) if not response: return False @@ -230,23 +190,40 @@ def process_discord_message(message_id: int) -> dict[str, Any]: } try: - response = call_llm(session, discord_message, settings.DISCORD_MODEL) + response = call_llm( + session, + bot_user=discord_message.recipient_user.system_user, + from_user=discord_message.from_user, + channel=discord_message.channel, + model=settings.DISCORD_MODEL, + system_prompt=discord_message.system_prompt, + ) except Exception: logger.exception("Failed to generate Discord response") - - print("response:", response) + return { + "status": "error", + "error": "Failed to generate Discord response", + "message_id": message_id, + } if not response: return { - "status": "processed", + "status": "no-response", "message_id": message_id, } - if discord_message.channel.server: - discord.send_to_channel( - bot_id, cast(int, discord_message.channel_id), response - ) - else: - discord.send_dm(bot_id, discord_message.from_user.username, response) + res = send_discord_response( + bot_id=bot_id, + response=response, + channel_id=discord_message.channel_id, + user_identifier=discord_message.from_user + and discord_message.from_user.username, + ) + if not res: + return { + "status": "error", + "error": "Failed to send Discord response", + "message_id": message_id, + } return { "status": "processed", diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py index 76746b9..14e79a7 100644 --- a/src/memory/workers/tasks/scheduled_calls.py +++ b/src/memory/workers/tasks/scheduled_calls.py @@ -2,41 +2,45 @@ import logging from datetime import datetime, timezone from typing import cast -from memory.common.db.connection import make_session -from memory.common.db.models import ScheduledLLMCall +from memory.common import settings from memory.common.celery_app import ( - app, EXECUTE_SCHEDULED_CALL, RUN_SCHEDULED_CALLS, + app, ) -from memory.common import llms, discord +from memory.common.db.connection import make_session +from memory.common.db.models import ScheduledLLMCall +from memory.discord.messages import call_llm, send_discord_response from memory.workers.tasks.content_processing import safe_task_execution logger = logging.getLogger(__name__) +def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None: + """Call LLM with tools support for scheduled calls.""" + if not scheduled_call.discord_user: + logger.warning("No discord_user for scheduled call - cannot execute") + return None + + model = cast(str, scheduled_call.model or settings.DISCORD_MODEL) + system_prompt = cast(str, scheduled_call.system_prompt or "") + message = cast(str, scheduled_call.message) + allowed_tools_list = cast(list[str] | None, scheduled_call.allowed_tools) + + return call_llm( + session=session, + bot_user=scheduled_call.user, + from_user=scheduled_call.discord_user, + channel=scheduled_call.discord_channel, + messages=[message], + model=model, + system_prompt=system_prompt, + allowed_tools=allowed_tools_list, + ) + + def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str): - """ - Send the LLM response to the specified Discord user. - - Args: - scheduled_call: The scheduled call object - response: The LLM response to send - """ - # Format the message with topic, model, and response - message_parts = [] - if cast(str, scheduled_call.topic): - message_parts.append(f"**Topic:** {scheduled_call.topic}") - if cast(str, scheduled_call.model): - message_parts.append(f"**Model:** {scheduled_call.model}") - message_parts.append(f"**Response:** {response}") - - message = "\n".join(message_parts) - - # Discord has a 2000 character limit, so we may need to split the message - if len(message) > 1900: # Leave some buffer - message = message[:1900] + "\n\n... (response truncated)" - + """Send the LLM response to Discord user or channel.""" bot_id_value = scheduled_call.user_id if bot_id_value is None: logger.warning( @@ -47,16 +51,16 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str): bot_id = cast(int, bot_id_value) - if discord_user := scheduled_call.discord_user: - logger.info(f"Sending DM to {discord_user.username}: {message}") - discord.send_dm(bot_id, discord_user.username, message) - elif discord_channel := scheduled_call.discord_channel: - logger.info(f"Broadcasting message to {discord_channel.name}: {message}") - discord.broadcast_message(bot_id, discord_channel.name, message) - else: - logger.warning( - f"No Discord user or channel found for scheduled call {scheduled_call.id}" - ) + send_discord_response( + bot_id=bot_id, + response=response, + channel_id=cast(int, scheduled_call.discord_channel.id) + if scheduled_call.discord_channel + else None, + user_identifier=scheduled_call.discord_user.username + if scheduled_call.discord_user + else None, + ) @app.task(bind=True, name=EXECUTE_SCHEDULED_CALL) @@ -92,15 +96,29 @@ def execute_scheduled_call(self, scheduled_call_id: str): logger.info(f"Calling LLM with model {scheduled_call.model}") - # Make the LLM call - if scheduled_call.model: - response = llms.summarize( - prompt=cast(str, scheduled_call.message), - model=cast(str, scheduled_call.model), - system_prompt=cast(str, scheduled_call.system_prompt), - ) - else: - response = cast(str, scheduled_call.message) + # Make the LLM call with tools support + try: + response = _call_llm_for_scheduled(session, scheduled_call) + except Exception: + logger.exception("Failed to generate LLM response") + scheduled_call.status = "failed" + scheduled_call.error_message = "LLM call failed" + session.commit() + return { + "success": False, + "error": "LLM call failed", + "scheduled_call_id": scheduled_call_id, + } + + if not response: + scheduled_call.status = "failed" + scheduled_call.error_message = "No response from LLM" + session.commit() + return { + "success": False, + "error": "No response from LLM", + "scheduled_call_id": scheduled_call_id, + } # Store the response scheduled_call.response = response