mirror of
https://github.com/mruwnik/memory.git
synced 2025-12-16 17:11:19 +01:00
Compare commits
7 Commits
e95a082147
...
9182f15c45
| Author | SHA1 | Date | |
|---|---|---|---|
| 9182f15c45 | |||
|
|
afdff1708b | ||
|
|
64e84b1c89 | ||
| 798b4779da | |||
| 69192f834a | |||
|
|
6bd7df8ee3 | ||
| a4f42e656a |
@ -0,0 +1,29 @@
|
|||||||
|
"""store discord images
|
||||||
|
|
||||||
|
Revision ID: 1954477b25f4
|
||||||
|
Revises: 2024235e37e7
|
||||||
|
Create Date: 2025-11-02 10:14:48.334934
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "1954477b25f4"
|
||||||
|
down_revision: Union[str, None] = "2024235e37e7"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"discord_message", sa.Column("images", sa.ARRAY(sa.Text()), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("discord_message", "images")
|
||||||
@ -5,15 +5,16 @@ SQLAdmin views for the knowledge base database models.
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sqladmin import Admin, ModelView
|
from sqladmin import Admin, ModelView
|
||||||
|
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
AgentObservation,
|
AgentObservation,
|
||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
BlogPost,
|
BlogPost,
|
||||||
Book,
|
Book,
|
||||||
BookSection,
|
BookSection,
|
||||||
ScheduledLLMCall,
|
|
||||||
Chunk,
|
Chunk,
|
||||||
Comic,
|
Comic,
|
||||||
|
DiscordMessage,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
ForumPost,
|
ForumPost,
|
||||||
@ -21,6 +22,7 @@ from memory.common.db.models import (
|
|||||||
MiscDoc,
|
MiscDoc,
|
||||||
Note,
|
Note,
|
||||||
Photo,
|
Photo,
|
||||||
|
ScheduledLLMCall,
|
||||||
SourceItem,
|
SourceItem,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
@ -153,6 +155,17 @@ class BookAdmin(ModelView, model=Book):
|
|||||||
column_searchable_list = ["title", "author", "id"]
|
column_searchable_list = ["title", "author", "id"]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordMessageAdmin(ModelView, model=DiscordMessage):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"content",
|
||||||
|
"images",
|
||||||
|
"sent_at",
|
||||||
|
]
|
||||||
|
column_searchable_list = ["content", "id", "images"]
|
||||||
|
column_sortable_list = ["sent_at"]
|
||||||
|
|
||||||
|
|
||||||
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
|
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
|
||||||
column_list = [
|
column_list = [
|
||||||
"id",
|
"id",
|
||||||
@ -310,6 +323,7 @@ def setup_admin(admin: Admin):
|
|||||||
admin.add_view(ForumPostAdmin)
|
admin.add_view(ForumPostAdmin)
|
||||||
admin.add_view(ComicAdmin)
|
admin.add_view(ComicAdmin)
|
||||||
admin.add_view(PhotoAdmin)
|
admin.add_view(PhotoAdmin)
|
||||||
|
admin.add_view(DiscordMessageAdmin)
|
||||||
admin.add_view(UserAdmin)
|
admin.add_view(UserAdmin)
|
||||||
admin.add_view(DiscordUserAdmin)
|
admin.add_view(DiscordUserAdmin)
|
||||||
admin.add_view(DiscordServerAdmin)
|
admin.add_view(DiscordServerAdmin)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ Database models for the knowledge base system.
|
|||||||
import pathlib
|
import pathlib
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from collections.abc import Collection
|
||||||
from typing import Any, Annotated, Sequence, cast
|
from typing import Any, Annotated, Sequence, cast
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -301,6 +302,7 @@ class DiscordMessage(SourceItem):
|
|||||||
BigInteger, nullable=True
|
BigInteger, nullable=True
|
||||||
) # Discord thread snowflake ID if in thread
|
) # Discord thread snowflake ID if in thread
|
||||||
edited_at = Column(DateTime(timezone=True), nullable=True)
|
edited_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
images = Column(ARRAY(Text), nullable=True) # List of image URLs
|
||||||
|
|
||||||
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
|
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
|
||||||
server = relationship("DiscordServer", foreign_keys=[server_id])
|
server = relationship("DiscordServer", foreign_keys=[server_id])
|
||||||
@ -308,16 +310,16 @@ class DiscordMessage(SourceItem):
|
|||||||
recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
|
recipient_user = relationship("DiscordUser", foreign_keys=[recipient_id])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allowed_tools(self) -> list[str]:
|
def allowed_tools(self) -> set[str]:
|
||||||
return (
|
return set(
|
||||||
(self.channel.allowed_tools if self.channel else [])
|
(self.channel.allowed_tools if self.channel else [])
|
||||||
+ (self.from_user.allowed_tools if self.from_user else [])
|
+ (self.from_user.allowed_tools if self.from_user else [])
|
||||||
+ (self.server.allowed_tools if self.server else [])
|
+ (self.server.allowed_tools if self.server else [])
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disallowed_tools(self) -> list[str]:
|
def disallowed_tools(self) -> set[str]:
|
||||||
return (
|
return set(
|
||||||
(self.channel.disallowed_tools if self.channel else [])
|
(self.channel.disallowed_tools if self.channel else [])
|
||||||
+ (self.from_user.disallowed_tools if self.from_user else [])
|
+ (self.from_user.disallowed_tools if self.from_user else [])
|
||||||
+ (self.server.disallowed_tools if self.server else [])
|
+ (self.server.disallowed_tools if self.server else [])
|
||||||
@ -328,6 +330,11 @@ class DiscordMessage(SourceItem):
|
|||||||
not self.allowed_tools or tool in self.allowed_tools
|
not self.allowed_tools or tool in self.allowed_tools
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def filter_tools(self, tools: Collection[str] | None = None) -> set[str]:
|
||||||
|
if tools is None:
|
||||||
|
return self.allowed_tools - self.disallowed_tools
|
||||||
|
return set(tools) - self.disallowed_tools & self.allowed_tools
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ignore_messages(self) -> bool:
|
def ignore_messages(self) -> bool:
|
||||||
return (
|
return (
|
||||||
@ -358,6 +365,20 @@ class DiscordMessage(SourceItem):
|
|||||||
def title(self) -> str:
|
def title(self) -> str:
|
||||||
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
return f"{self.from_user.username} ({self.sent_at.isoformat()[:19]}): {self.content}"
|
||||||
|
|
||||||
|
def as_content(self) -> dict[str, Any]:
|
||||||
|
"""Return message content ready for LLM (text + images from disk)."""
|
||||||
|
content = {"text": self.title, "images": []}
|
||||||
|
for path in cast(list[str] | None, self.images) or []:
|
||||||
|
try:
|
||||||
|
full_path = settings.FILE_STORAGE_DIR / path
|
||||||
|
if full_path.exists():
|
||||||
|
image = Image.open(full_path)
|
||||||
|
content["images"].append(image)
|
||||||
|
except Exception:
|
||||||
|
pass # Skip failed image loads
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
"polymorphic_identity": "discord_message",
|
"polymorphic_identity": "discord_message",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -55,14 +55,14 @@ def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
|
||||||
"""Send a DM via the Discord collector API"""
|
"""Send message to a channel by name or ID (ID supports threads)"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_channel",
|
f"{get_api_url()}/send_channel",
|
||||||
json={
|
json={
|
||||||
"bot_id": bot_id,
|
"bot_id": bot_id,
|
||||||
"channel_name": channel_name,
|
"channel": channel,
|
||||||
"message": message,
|
"message": message,
|
||||||
},
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
@ -73,16 +73,16 @@ def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
|||||||
return result.get("success", False)
|
return result.get("success", False)
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Failed to send to channel {channel_name}: {e}")
|
logger.error(f"Failed to send to channel {channel}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
|
def trigger_typing_channel(bot_id: int, channel: int | str) -> bool:
|
||||||
"""Trigger typing indicator for a channel via the Discord collector API"""
|
"""Trigger typing indicator for a channel by name or ID (ID supports threads)"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/typing/channel",
|
f"{get_api_url()}/typing/channel",
|
||||||
json={"bot_id": bot_id, "channel_name": channel_name},
|
json={"bot_id": bot_id, "channel": channel},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -90,18 +90,18 @@ def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
|
|||||||
return result.get("success", False)
|
return result.get("success", False)
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
logger.error(f"Failed to trigger typing for channel {channel}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
def broadcast_message(bot_id: int, channel: int | str, message: str) -> bool:
|
||||||
"""Send a message to a channel via the Discord collector API"""
|
"""Send a message to a channel by name or ID (ID supports threads)"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_channel",
|
f"{get_api_url()}/send_channel",
|
||||||
json={
|
json={
|
||||||
"bot_id": bot_id,
|
"bot_id": bot_id,
|
||||||
"channel_name": channel_name,
|
"channel": channel,
|
||||||
"message": message,
|
"message": message,
|
||||||
},
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
@ -111,7 +111,7 @@ def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
|||||||
return result.get("success", False)
|
return result.get("success", False)
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
logger.error(f"Failed to send message to channel {channel}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -71,15 +71,23 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _convert_message(self, message: Message) -> dict[str, Any]:
|
def _convert_message(self, message: Message) -> dict[str, Any]:
|
||||||
converted = message.to_dict()
|
# Handle string content directly
|
||||||
if converted["role"] == MessageRole.ASSISTANT and isinstance(
|
if isinstance(message.content, str):
|
||||||
converted["content"], list
|
return {"role": message.role.value, "content": message.content}
|
||||||
):
|
|
||||||
content = sorted(
|
# Convert content items, handling ImageContent specially
|
||||||
converted["content"], key=lambda x: x["type"] != "thinking"
|
content_list = []
|
||||||
)
|
for item in message.content:
|
||||||
return converted | {"content": content}
|
if isinstance(item, ImageContent):
|
||||||
return converted
|
content_list.append(self._convert_image_content(item))
|
||||||
|
else:
|
||||||
|
content_list.append(item.to_dict())
|
||||||
|
|
||||||
|
# Sort assistant messages to put thinking last
|
||||||
|
if message.role == MessageRole.ASSISTANT:
|
||||||
|
content_list = sorted(content_list, key=lambda x: x["type"] != "thinking")
|
||||||
|
|
||||||
|
return {"role": message.role.value, "content": content_list}
|
||||||
|
|
||||||
def _should_include_message(self, message: Message) -> bool:
|
def _should_include_message(self, message: Message) -> bool:
|
||||||
"""Filter out system messages (handled separately in Anthropic)."""
|
"""Filter out system messages (handled separately in Anthropic)."""
|
||||||
@ -105,6 +113,7 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
"messages": anthropic_messages,
|
"messages": anthropic_messages,
|
||||||
"temperature": settings.temperature,
|
"temperature": settings.temperature,
|
||||||
"max_tokens": settings.max_tokens,
|
"max_tokens": settings.max_tokens,
|
||||||
|
"extra_headers": {"anthropic-beta": "web-fetch-2025-09-10"},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Only include top_p if explicitly set
|
# Only include top_p if explicitly set
|
||||||
@ -144,7 +153,6 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
Tuple of (StreamEvent or None, updated current_tool_use or None)
|
Tuple of (StreamEvent or None, updated current_tool_use or None)
|
||||||
"""
|
"""
|
||||||
event_type = getattr(event, "type", None)
|
event_type = getattr(event, "type", None)
|
||||||
|
|
||||||
# Handle error events
|
# Handle error events
|
||||||
if event_type == "error":
|
if event_type == "error":
|
||||||
error = getattr(event, "error", None)
|
error = getattr(event, "error", None)
|
||||||
|
|||||||
@ -171,13 +171,17 @@ class Message:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def user(
|
def user(
|
||||||
text: str | None = None, tool_result: ToolResultContent | None = None
|
text: str | None = None,
|
||||||
|
images: list[Image.Image] | None = None,
|
||||||
|
tool_result: ToolResultContent | None = None,
|
||||||
) -> "Message":
|
) -> "Message":
|
||||||
parts = []
|
parts = []
|
||||||
if text:
|
if text:
|
||||||
parts.append(TextContent(text=text))
|
parts.append(TextContent(text=text))
|
||||||
if tool_result:
|
if tool_result:
|
||||||
parts.append(tool_result)
|
parts.append(tool_result)
|
||||||
|
for image in images or []:
|
||||||
|
parts.append(ImageContent(image=image))
|
||||||
return Message(role=MessageRole.USER, content=parts)
|
return Message(role=MessageRole.USER, content=parts)
|
||||||
|
|
||||||
|
|
||||||
@ -418,7 +422,11 @@ class BaseLLMProvider(ABC):
|
|||||||
"""Convert tool definitions to provider format."""
|
"""Convert tool definitions to provider format."""
|
||||||
if not tools:
|
if not tools:
|
||||||
return None
|
return None
|
||||||
return [self._convert_tool(tool) for tool in tools]
|
converted = [
|
||||||
|
tool.provider_format(self.provider) or self._convert_tool(tool)
|
||||||
|
for tool in tools
|
||||||
|
]
|
||||||
|
return [c for c in converted if c is not None]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(
|
def generate(
|
||||||
@ -625,8 +633,16 @@ class BaseLLMProvider(ABC):
|
|||||||
tool_calls=tool_calls or None,
|
tool_calls=tool_calls or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_messages(self, messages) -> list[Message]:
|
def as_messages(self, messages: list[dict[str, Any] | str]) -> list[Message]:
|
||||||
return [Message.user(text=msg) for msg in messages]
|
def make_message(msg: dict[str, Any] | str) -> Message:
|
||||||
|
if isinstance(msg, str):
|
||||||
|
return Message.user(text=msg)
|
||||||
|
elif isinstance(msg, dict):
|
||||||
|
return Message.user(text=msg["text"], images=msg.get("images"))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown message type: {type(msg)}")
|
||||||
|
|
||||||
|
return [make_message(msg) for msg in messages]
|
||||||
|
|
||||||
|
|
||||||
def create_provider(
|
def create_provider(
|
||||||
|
|||||||
@ -151,23 +151,17 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
|
|
||||||
return openai_messages
|
return openai_messages
|
||||||
|
|
||||||
def _convert_tools(
|
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
|
||||||
self, tools: list[ToolDefinition] | None
|
|
||||||
) -> list[dict[str, Any]] | None:
|
|
||||||
"""
|
"""
|
||||||
Convert our tool definitions to OpenAI format.
|
Convert our tool definitions to OpenAI format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tools: List of tool definitions
|
tool: Tool definition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tools in OpenAI format
|
Tool in OpenAI format
|
||||||
"""
|
"""
|
||||||
if not tools:
|
return {
|
||||||
return None
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
@ -175,8 +169,6 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
"parameters": tool.input_schema,
|
"parameters": tool.input_schema,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for tool in tools
|
|
||||||
]
|
|
||||||
|
|
||||||
def _build_request_kwargs(
|
def _build_request_kwargs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -34,3 +34,6 @@ class ToolDefinition:
|
|||||||
|
|
||||||
def __call__(self, input: ToolInput) -> str:
|
def __call__(self, input: ToolInput) -> str:
|
||||||
return self.function(input)
|
return self.function(input)
|
||||||
|
|
||||||
|
def provider_format(self, provider: str) -> dict[str, Any] | None:
|
||||||
|
return None
|
||||||
|
|||||||
36
src/memory/common/llms/tools/base.py
Normal file
36
src/memory/common/llms/tools/base.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from typing import Any
|
||||||
|
from memory.common.llms.tools import ToolDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchTool(ToolDefinition):
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
defaults = {
|
||||||
|
"name": "web_search",
|
||||||
|
"description": "Search the web for information",
|
||||||
|
"input_schema": {},
|
||||||
|
"function": lambda input: "result",
|
||||||
|
}
|
||||||
|
super().__init__(**(defaults | kwargs))
|
||||||
|
|
||||||
|
def provider_format(self, provider: str) -> dict[str, Any] | None:
|
||||||
|
if provider == "openai":
|
||||||
|
return {"type": "web_search"}
|
||||||
|
if provider == "anthropic":
|
||||||
|
return {"type": "web_search_20250305", "name": "web_search", "max_uses": 10}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class WebFetchTool(ToolDefinition):
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
defaults = {
|
||||||
|
"name": "web_fetch",
|
||||||
|
"description": "Fetch the contents of a web page",
|
||||||
|
"input_schema": {},
|
||||||
|
"function": lambda input: "result",
|
||||||
|
}
|
||||||
|
super().__init__(**(defaults | kwargs))
|
||||||
|
|
||||||
|
def provider_format(self, provider: str) -> dict[str, Any] | None:
|
||||||
|
if provider == "anthropic":
|
||||||
|
return {"type": "web_fetch_20250910", "name": "web_fetch", "max_uses": 10}
|
||||||
|
return None
|
||||||
@ -73,6 +73,9 @@ WEBPAGE_STORAGE_DIR = pathlib.Path(
|
|||||||
NOTES_STORAGE_DIR = pathlib.Path(
|
NOTES_STORAGE_DIR = pathlib.Path(
|
||||||
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
|
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
|
||||||
)
|
)
|
||||||
|
DISCORD_STORAGE_DIR = pathlib.Path(
|
||||||
|
os.getenv("DISCORD_STORAGE_DIR", FILE_STORAGE_DIR / "discord")
|
||||||
|
)
|
||||||
PRIVATE_DIRS = [
|
PRIVATE_DIRS = [
|
||||||
EMAIL_STORAGE_DIR,
|
EMAIL_STORAGE_DIR,
|
||||||
NOTES_STORAGE_DIR,
|
NOTES_STORAGE_DIR,
|
||||||
@ -88,6 +91,7 @@ storage_dirs = [
|
|||||||
PHOTO_STORAGE_DIR,
|
PHOTO_STORAGE_DIR,
|
||||||
WEBPAGE_STORAGE_DIR,
|
WEBPAGE_STORAGE_DIR,
|
||||||
NOTES_STORAGE_DIR,
|
NOTES_STORAGE_DIR,
|
||||||
|
DISCORD_STORAGE_DIR,
|
||||||
]
|
]
|
||||||
for dir in storage_dirs:
|
for dir in storage_dirs:
|
||||||
dir.mkdir(parents=True, exist_ok=True)
|
dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class SendDMRequest(BaseModel):
|
|||||||
|
|
||||||
class SendChannelRequest(BaseModel):
|
class SendChannelRequest(BaseModel):
|
||||||
bot_id: int
|
bot_id: int
|
||||||
channel_name: str # Channel name (e.g., "memory-errors")
|
channel: int | str # Channel name or ID (ID supports threads)
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ class TypingDMRequest(BaseModel):
|
|||||||
|
|
||||||
class TypingChannelRequest(BaseModel):
|
class TypingChannelRequest(BaseModel):
|
||||||
bot_id: int
|
bot_id: int
|
||||||
channel_name: str
|
channel: int | str # Channel name or ID (ID supports threads)
|
||||||
|
|
||||||
|
|
||||||
class Collector:
|
class Collector:
|
||||||
@ -154,7 +154,7 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
success = await collector.collector.send_to_channel(
|
success = await collector.collector.send_to_channel(
|
||||||
request.channel_name, request.message
|
request.channel, request.message
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send channel message: {e}")
|
logger.error(f"Failed to send channel message: {e}")
|
||||||
@ -163,13 +163,13 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
|||||||
if success:
|
if success:
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": f"Message sent to channel {request.channel_name}",
|
"message": f"Message sent to channel {request.channel}",
|
||||||
"channel": request.channel_name,
|
"channel": request.channel,
|
||||||
}
|
}
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Failed to send message to channel {request.channel_name}",
|
detail=f"Failed to send message to channel {request.channel}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -181,7 +181,7 @@ async def trigger_channel_typing(request: TypingChannelRequest):
|
|||||||
raise HTTPException(status_code=404, detail="Bot not found")
|
raise HTTPException(status_code=404, detail="Bot not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success = await collector.collector.trigger_typing_channel(request.channel_name)
|
success = await collector.collector.trigger_typing_channel(request.channel)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to trigger channel typing: {e}")
|
logger.error(f"Failed to trigger channel typing: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@ -189,13 +189,13 @@ async def trigger_channel_typing(request: TypingChannelRequest):
|
|||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Failed to trigger typing for channel {request.channel_name}",
|
detail=f"Failed to trigger typing for channel {request.channel}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"channel": request.channel_name,
|
"channel": request.channel,
|
||||||
"message": f"Typing triggered for channel {request.channel_name}",
|
"message": f"Typing triggered for channel {request.channel}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,32 @@ from memory.workers.tasks.discord import add_discord_message, edit_discord_messa
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def process_mentions(session: Session | scoped_session, message: str) -> str:
|
||||||
|
"""Convert username mentions (<@username>) to ID mentions (<@123456>)"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
def replace_mention(match):
|
||||||
|
mention_content = match.group(1)
|
||||||
|
# If it's already numeric, leave it alone
|
||||||
|
if mention_content.isdigit():
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
# Look up username in database
|
||||||
|
user = (
|
||||||
|
session.query(DiscordUser)
|
||||||
|
.filter(DiscordUser.username == mention_content)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if user:
|
||||||
|
return f"<@{user.id}>"
|
||||||
|
|
||||||
|
# If user not found, return original
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
return re.sub(r"<@([^>]+)>", replace_mention, message)
|
||||||
|
|
||||||
|
|
||||||
# Pure functions for Discord entity creation/updates
|
# Pure functions for Discord entity creation/updates
|
||||||
def create_or_update_server(
|
def create_or_update_server(
|
||||||
session: Session | scoped_session, guild: discord.Guild | None
|
session: Session | scoped_session, guild: discord.Guild | None
|
||||||
@ -181,6 +207,10 @@ def sync_guild_metadata(guild: discord.Guild) -> None:
|
|||||||
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
||||||
create_or_update_channel(session, channel)
|
create_or_update_channel(session, channel)
|
||||||
|
|
||||||
|
# Sync threads
|
||||||
|
for thread in guild.threads:
|
||||||
|
create_or_update_channel(session, thread)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
@ -233,6 +263,18 @@ class MessageCollector(commands.Bot):
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
# Extract image URLs from attachments
|
||||||
|
image_urls = [
|
||||||
|
att.url
|
||||||
|
for att in message.attachments
|
||||||
|
if att.content_type and att.content_type.startswith("image/")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Determine message metadata (type, reply, thread)
|
||||||
|
message_type, reply_to_id, thread_id = determine_message_metadata(
|
||||||
|
message
|
||||||
|
)
|
||||||
|
|
||||||
# Queue the message for processing
|
# Queue the message for processing
|
||||||
add_discord_message.delay(
|
add_discord_message.delay(
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
@ -242,9 +284,10 @@ class MessageCollector(commands.Bot):
|
|||||||
server_id=message.guild.id if message.guild else None,
|
server_id=message.guild.id if message.guild else None,
|
||||||
content=message.content or "",
|
content=message.content or "",
|
||||||
sent_at=message.created_at.isoformat(),
|
sent_at=message.created_at.isoformat(),
|
||||||
message_reference_id=message.reference.message_id
|
message_reference_id=reply_to_id,
|
||||||
if message.reference
|
message_type=message_type,
|
||||||
else None,
|
thread_id=thread_id,
|
||||||
|
image_urls=image_urls,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error queuing message {message.id}: {e}")
|
logger.error(f"Error queuing message {message.id}: {e}")
|
||||||
@ -288,6 +331,11 @@ class MessageCollector(commands.Bot):
|
|||||||
create_or_update_channel(session, channel)
|
create_or_update_channel(session, channel)
|
||||||
channels_updated += 1
|
channels_updated += 1
|
||||||
|
|
||||||
|
# Refresh all threads in this server
|
||||||
|
for thread in guild.threads:
|
||||||
|
create_or_update_channel(session, thread)
|
||||||
|
channels_updated += 1
|
||||||
|
|
||||||
# Refresh all members in this server (if members intent is enabled)
|
# Refresh all members in this server (if members intent is enabled)
|
||||||
if self.intents.members:
|
if self.intents.members:
|
||||||
for member in guild.members:
|
for member in guild.members:
|
||||||
@ -373,7 +421,11 @@ class MessageCollector(commands.Bot):
|
|||||||
logger.error(f"User {user_identifier} not found")
|
logger.error(f"User {user_identifier} not found")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
await user.send(message)
|
# Post-process mentions to convert usernames to IDs
|
||||||
|
with make_session() as session:
|
||||||
|
processed_message = process_mentions(session, message)
|
||||||
|
|
||||||
|
await user.send(processed_message)
|
||||||
logger.info(f"Sent DM to {user_identifier}")
|
logger.info(f"Sent DM to {user_identifier}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -402,34 +454,50 @@ class MessageCollector(commands.Bot):
|
|||||||
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def send_to_channel(self, channel_name: str, message: str) -> bool:
|
async def send_to_channel(
|
||||||
"""Send a message to a channel by name across all guilds"""
|
self, channel_identifier: int | str, message: str
|
||||||
|
) -> bool:
|
||||||
|
"""Send a message to a channel by name or ID (supports threads)"""
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
channel = await self.get_channel_by_name(channel_name)
|
# Get channel by ID or name
|
||||||
|
if isinstance(channel_identifier, int):
|
||||||
|
channel = self.get_channel(channel_identifier)
|
||||||
|
else:
|
||||||
|
channel = await self.get_channel_by_name(channel_identifier)
|
||||||
|
|
||||||
if not channel:
|
if not channel:
|
||||||
logger.error(f"Channel {channel_name} not found")
|
logger.error(f"Channel {channel_identifier} not found")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
await channel.send(message)
|
# Post-process mentions to convert usernames to IDs
|
||||||
logger.info(f"Sent message to channel {channel_name}")
|
with make_session() as session:
|
||||||
|
processed_message = process_mentions(session, message)
|
||||||
|
|
||||||
|
await channel.send(processed_message)
|
||||||
|
logger.info(f"Sent message to channel {channel_identifier}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
logger.error(f"Failed to send message to channel {channel_identifier}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def trigger_typing_channel(self, channel_name: str) -> bool:
|
async def trigger_typing_channel(self, channel_identifier: int | str) -> bool:
|
||||||
"""Trigger typing indicator in a channel"""
|
"""Trigger typing indicator in a channel by name or ID (supports threads)"""
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
channel = await self.get_channel_by_name(channel_name)
|
# Get channel by ID or name
|
||||||
|
if isinstance(channel_identifier, int):
|
||||||
|
channel = self.get_channel(channel_identifier)
|
||||||
|
else:
|
||||||
|
channel = await self.get_channel_by_name(channel_identifier)
|
||||||
|
|
||||||
if not channel:
|
if not channel:
|
||||||
logger.error(f"Channel {channel_name} not found")
|
logger.error(f"Channel {channel_identifier} not found")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async with channel.typing():
|
async with channel.typing():
|
||||||
@ -437,5 +505,7 @@ class MessageCollector(commands.Bot):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
logger.error(
|
||||||
|
f"Failed to trigger typing for channel {channel_identifier}: {e}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -1,16 +1,20 @@
|
|||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from collections.abc import Collection
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
|
from memory.common import discord, settings
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
|
DiscordMessage,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
ScheduledLLMCall,
|
ScheduledLLMCall,
|
||||||
DiscordMessage,
|
|
||||||
)
|
)
|
||||||
|
from memory.common.db.models.users import BotUser
|
||||||
|
from memory.common.llms.base import create_provider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -199,3 +203,98 @@ def comm_channel_prompt(
|
|||||||
You will be given the last {max_messages} messages in the conversation.
|
You will be given the last {max_messages} messages in the conversation.
|
||||||
Please react to them appropriately. You can return an empty response if you don't have anything to say.
|
Please react to them appropriately. You can return an empty response if you don't have anything to say.
|
||||||
""").format(server_context=server_context, max_messages=max_messages)
|
""").format(server_context=server_context, max_messages=max_messages)
|
||||||
|
|
||||||
|
|
||||||
|
def call_llm(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
bot_user: DiscordUser,
|
||||||
|
from_user: DiscordUser | None,
|
||||||
|
channel: DiscordChannel | None,
|
||||||
|
model: str,
|
||||||
|
system_prompt: str = "",
|
||||||
|
messages: list[str | dict[str, Any]] = [],
|
||||||
|
allowed_tools: Collection[str] | None = None,
|
||||||
|
num_previous_messages: int = 10,
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Call LLM with Discord tools support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
bot_user: Bot user making the call
|
||||||
|
from_user: Discord user who initiated the interaction
|
||||||
|
channel: Discord channel (if any)
|
||||||
|
messages: List of message strings or dicts with text/images
|
||||||
|
model: LLM model to use
|
||||||
|
system_prompt: System prompt
|
||||||
|
allowed_tools: List of allowed tool names (None = all tools allowed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLM response or None if failed
|
||||||
|
"""
|
||||||
|
provider = create_provider(model=model)
|
||||||
|
|
||||||
|
if provider.usage_tracker.is_rate_limited(model):
|
||||||
|
logger.error(
|
||||||
|
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
if from_user and not channel:
|
||||||
|
user_id = cast(int, from_user.id)
|
||||||
|
prev_messages = previous_messages(
|
||||||
|
session,
|
||||||
|
user_id,
|
||||||
|
channel and channel.id,
|
||||||
|
max_messages=num_previous_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
from memory.common.llms.tools.discord import make_discord_tools
|
||||||
|
from memory.common.llms.tools.base import WebSearchTool
|
||||||
|
|
||||||
|
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
|
||||||
|
tools |= {"web_search": WebSearchTool()}
|
||||||
|
|
||||||
|
# Filter to allowed tools if specified
|
||||||
|
if allowed_tools is not None:
|
||||||
|
tools = {name: tool for name, tool in tools.items() if name in allowed_tools}
|
||||||
|
|
||||||
|
if bot_user.system_prompt:
|
||||||
|
system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "")
|
||||||
|
message_content = [m.as_content() for m in prev_messages] + messages
|
||||||
|
return provider.run_with_tools(
|
||||||
|
messages=provider.as_messages(message_content),
|
||||||
|
tools=tools,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||||
|
).response
|
||||||
|
|
||||||
|
|
||||||
|
def send_discord_response(
|
||||||
|
bot_id: int,
|
||||||
|
response: str,
|
||||||
|
channel_id: int | None = None,
|
||||||
|
user_identifier: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Send a response to Discord channel or user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot_id: Bot user ID
|
||||||
|
response: Message to send
|
||||||
|
channel_id: Channel ID (for channel messages)
|
||||||
|
user_identifier: Username (for DMs)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if sent successfully
|
||||||
|
"""
|
||||||
|
if channel_id is not None:
|
||||||
|
logger.info(f"Sending message to channel {channel_id}")
|
||||||
|
return discord.send_to_channel(bot_id, channel_id, response)
|
||||||
|
elif user_identifier is not None:
|
||||||
|
logger.info(f"Sending DM to {user_identifier}")
|
||||||
|
return discord.send_dm(bot_id, user_identifier, response)
|
||||||
|
else:
|
||||||
|
logger.error("Neither channel_id nor user_identifier provided")
|
||||||
|
return False
|
||||||
|
|||||||
@ -4,11 +4,13 @@ Celery tasks for Discord message processing.
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import requests
|
||||||
from sqlalchemy import exc as sqlalchemy_exc
|
from sqlalchemy import exc as sqlalchemy_exc
|
||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
@ -21,9 +23,7 @@ from memory.common.celery_app import (
|
|||||||
)
|
)
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import DiscordMessage, DiscordUser
|
from memory.common.db.models import DiscordMessage, DiscordUser
|
||||||
from memory.common.llms.base import create_provider
|
from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
|
||||||
from memory.common.llms.tools.discord import make_discord_tools
|
|
||||||
from memory.discord.messages import comm_channel_prompt, previous_messages
|
|
||||||
from memory.workers.tasks.content_processing import (
|
from memory.workers.tasks.content_processing import (
|
||||||
check_content_exists,
|
check_content_exists,
|
||||||
create_task_result,
|
create_task_result,
|
||||||
@ -34,6 +34,37 @@ from memory.workers.tasks.content_processing import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_save_images(image_urls: list[str], message_id: int) -> list[str]:
|
||||||
|
"""Download images from URLs and save to disk. Returns relative file paths."""
|
||||||
|
image_dir = settings.DISCORD_STORAGE_DIR / str(message_id)
|
||||||
|
image_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
saved_paths = []
|
||||||
|
for url in image_urls:
|
||||||
|
try:
|
||||||
|
response = requests.get(url, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Generate filename
|
||||||
|
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||||
|
ext = pathlib.Path(url).suffix or ".jpg"
|
||||||
|
ext = ext.split("?")[0]
|
||||||
|
filename = f"{url_hash}{ext}"
|
||||||
|
local_path = image_dir / filename
|
||||||
|
|
||||||
|
# Save image
|
||||||
|
local_path.write_bytes(response.content)
|
||||||
|
|
||||||
|
# Store relative path from FILE_STORAGE_DIR
|
||||||
|
relative_path = local_path.relative_to(settings.FILE_STORAGE_DIR)
|
||||||
|
saved_paths.append(str(relative_path))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download/save image from {url}: {e}")
|
||||||
|
|
||||||
|
return saved_paths
|
||||||
|
|
||||||
|
|
||||||
def get_prev(
|
def get_prev(
|
||||||
session: Session | scoped_session, channel_id: int, sent_at: datetime
|
session: Session | scoped_session, channel_id: int, sent_at: datetime
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
@ -51,50 +82,6 @@ def get_prev(
|
|||||||
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
||||||
|
|
||||||
|
|
||||||
def call_llm(
|
|
||||||
session,
|
|
||||||
message: DiscordMessage,
|
|
||||||
model: str,
|
|
||||||
msgs: list[str] = [],
|
|
||||||
allowed_tools: list[str] | None = None,
|
|
||||||
) -> str | None:
|
|
||||||
provider = create_provider(model=model)
|
|
||||||
if provider.usage_tracker.is_rate_limited(model):
|
|
||||||
logger.error(
|
|
||||||
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
tools = make_discord_tools(
|
|
||||||
message.recipient_user.system_user,
|
|
||||||
message.from_user,
|
|
||||||
message.channel,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
tools = {
|
|
||||||
name: tool
|
|
||||||
for name, tool in tools.items()
|
|
||||||
if message.tool_allowed(name)
|
|
||||||
and (allowed_tools is None or name in allowed_tools)
|
|
||||||
}
|
|
||||||
system_prompt = message.system_prompt or ""
|
|
||||||
system_prompt += comm_channel_prompt(
|
|
||||||
session, message.recipient_user, message.channel
|
|
||||||
)
|
|
||||||
messages = previous_messages(
|
|
||||||
session,
|
|
||||||
message.recipient_user and message.recipient_user.id,
|
|
||||||
message.channel and message.channel.id,
|
|
||||||
max_messages=10,
|
|
||||||
)
|
|
||||||
return provider.run_with_tools(
|
|
||||||
messages=provider.as_messages([m.title for m in messages] + msgs),
|
|
||||||
tools=tools,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
|
||||||
).response
|
|
||||||
|
|
||||||
|
|
||||||
def should_process(message: DiscordMessage) -> bool:
|
def should_process(message: DiscordMessage) -> bool:
|
||||||
if not (
|
if not (
|
||||||
settings.DISCORD_PROCESS_MESSAGES
|
settings.DISCORD_PROCESS_MESSAGES
|
||||||
@ -118,16 +105,26 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
<reason>I want to continue the conversation because I think it's important.</reason>
|
<reason>I want to continue the conversation because I think it's important.</reason>
|
||||||
</response>
|
</response>
|
||||||
""")
|
""")
|
||||||
response = call_llm(
|
|
||||||
session,
|
system_prompt = message.system_prompt or ""
|
||||||
message,
|
system_prompt += comm_channel_prompt(
|
||||||
settings.SUMMARIZER_MODEL,
|
session, message.recipient_user, message.channel
|
||||||
[msg],
|
)
|
||||||
allowed_tools=[
|
allowed_tools = [
|
||||||
"update_channel_summary",
|
"update_channel_summary",
|
||||||
"update_user_summary",
|
"update_user_summary",
|
||||||
"update_server_summary",
|
"update_server_summary",
|
||||||
],
|
]
|
||||||
|
|
||||||
|
response = call_llm(
|
||||||
|
session,
|
||||||
|
bot_user=message.recipient_user,
|
||||||
|
from_user=message.from_user,
|
||||||
|
channel=message.channel,
|
||||||
|
model=settings.SUMMARIZER_MODEL,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
messages=[msg],
|
||||||
|
allowed_tools=message.filter_tools(allowed_tools),
|
||||||
)
|
)
|
||||||
if not response:
|
if not response:
|
||||||
return False
|
return False
|
||||||
@ -143,7 +140,7 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if message.channel and message.channel.server:
|
if message.channel and message.channel.server:
|
||||||
discord.trigger_typing_channel(bot_id, message.channel.name)
|
discord.trigger_typing_channel(bot_id, cast(int, message.channel_id))
|
||||||
else:
|
else:
|
||||||
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
|
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
|
||||||
return True
|
return True
|
||||||
@ -193,21 +190,40 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
response = call_llm(
|
||||||
|
session,
|
||||||
|
bot_user=discord_message.recipient_user,
|
||||||
|
from_user=discord_message.from_user,
|
||||||
|
channel=discord_message.channel,
|
||||||
|
model=settings.DISCORD_MODEL,
|
||||||
|
system_prompt=discord_message.system_prompt,
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to generate Discord response")
|
logger.exception("Failed to generate Discord response")
|
||||||
|
return {
|
||||||
print("response:", response)
|
"status": "error",
|
||||||
|
"error": "Failed to generate Discord response",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
if not response:
|
if not response:
|
||||||
return {
|
return {
|
||||||
"status": "processed",
|
"status": "no-response",
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if discord_message.channel.server:
|
res = send_discord_response(
|
||||||
discord.send_to_channel(bot_id, discord_message.channel.name, response)
|
bot_id=bot_id,
|
||||||
else:
|
response=response,
|
||||||
discord.send_dm(bot_id, discord_message.from_user.username, response)
|
channel_id=discord_message.channel_id,
|
||||||
|
user_identifier=discord_message.from_user
|
||||||
|
and discord_message.from_user.username,
|
||||||
|
)
|
||||||
|
if not res:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": "Failed to send Discord response",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "processed",
|
"status": "processed",
|
||||||
@ -226,6 +242,9 @@ def add_discord_message(
|
|||||||
server_id: int | None = None,
|
server_id: int | None = None,
|
||||||
recipient_id: int | None = None,
|
recipient_id: int | None = None,
|
||||||
message_reference_id: int | None = None,
|
message_reference_id: int | None = None,
|
||||||
|
message_type: str = "default",
|
||||||
|
thread_id: int | None = None,
|
||||||
|
image_urls: list[str] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Add a Discord message to the database.
|
Add a Discord message to the database.
|
||||||
@ -237,6 +256,11 @@ def add_discord_message(
|
|||||||
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
|
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
|
||||||
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
|
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
# Download and save images to disk
|
||||||
|
saved_image_paths = []
|
||||||
|
if image_urls:
|
||||||
|
saved_image_paths = download_and_save_images(image_urls, message_id)
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
discord_message = DiscordMessage(
|
discord_message = DiscordMessage(
|
||||||
modality="text",
|
modality="text",
|
||||||
@ -248,8 +272,10 @@ def add_discord_message(
|
|||||||
from_id=author_id,
|
from_id=author_id,
|
||||||
recipient_id=recipient_id,
|
recipient_id=recipient_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
message_type="reply" if message_reference_id else "default",
|
message_type=message_type,
|
||||||
reply_to_message_id=message_reference_id,
|
reply_to_message_id=message_reference_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
images=saved_image_paths or None,
|
||||||
)
|
)
|
||||||
existing_msg = check_content_exists(
|
existing_msg = check_content_exists(
|
||||||
session, DiscordMessage, message_id=message_id, sha256=content_hash
|
session, DiscordMessage, message_id=message_id, sha256=content_hash
|
||||||
|
|||||||
@ -2,41 +2,48 @@ import logging
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common import settings
|
||||||
from memory.common.db.models import ScheduledLLMCall
|
|
||||||
from memory.common.celery_app import (
|
from memory.common.celery_app import (
|
||||||
app,
|
|
||||||
EXECUTE_SCHEDULED_CALL,
|
EXECUTE_SCHEDULED_CALL,
|
||||||
RUN_SCHEDULED_CALLS,
|
RUN_SCHEDULED_CALLS,
|
||||||
|
app,
|
||||||
)
|
)
|
||||||
from memory.common import llms, discord
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import ScheduledLLMCall
|
||||||
|
from memory.discord.messages import call_llm, send_discord_response
|
||||||
from memory.workers.tasks.content_processing import safe_task_execution
|
from memory.workers.tasks.content_processing import safe_task_execution
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
|
||||||
|
"""Call LLM with tools support for scheduled calls."""
|
||||||
|
if not scheduled_call.discord_user:
|
||||||
|
logger.warning("No discord_user for scheduled call - cannot execute")
|
||||||
|
return None
|
||||||
|
|
||||||
|
model = cast(str, scheduled_call.model or settings.DISCORD_MODEL)
|
||||||
|
system_prompt = cast(str, scheduled_call.system_prompt or "")
|
||||||
|
message = cast(str, scheduled_call.message)
|
||||||
|
allowed_tools_list = cast(list[str] | None, scheduled_call.allowed_tools)
|
||||||
|
|
||||||
|
bot_user = (
|
||||||
|
scheduled_call.user.discord_users and scheduled_call.user.discord_users[0]
|
||||||
|
)
|
||||||
|
return call_llm(
|
||||||
|
session=session,
|
||||||
|
bot_user=bot_user,
|
||||||
|
from_user=scheduled_call.discord_user,
|
||||||
|
channel=scheduled_call.discord_channel,
|
||||||
|
messages=[message],
|
||||||
|
model=model,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
allowed_tools=allowed_tools_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
||||||
"""
|
"""Send the LLM response to Discord user or channel."""
|
||||||
Send the LLM response to the specified Discord user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scheduled_call: The scheduled call object
|
|
||||||
response: The LLM response to send
|
|
||||||
"""
|
|
||||||
# Format the message with topic, model, and response
|
|
||||||
message_parts = []
|
|
||||||
if cast(str, scheduled_call.topic):
|
|
||||||
message_parts.append(f"**Topic:** {scheduled_call.topic}")
|
|
||||||
if cast(str, scheduled_call.model):
|
|
||||||
message_parts.append(f"**Model:** {scheduled_call.model}")
|
|
||||||
message_parts.append(f"**Response:** {response}")
|
|
||||||
|
|
||||||
message = "\n".join(message_parts)
|
|
||||||
|
|
||||||
# Discord has a 2000 character limit, so we may need to split the message
|
|
||||||
if len(message) > 1900: # Leave some buffer
|
|
||||||
message = message[:1900] + "\n\n... (response truncated)"
|
|
||||||
|
|
||||||
bot_id_value = scheduled_call.user_id
|
bot_id_value = scheduled_call.user_id
|
||||||
if bot_id_value is None:
|
if bot_id_value is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -47,15 +54,15 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
|||||||
|
|
||||||
bot_id = cast(int, bot_id_value)
|
bot_id = cast(int, bot_id_value)
|
||||||
|
|
||||||
if discord_user := scheduled_call.discord_user:
|
send_discord_response(
|
||||||
logger.info(f"Sending DM to {discord_user.username}: {message}")
|
bot_id=bot_id,
|
||||||
discord.send_dm(bot_id, discord_user.username, message)
|
response=response,
|
||||||
elif discord_channel := scheduled_call.discord_channel:
|
channel_id=cast(int, scheduled_call.discord_channel.id)
|
||||||
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
|
if scheduled_call.discord_channel
|
||||||
discord.broadcast_message(bot_id, discord_channel.name, message)
|
else None,
|
||||||
else:
|
user_identifier=scheduled_call.discord_user.username
|
||||||
logger.warning(
|
if scheduled_call.discord_user
|
||||||
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -92,15 +99,29 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
|||||||
|
|
||||||
logger.info(f"Calling LLM with model {scheduled_call.model}")
|
logger.info(f"Calling LLM with model {scheduled_call.model}")
|
||||||
|
|
||||||
# Make the LLM call
|
# Make the LLM call with tools support
|
||||||
if scheduled_call.model:
|
try:
|
||||||
response = llms.summarize(
|
response = _call_llm_for_scheduled(session, scheduled_call)
|
||||||
prompt=cast(str, scheduled_call.message),
|
except Exception:
|
||||||
model=cast(str, scheduled_call.model),
|
logger.exception("Failed to generate LLM response")
|
||||||
system_prompt=cast(str, scheduled_call.system_prompt),
|
scheduled_call.status = "failed"
|
||||||
)
|
scheduled_call.error_message = "LLM call failed"
|
||||||
else:
|
session.commit()
|
||||||
response = cast(str, scheduled_call.message)
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "LLM call failed",
|
||||||
|
"scheduled_call_id": scheduled_call_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
scheduled_call.status = "failed"
|
||||||
|
scheduled_call.error_message = "No response from LLM"
|
||||||
|
session.commit()
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "No response from LLM",
|
||||||
|
"scheduled_call_id": scheduled_call_id,
|
||||||
|
}
|
||||||
|
|
||||||
# Store the response
|
# Store the response
|
||||||
scheduled_call.response = response
|
scheduled_call.response = response
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user