From aaa0c2c3cd3ebc99e94ebc5b2fa95f2696ec5b86 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 20 Oct 2025 23:08:34 +0200 Subject: [PATCH] better discord integration --- .../20251020_220911_proper_user_id_type.py | 70 ++++++++++ src/memory/common/db/models/discord.py | 12 ++ src/memory/common/db/models/source_items.py | 26 ++++ src/memory/discord/api.py | 24 ++-- src/memory/workers/tasks/discord.py | 130 ++++++++++-------- 5 files changed, 194 insertions(+), 68 deletions(-) create mode 100644 db/migrations/versions/20251020_220911_proper_user_id_type.py diff --git a/db/migrations/versions/20251020_220911_proper_user_id_type.py b/db/migrations/versions/20251020_220911_proper_user_id_type.py new file mode 100644 index 0000000..676a737 --- /dev/null +++ b/db/migrations/versions/20251020_220911_proper_user_id_type.py @@ -0,0 +1,70 @@ +"""proper user id type + +Revision ID: 7dc03dbf184c +Revises: 35a2c1b610b6 +Create Date: 2025-10-20 22:09:11.243681 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "7dc03dbf184c" +down_revision: Union[str, None] = "35a2c1b610b6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "discord_channels", sa.Column("system_prompt", sa.Text(), nullable=True) + ) + op.add_column( + "discord_channels", + sa.Column( + "chattiness_threshold", sa.Integer(), nullable=False, server_default="50" + ), + ) + op.add_column( + "discord_servers", sa.Column("system_prompt", sa.Text(), nullable=True) + ) + op.add_column( + "discord_servers", + sa.Column( + "chattiness_threshold", sa.Integer(), nullable=False, server_default="50" + ), + ) + op.add_column("discord_users", sa.Column("system_prompt", sa.Text(), nullable=True)) + op.add_column( + "discord_users", + sa.Column( + "chattiness_threshold", sa.Integer(), nullable=False, server_default="50" + ), + ) + op.alter_column( + "scheduled_llm_calls", + "user_id", + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=False, + ) + + +def downgrade() -> None: + op.alter_column( + "scheduled_llm_calls", + "user_id", + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=False, + ) + op.drop_column("discord_users", "chattiness_threshold") + op.drop_column("discord_users", "system_prompt") + op.drop_column("discord_servers", "chattiness_threshold") + op.drop_column("discord_servers", "system_prompt") + op.drop_column("discord_channels", "chattiness_threshold") + op.drop_column("discord_channels", "system_prompt") diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py index 544ffb0..353727e 100644 --- a/src/memory/common/db/models/discord.py +++ b/src/memory/common/db/models/discord.py @@ -28,6 +28,18 @@ class MessageProcessor: allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") + system_prompt = Column( + Text, + nullable=True, + doc="System prompt for this processor. The precedence is user -> channel -> server -> default.", + ) + chattiness_threshold = Column( + Integer, + nullable=False, + default=50, + doc="The threshold for the bot to continue the conversation, between 0 and 100.", + ) + summary = Column( Text, nullable=True, diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index fd0eac1..34e8307 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -328,6 +328,32 @@ class DiscordMessage(SourceItem): not self.allowed_tools or tool in self.allowed_tools ) + @property + def ignore_messages(self) -> bool: + return ( + (self.server and self.server.ignore_messages) + or (self.channel and self.channel.ignore_messages) + or (self.from_user and self.from_user.ignore_messages) + ) + + @property + def system_prompt(self) -> str: + return ( + (self.from_user and self.from_user.system_prompt) + or (self.channel and self.channel.system_prompt) + or (self.server and self.server.system_prompt) + ) + + @property + def chattiness_threshold(self) -> int: + vals = [ + (self.from_user and self.from_user.chattiness_threshold), + (self.channel and self.channel.chattiness_threshold), + (self.server and self.server.chattiness_threshold), + 90, + ] + return min(val for val in vals if val is not None) + @property def title(self) -> str: return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}" diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py index 129e96e..53d6604 100644 --- a/src/memory/discord/api.py +++ b/src/memory/discord/api.py @@ -8,6 +8,7 @@ providing HTTP endpoints for sending Discord messages. import asyncio import logging from contextlib import asynccontextmanager +import traceback from fastapi import FastAPI, HTTPException from pydantic import BaseModel @@ -81,21 +82,22 @@ async def send_dm_endpoint(request: SendDMRequest): try: success = await app_state.collector.send_dm(request.user, request.message) - - if not success: - raise HTTPException( - status_code=400, - detail=f"Failed to send DM to {request.user}", - ) - return { - "success": True, - "message": f"DM sent to {request.user}", - "user": request.user, - } except Exception as e: + traceback.print_exc() logger.error(f"Failed to send DM: {e}") raise HTTPException(status_code=500, detail=str(e)) + if not success: + raise HTTPException( + status_code=400, + detail=f"Failed to send DM to {request.user}", + ) + return { + "success": True, + "message": f"DM sent to {request.user}", + "user": request.user, + } + @app.post("/send_channel") async def send_channel_endpoint(request: SendChannelRequest): diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index f146af0..8f6cfb4 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -4,6 +4,7 @@ Celery tasks for Discord message processing. import hashlib import logging +import re import textwrap from datetime import datetime from typing import Any @@ -50,45 +51,85 @@ def get_prev( return [f"{msg.username}: {msg.content}" for msg in prev[::-1]] +def call_llm( + session, + message: DiscordMessage, + model: str, + msgs: list[str] = [], + allowed_tools: list[str] = [], +) -> str | None: + tools = make_discord_tools( + message.recipient_user.system_user, + message.from_user, + message.channel, + model=model, + ) + tools = { + name: tool + for name, tool in tools.items() + if message.tool_allowed(name) and name in allowed_tools + } + system_prompt = message.system_prompt or "" + system_prompt += comm_channel_prompt( + session, message.recipient_user, message.channel + ) + provider = create_provider(model=model) + messages = previous_messages( + session, + message.recipient_user and message.recipient_user.id, + message.channel and message.channel.id, + max_messages=10, + ) + return provider.run_with_tools( + messages=provider.as_messages([m.title for m in messages] + msgs), + tools=tools, + system_prompt=system_prompt, + max_iterations=settings.DISCORD_MAX_TOOL_CALLS, + ).response + + def should_process(message: DiscordMessage) -> bool: if not ( settings.DISCORD_PROCESS_MESSAGES and settings.DISCORD_NOTIFICATIONS_ENABLED - and not ( - (message.server and message.server.ignore_messages) - or (message.channel and message.channel.ignore_messages) - or (message.from_user and message.from_user.ignore_messages) - ) + and not message.ignore_messages ): return False - provider = create_provider(model=settings.SUMMARIZER_MODEL) + if message.from_user == message.recipient_user: + logger.info("Skipping message because from_user == recipient_user") + return False + with make_session() as session: - system_prompt = comm_channel_prompt( - session, message.recipient_user, message.channel - ) - messages = previous_messages( - session, - message.recipient_user and message.recipient_user.id, - message.channel and message.channel.id, - max_messages=10, - ) msg = textwrap.dedent(""" Should you continue the conversation with the user? - Please return "yes" or "no" as: - - yes - - or - - no - + Please return a number between 0 and 100 indicating how much you want to continue the conversation (0 is no, 100 is yes). + Please return the number in the following format: + + + 50 + I want to continue the conversation because I think it's important. + """) - response = provider.generate( - messages=provider.as_messages([m.title for m in messages] + [msg]), - system_prompt=system_prompt, + response = call_llm( + session, + message, + settings.SUMMARIZER_MODEL, + [msg], + allowed_tools=[ + "update_channel_summary", + "update_user_summary", + "update_server_summary", + ], ) - return "yes" in "".join(response.lower().split()) + if not response: + return False + if not (res := re.search(r"(.*)", response)): + return False + try: + return int(res.group(1)) > message.chattiness_threshold + except ValueError: + return False @app.task(name=PROCESS_DISCORD_MESSAGE) @@ -111,39 +152,14 @@ def process_discord_message(message_id: int) -> dict[str, Any]: "message_id": message_id, } - tools = make_discord_tools( - discord_message.recipient_user, - discord_message.from_user, - discord_message.channel, - model=settings.DISCORD_MODEL, - ) - tools = { - name: tool - for name, tool in tools.items() - if discord_message.tool_allowed(name) - } - system_prompt = comm_channel_prompt( - session, discord_message.recipient_user, discord_message.channel - ) - messages = previous_messages( - session, - discord_message.recipient_user and discord_message.recipient_user.id, - discord_message.channel and discord_message.channel.id, - max_messages=10, - ) - provider = create_provider(model=settings.DISCORD_MODEL) - turn = provider.run_with_tools( - messages=provider.as_messages([m.title for m in messages]), - tools=tools, - system_prompt=system_prompt, - max_iterations=settings.DISCORD_MAX_TOOL_CALLS, - ) - if not turn.response: + response = call_llm(session, discord_message, settings.DISCORD_MODEL) + + if not response: pass elif discord_message.channel.server: - discord.send_to_channel(discord_message.channel.name, turn.response) + discord.send_to_channel(discord_message.channel.name, response) else: - discord.send_dm(discord_message.from_user.username, turn.response) + discord.send_dm(discord_message.from_user.username, response) return { "status": "processed",