Compare commits

...

7 Commits

Author SHA1 Message Date
9182f15c45 properly handle bot prompts 2025-11-02 15:51:30 +00:00
Daniel O'Connell
afdff1708b prompt from bot user 2025-11-02 16:46:26 +01:00
Daniel O'Connell
64e84b1c89 basic tools 2025-11-02 16:34:38 +01:00
798b4779da unify discord callers 2025-11-02 14:46:43 +00:00
69192f834a handle discord threads 2025-11-02 11:23:31 +00:00
Daniel O'Connell
6bd7df8ee3 properly handle images by anthropic 2025-11-02 12:08:46 +01:00
a4f42e656a save images 2025-11-02 10:25:23 +00:00
15 changed files with 525 additions and 186 deletions

View File

@ -0,0 +1,29 @@
"""store discord images
Revision ID: 1954477b25f4
Revises: 2024235e37e7
Create Date: 2025-11-02 10:14:48.334934
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "1954477b25f4"
down_revision: Union[str, None] = "2024235e37e7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"discord_message", sa.Column("images", sa.ARRAY(sa.Text()), nullable=True)
)
def downgrade() -> None:
op.drop_column("discord_message", "images")

View File

@ -5,15 +5,16 @@ SQLAdmin views for the knowledge base database models.
import logging
from sqladmin import Admin, ModelView
from memory.common.db.models import (
AgentObservation,
ArticleFeed,
BlogPost,
Book,
BookSection,
ScheduledLLMCall,
Chunk,
Comic,
DiscordMessage,
EmailAccount,
EmailAttachment,
ForumPost,
@ -21,6 +22,7 @@ from memory.common.db.models import (
MiscDoc,
Note,
Photo,
ScheduledLLMCall,
SourceItem,
User,
)
@ -153,6 +155,17 @@ class BookAdmin(ModelView, model=Book):
column_searchable_list = ["title", "author", "id"]
class DiscordMessageAdmin(ModelView, model=DiscordMessage):
column_list = [
"id",
"content",
"images",
"sent_at",
]
column_searchable_list = ["content", "id", "images"]
column_sortable_list = ["sent_at"]
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
column_list = [
"id",
@ -310,6 +323,7 @@ def setup_admin(admin: Admin):
admin.add_view(ForumPostAdmin)
admin.add_view(ComicAdmin)
admin.add_view(PhotoAdmin)
admin.add_view(DiscordMessageAdmin)
admin.add_view(UserAdmin)
admin.add_view(DiscordUserAdmin)
admin.add_view(DiscordServerAdmin)

View File

@ -5,6 +5,7 @@ Database models for the knowledge base system.
import pathlib
import textwrap
from datetime import datetime
from collections.abc import Collection
from typing import Any, Annotated, Sequence, cast
from PIL import Image
@ -301,6 +302,7 @@ class DiscordMessage(SourceItem):
BigInteger, nullable=True
) # Discord thread snowflake ID if in thread
edited_at = Column(DateTime(timezone=True), nullable=True)
images = Column(ARRAY(Text), nullable=True) # List of image URLs
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
server = relationship("DiscordServer", foreign_keys=[server_id])
@ -308,16 +310,16 @@ class DiscordMessage(SourceItem):
recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
@property
def allowed_tools(self) -> list[str]:
return (
def allowed_tools(self) -> set[str]:
return set(
(self.channel.allowed_tools if self.channel else [])
+ (self.from_user.allowed_tools if self.from_user else [])
+ (self.server.allowed_tools if self.server else [])
)
@property
def disallowed_tools(self) -> list[str]:
return (
def disallowed_tools(self) -> set[str]:
return set(
(self.channel.disallowed_tools if self.channel else [])
+ (self.from_user.disallowed_tools if self.from_user else [])
+ (self.server.disallowed_tools if self.server else [])
@ -328,6 +330,11 @@ class DiscordMessage(SourceItem):
not self.allowed_tools or tool in self.allowed_tools
)
def filter_tools(self, tools: Collection[str] | None = None) -> set[str]:
if tools is None:
return self.allowed_tools - self.disallowed_tools
return set(tools) - self.disallowed_tools & self.allowed_tools
@property
def ignore_messages(self) -> bool:
return (
@ -358,6 +365,20 @@ class DiscordMessage(SourceItem):
def title(self) -> str:
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
def as_content(self) -> dict[str, Any]:
"""Return message content ready for LLM (text + images from disk)."""
content = {"text": self.title, "images": []}
for path in cast(list[str] | None, self.images) or []:
try:
full_path = settings.FILE_STORAGE_DIR / path
if full_path.exists():
image = Image.open(full_path)
content["images"].append(image)
except Exception:
pass # Skip failed image loads
return content
__mapper_args__ = {
"polymorphic_identity": "discord_message",
}

View File

@ -55,14 +55,14 @@ def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
return False
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
"""Send a DM via the Discord collector API"""
def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
"""Send message to a channel by name or ID (ID supports threads)"""
try:
response = requests.post(
f"{get_api_url()}/send_channel",
json={
"bot_id": bot_id,
"channel_name": channel_name,
"channel": channel,
"message": message,
},
timeout=10,
@ -73,16 +73,16 @@ def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to send to channel {channel_name}: {e}")
logger.error(f"Failed to send to channel {channel}: {e}")
return False
def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
"""Trigger typing indicator for a channel via the Discord collector API"""
def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
"""Trigger typing indicator for a channel by name or ID (ID supports threads)"""
try:
response = requests.post(
f"{get_api_url()}/typing/channel",
json={"bot_id": bot_id, "channel_name": channel_name},
json={"bot_id": bot_id, "channel": channel},
timeout=10,
)
response.raise_for_status()
@ -90,18 +90,18 @@ def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
logger.error(f"Failed to trigger typing for channel {channel}: {e}")
return False
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API"""
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
"""Send a message to a channel by name or ID (ID supports threads)"""
try:
response = requests.post(
f"{get_api_url()}/send_channel",
json={
"bot_id": bot_id,
"channel_name": channel_name,
"channel": channel,
"message": message,
},
timeout=10,
@ -111,7 +111,7 @@ def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}")
logger.error(f"Failed to send message to channel {channel}: {e}")
return False

View File

@ -71,15 +71,23 @@ class AnthropicProvider(BaseLLMProvider):
}
def _convert_message(self, message: Message) -> dict[str, Any]:
converted = message.to_dict()
if converted["role"] == MessageRole.ASSISTANT and isinstance(
converted["content"], list
):
content = sorted(
converted["content"], key=lambda x: x["type"] != "thinking"
)
return converted | {"content": content}
return converted
# Handle string content directly
if isinstance(message.content, str):
return {"role": message.role.value, "content": message.content}
# Convert content items, handling ImageContent specially
content_list = []
for item in message.content:
if isinstance(item, ImageContent):
content_list.append(self._convert_image_content(item))
else:
content_list.append(item.to_dict())
# Sort assistant messages to put thinking last
if message.role == MessageRole.ASSISTANT:
content_list = sorted(content_list, key=lambda x: x["type"] != "thinking")
return {"role": message.role.value, "content": content_list}
def _should_include_message(self, message: Message) -> bool:
"""Filter out system messages (handled separately in Anthropic)."""
@ -105,6 +113,7 @@ class AnthropicProvider(BaseLLMProvider):
"messages": anthropic_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
}
# Only include top_p if explicitly set
@ -144,7 +153,6 @@ class AnthropicProvider(BaseLLMProvider):
Tuple of (StreamEvent or None, updated current_tool_use or None)
"""
event_type = getattr(event, "type", None)
# Handle error events
if event_type == "error":
error = getattr(event, "error", None)

View File

@ -171,13 +171,17 @@ class Message:
@staticmethod
def user(
text: str | None = None, tool_result: ToolResultContent | None = None
text: str | None = None,
images: list[Image.Image] | None = None,
tool_result: ToolResultContent | None = None,
) -> "Message":
parts = []
if text:
parts.append(TextContent(text=text))
if tool_result:
parts.append(tool_result)
for image in images or []:
parts.append(ImageContent(image=image))
return Message(role=MessageRole.USER, content=parts)
@ -418,7 +422,11 @@ class BaseLLMProvider(ABC):
"""Convert tool definitions to provider format."""
if not tools:
return None
return [self._convert_tool(tool) for tool in tools]
converted = [
tool.provider_format(self.provider) or self._convert_tool(tool)
for tool in tools
]
return [c for c in converted if c is not None]
@abstractmethod
def generate(
@ -625,8 +633,16 @@ class BaseLLMProvider(ABC):
tool_calls=tool_calls or None,
)
def as_messages(self, messages) -> list[Message]:
return [Message.user(text=msg) for msg in messages]
def as_messages(self, messages: list[dict[str, Any] | str]) -> list[Message]:
def make_message(msg: dict[str, Any] | str) -> Message:
if isinstance(msg, str):
return Message.user(text=msg)
elif isinstance(msg, dict):
return Message.user(text=msg["text"], images=msg.get("images"))
else:
raise ValueError(f"Unknown message type: {type(msg)}")
return [make_message(msg) for msg in messages]
def create_provider(

View File

@ -151,23 +151,17 @@ class OpenAIProvider(BaseLLMProvider):
return openai_messages
def _convert_tools(
self, tools: list[ToolDefinition] | None
) -> list[dict[str, Any]] | None:
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
"""
Convert our tool definitions to OpenAI format.
Args:
tools: List of tool definitions
tool: Tool definition
Returns:
List of tools in OpenAI format
Tool in OpenAI format
"""
if not tools:
return None
return [
{
return {
"type": "function",
"function": {
"name": tool.name,
@ -175,8 +169,6 @@ class OpenAIProvider(BaseLLMProvider):
"parameters": tool.input_schema,
},
}
for tool in tools
]
def _build_request_kwargs(
self,

View File

@ -34,3 +34,6 @@ class ToolDefinition:
def __call__(self, input: ToolInput) -> str:
return self.function(input)
def provider_format(self, provider: str) -> dict[str, Any] | None:
return None

View File

@ -0,0 +1,36 @@
from typing import Any
from memory.common.llms.tools import ToolDefinition
class WebSearchTool(ToolDefinition):
def __init__(self, **kwargs: Any):
defaults = {
"name": "web_search",
"description": "Search the web for information",
"input_schema": {},
"function": lambda input: "result",
}
super().__init__(**(defaults | kwargs))
def provider_format(self, provider: str) -> dict[str, Any] | None:
if provider == "openai":
return {"type": "web_search"}
if provider == "anthropic":
return {"type": "web_search_20250305", "name": "web_search", "max_uses": 10}
return None
class WebFetchTool(ToolDefinition):
def __init__(self, **kwargs: Any):
defaults = {
"name": "web_fetch",
"description": "Fetch the contents of a web page",
"input_schema": {},
"function": lambda input: "result",
}
super().__init__(**(defaults | kwargs))
def provider_format(self, provider: str) -> dict[str, Any] | None:
if provider == "anthropic":
return {"type": "web_fetch_20250910", "name": "web_fetch", "max_uses": 10}
return None

View File

@ -73,6 +73,9 @@ WEBPAGE_STORAGE_DIR = pathlib.Path(
NOTES_STORAGE_DIR = pathlib.Path(
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
)
DISCORD_STORAGE_DIR = pathlib.Path(
os.getenv("DISCORD_STORAGE_DIR", FILE_STORAGE_DIR / "discord")
)
PRIVATE_DIRS = [
EMAIL_STORAGE_DIR,
NOTES_STORAGE_DIR,
@ -88,6 +91,7 @@ storage_dirs = [
PHOTO_STORAGE_DIR,
WEBPAGE_STORAGE_DIR,
NOTES_STORAGE_DIR,
DISCORD_STORAGE_DIR,
]
for dir in storage_dirs:
dir.mkdir(parents=True, exist_ok=True)

View File

@ -31,7 +31,7 @@ class SendDMRequest(BaseModel):
class SendChannelRequest(BaseModel):
bot_id: int
channel_name: str # Channel name (e.g., "memory-errors")
channel: int | str # Channel name or ID (ID supports threads)
message: str
@ -42,7 +42,7 @@ class TypingDMRequest(BaseModel):
class TypingChannelRequest(BaseModel):
bot_id: int
channel_name: str
channel: int | str # Channel name or ID (ID supports threads)
class Collector:
@ -154,7 +154,7 @@ async def send_channel_endpoint(request: SendChannelRequest):
try:
success = await collector.collector.send_to_channel(
request.channel_name, request.message
request.channel, request.message
)
except Exception as e:
logger.error(f"Failed to send channel message: {e}")
@ -163,13 +163,13 @@ async def send_channel_endpoint(request: SendChannelRequest):
if success:
return {
"success": True,
"message": f"Message sent to channel {request.channel_name}",
"channel": request.channel_name,
"message": f"Message sent to channel {request.channel}",
"channel": request.channel,
}
raise HTTPException(
status_code=400,
detail=f"Failed to send message to channel {request.channel_name}",
detail=f"Failed to send message to channel {request.channel}",
)
@ -181,7 +181,7 @@ async def trigger_channel_typing(request: TypingChannelRequest):
raise HTTPException(status_code=404, detail="Bot not found")
try:
success = await collector.collector.trigger_typing_channel(request.channel_name)
success = await collector.collector.trigger_typing_channel(request.channel)
except Exception as e:
logger.error(f"Failed to trigger channel typing: {e}")
raise HTTPException(status_code=500, detail=str(e))
@ -189,13 +189,13 @@ async def trigger_channel_typing(request: TypingChannelRequest):
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to trigger typing for channel {request.channel_name}",
detail=f"Failed to trigger typing for channel {request.channel}",
)
return {
"success": True,
"channel": request.channel_name,
"message": f"Typing triggered for channel {request.channel_name}",
"channel": request.channel,
"message": f"Typing triggered for channel {request.channel}",
}

View File

@ -24,6 +24,32 @@ from memory.workers.tasks.discord import add_discord_message, edit_discord_messa
logger = logging.getLogger(__name__)
def process_mentions(session: Session | scoped_session, message: str) -> str:
"""Convert username mentions (<@username>) to ID mentions (<@123456>)"""
import re
def replace_mention(match):
mention_content = match.group(1)
# If it's already numeric, leave it alone
if mention_content.isdigit():
return match.group(0)
# Look up username in database
user = (
session.query(DiscordUser)
.filter(DiscordUser.username == mention_content)
.first()
)
if user:
return f"<@{user.id}>"
# If user not found, return original
return match.group(0)
return re.sub(r"<@([^>]+)>", replace_mention, message)
# Pure functions for Discord entity creation/updates
def create_or_update_server(
session: Session | scoped_session, guild: discord.Guild | None
@ -181,6 +207,10 @@ def sync_guild_metadata(guild: discord.Guild) -> None:
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
create_or_update_channel(session, channel)
# Sync threads
for thread in guild.threads:
create_or_update_channel(session, thread)
session.commit()
@ -233,6 +263,18 @@ class MessageCollector(commands.Bot):
session.commit()
# Extract image URLs from attachments
image_urls = [
att.url
for att in message.attachments
if att.content_type and att.content_type.startswith("image/")
]
# Determine message metadata (type, reply, thread)
message_type, reply_to_id, thread_id = determine_message_metadata(
message
)
# Queue the message for processing
add_discord_message.delay(
message_id=message.id,
@ -242,9 +284,10 @@ class MessageCollector(commands.Bot):
server_id=message.guild.id if message.guild else None,
content=message.content or "",
sent_at=message.created_at.isoformat(),
message_reference_id=message.reference.message_id
if message.reference
else None,
message_reference_id=reply_to_id,
message_type=message_type,
thread_id=thread_id,
image_urls=image_urls,
)
except Exception as e:
logger.error(f"Error queuing message {message.id}: {e}")
@ -288,6 +331,11 @@ class MessageCollector(commands.Bot):
create_or_update_channel(session, channel)
channels_updated += 1
# Refresh all threads in this server
for thread in guild.threads:
create_or_update_channel(session, thread)
channels_updated += 1
# Refresh all members in this server (if members intent is enabled)
if self.intents.members:
for member in guild.members:
@ -373,7 +421,11 @@ class MessageCollector(commands.Bot):
logger.error(f"User {user_identifier} not found")
return False
await user.send(message)
# Post-process mentions to convert usernames to IDs
with make_session() as session:
processed_message = process_mentions(session, message)
await user.send(processed_message)
logger.info(f"Sent DM to {user_identifier}")
return True
@ -402,34 +454,50 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False
async def send_to_channel(self, channel_name: str, message: str) -> bool:
"""Send a message to a channel by name across all guilds"""
async def send_to_channel(
self, channel_identifier: int | str, message: str
) -> bool:
"""Send a message to a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False
try:
channel = await self.get_channel_by_name(channel_name)
# Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
if not channel:
logger.error(f"Channel {channel_name} not found")
logger.error(f"Channel {channel_identifier} not found")
return False
await channel.send(message)
logger.info(f"Sent message to channel {channel_name}")
# Post-process mentions to convert usernames to IDs
with make_session() as session:
processed_message = process_mentions(session, message)
await channel.send(processed_message)
logger.info(f"Sent message to channel {channel_identifier}")
return True
except Exception as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}")
logger.error(f"Failed to send message to channel {channel_identifier}: {e}")
return False
async def trigger_typing_channel(self, channel_name: str) -> bool:
"""Trigger typing indicator in a channel"""
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False
try:
channel = await self.get_channel_by_name(channel_name)
# Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
if not channel:
logger.error(f"Channel {channel_name} not found")
logger.error(f"Channel {channel_identifier} not found")
return False
async with channel.typing():
@ -437,5 +505,7 @@ class MessageCollector(commands.Bot):
return True
except Exception as e:
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
logger.error(
f"Failed to trigger typing for channel {channel_identifier}: {e}"
)
return False

View File

@ -1,16 +1,20 @@
import logging
import textwrap
from collections.abc import Collection
from datetime import datetime, timezone
from typing import Any, cast
from sqlalchemy.orm import Session, scoped_session
from memory.common import discord, settings
from memory.common.db.models import (
DiscordChannel,
DiscordMessage,
DiscordUser,
ScheduledLLMCall,
DiscordMessage,
)
from memory.common.db.models.users import BotUser
from memory.common.llms.base import create_provider
logger = logging.getLogger(__name__)
@ -199,3 +203,98 @@ def comm_channel_prompt(
You will be given the last {max_messages} messages in the conversation.
Please react to them appropriately. You can return an empty response if you don't have anything to say.
""").format(server_context=server_context, max_messages=max_messages)
def call_llm(
session: Session | scoped_session,
bot_user: DiscordUser,
from_user: DiscordUser | None,
channel: DiscordChannel | None,
model: str,
system_prompt: str = "",
messages: list[str | dict[str, Any]] = [],
allowed_tools: Collection[str] | None = None,
num_previous_messages: int = 10,
) -> str | None:
"""
Call LLM with Discord tools support.
Args:
session: Database session
bot_user: Bot user making the call
from_user: Discord user who initiated the interaction
channel: Discord channel (if any)
messages: List of message strings or dicts with text/images
model: LLM model to use
system_prompt: System prompt
allowed_tools: List of allowed tool names (None = all tools allowed)
Returns:
LLM response or None if failed
"""
provider = create_provider(model=model)
if provider.usage_tracker.is_rate_limited(model):
logger.error(
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
)
return None
user_id = None
if from_user and not channel:
user_id = cast(int, from_user.id)
prev_messages = previous_messages(
session,
user_id,
channel and channel.id,
max_messages=num_previous_messages,
)
from memory.common.llms.tools.discord import make_discord_tools
from memory.common.llms.tools.base import WebSearchTool
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
tools |= {"web_search": WebSearchTool()}
# Filter to allowed tools if specified
if allowed_tools is not None:
tools = {name: tool for name, tool in tools.items() if name in allowed_tools}
if bot_user.system_prompt:
system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "")
message_content = [m.as_content() for m in prev_messages] + messages
return provider.run_with_tools(
messages=provider.as_messages(message_content),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response
def send_discord_response(
bot_id: int,
response: str,
channel_id: int | None = None,
user_identifier: str | None = None,
) -> bool:
"""
Send a response to Discord channel or user.
Args:
bot_id: Bot user ID
response: Message to send
channel_id: Channel ID (for channel messages)
user_identifier: Username (for DMs)
Returns:
True if sent successfully
"""
if channel_id is not None:
logger.info(f"Sending message to channel {channel_id}")
return discord.send_to_channel(bot_id, channel_id, response)
elif user_identifier is not None:
logger.info(f"Sending DM to {user_identifier}")
return discord.send_dm(bot_id, user_identifier, response)
else:
logger.error("Neither channel_id nor user_identifier provided")
return False

View File

@ -4,11 +4,13 @@ Celery tasks for Discord message processing.
import hashlib
import logging
import pathlib
import re
import textwrap
from datetime import datetime
from typing import Any, cast
import requests
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy.orm import Session, scoped_session
@ -21,9 +23,7 @@ from memory.common.celery_app import (
)
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordMessage, DiscordUser
from memory.common.llms.base import create_provider
from memory.common.llms.tools.discord import make_discord_tools
from memory.discord.messages import comm_channel_prompt, previous_messages
from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
from memory.workers.tasks.content_processing import (
check_content_exists,
create_task_result,
@ -34,6 +34,37 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__)
def download_and_save_images(image_urls: list[str], message_id: int) -> list[str]:
"""Download images from URLs and save to disk. Returns relative file paths."""
image_dir = settings.DISCORD_STORAGE_DIR / str(message_id)
image_dir.mkdir(parents=True, exist_ok=True)
saved_paths = []
for url in image_urls:
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
# Generate filename
url_hash = hashlib.md5(url.encode()).hexdigest()
ext = pathlib.Path(url).suffix or ".jpg"
ext = ext.split("?")[0]
filename = f"{url_hash}{ext}"
local_path = image_dir / filename
# Save image
local_path.write_bytes(response.content)
# Store relative path from FILE_STORAGE_DIR
relative_path = local_path.relative_to(settings.FILE_STORAGE_DIR)
saved_paths.append(str(relative_path))
except Exception as e:
logger.error(f"Failed to download/save image from {url}: {e}")
return saved_paths
def get_prev(
session: Session | scoped_session, channel_id: int, sent_at: datetime
) -> list[str]:
@ -51,50 +82,6 @@ def get_prev(
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
def call_llm(
session,
message: DiscordMessage,
model: str,
msgs: list[str] = [],
allowed_tools: list[str] | None = None,
) -> str | None:
provider = create_provider(model=model)
if provider.usage_tracker.is_rate_limited(model):
logger.error(
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
)
return None
tools = make_discord_tools(
message.recipient_user.system_user,
message.from_user,
message.channel,
model=model,
)
tools = {
name: tool
for name, tool in tools.items()
if message.tool_allowed(name)
and (allowed_tools is None or name in allowed_tools)
}
system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel
)
messages = previous_messages(
session,
message.recipient_user and message.recipient_user.id,
message.channel and message.channel.id,
max_messages=10,
)
return provider.run_with_tools(
messages=provider.as_messages([m.title for m in messages] + msgs),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response
def should_process(message: DiscordMessage) -> bool:
if not (
settings.DISCORD_PROCESS_MESSAGES
@ -118,16 +105,26 @@ def should_process(message: DiscordMessage) -> bool:
<reason>I want to continue the conversation because I think it's important.</reason>
</response>
""")
response = call_llm(
session,
message,
settings.SUMMARIZER_MODEL,
[msg],
allowed_tools=[
system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel
)
allowed_tools = [
"update_channel_summary",
"update_user_summary",
"update_server_summary",
],
]
response = call_llm(
session,
bot_user=message.recipient_user,
from_user=message.from_user,
channel=message.channel,
model=settings.SUMMARIZER_MODEL,
system_prompt=system_prompt,
messages=[msg],
allowed_tools=message.filter_tools(allowed_tools),
)
if not response:
return False
@ -143,7 +140,7 @@ def should_process(message: DiscordMessage) -> bool:
return False
if message.channel and message.channel.server:
discord.trigger_typing_channel(bot_id, message.channel.name)
discord.trigger_typing_channel(bot_id, cast(int, message.channel_id))
else:
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
return True
@ -193,21 +190,40 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
}
try:
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
response = call_llm(
session,
bot_user=discord_message.recipient_user,
from_user=discord_message.from_user,
channel=discord_message.channel,
model=settings.DISCORD_MODEL,
system_prompt=discord_message.system_prompt,
)
except Exception:
logger.exception("Failed to generate Discord response")
print("response:", response)
return {
"status": "error",
"error": "Failed to generate Discord response",
"message_id": message_id,
}
if not response:
return {
"status": "processed",
"status": "no-response",
"message_id": message_id,
}
if discord_message.channel.server:
discord.send_to_channel(bot_id, discord_message.channel.name, response)
else:
discord.send_dm(bot_id, discord_message.from_user.username, response)
res = send_discord_response(
bot_id=bot_id,
response=response,
channel_id=discord_message.channel_id,
user_identifier=discord_message.from_user
and discord_message.from_user.username,
)
if not res:
return {
"status": "error",
"error": "Failed to send Discord response",
"message_id": message_id,
}
return {
"status": "processed",
@ -226,6 +242,9 @@ def add_discord_message(
server_id: int | None = None,
recipient_id: int | None = None,
message_reference_id: int | None = None,
message_type: str = "default",
thread_id: int | None = None,
image_urls: list[str] | None = None,
) -> dict[str, Any]:
"""
Add a Discord message to the database.
@ -237,6 +256,11 @@ def add_discord_message(
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
# Download and save images to disk
saved_image_paths = []
if image_urls:
saved_image_paths = download_and_save_images(image_urls, message_id)
with make_session() as session:
discord_message = DiscordMessage(
modality="text",
@ -248,8 +272,10 @@ def add_discord_message(
from_id=author_id,
recipient_id=recipient_id,
message_id=message_id,
message_type="reply" if message_reference_id else "default",
message_type=message_type,
reply_to_message_id=message_reference_id,
thread_id=thread_id,
images=saved_image_paths or None,
)
existing_msg = check_content_exists(
session, DiscordMessage, message_id=message_id, sha256=content_hash

View File

@ -2,41 +2,48 @@ import logging
from datetime import datetime, timezone
from typing import cast
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
from memory.common import settings
from memory.common.celery_app import (
app,
EXECUTE_SCHEDULED_CALL,
RUN_SCHEDULED_CALLS,
app,
)
from memory.common import llms, discord
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
from memory.discord.messages import call_llm, send_discord_response
from memory.workers.tasks.content_processing import safe_task_execution
logger = logging.getLogger(__name__)
def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
"""Call LLM with tools support for scheduled calls."""
if not scheduled_call.discord_user:
logger.warning("No discord_user for scheduled call - cannot execute")
return None
model = cast(str, scheduled_call.model or settings.DISCORD_MODEL)
system_prompt = cast(str, scheduled_call.system_prompt or "")
message = cast(str, scheduled_call.message)
allowed_tools_list = cast(list[str] | None, scheduled_call.allowed_tools)
bot_user = (
scheduled_call.user.discord_users and scheduled_call.user.discord_users[0]
)
return call_llm(
session=session,
bot_user=bot_user,
from_user=scheduled_call.discord_user,
channel=scheduled_call.discord_channel,
messages=[message],
model=model,
system_prompt=system_prompt,
allowed_tools=allowed_tools_list,
)
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
"""
Send the LLM response to the specified Discord user.
Args:
scheduled_call: The scheduled call object
response: The LLM response to send
"""
# Format the message with topic, model, and response
message_parts = []
if cast(str, scheduled_call.topic):
message_parts.append(f"**Topic:** {scheduled_call.topic}")
if cast(str, scheduled_call.model):
message_parts.append(f"**Model:** {scheduled_call.model}")
message_parts.append(f"**Response:** {response}")
message = "\n".join(message_parts)
# Discord has a 2000 character limit, so we may need to split the message
if len(message) > 1900: # Leave some buffer
message = message[:1900] + "\n\n... (response truncated)"
"""Send the LLM response to Discord user or channel."""
bot_id_value = scheduled_call.user_id
if bot_id_value is None:
logger.warning(
@ -47,15 +54,15 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
bot_id = cast(int, bot_id_value)
if discord_user := scheduled_call.discord_user:
logger.info(f"Sending DM to {discord_user.username}: {message}")
discord.send_dm(bot_id, discord_user.username, message)
elif discord_channel := scheduled_call.discord_channel:
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
discord.broadcast_message(bot_id, discord_channel.name, message)
else:
logger.warning(
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
send_discord_response(
bot_id=bot_id,
response=response,
channel_id=cast(int, scheduled_call.discord_channel.id)
if scheduled_call.discord_channel
else None,
user_identifier=scheduled_call.discord_user.username
if scheduled_call.discord_user
else None,
)
@ -92,15 +99,29 @@ def execute_scheduled_call(self, scheduled_call_id: str):
logger.info(f"Calling LLM with model {scheduled_call.model}")
# Make the LLM call
if scheduled_call.model:
response = llms.summarize(
prompt=cast(str, scheduled_call.message),
model=cast(str, scheduled_call.model),
system_prompt=cast(str, scheduled_call.system_prompt),
)
else:
response = cast(str, scheduled_call.message)
# Make the LLM call with tools support
try:
response = _call_llm_for_scheduled(session, scheduled_call)
except Exception:
logger.exception("Failed to generate LLM response")
scheduled_call.status = "failed"
scheduled_call.error_message = "LLM call failed"
session.commit()
return {
"success": False,
"error": "LLM call failed",
"scheduled_call_id": scheduled_call_id,
}
if not response:
scheduled_call.status = "failed"
scheduled_call.error_message = "No response from LLM"
session.commit()
return {
"success": False,
"error": "No response from LLM",
"scheduled_call_id": scheduled_call_id,
}
# Store the response
scheduled_call.response = response