diff --git a/db/migrations/versions/20251020_220911_proper_user_id_type.py b/db/migrations/versions/20251020_220911_proper_user_id_type.py
new file mode 100644
index 0000000..676a737
--- /dev/null
+++ b/db/migrations/versions/20251020_220911_proper_user_id_type.py
@@ -0,0 +1,70 @@
+"""proper user id type
+
+Revision ID: 7dc03dbf184c
+Revises: 35a2c1b610b6
+Create Date: 2025-10-20 22:09:11.243681
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "7dc03dbf184c"
+down_revision: Union[str, None] = "35a2c1b610b6"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ op.add_column(
+ "discord_channels", sa.Column("system_prompt", sa.Text(), nullable=True)
+ )
+ op.add_column(
+ "discord_channels",
+ sa.Column(
+ "chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
+ ),
+ )
+ op.add_column(
+ "discord_servers", sa.Column("system_prompt", sa.Text(), nullable=True)
+ )
+ op.add_column(
+ "discord_servers",
+ sa.Column(
+ "chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
+ ),
+ )
+ op.add_column("discord_users", sa.Column("system_prompt", sa.Text(), nullable=True))
+ op.add_column(
+ "discord_users",
+ sa.Column(
+ "chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
+ ),
+ )
+ op.alter_column(
+ "scheduled_llm_calls",
+ "user_id",
+ existing_type=sa.INTEGER(),
+ type_=sa.BigInteger(),
+ existing_nullable=False,
+ )
+
+
+def downgrade() -> None:
+ op.alter_column(
+ "scheduled_llm_calls",
+ "user_id",
+ existing_type=sa.BigInteger(),
+ type_=sa.INTEGER(),
+ existing_nullable=False,
+ )
+ op.drop_column("discord_users", "chattiness_threshold")
+ op.drop_column("discord_users", "system_prompt")
+ op.drop_column("discord_servers", "chattiness_threshold")
+ op.drop_column("discord_servers", "system_prompt")
+ op.drop_column("discord_channels", "chattiness_threshold")
+ op.drop_column("discord_channels", "system_prompt")
diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py
index 544ffb0..353727e 100644
--- a/src/memory/common/db/models/discord.py
+++ b/src/memory/common/db/models/discord.py
@@ -28,6 +28,18 @@ class MessageProcessor:
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
+ system_prompt = Column(
+ Text,
+ nullable=True,
+ doc="System prompt for this processor. The precedence is user -> channel -> server -> default.",
+ )
+ chattiness_threshold = Column(
+ Integer,
+ nullable=False,
+ default=50,
+ doc="The threshold for the bot to continue the conversation, between 0 and 100.",
+ )
+
summary = Column(
Text,
nullable=True,
diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py
index fd0eac1..34e8307 100644
--- a/src/memory/common/db/models/source_items.py
+++ b/src/memory/common/db/models/source_items.py
@@ -328,6 +328,32 @@ class DiscordMessage(SourceItem):
not self.allowed_tools or tool in self.allowed_tools
)
+ @property
+ def ignore_messages(self) -> bool:
+ return (
+ (self.server and self.server.ignore_messages)
+ or (self.channel and self.channel.ignore_messages)
+ or (self.from_user and self.from_user.ignore_messages)
+ )
+
+ @property
+ def system_prompt(self) -> str:
+ return (
+ (self.from_user and self.from_user.system_prompt)
+ or (self.channel and self.channel.system_prompt)
+ or (self.server and self.server.system_prompt)
+ )
+
+ @property
+ def chattiness_threshold(self) -> int:
+ vals = [
+ (self.from_user and self.from_user.chattiness_threshold),
+ (self.channel and self.channel.chattiness_threshold),
+ (self.server and self.server.chattiness_threshold),
+ 90,
+ ]
+ return min(val for val in vals if val is not None)
+
@property
def title(self) -> str:
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py
index 129e96e..53d6604 100644
--- a/src/memory/discord/api.py
+++ b/src/memory/discord/api.py
@@ -8,6 +8,7 @@ providing HTTP endpoints for sending Discord messages.
import asyncio
import logging
from contextlib import asynccontextmanager
+import traceback
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
@@ -81,21 +82,22 @@ async def send_dm_endpoint(request: SendDMRequest):
try:
success = await app_state.collector.send_dm(request.user, request.message)
-
- if not success:
- raise HTTPException(
- status_code=400,
- detail=f"Failed to send DM to {request.user}",
- )
- return {
- "success": True,
- "message": f"DM sent to {request.user}",
- "user": request.user,
- }
except Exception as e:
+ traceback.print_exc()
logger.error(f"Failed to send DM: {e}")
raise HTTPException(status_code=500, detail=str(e))
+ if not success:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to send DM to {request.user}",
+ )
+ return {
+ "success": True,
+ "message": f"DM sent to {request.user}",
+ "user": request.user,
+ }
+
@app.post("/send_channel")
async def send_channel_endpoint(request: SendChannelRequest):
diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py
index f146af0..8f6cfb4 100644
--- a/src/memory/workers/tasks/discord.py
+++ b/src/memory/workers/tasks/discord.py
@@ -4,6 +4,7 @@ Celery tasks for Discord message processing.
import hashlib
import logging
+import re
import textwrap
from datetime import datetime
from typing import Any
@@ -50,45 +51,85 @@ def get_prev(
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
+def call_llm(
+ session,
+ message: DiscordMessage,
+ model: str,
+ msgs: list[str] = [],
+ allowed_tools: list[str] = [],
+) -> str | None:
+ tools = make_discord_tools(
+ message.recipient_user.system_user,
+ message.from_user,
+ message.channel,
+ model=model,
+ )
+ tools = {
+ name: tool
+ for name, tool in tools.items()
+ if message.tool_allowed(name) and name in allowed_tools
+ }
+ system_prompt = message.system_prompt or ""
+ system_prompt += comm_channel_prompt(
+ session, message.recipient_user, message.channel
+ )
+ provider = create_provider(model=model)
+ messages = previous_messages(
+ session,
+ message.recipient_user and message.recipient_user.id,
+ message.channel and message.channel.id,
+ max_messages=10,
+ )
+ return provider.run_with_tools(
+ messages=provider.as_messages([m.title for m in messages] + msgs),
+ tools=tools,
+ system_prompt=system_prompt,
+ max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
+ ).response
+
+
def should_process(message: DiscordMessage) -> bool:
if not (
settings.DISCORD_PROCESS_MESSAGES
and settings.DISCORD_NOTIFICATIONS_ENABLED
- and not (
- (message.server and message.server.ignore_messages)
- or (message.channel and message.channel.ignore_messages)
- or (message.from_user and message.from_user.ignore_messages)
- )
+ and not message.ignore_messages
):
return False
- provider = create_provider(model=settings.SUMMARIZER_MODEL)
+ if message.from_user == message.recipient_user:
+ logger.info("Skipping message because from_user == recipient_user")
+ return False
+
with make_session() as session:
- system_prompt = comm_channel_prompt(
- session, message.recipient_user, message.channel
- )
- messages = previous_messages(
- session,
- message.recipient_user and message.recipient_user.id,
- message.channel and message.channel.id,
- max_messages=10,
- )
msg = textwrap.dedent("""
Should you continue the conversation with the user?
- Please return "yes" or "no" as:
-
- yes
-
- or
-
- no
-
+ Please return a number between 0 and 100 indicating how much you want to continue the conversation (0 is no, 100 is yes).
+ Please return the number in the following format:
+
+
+ 50
+ I want to continue the conversation because I think it's important.
+
""")
- response = provider.generate(
- messages=provider.as_messages([m.title for m in messages] + [msg]),
- system_prompt=system_prompt,
+ response = call_llm(
+ session,
+ message,
+ settings.SUMMARIZER_MODEL,
+ [msg],
+ allowed_tools=[
+ "update_channel_summary",
+ "update_user_summary",
+ "update_server_summary",
+ ],
)
- return "yes" in "".join(response.lower().split())
+ if not response:
+ return False
+ if not (res := re.search(r"(.*)", response)):
+ return False
+ try:
+ return int(res.group(1)) > message.chattiness_threshold
+ except ValueError:
+ return False
@app.task(name=PROCESS_DISCORD_MESSAGE)
@@ -111,39 +152,14 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id,
}
- tools = make_discord_tools(
- discord_message.recipient_user,
- discord_message.from_user,
- discord_message.channel,
- model=settings.DISCORD_MODEL,
- )
- tools = {
- name: tool
- for name, tool in tools.items()
- if discord_message.tool_allowed(name)
- }
- system_prompt = comm_channel_prompt(
- session, discord_message.recipient_user, discord_message.channel
- )
- messages = previous_messages(
- session,
- discord_message.recipient_user and discord_message.recipient_user.id,
- discord_message.channel and discord_message.channel.id,
- max_messages=10,
- )
- provider = create_provider(model=settings.DISCORD_MODEL)
- turn = provider.run_with_tools(
- messages=provider.as_messages([m.title for m in messages]),
- tools=tools,
- system_prompt=system_prompt,
- max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
- )
- if not turn.response:
+ response = call_llm(session, discord_message, settings.DISCORD_MODEL)
+
+ if not response:
pass
elif discord_message.channel.server:
- discord.send_to_channel(discord_message.channel.name, turn.response)
+ discord.send_to_channel(discord_message.channel.name, response)
else:
- discord.send_dm(discord_message.from_user.username, turn.response)
+ discord.send_dm(discord_message.from_user.username, response)
return {
"status": "processed",