better discord integration

This commit is contained in:
Daniel O'Connell 2025-10-20 23:08:34 +02:00
parent 1a3cf9c931
commit aaa0c2c3cd
5 changed files with 194 additions and 68 deletions

View File

@ -0,0 +1,70 @@
"""proper user id type
Revision ID: 7dc03dbf184c
Revises: 35a2c1b610b6
Create Date: 2025-10-20 22:09:11.243681
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "7dc03dbf184c"
down_revision: Union[str, None] = "35a2c1b610b6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"discord_channels", sa.Column("system_prompt", sa.Text(), nullable=True)
)
op.add_column(
"discord_channels",
sa.Column(
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
),
)
op.add_column(
"discord_servers", sa.Column("system_prompt", sa.Text(), nullable=True)
)
op.add_column(
"discord_servers",
sa.Column(
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
),
)
op.add_column("discord_users", sa.Column("system_prompt", sa.Text(), nullable=True))
op.add_column(
"discord_users",
sa.Column(
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
),
)
op.alter_column(
"scheduled_llm_calls",
"user_id",
existing_type=sa.INTEGER(),
type_=sa.BigInteger(),
existing_nullable=False,
)
def downgrade() -> None:
op.alter_column(
"scheduled_llm_calls",
"user_id",
existing_type=sa.BigInteger(),
type_=sa.INTEGER(),
existing_nullable=False,
)
op.drop_column("discord_users", "chattiness_threshold")
op.drop_column("discord_users", "system_prompt")
op.drop_column("discord_servers", "chattiness_threshold")
op.drop_column("discord_servers", "system_prompt")
op.drop_column("discord_channels", "chattiness_threshold")
op.drop_column("discord_channels", "system_prompt")

View File

@ -28,6 +28,18 @@ class MessageProcessor:
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
system_prompt = Column(
Text,
nullable=True,
doc="System prompt for this processor. The precedence is user -> channel -> server -> default.",
)
chattiness_threshold = Column(
Integer,
nullable=False,
default=50,
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
)
summary = Column( summary = Column(
Text, Text,
nullable=True, nullable=True,

View File

@ -328,6 +328,32 @@ class DiscordMessage(SourceItem):
not self.allowed_tools or tool in self.allowed_tools not self.allowed_tools or tool in self.allowed_tools
) )
@property
def ignore_messages(self) -> bool:
return (
(self.server and self.server.ignore_messages)
or (self.channel and self.channel.ignore_messages)
or (self.from_user and self.from_user.ignore_messages)
)
@property
def system_prompt(self) -> str:
return (
(self.from_user and self.from_user.system_prompt)
or (self.channel and self.channel.system_prompt)
or (self.server and self.server.system_prompt)
)
@property
def chattiness_threshold(self) -> int:
vals = [
(self.from_user and self.from_user.chattiness_threshold),
(self.channel and self.channel.chattiness_threshold),
(self.server and self.server.chattiness_threshold),
90,
]
return min(val for val in vals if val is not None)
@property @property
def title(self) -> str: def title(self) -> str:
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}" return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"

View File

@ -8,6 +8,7 @@ providing HTTP endpoints for sending Discord messages.
import asyncio import asyncio
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import traceback
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
@ -81,21 +82,22 @@ async def send_dm_endpoint(request: SendDMRequest):
try: try:
success = await app_state.collector.send_dm(request.user, request.message) success = await app_state.collector.send_dm(request.user, request.message)
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to send DM to {request.user}",
)
return {
"success": True,
"message": f"DM sent to {request.user}",
"user": request.user,
}
except Exception as e: except Exception as e:
traceback.print_exc()
logger.error(f"Failed to send DM: {e}") logger.error(f"Failed to send DM: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to send DM to {request.user}",
)
return {
"success": True,
"message": f"DM sent to {request.user}",
"user": request.user,
}
@app.post("/send_channel") @app.post("/send_channel")
async def send_channel_endpoint(request: SendChannelRequest): async def send_channel_endpoint(request: SendChannelRequest):

View File

@ -4,6 +4,7 @@ Celery tasks for Discord message processing.
import hashlib import hashlib
import logging import logging
import re
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
@ -50,45 +51,85 @@ def get_prev(
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]] return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
def call_llm(
session,
message: DiscordMessage,
model: str,
msgs: list[str] = [],
allowed_tools: list[str] = [],
) -> str | None:
tools = make_discord_tools(
message.recipient_user.system_user,
message.from_user,
message.channel,
model=model,
)
tools = {
name: tool
for name, tool in tools.items()
if message.tool_allowed(name) and name in allowed_tools
}
system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel
)
provider = create_provider(model=model)
messages = previous_messages(
session,
message.recipient_user and message.recipient_user.id,
message.channel and message.channel.id,
max_messages=10,
)
return provider.run_with_tools(
messages=provider.as_messages([m.title for m in messages] + msgs),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response
def should_process(message: DiscordMessage) -> bool: def should_process(message: DiscordMessage) -> bool:
if not ( if not (
settings.DISCORD_PROCESS_MESSAGES settings.DISCORD_PROCESS_MESSAGES
and settings.DISCORD_NOTIFICATIONS_ENABLED and settings.DISCORD_NOTIFICATIONS_ENABLED
and not ( and not message.ignore_messages
(message.server and message.server.ignore_messages)
or (message.channel and message.channel.ignore_messages)
or (message.from_user and message.from_user.ignore_messages)
)
): ):
return False return False
provider = create_provider(model=settings.SUMMARIZER_MODEL) if message.from_user == message.recipient_user:
logger.info("Skipping message because from_user == recipient_user")
return False
with make_session() as session: with make_session() as session:
system_prompt = comm_channel_prompt(
session, message.recipient_user, message.channel
)
messages = previous_messages(
session,
message.recipient_user and message.recipient_user.id,
message.channel and message.channel.id,
max_messages=10,
)
msg = textwrap.dedent(""" msg = textwrap.dedent("""
Should you continue the conversation with the user? Should you continue the conversation with the user?
Please return "yes" or "no" as: Please return a number between 0 and 100 indicating how much you want to continue the conversation (0 is no, 100 is yes).
Please return the number in the following format:
<response>yes</response>
<response>
or <number>50</number>
<reason>I want to continue the conversation because I think it's important.</reason>
<response>no</response> </response>
""") """)
response = provider.generate( response = call_llm(
messages=provider.as_messages([m.title for m in messages] + [msg]), session,
system_prompt=system_prompt, message,
settings.SUMMARIZER_MODEL,
[msg],
allowed_tools=[
"update_channel_summary",
"update_user_summary",
"update_server_summary",
],
) )
return "<response>yes</response>" in "".join(response.lower().split()) if not response:
return False
if not (res := re.search(r"<number>(.*)</number>", response)):
return False
try:
return int(res.group(1)) > message.chattiness_threshold
except ValueError:
return False
@app.task(name=PROCESS_DISCORD_MESSAGE) @app.task(name=PROCESS_DISCORD_MESSAGE)
@ -111,39 +152,14 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id, "message_id": message_id,
} }
tools = make_discord_tools( response = call_llm(session, discord_message, settings.DISCORD_MODEL)
discord_message.recipient_user,
discord_message.from_user, if not response:
discord_message.channel,
model=settings.DISCORD_MODEL,
)
tools = {
name: tool
for name, tool in tools.items()
if discord_message.tool_allowed(name)
}
system_prompt = comm_channel_prompt(
session, discord_message.recipient_user, discord_message.channel
)
messages = previous_messages(
session,
discord_message.recipient_user and discord_message.recipient_user.id,
discord_message.channel and discord_message.channel.id,
max_messages=10,
)
provider = create_provider(model=settings.DISCORD_MODEL)
turn = provider.run_with_tools(
messages=provider.as_messages([m.title for m in messages]),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
)
if not turn.response:
pass pass
elif discord_message.channel.server: elif discord_message.channel.server:
discord.send_to_channel(discord_message.channel.name, turn.response) discord.send_to_channel(discord_message.channel.name, response)
else: else:
discord.send_dm(discord_message.from_user.username, turn.response) discord.send_dm(discord_message.from_user.username, response)
return { return {
"status": "processed", "status": "processed",