mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 08:14:05 +01:00
unify discord callers
This commit is contained in:
parent
69192f834a
commit
798b4779da
@ -5,6 +5,7 @@ Database models for the knowledge base system.
|
|||||||
import pathlib
|
import pathlib
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from collections.abc import Collection
|
||||||
from typing import Any, Annotated, Sequence, cast
|
from typing import Any, Annotated, Sequence, cast
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -309,16 +310,16 @@ class DiscordMessage(SourceItem):
|
|||||||
recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
|
recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allowed_tools(self) -> list[str]:
|
def allowed_tools(self) -> set[str]:
|
||||||
return (
|
return set(
|
||||||
(self.channel.allowed_tools if self.channel else [])
|
(self.channel.allowed_tools if self.channel else [])
|
||||||
+ (self.from_user.allowed_tools if self.from_user else [])
|
+ (self.from_user.allowed_tools if self.from_user else [])
|
||||||
+ (self.server.allowed_tools if self.server else [])
|
+ (self.server.allowed_tools if self.server else [])
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disallowed_tools(self) -> list[str]:
|
def disallowed_tools(self) -> set[str]:
|
||||||
return (
|
return set(
|
||||||
(self.channel.disallowed_tools if self.channel else [])
|
(self.channel.disallowed_tools if self.channel else [])
|
||||||
+ (self.from_user.disallowed_tools if self.from_user else [])
|
+ (self.from_user.disallowed_tools if self.from_user else [])
|
||||||
+ (self.server.disallowed_tools if self.server else [])
|
+ (self.server.disallowed_tools if self.server else [])
|
||||||
@ -329,6 +330,11 @@ class DiscordMessage(SourceItem):
|
|||||||
not self.allowed_tools or tool in self.allowed_tools
|
not self.allowed_tools or tool in self.allowed_tools
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def filter_tools(self, tools: Collection[str] | None = None) -> set[str]:
|
||||||
|
if tools is None:
|
||||||
|
return self.allowed_tools - self.disallowed_tools
|
||||||
|
return set(tools) - self.disallowed_tools & self.allowed_tools
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ignore_messages(self) -> bool:
|
def ignore_messages(self) -> bool:
|
||||||
return (
|
return (
|
||||||
@ -359,7 +365,7 @@ class DiscordMessage(SourceItem):
|
|||||||
def title(self) -> str:
|
def title(self) -> str:
|
||||||
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
||||||
|
|
||||||
def as_content(self):
|
def as_content(self) -> dict[str, Any]:
|
||||||
"""Return message content ready for LLM (text + images from disk)."""
|
"""Return message content ready for LLM (text + images from disk)."""
|
||||||
content = {"text": self.title, "images": []}
|
content = {"text": self.title, "images": []}
|
||||||
for path in cast(list[str] | None, self.images) or []:
|
for path in cast(list[str] | None, self.images) or []:
|
||||||
|
|||||||
@ -1,16 +1,20 @@
|
|||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from collections.abc import Collection
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
|
from memory.common import discord, settings
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
|
DiscordMessage,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
ScheduledLLMCall,
|
ScheduledLLMCall,
|
||||||
DiscordMessage,
|
|
||||||
)
|
)
|
||||||
|
from memory.common.db.models.users import BotUser
|
||||||
|
from memory.common.llms.base import create_provider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -199,3 +203,94 @@ def comm_channel_prompt(
|
|||||||
You will be given the last {max_messages} messages in the conversation.
|
You will be given the last {max_messages} messages in the conversation.
|
||||||
Please react to them appropriately. You can return an empty response if you don't have anything to say.
|
Please react to them appropriately. You can return an empty response if you don't have anything to say.
|
||||||
""").format(server_context=server_context, max_messages=max_messages)
|
""").format(server_context=server_context, max_messages=max_messages)
|
||||||
|
|
||||||
|
|
||||||
|
def call_llm(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
bot_user: BotUser,
|
||||||
|
from_user: DiscordUser | None,
|
||||||
|
channel: DiscordChannel | None,
|
||||||
|
model: str,
|
||||||
|
system_prompt: str = "",
|
||||||
|
messages: list[str | dict[str, Any]] = [],
|
||||||
|
allowed_tools: Collection[str] | None = None,
|
||||||
|
num_previous_messages: int = 10,
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Call LLM with Discord tools support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
bot_user: Bot user making the call
|
||||||
|
from_user: Discord user who initiated the interaction
|
||||||
|
channel: Discord channel (if any)
|
||||||
|
messages: List of message strings or dicts with text/images
|
||||||
|
model: LLM model to use
|
||||||
|
system_prompt: System prompt
|
||||||
|
allowed_tools: List of allowed tool names (None = all tools allowed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLM response or None if failed
|
||||||
|
"""
|
||||||
|
provider = create_provider(model=model)
|
||||||
|
|
||||||
|
if provider.usage_tracker.is_rate_limited(model):
|
||||||
|
logger.error(
|
||||||
|
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
if from_user and not channel:
|
||||||
|
user_id = from_user.id
|
||||||
|
prev_messages = previous_messages(
|
||||||
|
session,
|
||||||
|
user_id,
|
||||||
|
channel and channel.id,
|
||||||
|
max_messages=num_previous_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
from memory.common.llms.tools.discord import make_discord_tools
|
||||||
|
|
||||||
|
tools = make_discord_tools(bot_user, from_user, channel, model=model)
|
||||||
|
|
||||||
|
# Filter to allowed tools if specified
|
||||||
|
if allowed_tools is not None:
|
||||||
|
tools = {name: tool for name, tool in tools.items() if name in allowed_tools}
|
||||||
|
|
||||||
|
message_content = [m.as_content() for m in prev_messages] + messages
|
||||||
|
return provider.run_with_tools(
|
||||||
|
messages=provider.as_messages(message_content),
|
||||||
|
tools=tools,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||||
|
).response
|
||||||
|
|
||||||
|
|
||||||
|
def send_discord_response(
|
||||||
|
bot_id: int,
|
||||||
|
response: str,
|
||||||
|
channel_id: int | None = None,
|
||||||
|
user_identifier: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Send a response to Discord channel or user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot_id: Bot user ID
|
||||||
|
response: Message to send
|
||||||
|
channel_id: Channel ID (for channel messages)
|
||||||
|
user_identifier: Username (for DMs)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if sent successfully
|
||||||
|
"""
|
||||||
|
if channel_id is not None:
|
||||||
|
logger.info(f"Sending message to channel {channel_id}")
|
||||||
|
return discord.send_to_channel(bot_id, channel_id, response)
|
||||||
|
elif user_identifier is not None:
|
||||||
|
logger.info(f"Sending DM to {user_identifier}")
|
||||||
|
return discord.send_dm(bot_id, user_identifier, response)
|
||||||
|
else:
|
||||||
|
logger.error("Neither channel_id nor user_identifier provided")
|
||||||
|
return False
|
||||||
|
|||||||
@ -23,9 +23,7 @@ from memory.common.celery_app import (
|
|||||||
)
|
)
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import DiscordMessage, DiscordUser
|
from memory.common.db.models import DiscordMessage, DiscordUser
|
||||||
from memory.common.llms.base import create_provider
|
from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
|
||||||
from memory.common.llms.tools.discord import make_discord_tools
|
|
||||||
from memory.discord.messages import comm_channel_prompt, previous_messages
|
|
||||||
from memory.workers.tasks.content_processing import (
|
from memory.workers.tasks.content_processing import (
|
||||||
check_content_exists,
|
check_content_exists,
|
||||||
create_task_result,
|
create_task_result,
|
||||||
@ -84,54 +82,6 @@ def get_prev(
|
|||||||
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
||||||
|
|
||||||
|
|
||||||
def call_llm(
|
|
||||||
session,
|
|
||||||
message: DiscordMessage,
|
|
||||||
model: str,
|
|
||||||
msgs: list[str] = [],
|
|
||||||
allowed_tools: list[str] | None = None,
|
|
||||||
) -> str | None:
|
|
||||||
provider = create_provider(model=model)
|
|
||||||
if provider.usage_tracker.is_rate_limited(model):
|
|
||||||
logger.error(
|
|
||||||
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
tools = make_discord_tools(
|
|
||||||
message.recipient_user.system_user,
|
|
||||||
message.from_user,
|
|
||||||
message.channel,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
tools = {
|
|
||||||
name: tool
|
|
||||||
for name, tool in tools.items()
|
|
||||||
if message.tool_allowed(name)
|
|
||||||
and (allowed_tools is None or name in allowed_tools)
|
|
||||||
}
|
|
||||||
system_prompt = message.system_prompt or ""
|
|
||||||
system_prompt += comm_channel_prompt(
|
|
||||||
session, message.recipient_user, message.channel
|
|
||||||
)
|
|
||||||
messages = previous_messages(
|
|
||||||
session,
|
|
||||||
message.recipient_user and message.recipient_user.id,
|
|
||||||
message.channel and message.channel.id,
|
|
||||||
max_messages=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build message list: previous messages + current message + any extra text msgs
|
|
||||||
message_content = [m.as_content() for m in messages + [message]] + msgs
|
|
||||||
|
|
||||||
return provider.run_with_tools(
|
|
||||||
messages=provider.as_messages(message_content),
|
|
||||||
tools=tools,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
|
||||||
).response
|
|
||||||
|
|
||||||
|
|
||||||
def should_process(message: DiscordMessage) -> bool:
|
def should_process(message: DiscordMessage) -> bool:
|
||||||
if not (
|
if not (
|
||||||
settings.DISCORD_PROCESS_MESSAGES
|
settings.DISCORD_PROCESS_MESSAGES
|
||||||
@ -155,16 +105,26 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
<reason>I want to continue the conversation because I think it's important.</reason>
|
<reason>I want to continue the conversation because I think it's important.</reason>
|
||||||
</response>
|
</response>
|
||||||
""")
|
""")
|
||||||
response = call_llm(
|
|
||||||
session,
|
system_prompt = message.system_prompt or ""
|
||||||
message,
|
system_prompt += comm_channel_prompt(
|
||||||
settings.SUMMARIZER_MODEL,
|
session, message.recipient_user, message.channel
|
||||||
[msg],
|
)
|
||||||
allowed_tools=[
|
allowed_tools = [
|
||||||
"update_channel_summary",
|
"update_channel_summary",
|
||||||
"update_user_summary",
|
"update_user_summary",
|
||||||
"update_server_summary",
|
"update_server_summary",
|
||||||
],
|
]
|
||||||
|
|
||||||
|
response = call_llm(
|
||||||
|
session,
|
||||||
|
bot_user=message.recipient_user.system_user,
|
||||||
|
from_user=message.from_user,
|
||||||
|
channel=message.channel,
|
||||||
|
model=settings.SUMMARIZER_MODEL,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
messages=[msg],
|
||||||
|
allowed_tools=message.filter_tools(allowed_tools),
|
||||||
)
|
)
|
||||||
if not response:
|
if not response:
|
||||||
return False
|
return False
|
||||||
@ -230,23 +190,40 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
response = call_llm(
|
||||||
|
session,
|
||||||
|
bot_user=discord_message.recipient_user.system_user,
|
||||||
|
from_user=discord_message.from_user,
|
||||||
|
channel=discord_message.channel,
|
||||||
|
model=settings.DISCORD_MODEL,
|
||||||
|
system_prompt=discord_message.system_prompt,
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to generate Discord response")
|
logger.exception("Failed to generate Discord response")
|
||||||
|
return {
|
||||||
print("response:", response)
|
"status": "error",
|
||||||
|
"error": "Failed to generate Discord response",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
if not response:
|
if not response:
|
||||||
return {
|
return {
|
||||||
"status": "processed",
|
"status": "no-response",
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if discord_message.channel.server:
|
res = send_discord_response(
|
||||||
discord.send_to_channel(
|
bot_id=bot_id,
|
||||||
bot_id, cast(int, discord_message.channel_id), response
|
response=response,
|
||||||
|
channel_id=discord_message.channel_id,
|
||||||
|
user_identifier=discord_message.from_user
|
||||||
|
and discord_message.from_user.username,
|
||||||
)
|
)
|
||||||
else:
|
if not res:
|
||||||
discord.send_dm(bot_id, discord_message.from_user.username, response)
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": "Failed to send Discord response",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "processed",
|
"status": "processed",
|
||||||
|
|||||||
@ -2,41 +2,45 @@ import logging
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common import settings
|
||||||
from memory.common.db.models import ScheduledLLMCall
|
|
||||||
from memory.common.celery_app import (
|
from memory.common.celery_app import (
|
||||||
app,
|
|
||||||
EXECUTE_SCHEDULED_CALL,
|
EXECUTE_SCHEDULED_CALL,
|
||||||
RUN_SCHEDULED_CALLS,
|
RUN_SCHEDULED_CALLS,
|
||||||
|
app,
|
||||||
)
|
)
|
||||||
from memory.common import llms, discord
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import ScheduledLLMCall
|
||||||
|
from memory.discord.messages import call_llm, send_discord_response
|
||||||
from memory.workers.tasks.content_processing import safe_task_execution
|
from memory.workers.tasks.content_processing import safe_task_execution
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
|
||||||
|
"""Call LLM with tools support for scheduled calls."""
|
||||||
|
if not scheduled_call.discord_user:
|
||||||
|
logger.warning("No discord_user for scheduled call - cannot execute")
|
||||||
|
return None
|
||||||
|
|
||||||
|
model = cast(str, scheduled_call.model or settings.DISCORD_MODEL)
|
||||||
|
system_prompt = cast(str, scheduled_call.system_prompt or "")
|
||||||
|
message = cast(str, scheduled_call.message)
|
||||||
|
allowed_tools_list = cast(list[str] | None, scheduled_call.allowed_tools)
|
||||||
|
|
||||||
|
return call_llm(
|
||||||
|
session=session,
|
||||||
|
bot_user=scheduled_call.user,
|
||||||
|
from_user=scheduled_call.discord_user,
|
||||||
|
channel=scheduled_call.discord_channel,
|
||||||
|
messages=[message],
|
||||||
|
model=model,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
allowed_tools=allowed_tools_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
||||||
"""
|
"""Send the LLM response to Discord user or channel."""
|
||||||
Send the LLM response to the specified Discord user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scheduled_call: The scheduled call object
|
|
||||||
response: The LLM response to send
|
|
||||||
"""
|
|
||||||
# Format the message with topic, model, and response
|
|
||||||
message_parts = []
|
|
||||||
if cast(str, scheduled_call.topic):
|
|
||||||
message_parts.append(f"**Topic:** {scheduled_call.topic}")
|
|
||||||
if cast(str, scheduled_call.model):
|
|
||||||
message_parts.append(f"**Model:** {scheduled_call.model}")
|
|
||||||
message_parts.append(f"**Response:** {response}")
|
|
||||||
|
|
||||||
message = "\n".join(message_parts)
|
|
||||||
|
|
||||||
# Discord has a 2000 character limit, so we may need to split the message
|
|
||||||
if len(message) > 1900: # Leave some buffer
|
|
||||||
message = message[:1900] + "\n\n... (response truncated)"
|
|
||||||
|
|
||||||
bot_id_value = scheduled_call.user_id
|
bot_id_value = scheduled_call.user_id
|
||||||
if bot_id_value is None:
|
if bot_id_value is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -47,15 +51,15 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
|||||||
|
|
||||||
bot_id = cast(int, bot_id_value)
|
bot_id = cast(int, bot_id_value)
|
||||||
|
|
||||||
if discord_user := scheduled_call.discord_user:
|
send_discord_response(
|
||||||
logger.info(f"Sending DM to {discord_user.username}: {message}")
|
bot_id=bot_id,
|
||||||
discord.send_dm(bot_id, discord_user.username, message)
|
response=response,
|
||||||
elif discord_channel := scheduled_call.discord_channel:
|
channel_id=cast(int, scheduled_call.discord_channel.id)
|
||||||
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
|
if scheduled_call.discord_channel
|
||||||
discord.broadcast_message(bot_id, discord_channel.name, message)
|
else None,
|
||||||
else:
|
user_identifier=scheduled_call.discord_user.username
|
||||||
logger.warning(
|
if scheduled_call.discord_user
|
||||||
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -92,15 +96,29 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
|||||||
|
|
||||||
logger.info(f"Calling LLM with model {scheduled_call.model}")
|
logger.info(f"Calling LLM with model {scheduled_call.model}")
|
||||||
|
|
||||||
# Make the LLM call
|
# Make the LLM call with tools support
|
||||||
if scheduled_call.model:
|
try:
|
||||||
response = llms.summarize(
|
response = _call_llm_for_scheduled(session, scheduled_call)
|
||||||
prompt=cast(str, scheduled_call.message),
|
except Exception:
|
||||||
model=cast(str, scheduled_call.model),
|
logger.exception("Failed to generate LLM response")
|
||||||
system_prompt=cast(str, scheduled_call.system_prompt),
|
scheduled_call.status = "failed"
|
||||||
)
|
scheduled_call.error_message = "LLM call failed"
|
||||||
else:
|
session.commit()
|
||||||
response = cast(str, scheduled_call.message)
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "LLM call failed",
|
||||||
|
"scheduled_call_id": scheduled_call_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
scheduled_call.status = "failed"
|
||||||
|
scheduled_call.error_message = "No response from LLM"
|
||||||
|
session.commit()
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "No response from LLM",
|
||||||
|
"scheduled_call_id": scheduled_call_id,
|
||||||
|
}
|
||||||
|
|
||||||
# Store the response
|
# Store the response
|
||||||
scheduled_call.response = response
|
scheduled_call.response = response
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user