mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-22 22:56:38 +02:00
better discord integration
This commit is contained in:
parent
1a3cf9c931
commit
aaa0c2c3cd
@ -0,0 +1,70 @@
|
|||||||
|
"""proper user id type
|
||||||
|
|
||||||
|
Revision ID: 7dc03dbf184c
|
||||||
|
Revises: 35a2c1b610b6
|
||||||
|
Create Date: 2025-10-20 22:09:11.243681
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "7dc03dbf184c"
|
||||||
|
down_revision: Union[str, None] = "35a2c1b610b6"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"discord_channels", sa.Column("system_prompt", sa.Text(), nullable=True)
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"discord_channels",
|
||||||
|
sa.Column(
|
||||||
|
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"discord_servers", sa.Column("system_prompt", sa.Text(), nullable=True)
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"discord_servers",
|
||||||
|
sa.Column(
|
||||||
|
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.add_column("discord_users", sa.Column("system_prompt", sa.Text(), nullable=True))
|
||||||
|
op.add_column(
|
||||||
|
"discord_users",
|
||||||
|
sa.Column(
|
||||||
|
"chattiness_threshold", sa.Integer(), nullable=False, server_default="50"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
"user_id",
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
type_=sa.BigInteger(),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
"user_id",
|
||||||
|
existing_type=sa.BigInteger(),
|
||||||
|
type_=sa.INTEGER(),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
op.drop_column("discord_users", "chattiness_threshold")
|
||||||
|
op.drop_column("discord_users", "system_prompt")
|
||||||
|
op.drop_column("discord_servers", "chattiness_threshold")
|
||||||
|
op.drop_column("discord_servers", "system_prompt")
|
||||||
|
op.drop_column("discord_channels", "chattiness_threshold")
|
||||||
|
op.drop_column("discord_channels", "system_prompt")
|
@ -28,6 +28,18 @@ class MessageProcessor:
|
|||||||
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
|
system_prompt = Column(
|
||||||
|
Text,
|
||||||
|
nullable=True,
|
||||||
|
doc="System prompt for this processor. The precedence is user -> channel -> server -> default.",
|
||||||
|
)
|
||||||
|
chattiness_threshold = Column(
|
||||||
|
Integer,
|
||||||
|
nullable=False,
|
||||||
|
default=50,
|
||||||
|
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
|
||||||
|
)
|
||||||
|
|
||||||
summary = Column(
|
summary = Column(
|
||||||
Text,
|
Text,
|
||||||
nullable=True,
|
nullable=True,
|
||||||
|
@ -328,6 +328,32 @@ class DiscordMessage(SourceItem):
|
|||||||
not self.allowed_tools or tool in self.allowed_tools
|
not self.allowed_tools or tool in self.allowed_tools
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_messages(self) -> bool:
|
||||||
|
return (
|
||||||
|
(self.server and self.server.ignore_messages)
|
||||||
|
or (self.channel and self.channel.ignore_messages)
|
||||||
|
or (self.from_user and self.from_user.ignore_messages)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def system_prompt(self) -> str:
|
||||||
|
return (
|
||||||
|
(self.from_user and self.from_user.system_prompt)
|
||||||
|
or (self.channel and self.channel.system_prompt)
|
||||||
|
or (self.server and self.server.system_prompt)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chattiness_threshold(self) -> int:
|
||||||
|
vals = [
|
||||||
|
(self.from_user and self.from_user.chattiness_threshold),
|
||||||
|
(self.channel and self.channel.chattiness_threshold),
|
||||||
|
(self.server and self.server.chattiness_threshold),
|
||||||
|
90,
|
||||||
|
]
|
||||||
|
return min(val for val in vals if val is not None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def title(self) -> str:
|
def title(self) -> str:
|
||||||
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
||||||
|
@ -8,6 +8,7 @@ providing HTTP endpoints for sending Discord messages.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import traceback
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -81,21 +82,22 @@ async def send_dm_endpoint(request: SendDMRequest):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
success = await app_state.collector.send_dm(request.user, request.message)
|
success = await app_state.collector.send_dm(request.user, request.message)
|
||||||
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Failed to send DM to {request.user}",
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": f"DM sent to {request.user}",
|
|
||||||
"user": request.user,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
logger.error(f"Failed to send DM: {e}")
|
logger.error(f"Failed to send DM: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to send DM to {request.user}",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"DM sent to {request.user}",
|
||||||
|
"user": request.user,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/send_channel")
|
@app.post("/send_channel")
|
||||||
async def send_channel_endpoint(request: SendChannelRequest):
|
async def send_channel_endpoint(request: SendChannelRequest):
|
||||||
|
@ -4,6 +4,7 @@ Celery tasks for Discord message processing.
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -50,45 +51,85 @@ def get_prev(
|
|||||||
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
||||||
|
|
||||||
|
|
||||||
|
def call_llm(
|
||||||
|
session,
|
||||||
|
message: DiscordMessage,
|
||||||
|
model: str,
|
||||||
|
msgs: list[str] = [],
|
||||||
|
allowed_tools: list[str] = [],
|
||||||
|
) -> str | None:
|
||||||
|
tools = make_discord_tools(
|
||||||
|
message.recipient_user.system_user,
|
||||||
|
message.from_user,
|
||||||
|
message.channel,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
tools = {
|
||||||
|
name: tool
|
||||||
|
for name, tool in tools.items()
|
||||||
|
if message.tool_allowed(name) and name in allowed_tools
|
||||||
|
}
|
||||||
|
system_prompt = message.system_prompt or ""
|
||||||
|
system_prompt += comm_channel_prompt(
|
||||||
|
session, message.recipient_user, message.channel
|
||||||
|
)
|
||||||
|
provider = create_provider(model=model)
|
||||||
|
messages = previous_messages(
|
||||||
|
session,
|
||||||
|
message.recipient_user and message.recipient_user.id,
|
||||||
|
message.channel and message.channel.id,
|
||||||
|
max_messages=10,
|
||||||
|
)
|
||||||
|
return provider.run_with_tools(
|
||||||
|
messages=provider.as_messages([m.title for m in messages] + msgs),
|
||||||
|
tools=tools,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||||
|
).response
|
||||||
|
|
||||||
|
|
||||||
def should_process(message: DiscordMessage) -> bool:
|
def should_process(message: DiscordMessage) -> bool:
|
||||||
if not (
|
if not (
|
||||||
settings.DISCORD_PROCESS_MESSAGES
|
settings.DISCORD_PROCESS_MESSAGES
|
||||||
and settings.DISCORD_NOTIFICATIONS_ENABLED
|
and settings.DISCORD_NOTIFICATIONS_ENABLED
|
||||||
and not (
|
and not message.ignore_messages
|
||||||
(message.server and message.server.ignore_messages)
|
|
||||||
or (message.channel and message.channel.ignore_messages)
|
|
||||||
or (message.from_user and message.from_user.ignore_messages)
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
provider = create_provider(model=settings.SUMMARIZER_MODEL)
|
if message.from_user == message.recipient_user:
|
||||||
|
logger.info("Skipping message because from_user == recipient_user")
|
||||||
|
return False
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
system_prompt = comm_channel_prompt(
|
|
||||||
session, message.recipient_user, message.channel
|
|
||||||
)
|
|
||||||
messages = previous_messages(
|
|
||||||
session,
|
|
||||||
message.recipient_user and message.recipient_user.id,
|
|
||||||
message.channel and message.channel.id,
|
|
||||||
max_messages=10,
|
|
||||||
)
|
|
||||||
msg = textwrap.dedent("""
|
msg = textwrap.dedent("""
|
||||||
Should you continue the conversation with the user?
|
Should you continue the conversation with the user?
|
||||||
Please return "yes" or "no" as:
|
Please return a number between 0 and 100 indicating how much you want to continue the conversation (0 is no, 100 is yes).
|
||||||
|
Please return the number in the following format:
|
||||||
<response>yes</response>
|
|
||||||
|
<response>
|
||||||
or
|
<number>50</number>
|
||||||
|
<reason>I want to continue the conversation because I think it's important.</reason>
|
||||||
<response>no</response>
|
</response>
|
||||||
|
|
||||||
""")
|
""")
|
||||||
response = provider.generate(
|
response = call_llm(
|
||||||
messages=provider.as_messages([m.title for m in messages] + [msg]),
|
session,
|
||||||
system_prompt=system_prompt,
|
message,
|
||||||
|
settings.SUMMARIZER_MODEL,
|
||||||
|
[msg],
|
||||||
|
allowed_tools=[
|
||||||
|
"update_channel_summary",
|
||||||
|
"update_user_summary",
|
||||||
|
"update_server_summary",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
return "<response>yes</response>" in "".join(response.lower().split())
|
if not response:
|
||||||
|
return False
|
||||||
|
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return int(res.group(1)) > message.chattiness_threshold
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||||
@ -111,39 +152,14 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
tools = make_discord_tools(
|
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
||||||
discord_message.recipient_user,
|
|
||||||
discord_message.from_user,
|
if not response:
|
||||||
discord_message.channel,
|
|
||||||
model=settings.DISCORD_MODEL,
|
|
||||||
)
|
|
||||||
tools = {
|
|
||||||
name: tool
|
|
||||||
for name, tool in tools.items()
|
|
||||||
if discord_message.tool_allowed(name)
|
|
||||||
}
|
|
||||||
system_prompt = comm_channel_prompt(
|
|
||||||
session, discord_message.recipient_user, discord_message.channel
|
|
||||||
)
|
|
||||||
messages = previous_messages(
|
|
||||||
session,
|
|
||||||
discord_message.recipient_user and discord_message.recipient_user.id,
|
|
||||||
discord_message.channel and discord_message.channel.id,
|
|
||||||
max_messages=10,
|
|
||||||
)
|
|
||||||
provider = create_provider(model=settings.DISCORD_MODEL)
|
|
||||||
turn = provider.run_with_tools(
|
|
||||||
messages=provider.as_messages([m.title for m in messages]),
|
|
||||||
tools=tools,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
|
||||||
)
|
|
||||||
if not turn.response:
|
|
||||||
pass
|
pass
|
||||||
elif discord_message.channel.server:
|
elif discord_message.channel.server:
|
||||||
discord.send_to_channel(discord_message.channel.name, turn.response)
|
discord.send_to_channel(discord_message.channel.name, response)
|
||||||
else:
|
else:
|
||||||
discord.send_dm(discord_message.from_user.username, turn.response)
|
discord.send_dm(discord_message.from_user.username, response)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "processed",
|
"status": "processed",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user