better discord integration

This commit is contained in:
Daniel O'Connell 2025-10-20 23:08:34 +02:00
parent 1a3cf9c931
commit aaa0c2c3cd
5 changed files with 194 additions and 68 deletions

View File

@ -0,0 +1,70 @@
"""proper user id type
Revision ID: 7dc03dbf184c
Revises: 35a2c1b610b6
Create Date: 2025-10-20 22:09:11.243681
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "7dc03dbf184c"
down_revision: Union[str, None] = "35a2c1b610b6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"discord_channels", sa.Column("system_prompt", sa.Text(), nullable=True)
)
op.add_column(
"discord_channels",
sa.Column(
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
),
)
op.add_column(
"discord_servers", sa.Column("system_prompt", sa.Text(), nullable=True)
)
op.add_column(
"discord_servers",
sa.Column(
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
),
)
op.add_column("discord_users", sa.Column("system_prompt", sa.Text(), nullable=True))
op.add_column(
"discord_users",
sa.Column(
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
),
)
op.alter_column(
"scheduled_llm_calls",
"user_id",
existing_type=sa.INTEGER(),
type_=sa.BigInteger(),
existing_nullable=False,
)
def downgrade() -> None:
op.alter_column(
"scheduled_llm_calls",
"user_id",
existing_type=sa.BigInteger(),
type_=sa.INTEGER(),
existing_nullable=False,
)
op.drop_column("discord_users", "chattiness_threshold")
op.drop_column("discord_users", "system_prompt")
op.drop_column("discord_servers", "chattiness_threshold")
op.drop_column("discord_servers", "system_prompt")
op.drop_column("discord_channels", "chattiness_threshold")
op.drop_column("discord_channels", "system_prompt")

View File

@ -28,6 +28,18 @@ class MessageProcessor:
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
system_prompt = Column(
Text,
nullable=True,
doc="System prompt for this processor. The precedence is user -> channel -> server -> default.",
)
chattiness_threshold = Column(
Integer,
nullable=False,
default=50,
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
)
summary = Column( summary = Column(
Text, Text,
nullable=True, nullable=True,

View File

@ -328,6 +328,32 @@ class DiscordMessage(SourceItem):
not self.allowed_tools or tool in self.allowed_tools not self.allowed_tools or tool in self.allowed_tools
) )
@property
def ignore_messages(self) -> bool:
return (
(self.server and self.server.ignore_messages)
or (self.channel and self.channel.ignore_messages)
or (self.from_user and self.from_user.ignore_messages)
)
@property
def system_prompt(self) -> str:
return (
(self.from_user and self.from_user.system_prompt)
or (self.channel and self.channel.system_prompt)
or (self.server and self.server.system_prompt)
)
@property
def chattiness_threshold(self) -> int:
vals = [
(self.from_user and self.from_user.chattiness_threshold),
(self.channel and self.channel.chattiness_threshold),
(self.server and self.server.chattiness_threshold),
90,
]
return min(val for val in vals if val is not None)
@property @property
def title(self) -> str: def title(self) -> str:
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}" return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"

View File

@ -8,6 +8,7 @@ providing HTTP endpoints for sending Discord messages.
import asyncio import asyncio
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import traceback
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
@ -81,6 +82,10 @@ async def send_dm_endpoint(request: SendDMRequest):
try: try:
success = await app_state.collector.send_dm(request.user, request.message) success = await app_state.collector.send_dm(request.user, request.message)
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to send DM: {e}")
raise HTTPException(status_code=500, detail=str(e))
if not success: if not success:
raise HTTPException( raise HTTPException(
@ -92,9 +97,6 @@ async def send_dm_endpoint(request: SendDMRequest):
"message": f"DM sent to {request.user}", "message": f"DM sent to {request.user}",
"user": request.user, "user": request.user,
} }
except Exception as e:
logger.error(f"Failed to send DM: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/send_channel") @app.post("/send_channel")

View File

@ -4,6 +4,7 @@ Celery tasks for Discord message processing.
import hashlib import hashlib
import logging import logging
import re
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
@ -50,45 +51,85 @@ def get_prev(
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]] return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
def should_process(message: DiscordMessage) -> bool: def call_llm(
if not ( session,
settings.DISCORD_PROCESS_MESSAGES message: DiscordMessage,
and settings.DISCORD_NOTIFICATIONS_ENABLED model: str,
and not ( msgs: list[str] = [],
(message.server and message.server.ignore_messages) allowed_tools: list[str] = [],
or (message.channel and message.channel.ignore_messages) ) -> str | None:
or (message.from_user and message.from_user.ignore_messages) tools = make_discord_tools(
message.recipient_user.system_user,
message.from_user,
message.channel,
model=model,
) )
): tools = {
return False name: tool
for name, tool in tools.items()
provider = create_provider(model=settings.SUMMARIZER_MODEL) if message.tool_allowed(name) and name in allowed_tools
with make_session() as session: }
system_prompt = comm_channel_prompt( system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel session, message.recipient_user, message.channel
) )
provider = create_provider(model=model)
messages = previous_messages( messages = previous_messages(
session, session,
message.recipient_user and message.recipient_user.id, message.recipient_user and message.recipient_user.id,
message.channel and message.channel.id, message.channel and message.channel.id,
max_messages=10, max_messages=10,
) )
return provider.run_with_tools(
messages=provider.as_messages([m.title for m in messages] + msgs),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response
def should_process(message: DiscordMessage) -> bool:
if not (
settings.DISCORD_PROCESS_MESSAGES
and settings.DISCORD_NOTIFICATIONS_ENABLED
and not message.ignore_messages
):
return False
if message.from_user == message.recipient_user:
logger.info("Skipping message because from_user == recipient_user")
return False
with make_session() as session:
msg = textwrap.dedent(""" msg = textwrap.dedent("""
Should you continue the conversation with the user? Should you continue the conversation with the user?
Please return "yes" or "no" as: Please return a number between 0 and 100 indicating how much you want to continue the conversation (0 is no, 100 is yes).
Please return the number in the following format:
<response>yes</response>
or
<response>no</response>
<response>
<number>50</number>
<reason>I want to continue the conversation because I think it's important.</reason>
</response>
""") """)
response = provider.generate( response = call_llm(
messages=provider.as_messages([m.title for m in messages] + [msg]), session,
system_prompt=system_prompt, message,
settings.SUMMARIZER_MODEL,
[msg],
allowed_tools=[
"update_channel_summary",
"update_user_summary",
"update_server_summary",
],
) )
return "<response>yes</response>" in "".join(response.lower().split()) if not response:
return False
if not (res := re.search(r"<number>(.*)</number>", response)):
return False
try:
return int(res.group(1)) > message.chattiness_threshold
except ValueError:
return False
@app.task(name=PROCESS_DISCORD_MESSAGE) @app.task(name=PROCESS_DISCORD_MESSAGE)
@ -111,39 +152,14 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id, "message_id": message_id,
} }
tools = make_discord_tools( response = call_llm(session, discord_message, settings.DISCORD_MODEL)
discord_message.recipient_user,
discord_message.from_user, if not response:
discord_message.channel,
model=settings.DISCORD_MODEL,
)
tools = {
name: tool
for name, tool in tools.items()
if discord_message.tool_allowed(name)
}
system_prompt = comm_channel_prompt(
session, discord_message.recipient_user, discord_message.channel
)
messages = previous_messages(
session,
discord_message.recipient_user and discord_message.recipient_user.id,
discord_message.channel and discord_message.channel.id,
max_messages=10,
)
provider = create_provider(model=settings.DISCORD_MODEL)
turn = provider.run_with_tools(
messages=provider.as_messages([m.title for m in messages]),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
)
if not turn.response:
pass pass
elif discord_message.channel.server: elif discord_message.channel.server:
discord.send_to_channel(discord_message.channel.name, turn.response) discord.send_to_channel(discord_message.channel.name, response)
else: else:
discord.send_dm(discord_message.from_user.username, turn.response) discord.send_dm(discord_message.from_user.username, response)
return { return {
"status": "processed", "status": "processed",