Compare commits

...

7 Commits

Author SHA1 Message Date
9182f15c45 properly handle bot prompts 2025-11-02 15:51:30 +00:00
Daniel O'Connell
afdff1708b prompt from bot user 2025-11-02 16:46:26 +01:00
Daniel O'Connell
64e84b1c89 basic tools 2025-11-02 16:34:38 +01:00
798b4779da unify discord callers 2025-11-02 14:46:43 +00:00
69192f834a handle discord threads 2025-11-02 11:23:31 +00:00
Daniel O'Connell
6bd7df8ee3 properly handle images by anthropic 2025-11-02 12:08:46 +01:00
a4f42e656a save images 2025-11-02 10:25:23 +00:00
15 changed files with 525 additions and 186 deletions

View File

@ -0,0 +1,29 @@
"""store discord images
Revision ID: 1954477b25f4
Revises: 2024235e37e7
Create Date: 2025-11-02 10:14:48.334934
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "1954477b25f4"
down_revision: Union[str, None] = "2024235e37e7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"discord_message", sa.Column("images", sa.ARRAY(sa.Text()), nullable=True)
)
def downgrade() -> None:
op.drop_column("discord_message", "images")

View File

@ -5,15 +5,16 @@ SQLAdmin views for the knowledge base database models.
import logging import logging
from sqladmin import Admin, ModelView from sqladmin import Admin, ModelView
from memory.common.db.models import ( from memory.common.db.models import (
AgentObservation, AgentObservation,
ArticleFeed, ArticleFeed,
BlogPost, BlogPost,
Book, Book,
BookSection, BookSection,
ScheduledLLMCall,
Chunk, Chunk,
Comic, Comic,
DiscordMessage,
EmailAccount, EmailAccount,
EmailAttachment, EmailAttachment,
ForumPost, ForumPost,
@ -21,6 +22,7 @@ from memory.common.db.models import (
MiscDoc, MiscDoc,
Note, Note,
Photo, Photo,
ScheduledLLMCall,
SourceItem, SourceItem,
User, User,
) )
@ -153,6 +155,17 @@ class BookAdmin(ModelView, model=Book):
column_searchable_list = ["title", "author", "id"] column_searchable_list = ["title", "author", "id"]
class DiscordMessageAdmin(ModelView, model=DiscordMessage):
column_list = [
"id",
"content",
"images",
"sent_at",
]
column_searchable_list = ["content", "id", "images"]
column_sortable_list = ["sent_at"]
class ArticleFeedAdmin(ModelView, model=ArticleFeed): class ArticleFeedAdmin(ModelView, model=ArticleFeed):
column_list = [ column_list = [
"id", "id",
@ -310,6 +323,7 @@ def setup_admin(admin: Admin):
admin.add_view(ForumPostAdmin) admin.add_view(ForumPostAdmin)
admin.add_view(ComicAdmin) admin.add_view(ComicAdmin)
admin.add_view(PhotoAdmin) admin.add_view(PhotoAdmin)
admin.add_view(DiscordMessageAdmin)
admin.add_view(UserAdmin) admin.add_view(UserAdmin)
admin.add_view(DiscordUserAdmin) admin.add_view(DiscordUserAdmin)
admin.add_view(DiscordServerAdmin) admin.add_view(DiscordServerAdmin)

View File

@ -5,6 +5,7 @@ Database models for the knowledge base system.
import pathlib import pathlib
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from collections.abc import Collection
from typing import Any, Annotated, Sequence, cast from typing import Any, Annotated, Sequence, cast
from PIL import Image from PIL import Image
@ -301,6 +302,7 @@ class DiscordMessage(SourceItem):
BigInteger, nullable=True BigInteger, nullable=True
) # Discord thread snowflake ID if in thread ) # Discord thread snowflake ID if in thread
edited_at = Column(DateTime(timezone=True), nullable=True) edited_at = Column(DateTime(timezone=True), nullable=True)
images = Column(ARRAY(Text), nullable=True) # List of image URLs
channel = relationship("DiscordChannel", foreign_keys=[channel_id]) channel = relationship("DiscordChannel", foreign_keys=[channel_id])
server = relationship("DiscordServer", foreign_keys=[server_id]) server = relationship("DiscordServer", foreign_keys=[server_id])
@ -308,16 +310,16 @@ class DiscordMessage(SourceItem):
recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id]) recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
@property @property
def allowed_tools(self) -> list[str]: def allowed_tools(self) -> set[str]:
return ( return set(
(self.channel.allowed_tools if self.channel else []) (self.channel.allowed_tools if self.channel else [])
+ (self.from_user.allowed_tools if self.from_user else []) + (self.from_user.allowed_tools if self.from_user else [])
+ (self.server.allowed_tools if self.server else []) + (self.server.allowed_tools if self.server else [])
) )
@property @property
def disallowed_tools(self) -> list[str]: def disallowed_tools(self) -> set[str]:
return ( return set(
(self.channel.disallowed_tools if self.channel else []) (self.channel.disallowed_tools if self.channel else [])
+ (self.from_user.disallowed_tools if self.from_user else []) + (self.from_user.disallowed_tools if self.from_user else [])
+ (self.server.disallowed_tools if self.server else []) + (self.server.disallowed_tools if self.server else [])
@ -328,6 +330,11 @@ class DiscordMessage(SourceItem):
not self.allowed_tools or tool in self.allowed_tools not self.allowed_tools or tool in self.allowed_tools
) )
def filter_tools(self, tools: Collection[str] | None = None) -> set[str]:
if tools is None:
return self.allowed_tools - self.disallowed_tools
return set(tools) - self.disallowed_tools & self.allowed_tools
@property @property
def ignore_messages(self) -> bool: def ignore_messages(self) -> bool:
return ( return (
@ -358,6 +365,20 @@ class DiscordMessage(SourceItem):
def title(self) -> str: def title(self) -> str:
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}" return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
def as_content(self) -> dict[str, Any]:
"""Return message content ready for LLM (text + images from disk)."""
content = {"text": self.title, "images": []}
for path in cast(list[str] | None, self.images) or []:
try:
full_path = settings.FILE_STORAGE_DIR / path
if full_path.exists():
image = Image.open(full_path)
content["images"].append(image)
except Exception:
pass # Skip failed image loads
return content
__mapper_args__ = { __mapper_args__ = {
"polymorphic_identity": "discord_message", "polymorphic_identity": "discord_message",
} }

View File

@ -55,14 +55,14 @@ def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
return False return False
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool: def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
"""Send a DM via the Discord collector API""" """Send message to a channel by name or ID (ID supports threads)"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_channel", f"{get_api_url()}/send_channel",
json={ json={
"bot_id": bot_id, "bot_id": bot_id,
"channel_name": channel_name, "channel": channel,
"message": message, "message": message,
}, },
timeout=10, timeout=10,
@ -73,16 +73,16 @@ def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
return result.get("success", False) return result.get("success", False)
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Failed to send to channel {channel_name}: {e}") logger.error(f"Failed to send to channel {channel}: {e}")
return False return False
def trigger_typing_channel(bot_id: int, channel_name: str) -> bool: def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
"""Trigger typing indicator for a channel via the Discord collector API""" """Trigger typing indicator for a channel by name or ID (ID supports threads)"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/typing/channel", f"{get_api_url()}/typing/channel",
json={"bot_id": bot_id, "channel_name": channel_name}, json={"bot_id": bot_id, "channel": channel},
timeout=10, timeout=10,
) )
response.raise_for_status() response.raise_for_status()
@ -90,18 +90,18 @@ def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
return result.get("success", False) return result.get("success", False)
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}") logger.error(f"Failed to trigger typing for channel {channel}: {e}")
return False return False
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool: def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API""" """Send a message to a channel by name or ID (ID supports threads)"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_channel", f"{get_api_url()}/send_channel",
json={ json={
"bot_id": bot_id, "bot_id": bot_id,
"channel_name": channel_name, "channel": channel,
"message": message, "message": message,
}, },
timeout=10, timeout=10,
@ -111,7 +111,7 @@ def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
return result.get("success", False) return result.get("success", False)
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}") logger.error(f"Failed to send message to channel {channel}: {e}")
return False return False

View File

@ -71,15 +71,23 @@ class AnthropicProvider(BaseLLMProvider):
} }
def _convert_message(self, message: Message) -> dict[str, Any]: def _convert_message(self, message: Message) -> dict[str, Any]:
converted = message.to_dict() # Handle string content directly
if converted["role"] == MessageRole.ASSISTANT and isinstance( if isinstance(message.content, str):
converted["content"], list return {"role": message.role.value, "content": message.content}
):
content = sorted( # Convert content items, handling ImageContent specially
converted["content"], key=lambda x: x["type"] != "thinking" content_list = []
) for item in message.content:
return converted | {"content": content} if isinstance(item, ImageContent):
return converted content_list.append(self._convert_image_content(item))
else:
content_list.append(item.to_dict())
# Sort assistant messages to put thinking last
if message.role == MessageRole.ASSISTANT:
content_list = sorted(content_list, key=lambda x: x["type"] != "thinking")
return {"role": message.role.value, "content": content_list}
def _should_include_message(self, message: Message) -> bool: def _should_include_message(self, message: Message) -> bool:
"""Filter out system messages (handled separately in Anthropic).""" """Filter out system messages (handled separately in Anthropic)."""
@ -105,6 +113,7 @@ class AnthropicProvider(BaseLLMProvider):
"messages": anthropic_messages, "messages": anthropic_messages,
"temperature": settings.temperature, "temperature": settings.temperature,
"max_tokens": settings.max_tokens, "max_tokens": settings.max_tokens,
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
} }
# Only include top_p if explicitly set # Only include top_p if explicitly set
@ -144,7 +153,6 @@ class AnthropicProvider(BaseLLMProvider):
Tuple of (StreamEvent or None, updated current_tool_use or None) Tuple of (StreamEvent or None, updated current_tool_use or None)
""" """
event_type = getattr(event, "type", None) event_type = getattr(event, "type", None)
# Handle error events # Handle error events
if event_type == "error": if event_type == "error":
error = getattr(event, "error", None) error = getattr(event, "error", None)

View File

@ -171,13 +171,17 @@ class Message:
@staticmethod @staticmethod
def user( def user(
text: str | None = None, tool_result: ToolResultContent | None = None text: str | None = None,
images: list[Image.Image] | None = None,
tool_result: ToolResultContent | None = None,
) -> "Message": ) -> "Message":
parts = [] parts = []
if text: if text:
parts.append(TextContent(text=text)) parts.append(TextContent(text=text))
if tool_result: if tool_result:
parts.append(tool_result) parts.append(tool_result)
for image in images or []:
parts.append(ImageContent(image=image))
return Message(role=MessageRole.USER, content=parts) return Message(role=MessageRole.USER, content=parts)
@ -418,7 +422,11 @@ class BaseLLMProvider(ABC):
"""Convert tool definitions to provider format.""" """Convert tool definitions to provider format."""
if not tools: if not tools:
return None return None
return [self._convert_tool(tool) for tool in tools] converted = [
tool.provider_format(self.provider) or self._convert_tool(tool)
for tool in tools
]
return [c for c in converted if c is not None]
@abstractmethod @abstractmethod
def generate( def generate(
@ -625,8 +633,16 @@ class BaseLLMProvider(ABC):
tool_calls=tool_calls or None, tool_calls=tool_calls or None,
) )
def as_messages(self, messages) -> list[Message]: def as_messages(self, messages: list[dict[str, Any] | str]) -> list[Message]:
return [Message.user(text=msg) for msg in messages] def make_message(msg: dict[str, Any] | str) -> Message:
if isinstance(msg, str):
return Message.user(text=msg)
elif isinstance(msg, dict):
return Message.user(text=msg["text"], images=msg.get("images"))
else:
raise ValueError(f"Unknown message type: {type(msg)}")
return [make_message(msg) for msg in messages]
def create_provider( def create_provider(

View File

@ -151,32 +151,24 @@ class OpenAIProvider(BaseLLMProvider):
return openai_messages return openai_messages
def _convert_tools( def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
self, tools: list[ToolDefinition] | None
) -> list[dict[str, Any]] | None:
""" """
Convert our tool definitions to OpenAI format. Convert our tool definitions to OpenAI format.
Args: Args:
tools: List of tool definitions tool: Tool definition
Returns: Returns:
List of tools in OpenAI format Tool in OpenAI format
""" """
if not tools: return {
return None "type": "function",
"function": {
return [ "name": tool.name,
{ "description": tool.description,
"type": "function", "parameters": tool.input_schema,
"function": { },
"name": tool.name, }
"description": tool.description,
"parameters": tool.input_schema,
},
}
for tool in tools
]
def _build_request_kwargs( def _build_request_kwargs(
self, self,

View File

@ -34,3 +34,6 @@ class ToolDefinition:
def __call__(self, input: ToolInput) -> str: def __call__(self, input: ToolInput) -> str:
return self.function(input) return self.function(input)
def provider_format(self, provider: str) -> dict[str, Any] | None:
return None

View File

@ -0,0 +1,36 @@
from typing import Any
from memory.common.llms.tools import ToolDefinition
class WebSearchTool(ToolDefinition):
def __init__(self, **kwargs: Any):
defaults = {
"name": "web_search",
"description": "Search the web for information",
"input_schema": {},
"function": lambda input: "result",
}
super().__init__(**(defaults | kwargs))
def provider_format(self, provider: str) -> dict[str, Any] | None:
if provider == "openai":
return {"type": "web_search"}
if provider == "anthropic":
return {"type": "web_search_20250305", "name": "web_search", "max_uses": 10}
return None
class WebFetchTool(ToolDefinition):
def __init__(self, **kwargs: Any):
defaults = {
"name": "web_fetch",
"description": "Fetch the contents of a web page",
"input_schema": {},
"function": lambda input: "result",
}
super().__init__(**(defaults | kwargs))
def provider_format(self, provider: str) -> dict[str, Any] | None:
if provider == "anthropic":
return {"type": "web_fetch_20250910", "name": "web_fetch", "max_uses": 10}
return None

View File

@ -73,6 +73,9 @@ WEBPAGE_STORAGE_DIR = pathlib.Path(
NOTES_STORAGE_DIR = pathlib.Path( NOTES_STORAGE_DIR = pathlib.Path(
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes") os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
) )
DISCORD_STORAGE_DIR = pathlib.Path(
os.getenv("DISCORD_STORAGE_DIR", FILE_STORAGE_DIR / "discord")
)
PRIVATE_DIRS = [ PRIVATE_DIRS = [
EMAIL_STORAGE_DIR, EMAIL_STORAGE_DIR,
NOTES_STORAGE_DIR, NOTES_STORAGE_DIR,
@ -88,6 +91,7 @@ storage_dirs = [
PHOTO_STORAGE_DIR, PHOTO_STORAGE_DIR,
WEBPAGE_STORAGE_DIR, WEBPAGE_STORAGE_DIR,
NOTES_STORAGE_DIR, NOTES_STORAGE_DIR,
DISCORD_STORAGE_DIR,
] ]
for dir in storage_dirs: for dir in storage_dirs:
dir.mkdir(parents=True, exist_ok=True) dir.mkdir(parents=True, exist_ok=True)

View File

@ -31,7 +31,7 @@ class SendDMRequest(BaseModel):
class SendChannelRequest(BaseModel): class SendChannelRequest(BaseModel):
bot_id: int bot_id: int
channel_name: str # Channel name (e.g., "memory-errors") channel: int | str # Channel name or ID (ID supports threads)
message: str message: str
@ -42,7 +42,7 @@ class TypingDMRequest(BaseModel):
class TypingChannelRequest(BaseModel): class TypingChannelRequest(BaseModel):
bot_id: int bot_id: int
channel_name: str channel: int | str # Channel name or ID (ID supports threads)
class Collector: class Collector:
@ -154,7 +154,7 @@ async def send_channel_endpoint(request: SendChannelRequest):
try: try:
success = await collector.collector.send_to_channel( success = await collector.collector.send_to_channel(
request.channel_name, request.message request.channel, request.message
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send channel message: {e}") logger.error(f"Failed to send channel message: {e}")
@ -163,13 +163,13 @@ async def send_channel_endpoint(request: SendChannelRequest):
if success: if success:
return { return {
"success": True, "success": True,
"message": f"Message sent to channel {request.channel_name}", "message": f"Message sent to channel {request.channel}",
"channel": request.channel_name, "channel": request.channel,
} }
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Failed to send message to channel {request.channel_name}", detail=f"Failed to send message to channel {request.channel}",
) )
@ -181,7 +181,7 @@ async def trigger_channel_typing(request: TypingChannelRequest):
raise HTTPException(status_code=404, detail="Bot not found") raise HTTPException(status_code=404, detail="Bot not found")
try: try:
success = await collector.collector.trigger_typing_channel(request.channel_name) success = await collector.collector.trigger_typing_channel(request.channel)
except Exception as e: except Exception as e:
logger.error(f"Failed to trigger channel typing: {e}") logger.error(f"Failed to trigger channel typing: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -189,13 +189,13 @@ async def trigger_channel_typing(request: TypingChannelRequest):
if not success: if not success:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Failed to trigger typing for channel {request.channel_name}", detail=f"Failed to trigger typing for channel {request.channel}",
) )
return { return {
"success": True, "success": True,
"channel": request.channel_name, "channel": request.channel,
"message": f"Typing triggered for channel {request.channel_name}", "message": f"Typing triggered for channel {request.channel}",
} }

View File

@ -24,6 +24,32 @@ from memory.workers.tasks.discord import add_discord_message, edit_discord_messa
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def process_mentions(session: Session | scoped_session, message: str) -> str:
"""Convert username mentions (<@username>) to ID mentions (<@123456>)"""
import re
def replace_mention(match):
mention_content = match.group(1)
# If it's already numeric, leave it alone
if mention_content.isdigit():
return match.group(0)
# Look up username in database
user = (
session.query(DiscordUser)
.filter(DiscordUser.username == mention_content)
.first()
)
if user:
return f"<@{user.id}>"
# If user not found, return original
return match.group(0)
return re.sub(r"<@([^>]+)>", replace_mention, message)
# Pure functions for Discord entity creation/updates # Pure functions for Discord entity creation/updates
def create_or_update_server( def create_or_update_server(
session: Session | scoped_session, guild: discord.Guild | None session: Session | scoped_session, guild: discord.Guild | None
@ -181,6 +207,10 @@ def sync_guild_metadata(guild: discord.Guild) -> None:
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)): if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
create_or_update_channel(session, channel) create_or_update_channel(session, channel)
# Sync threads
for thread in guild.threads:
create_or_update_channel(session, thread)
session.commit() session.commit()
@ -233,6 +263,18 @@ class MessageCollector(commands.Bot):
session.commit() session.commit()
# Extract image URLs from attachments
image_urls = [
att.url
for att in message.attachments
if att.content_type and att.content_type.startswith("image/")
]
# Determine message metadata (type, reply, thread)
message_type, reply_to_id, thread_id = determine_message_metadata(
message
)
# Queue the message for processing # Queue the message for processing
add_discord_message.delay( add_discord_message.delay(
message_id=message.id, message_id=message.id,
@ -242,9 +284,10 @@ class MessageCollector(commands.Bot):
server_id=message.guild.id if message.guild else None, server_id=message.guild.id if message.guild else None,
content=message.content or "", content=message.content or "",
sent_at=message.created_at.isoformat(), sent_at=message.created_at.isoformat(),
message_reference_id=message.reference.message_id message_reference_id=reply_to_id,
if message.reference message_type=message_type,
else None, thread_id=thread_id,
image_urls=image_urls,
) )
except Exception as e: except Exception as e:
logger.error(f"Error queuing message {message.id}: {e}") logger.error(f"Error queuing message {message.id}: {e}")
@ -288,6 +331,11 @@ class MessageCollector(commands.Bot):
create_or_update_channel(session, channel) create_or_update_channel(session, channel)
channels_updated += 1 channels_updated += 1
# Refresh all threads in this server
for thread in guild.threads:
create_or_update_channel(session, thread)
channels_updated += 1
# Refresh all members in this server (if members intent is enabled) # Refresh all members in this server (if members intent is enabled)
if self.intents.members: if self.intents.members:
for member in guild.members: for member in guild.members:
@ -373,7 +421,11 @@ class MessageCollector(commands.Bot):
logger.error(f"User {user_identifier} not found") logger.error(f"User {user_identifier} not found")
return False return False
await user.send(message) # Post-process mentions to convert usernames to IDs
with make_session() as session:
processed_message = process_mentions(session, message)
await user.send(processed_message)
logger.info(f"Sent DM to {user_identifier}") logger.info(f"Sent DM to {user_identifier}")
return True return True
@ -402,34 +454,50 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}") logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False return False
async def send_to_channel(self, channel_name: str, message: str) -> bool: async def send_to_channel(
"""Send a message to a channel by name across all guilds""" self, channel_identifier: int | str, message: str
) -> bool:
"""Send a message to a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED: if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False return False
try: try:
channel = await self.get_channel_by_name(channel_name) # Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
if not channel: if not channel:
logger.error(f"Channel {channel_name} not found") logger.error(f"Channel {channel_identifier} not found")
return False return False
await channel.send(message) # Post-process mentions to convert usernames to IDs
logger.info(f"Sent message to channel {channel_name}") with make_session() as session:
processed_message = process_mentions(session, message)
await channel.send(processed_message)
logger.info(f"Sent message to channel {channel_identifier}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}") logger.error(f"Failed to send message to channel {channel_identifier}: {e}")
return False return False
async def trigger_typing_channel(self, channel_name: str) -> bool: async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
"""Trigger typing indicator in a channel""" """Trigger typing indicator in a channel by name or ID (supports threads)"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED: if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False return False
try: try:
channel = await self.get_channel_by_name(channel_name) # Get channel by ID or name
if isinstance(channel_identifier, int):
channel = self.get_channel(channel_identifier)
else:
channel = await self.get_channel_by_name(channel_identifier)
if not channel: if not channel:
logger.error(f"Channel {channel_name} not found") logger.error(f"Channel {channel_identifier} not found")
return False return False
async with channel.typing(): async with channel.typing():
@ -437,5 +505,7 @@ class MessageCollector(commands.Bot):
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}") logger.error(
f"Failed to trigger typing for channel {channel_identifier}: {e}"
)
return False return False

View File

@ -1,16 +1,20 @@
import logging import logging
import textwrap import textwrap
from collections.abc import Collection
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, cast from typing import Any, cast
from sqlalchemy.orm import Session, scoped_session from sqlalchemy.orm import Session, scoped_session
from memory.common import discord, settings
from memory.common.db.models import ( from memory.common.db.models import (
DiscordChannel, DiscordChannel,
DiscordMessage,
DiscordUser, DiscordUser,
ScheduledLLMCall, ScheduledLLMCall,
DiscordMessage,
) )
from memory.common.db.models.users import BotUser
from memory.common.llms.base import create_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -199,3 +203,98 @@ def comm_channel_prompt(
You will be given the last {max_messages} messages in the conversation. You will be given the last {max_messages} messages in the conversation.
Please react to them appropriately. You can return an empty response if you don't have anything to say. Please react to them appropriately. You can return an empty response if you don't have anything to say.
""").format(server_context=server_context, max_messages=max_messages) """).format(server_context=server_context, max_messages=max_messages)
def call_llm(
session: Session | scoped_session,
bot_user: DiscordUser,
from_user: DiscordUser | None,
channel: DiscordChannel | None,
model: str,
system_prompt: str = "",
messages: list[str | dict[str, Any]] = [],
allowed_tools: Collection[str] | None = None,
num_previous_messages: int = 10,
) -> str | None:
"""
Call LLM with Discord tools support.
Args:
session: Database session
bot_user: Bot user making the call
from_user: Discord user who initiated the interaction
channel: Discord channel (if any)
messages: List of message strings or dicts with text/images
model: LLM model to use
system_prompt: System prompt
allowed_tools: List of allowed tool names (None = all tools allowed)
Returns:
LLM response or None if failed
"""
provider = create_provider(model=model)
if provider.usage_tracker.is_rate_limited(model):
logger.error(
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
)
return None
user_id = None
if from_user and not channel:
user_id = cast(int, from_user.id)
prev_messages = previous_messages(
session,
user_id,
channel and channel.id,
max_messages=num_previous_messages,
)
from memory.common.llms.tools.discord import make_discord_tools
from memory.common.llms.tools.base import WebSearchTool
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
tools |= {"web_search": WebSearchTool()}
# Filter to allowed tools if specified
if allowed_tools is not None:
tools = {name: tool for name, tool in tools.items() if name in allowed_tools}
if bot_user.system_prompt:
system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "")
message_content = [m.as_content() for m in prev_messages] + messages
return provider.run_with_tools(
messages=provider.as_messages(message_content),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response
def send_discord_response(
bot_id: int,
response: str,
channel_id: int | None = None,
user_identifier: str | None = None,
) -> bool:
"""
Send a response to Discord channel or user.
Args:
bot_id: Bot user ID
response: Message to send
channel_id: Channel ID (for channel messages)
user_identifier: Username (for DMs)
Returns:
True if sent successfully
"""
if channel_id is not None:
logger.info(f"Sending message to channel {channel_id}")
return discord.send_to_channel(bot_id, channel_id, response)
elif user_identifier is not None:
logger.info(f"Sending DM to {user_identifier}")
return discord.send_dm(bot_id, user_identifier, response)
else:
logger.error("Neither channel_id nor user_identifier provided")
return False

View File

@ -4,11 +4,13 @@ Celery tasks for Discord message processing.
import hashlib import hashlib
import logging import logging
import pathlib
import re import re
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from typing import Any, cast from typing import Any, cast
import requests
from sqlalchemy import exc as sqlalchemy_exc from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy.orm import Session, scoped_session from sqlalchemy.orm import Session, scoped_session
@ -21,9 +23,7 @@ from memory.common.celery_app import (
) )
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import DiscordMessage, DiscordUser from memory.common.db.models import DiscordMessage, DiscordUser
from memory.common.llms.base import create_provider from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
from memory.common.llms.tools.discord import make_discord_tools
from memory.discord.messages import comm_channel_prompt, previous_messages
from memory.workers.tasks.content_processing import ( from memory.workers.tasks.content_processing import (
check_content_exists, check_content_exists,
create_task_result, create_task_result,
@ -34,6 +34,37 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def download_and_save_images(image_urls: list[str], message_id: int) -> list[str]:
"""Download images from URLs and save to disk. Returns relative file paths."""
image_dir = settings.DISCORD_STORAGE_DIR / str(message_id)
image_dir.mkdir(parents=True, exist_ok=True)
saved_paths = []
for url in image_urls:
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
# Generate filename
url_hash = hashlib.md5(url.encode()).hexdigest()
ext = pathlib.Path(url).suffix or ".jpg"
ext = ext.split("?")[0]
filename = f"{url_hash}{ext}"
local_path = image_dir / filename
# Save image
local_path.write_bytes(response.content)
# Store relative path from FILE_STORAGE_DIR
relative_path = local_path.relative_to(settings.FILE_STORAGE_DIR)
saved_paths.append(str(relative_path))
except Exception as e:
logger.error(f"Failed to download/save image from {url}: {e}")
return saved_paths
def get_prev( def get_prev(
session: Session | scoped_session, channel_id: int, sent_at: datetime session: Session | scoped_session, channel_id: int, sent_at: datetime
) -> list[str]: ) -> list[str]:
@ -51,50 +82,6 @@ def get_prev(
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]] return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
def call_llm(
session,
message: DiscordMessage,
model: str,
msgs: list[str] = [],
allowed_tools: list[str] | None = None,
) -> str | None:
provider = create_provider(model=model)
if provider.usage_tracker.is_rate_limited(model):
logger.error(
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
)
return None
tools = make_discord_tools(
message.recipient_user.system_user,
message.from_user,
message.channel,
model=model,
)
tools = {
name: tool
for name, tool in tools.items()
if message.tool_allowed(name)
and (allowed_tools is None or name in allowed_tools)
}
system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel
)
messages = previous_messages(
session,
message.recipient_user and message.recipient_user.id,
message.channel and message.channel.id,
max_messages=10,
)
return provider.run_with_tools(
messages=provider.as_messages([m.title for m in messages] + msgs),
tools=tools,
system_prompt=system_prompt,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response
def should_process(message: DiscordMessage) -> bool: def should_process(message: DiscordMessage) -> bool:
if not ( if not (
settings.DISCORD_PROCESS_MESSAGES settings.DISCORD_PROCESS_MESSAGES
@ -118,16 +105,26 @@ def should_process(message: DiscordMessage) -> bool:
<reason>I want to continue the conversation because I think it's important.</reason> <reason>I want to continue the conversation because I think it's important.</reason>
</response> </response>
""") """)
system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel
)
allowed_tools = [
"update_channel_summary",
"update_user_summary",
"update_server_summary",
]
response = call_llm( response = call_llm(
session, session,
message, bot_user=message.recipient_user,
settings.SUMMARIZER_MODEL, from_user=message.from_user,
[msg], channel=message.channel,
allowed_tools=[ model=settings.SUMMARIZER_MODEL,
"update_channel_summary", system_prompt=system_prompt,
"update_user_summary", messages=[msg],
"update_server_summary", allowed_tools=message.filter_tools(allowed_tools),
],
) )
if not response: if not response:
return False return False
@ -143,7 +140,7 @@ def should_process(message: DiscordMessage) -> bool:
return False return False
if message.channel and message.channel.server: if message.channel and message.channel.server:
discord.trigger_typing_channel(bot_id, message.channel.name) discord.trigger_typing_channel(bot_id, cast(int, message.channel_id))
else: else:
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id)) discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
return True return True
@ -193,21 +190,40 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
} }
try: try:
response = call_llm(session, discord_message, settings.DISCORD_MODEL) response = call_llm(
session,
bot_user=discord_message.recipient_user,
from_user=discord_message.from_user,
channel=discord_message.channel,
model=settings.DISCORD_MODEL,
system_prompt=discord_message.system_prompt,
)
except Exception: except Exception:
logger.exception("Failed to generate Discord response") logger.exception("Failed to generate Discord response")
return {
print("response:", response) "status": "error",
"error": "Failed to generate Discord response",
"message_id": message_id,
}
if not response: if not response:
return { return {
"status": "processed", "status": "no-response",
"message_id": message_id, "message_id": message_id,
} }
if discord_message.channel.server: res = send_discord_response(
discord.send_to_channel(bot_id, discord_message.channel.name, response) bot_id=bot_id,
else: response=response,
discord.send_dm(bot_id, discord_message.from_user.username, response) channel_id=discord_message.channel_id,
user_identifier=discord_message.from_user
and discord_message.from_user.username,
)
if not res:
return {
"status": "error",
"error": "Failed to send Discord response",
"message_id": message_id,
}
return { return {
"status": "processed", "status": "processed",
@ -226,6 +242,9 @@ def add_discord_message(
server_id: int | None = None, server_id: int | None = None,
recipient_id: int | None = None, recipient_id: int | None = None,
message_reference_id: int | None = None, message_reference_id: int | None = None,
message_type: str = "default",
thread_id: int | None = None,
image_urls: list[str] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Add a Discord message to the database. Add a Discord message to the database.
@ -237,6 +256,11 @@ def add_discord_message(
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest() content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00")) sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
# Download and save images to disk
saved_image_paths = []
if image_urls:
saved_image_paths = download_and_save_images(image_urls, message_id)
with make_session() as session: with make_session() as session:
discord_message = DiscordMessage( discord_message = DiscordMessage(
modality="text", modality="text",
@ -248,8 +272,10 @@ def add_discord_message(
from_id=author_id, from_id=author_id,
recipient_id=recipient_id, recipient_id=recipient_id,
message_id=message_id, message_id=message_id,
message_type="reply" if message_reference_id else "default", message_type=message_type,
reply_to_message_id=message_reference_id, reply_to_message_id=message_reference_id,
thread_id=thread_id,
images=saved_image_paths or None,
) )
existing_msg = check_content_exists( existing_msg = check_content_exists(
session, DiscordMessage, message_id=message_id, sha256=content_hash session, DiscordMessage, message_id=message_id, sha256=content_hash

View File

@ -2,41 +2,48 @@ import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import cast from typing import cast
from memory.common.db.connection import make_session from memory.common import settings
from memory.common.db.models import ScheduledLLMCall
from memory.common.celery_app import ( from memory.common.celery_app import (
app,
EXECUTE_SCHEDULED_CALL, EXECUTE_SCHEDULED_CALL,
RUN_SCHEDULED_CALLS, RUN_SCHEDULED_CALLS,
app,
) )
from memory.common import llms, discord from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
from memory.discord.messages import call_llm, send_discord_response
from memory.workers.tasks.content_processing import safe_task_execution from memory.workers.tasks.content_processing import safe_task_execution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
"""Call LLM with tools support for scheduled calls."""
if not scheduled_call.discord_user:
logger.warning("No discord_user for scheduled call - cannot execute")
return None
model = cast(str, scheduled_call.model or settings.DISCORD_MODEL)
system_prompt = cast(str, scheduled_call.system_prompt or "")
message = cast(str, scheduled_call.message)
allowed_tools_list = cast(list[str] | None, scheduled_call.allowed_tools)
bot_user = (
scheduled_call.user.discord_users and scheduled_call.user.discord_users[0]
)
return call_llm(
session=session,
bot_user=bot_user,
from_user=scheduled_call.discord_user,
channel=scheduled_call.discord_channel,
messages=[message],
model=model,
system_prompt=system_prompt,
allowed_tools=allowed_tools_list,
)
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str): def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
""" """Send the LLM response to Discord user or channel."""
Send the LLM response to the specified Discord user.
Args:
scheduled_call: The scheduled call object
response: The LLM response to send
"""
# Format the message with topic, model, and response
message_parts = []
if cast(str, scheduled_call.topic):
message_parts.append(f"**Topic:** {scheduled_call.topic}")
if cast(str, scheduled_call.model):
message_parts.append(f"**Model:** {scheduled_call.model}")
message_parts.append(f"**Response:** {response}")
message = "\n".join(message_parts)
# Discord has a 2000 character limit, so we may need to split the message
if len(message) > 1900: # Leave some buffer
message = message[:1900] + "\n\n... (response truncated)"
bot_id_value = scheduled_call.user_id bot_id_value = scheduled_call.user_id
if bot_id_value is None: if bot_id_value is None:
logger.warning( logger.warning(
@ -47,16 +54,16 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
bot_id = cast(int, bot_id_value) bot_id = cast(int, bot_id_value)
if discord_user := scheduled_call.discord_user: send_discord_response(
logger.info(f"Sending DM to {discord_user.username}: {message}") bot_id=bot_id,
discord.send_dm(bot_id, discord_user.username, message) response=response,
elif discord_channel := scheduled_call.discord_channel: channel_id=cast(int, scheduled_call.discord_channel.id)
logger.info(f"Broadcasting message to {discord_channel.name}: {message}") if scheduled_call.discord_channel
discord.broadcast_message(bot_id, discord_channel.name, message) else None,
else: user_identifier=scheduled_call.discord_user.username
logger.warning( if scheduled_call.discord_user
f"No Discord user or channel found for scheduled call {scheduled_call.id}" else None,
) )
@app.task(bind=True, name=EXECUTE_SCHEDULED_CALL) @app.task(bind=True, name=EXECUTE_SCHEDULED_CALL)
@ -92,15 +99,29 @@ def execute_scheduled_call(self, scheduled_call_id: str):
logger.info(f"Calling LLM with model {scheduled_call.model}") logger.info(f"Calling LLM with model {scheduled_call.model}")
# Make the LLM call # Make the LLM call with tools support
if scheduled_call.model: try:
response = llms.summarize( response = _call_llm_for_scheduled(session, scheduled_call)
prompt=cast(str, scheduled_call.message), except Exception:
model=cast(str, scheduled_call.model), logger.exception("Failed to generate LLM response")
system_prompt=cast(str, scheduled_call.system_prompt), scheduled_call.status = "failed"
) scheduled_call.error_message = "LLM call failed"
else: session.commit()
response = cast(str, scheduled_call.message) return {
"success": False,
"error": "LLM call failed",
"scheduled_call_id": scheduled_call_id,
}
if not response:
scheduled_call.status = "failed"
scheduled_call.error_message = "No response from LLM"
session.commit()
return {
"success": False,
"error": "No response from LLM",
"scheduled_call_id": scheduled_call_id,
}
# Store the response # Store the response
scheduled_call.response = response scheduled_call.response = response