mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-22 22:56:38 +02:00
better discord integration
This commit is contained in:
parent
1a3cf9c931
commit
aaa0c2c3cd
@ -0,0 +1,70 @@
|
||||
"""proper user id type
|
||||
|
||||
Revision ID: 7dc03dbf184c
|
||||
Revises: 35a2c1b610b6
|
||||
Create Date: 2025-10-20 22:09:11.243681
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "7dc03dbf184c"
|
||||
down_revision: Union[str, None] = "35a2c1b610b6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"discord_channels", sa.Column("system_prompt", sa.Text(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"discord_channels",
|
||||
sa.Column(
|
||||
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"discord_servers", sa.Column("system_prompt", sa.Text(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"discord_servers",
|
||||
sa.Column(
|
||||
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
|
||||
),
|
||||
)
|
||||
op.add_column("discord_users", sa.Column("system_prompt", sa.Text(), nullable=True))
|
||||
op.add_column(
|
||||
"discord_users",
|
||||
sa.Column(
|
||||
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"scheduled_llm_calls",
|
||||
"user_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
type_=sa.BigInteger(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"scheduled_llm_calls",
|
||||
"user_id",
|
||||
existing_type=sa.BigInteger(),
|
||||
type_=sa.INTEGER(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.drop_column("discord_users", "chattiness_threshold")
|
||||
op.drop_column("discord_users", "system_prompt")
|
||||
op.drop_column("discord_servers", "chattiness_threshold")
|
||||
op.drop_column("discord_servers", "system_prompt")
|
||||
op.drop_column("discord_channels", "chattiness_threshold")
|
||||
op.drop_column("discord_channels", "system_prompt")
|
@ -28,6 +28,18 @@ class MessageProcessor:
|
||||
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
system_prompt = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
doc="System prompt for this processor. The precedence is user -> channel -> server -> default.",
|
||||
)
|
||||
chattiness_threshold = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=50,
|
||||
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
|
||||
)
|
||||
|
||||
summary = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
|
@ -328,6 +328,32 @@ class DiscordMessage(SourceItem):
|
||||
not self.allowed_tools or tool in self.allowed_tools
|
||||
)
|
||||
|
||||
@property
|
||||
def ignore_messages(self) -> bool:
|
||||
return (
|
||||
(self.server and self.server.ignore_messages)
|
||||
or (self.channel and self.channel.ignore_messages)
|
||||
or (self.from_user and self.from_user.ignore_messages)
|
||||
)
|
||||
|
||||
@property
|
||||
def system_prompt(self) -> str:
|
||||
return (
|
||||
(self.from_user and self.from_user.system_prompt)
|
||||
or (self.channel and self.channel.system_prompt)
|
||||
or (self.server and self.server.system_prompt)
|
||||
)
|
||||
|
||||
@property
|
||||
def chattiness_threshold(self) -> int:
|
||||
vals = [
|
||||
(self.from_user and self.from_user.chattiness_threshold),
|
||||
(self.channel and self.channel.chattiness_threshold),
|
||||
(self.server and self.server.chattiness_threshold),
|
||||
90,
|
||||
]
|
||||
return min(val for val in vals if val is not None)
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
||||
|
@ -8,6 +8,7 @@ providing HTTP endpoints for sending Discord messages.
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
import traceback
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
@ -81,21 +82,22 @@ async def send_dm_endpoint(request: SendDMRequest):
|
||||
|
||||
try:
|
||||
success = await app_state.collector.send_dm(request.user, request.message)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to send DM to {request.user}",
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"DM sent to {request.user}",
|
||||
"user": request.user,
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"Failed to send DM: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to send DM to {request.user}",
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"DM sent to {request.user}",
|
||||
"user": request.user,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/send_channel")
|
||||
async def send_channel_endpoint(request: SendChannelRequest):
|
||||
|
@ -4,6 +4,7 @@ Celery tasks for Discord message processing.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@ -50,45 +51,85 @@ def get_prev(
|
||||
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
||||
|
||||
|
||||
def call_llm(
|
||||
session,
|
||||
message: DiscordMessage,
|
||||
model: str,
|
||||
msgs: list[str] = [],
|
||||
allowed_tools: list[str] = [],
|
||||
) -> str | None:
|
||||
tools = make_discord_tools(
|
||||
message.recipient_user.system_user,
|
||||
message.from_user,
|
||||
message.channel,
|
||||
model=model,
|
||||
)
|
||||
tools = {
|
||||
name: tool
|
||||
for name, tool in tools.items()
|
||||
if message.tool_allowed(name) and name in allowed_tools
|
||||
}
|
||||
system_prompt = message.system_prompt or ""
|
||||
system_prompt += comm_channel_prompt(
|
||||
session, message.recipient_user, message.channel
|
||||
)
|
||||
provider = create_provider(model=model)
|
||||
messages = previous_messages(
|
||||
session,
|
||||
message.recipient_user and message.recipient_user.id,
|
||||
message.channel and message.channel.id,
|
||||
max_messages=10,
|
||||
)
|
||||
return provider.run_with_tools(
|
||||
messages=provider.as_messages([m.title for m in messages] + msgs),
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||
).response
|
||||
|
||||
|
||||
def should_process(message: DiscordMessage) -> bool:
|
||||
if not (
|
||||
settings.DISCORD_PROCESS_MESSAGES
|
||||
and settings.DISCORD_NOTIFICATIONS_ENABLED
|
||||
and not (
|
||||
(message.server and message.server.ignore_messages)
|
||||
or (message.channel and message.channel.ignore_messages)
|
||||
or (message.from_user and message.from_user.ignore_messages)
|
||||
)
|
||||
and not message.ignore_messages
|
||||
):
|
||||
return False
|
||||
|
||||
provider = create_provider(model=settings.SUMMARIZER_MODEL)
|
||||
if message.from_user == message.recipient_user:
|
||||
logger.info("Skipping message because from_user == recipient_user")
|
||||
return False
|
||||
|
||||
with make_session() as session:
|
||||
system_prompt = comm_channel_prompt(
|
||||
session, message.recipient_user, message.channel
|
||||
)
|
||||
messages = previous_messages(
|
||||
session,
|
||||
message.recipient_user and message.recipient_user.id,
|
||||
message.channel and message.channel.id,
|
||||
max_messages=10,
|
||||
)
|
||||
msg = textwrap.dedent("""
|
||||
Should you continue the conversation with the user?
|
||||
Please return "yes" or "no" as:
|
||||
|
||||
<response>yes</response>
|
||||
|
||||
or
|
||||
|
||||
<response>no</response>
|
||||
|
||||
Please return a number between 0 and 100 indicating how much you want to continue the conversation (0 is no, 100 is yes).
|
||||
Please return the number in the following format:
|
||||
|
||||
<response>
|
||||
<number>50</number>
|
||||
<reason>I want to continue the conversation because I think it's important.</reason>
|
||||
</response>
|
||||
""")
|
||||
response = provider.generate(
|
||||
messages=provider.as_messages([m.title for m in messages] + [msg]),
|
||||
system_prompt=system_prompt,
|
||||
response = call_llm(
|
||||
session,
|
||||
message,
|
||||
settings.SUMMARIZER_MODEL,
|
||||
[msg],
|
||||
allowed_tools=[
|
||||
"update_channel_summary",
|
||||
"update_user_summary",
|
||||
"update_server_summary",
|
||||
],
|
||||
)
|
||||
return "<response>yes</response>" in "".join(response.lower().split())
|
||||
if not response:
|
||||
return False
|
||||
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||
return False
|
||||
try:
|
||||
return int(res.group(1)) > message.chattiness_threshold
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||
@ -111,39 +152,14 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
tools = make_discord_tools(
|
||||
discord_message.recipient_user,
|
||||
discord_message.from_user,
|
||||
discord_message.channel,
|
||||
model=settings.DISCORD_MODEL,
|
||||
)
|
||||
tools = {
|
||||
name: tool
|
||||
for name, tool in tools.items()
|
||||
if discord_message.tool_allowed(name)
|
||||
}
|
||||
system_prompt = comm_channel_prompt(
|
||||
session, discord_message.recipient_user, discord_message.channel
|
||||
)
|
||||
messages = previous_messages(
|
||||
session,
|
||||
discord_message.recipient_user and discord_message.recipient_user.id,
|
||||
discord_message.channel and discord_message.channel.id,
|
||||
max_messages=10,
|
||||
)
|
||||
provider = create_provider(model=settings.DISCORD_MODEL)
|
||||
turn = provider.run_with_tools(
|
||||
messages=provider.as_messages([m.title for m in messages]),
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||
)
|
||||
if not turn.response:
|
||||
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
||||
|
||||
if not response:
|
||||
pass
|
||||
elif discord_message.channel.server:
|
||||
discord.send_to_channel(discord_message.channel.name, turn.response)
|
||||
discord.send_to_channel(discord_message.channel.name, response)
|
||||
else:
|
||||
discord.send_dm(discord_message.from_user.username, turn.response)
|
||||
discord.send_dm(discord_message.from_user.username, response)
|
||||
|
||||
return {
|
||||
"status": "processed",
|
||||
|
Loading…
x
Reference in New Issue
Block a user