add Discord ingester

This commit is contained in:
Daniel O'Connell 2025-10-12 23:13:30 +02:00
parent f454aa9afa
commit e086b4a3a6
18 changed files with 1100 additions and 224 deletions

View File

@ -9,7 +9,6 @@ Create Date: 2025-10-12 10:12:57.421009
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
@ -20,10 +19,8 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Rename prompt column to message in scheduled_llm_calls table
op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message")
def downgrade() -> None:
# Rename message column back to prompt in scheduled_llm_calls table
op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt")

View File

@ -0,0 +1,165 @@
"""add_discord_models
Revision ID: a8c8e8b17179
Revises: c86079073c1d
Create Date: 2025-10-12 22:28:27.856164
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "a8c8e8b17179"
down_revision: Union[str, None] = "c86079073c1d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"discord_servers",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("member_count", sa.Integer(), nullable=True),
sa.Column(
"track_messages", sa.Boolean(), server_default="true", nullable=False
),
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"discord_servers_active_idx",
"discord_servers",
["track_messages", "last_sync_at"],
unique=False,
)
op.create_table(
"discord_channels",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("server_id", sa.BigInteger(), nullable=True),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("channel_type", sa.Text(), nullable=False),
sa.Column("track_messages", sa.Boolean(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(
["server_id"],
["discord_servers.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"discord_channels_server_idx", "discord_channels", ["server_id"], unique=False
)
op.create_table(
"discord_users",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("username", sa.Text(), nullable=False),
sa.Column("display_name", sa.Text(), nullable=True),
sa.Column("system_user_id", sa.Integer(), nullable=True),
sa.Column(
"allow_dm_tracking", sa.Boolean(), server_default="true", nullable=False
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(
["system_user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"discord_users_system_user_idx",
"discord_users",
["system_user_id"],
unique=False,
)
op.create_table(
"discord_message",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("server_id", sa.BigInteger(), nullable=True),
sa.Column("channel_id", sa.BigInteger(), nullable=False),
sa.Column("discord_user_id", sa.BigInteger(), nullable=False),
sa.Column("message_id", sa.BigInteger(), nullable=False),
sa.Column("message_type", sa.Text(), server_default="default", nullable=True),
sa.Column("reply_to_message_id", sa.BigInteger(), nullable=True),
sa.Column("thread_id", sa.BigInteger(), nullable=True),
sa.Column("edited_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["channel_id"],
["discord_channels.id"],
),
sa.ForeignKeyConstraint(
["discord_user_id"],
["discord_users.id"],
),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["server_id"],
["discord_servers.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"discord_message_discord_id_idx", "discord_message", ["message_id"], unique=True
)
op.create_index(
"discord_message_server_channel_idx",
"discord_message",
["server_id", "channel_id"],
unique=False,
)
op.create_index(
"discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False
)
def downgrade() -> None:
op.drop_index("discord_message_user_idx", table_name="discord_message")
op.drop_index("discord_message_server_channel_idx", table_name="discord_message")
op.drop_index("discord_message_discord_id_idx", table_name="discord_message")
op.drop_table("discord_message")
op.drop_index("discord_users_system_user_idx", table_name="discord_users")
op.drop_table("discord_users")
op.drop_index("discord_channels_server_idx", table_name="discord_channels")
op.drop_table("discord_channels")
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
op.drop_table("discord_servers")

View File

@ -174,7 +174,7 @@ services:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes,scheduler"
QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
ingest-hub:
<<: *worker-base

View File

@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
git config --global user.name "${GIT_USER_NAME}"
# Default queues to process
ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance"
ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
ENV PYTHONPATH="/app"
ENTRYPOINT ["./entry.sh"]

View File

@ -5,3 +5,4 @@ python-multipart==0.0.9
sqladmin==0.20.1
mcp==1.10.0
bm25s[full]==0.2.13
discord.py==2.3.2

View File

@ -12,6 +12,9 @@ MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
NOTES_ROOT = "memory.workers.tasks.notes"
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
DISCORD_ROOT = "memory.workers.tasks.discord"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
@ -72,17 +75,18 @@ app.conf.update(
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1,
task_routes={
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
f"{DISCORD_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
f"{MAINTENANCE_ROOT}.*": {
"queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
},
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
f"{SCHEDULED_CALLS_ROOT}.*": {
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
},

View File

@ -11,6 +11,7 @@ from memory.common.db.models.source_items import (
EmailAttachment,
AgentObservation,
ChatMessage,
DiscordMessage,
BlogPost,
Comic,
BookSection,
@ -40,6 +41,9 @@ from memory.common.db.models.sources import (
Book,
ArticleFeed,
EmailAccount,
DiscordServer,
DiscordChannel,
DiscordUser,
)
from memory.common.db.models.users import (
User,
@ -74,6 +78,7 @@ __all__ = [
"EmailAttachment",
"AgentObservation",
"ChatMessage",
"DiscordMessage",
"BlogPost",
"Comic",
"BookSection",
@ -93,6 +98,9 @@ __all__ = [
"Book",
"ArticleFeed",
"EmailAccount",
"DiscordServer",
"DiscordChannel",
"DiscordUser",
# Users
"User",
"UserSession",

View File

@ -70,7 +70,7 @@ class ScheduledLLMCall(Base):
"created_at": print_datetime(cast(datetime, self.created_at)),
"executed_at": print_datetime(cast(datetime, self.executed_at)),
"model": self.model,
"prompt": self.message,
"message": self.message,
"system_prompt": self.system_prompt,
"allowed_tools": self.allowed_tools,
"discord_channel": self.discord_channel,

View File

@ -262,7 +262,7 @@ class ChatMessage(SourceItem):
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
platform = Column(Text)
channel_id = Column(Text)
channel_id = Column(Text) # Keep as Text for cross-platform compatibility
author = Column(Text)
sent_at = Column(DateTime(timezone=True))
@ -274,6 +274,64 @@ class ChatMessage(SourceItem):
__table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),)
class DiscordMessage(SourceItem):
"""Discord-specific chat message with rich metadata"""
__tablename__ = "discord_message"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
sent_at = Column(DateTime(timezone=True), nullable=False)
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False)
discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID
# Discord-specific metadata
message_type = Column(
Text, server_default="default"
) # "default", "reply", "thread_starter"
reply_to_message_id = Column(
BigInteger, nullable=True
) # Discord message snowflake ID if replying
thread_id = Column(
BigInteger, nullable=True
) # Discord thread snowflake ID if in thread
edited_at = Column(DateTime(timezone=True), nullable=True)
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
server = relationship("DiscordServer", foreign_keys=[server_id])
discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
@property
def title(self) -> str:
return f"{self.discord_user.username}: {self.content}"
__mapper_args__ = {
"polymorphic_identity": "discord_message",
}
__table_args__ = (
Index("discord_message_discord_id_idx", "message_id", unique=True),
Index(
"discord_message_server_channel_idx",
"server_id",
"channel_id",
),
Index("discord_message_user_idx", "discord_user_id"),
)
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
content = cast(str | None, self.content)
if not content:
return []
prev = getattr(self, "messages_before", [])
content = "\n\n".join(prev) + "\n\n" + self.title
return extract.extract_text(content)
class GitCommit(SourceItem):
__tablename__ = "git_commit"

View File

@ -10,12 +10,14 @@ from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Text,
func,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
@ -123,3 +125,69 @@ class EmailAccount(Base):
Index("email_accounts_active_idx", "active", "last_sync_at"),
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
)
class DiscordServer(Base):
"""Discord server configuration and metadata"""
__tablename__ = "discord_servers"
id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
name = Column(Text, nullable=False)
description = Column(Text)
member_count = Column(Integer)
# Collection settings
track_messages = Column(Boolean, nullable=False, server_default="true")
last_sync_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
channels = relationship(
"DiscordChannel", back_populates="server", cascade="all, delete-orphan"
)
__table_args__ = (
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
)
class DiscordChannel(Base):
"""Discord channel metadata and configuration"""
__tablename__ = "discord_channels"
id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
name = Column(Text, nullable=False)
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
# Collection settings (null = inherit from server)
track_messages = Column(Boolean, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
server = relationship("DiscordServer", back_populates="channels")
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
class DiscordUser(Base):
"""Discord user metadata and preferences"""
__tablename__ = "discord_users"
id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
username = Column(Text, nullable=False)
display_name = Column(Text)
# Link to system user if registered
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
# Basic DM settings
allow_dm_tracking = Column(Boolean, nullable=False, server_default="true")
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
system_user = relationship("User", back_populates="discord_users")
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)

View File

@ -50,6 +50,7 @@ class User(Base):
oauth_states = relationship(
"OAuthState", back_populates="user", cascade="all, delete-orphan"
)
discord_users = relationship("DiscordUser", back_populates="system_user")
def serialize(self) -> dict:
return {

View File

@ -1,221 +1,101 @@
"""
Discord integration.
Simple HTTP client that communicates with the Discord collector's API server.
"""
import logging
import requests
import re
from typing import Any
from memory.common import settings
logger = logging.getLogger(__name__)
ERROR_CHANNEL = "memory-errors"
ACTIVITY_CHANNEL = "memory-activity"
DISCOVERY_CHANNEL = "memory-discoveries"
CHAT_CHANNEL = "memory-chat"
def get_api_url() -> str:
"""Get the Discord API server URL"""
host = settings.DISCORD_COLLECTOR_SERVER_URL
port = settings.DISCORD_COLLECTOR_PORT
return f"http://{host}:{port}"
class DiscordServer(requests.Session):
def __init__(self, server_id: str, server_name: str, *args, **kwargs):
self.server_id = server_id
self.server_name = server_name
self.channels = {}
super().__init__(*args, **kwargs)
self.setup_channels()
self.members = self.fetch_all_members()
def setup_channels(self):
resp = self.get(self.channels_url)
resp.raise_for_status()
channels = {channel["name"]: channel["id"] for channel in resp.json()}
if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)):
error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL)
self.channels[ERROR_CHANNEL] = error_channel
if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)):
activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL)
self.channels[ACTIVITY_CHANNEL] = activity_channel
if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)):
discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL)
self.channels[DISCOVERY_CHANNEL] = discovery_channel
if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)):
chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL)
self.channels[CHAT_CHANNEL] = chat_channel
@property
def error_channel(self) -> str:
return self.channels[ERROR_CHANNEL]
@property
def activity_channel(self) -> str:
return self.channels[ACTIVITY_CHANNEL]
@property
def discovery_channel(self) -> str:
return self.channels[DISCOVERY_CHANNEL]
@property
def chat_channel(self) -> str:
return self.channels[CHAT_CHANNEL]
def channel_id(self, channel_name: str) -> str:
if not (channel_id := self.channels.get(channel_name)):
raise ValueError(f"Channel {channel_name} not found")
return channel_id
def send_message(self, channel_id: str, content: str):
payload: dict[str, Any] = {"content": content}
mentions = re.findall(r"@(\S*)", content)
users = {u: i for u, i in self.members.items() if u in mentions}
if users:
for u, i in users.items():
payload["content"] = payload["content"].replace(f"@{u}", f"<@{i}>")
payload["allowed_mentions"] = {
"parse": [],
"users": list(users.values()),
}
return self.post(
f"https://discord.com/api/v10/channels/{channel_id}/messages",
json=payload,
)
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
resp = self.post(
self.channels_url, json={"name": channel_name, "type": channel_type}
)
resp.raise_for_status()
return resp.json()["id"]
def __str__(self):
return (
f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})"
)
def request(self, method: str, url: str, **kwargs):
headers = kwargs.get("headers", {})
headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}"
headers["Content-Type"] = "application/json"
kwargs["headers"] = headers
return super().request(method, url, **kwargs)
@property
def channels_url(self) -> str:
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
@property
def members_url(self) -> str:
return f"https://discord.com/api/v10/guilds/{self.server_id}/members"
@property
def dm_create_url(self) -> str:
return "https://discord.com/api/v10/users/@me/channels"
def list_members(
self, limit: int = 1000, after: str | None = None
) -> list[dict[str, Any]]:
"""List up to `limit` members in this guild, starting after a user ID.
Requires the bot to have the Server Members Intent enabled in the Discord developer portal.
"""
params: dict[str, Any] = {"limit": limit}
if after:
params["after"] = after
resp = self.get(self.members_url, params=params)
resp.raise_for_status()
return resp.json()
def fetch_all_members(self, page_size: int = 1000) -> dict[str, str]:
"""Retrieve all members in the guild by paginating the members list.
Note: Large guilds may take multiple requests. Rate limits are respected by requests.Session automatically.
"""
members: dict[str, str] = {}
after: str | None = None
while batch := self.list_members(limit=page_size, after=after):
for member in batch:
user = member.get("user", {})
members[user.get("global_name") or user.get("username", "")] = user.get(
"id", ""
)
after = user.get("id", "")
return members
def create_dm_channel(self, user_id: str) -> str:
"""Create (or retrieve) a DM channel with the given user and return the channel ID.
The bot must share a guild with the user, and the user's privacy settings must allow DMs from server members.
"""
resp = self.post(self.dm_create_url, json={"recipient_id": user_id})
resp.raise_for_status()
data = resp.json()
return data["id"]
def send_dm(self, user_id: str, content: str):
"""Send a direct message to a specific user by ID."""
channel_id = self.create_dm_channel(self.members.get(user_id) or user_id)
return self.post(
f"https://discord.com/api/v10/channels/{channel_id}/messages",
json={"content": content},
)
def get_bot_servers() -> list[dict[str, Any]]:
"""Get list of servers the bot is in."""
if not settings.DISCORD_BOT_TOKEN:
return []
def send_dm(user_identifier: str, message: str) -> bool:
"""Send a DM via the Discord collector API"""
try:
headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"}
response = requests.get(
"https://discord.com/api/v10/users/@me/guilds", headers=headers
response = requests.post(
f"{get_api_url()}/send_dm",
json={"user_identifier": user_identifier, "message": message},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to send DM to {user_identifier}: {e}")
return False
def broadcast_message(channel_name: str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/send_channel",
json={"channel_name": channel_name, "message": message},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}")
return False
def is_collector_healthy() -> bool:
"""Check if the Discord collector is running and healthy"""
try:
response = requests.get(f"{get_api_url()}/health", timeout=5)
response.raise_for_status()
result = response.json()
return result.get("status") == "healthy"
except requests.RequestException:
return False
def refresh_discord_metadata() -> dict[str, int] | None:
"""Refresh Discord server/channel/user metadata from Discord API"""
try:
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Failed to get bot servers: {e}")
return []
except requests.RequestException as e:
logger.error(f"Failed to refresh Discord metadata: {e}")
return None
servers: dict[str, DiscordServer] = {}
# Convenience functions
def send_error_message(message: str) -> bool:
"""Send an error message to the error channel"""
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
def load_servers():
for server in get_bot_servers():
servers[server["id"]] = DiscordServer(server["id"], server["name"])
def send_activity_message(message: str) -> bool:
"""Send an activity message to the activity channel"""
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
def broadcast_message(channel: str, message: str):
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return
for server in servers.values():
server.send_message(server.channel_id(channel), message)
def send_discovery_message(message: str) -> bool:
"""Send a discovery message to the discovery channel"""
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
def send_error_message(message: str):
broadcast_message(ERROR_CHANNEL, message)
def send_activity_message(message: str):
broadcast_message(ACTIVITY_CHANNEL, message)
def send_discovery_message(message: str):
broadcast_message(DISCOVERY_CHANNEL, message)
def send_chat_message(message: str):
broadcast_message(CHAT_CHANNEL, message)
def send_dm(user_id: str, message: str):
for server in servers.values():
if not server.members.get(user_id) and user_id not in server.members.values():
continue
server.send_dm(user_id, message)
def send_chat_message(message: str) -> bool:
"""Send a chat message to the chat channel"""
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
def notify_task_failure(
@ -234,9 +114,6 @@ def notify_task_failure(
task_args: Task arguments
task_kwargs: Task keyword arguments
traceback_str: Full traceback string
Returns:
True if notification sent successfully
"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
logger.debug("Discord notifications disabled")

View File

@ -172,3 +172,11 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
DISCORD_NOTIFICATIONS_ENABLED = bool(
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
)
# Discord collector settings
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True)
DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8001))
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1")
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))

View File

@ -0,0 +1,166 @@
"""
Discord API server.
FastAPI server that owns and manages a Discord collector instance,
providing HTTP endpoints for sending Discord messages.
"""
import asyncio
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from memory.common import settings
from memory.workers.discord.collector import MessageCollector
logger = logging.getLogger(__name__)
class SendDMRequest(BaseModel):
user: str # Discord user ID or username
message: str
class SendChannelRequest(BaseModel):
channel_name: str # Channel name (e.g., "memory-errors")
message: str
# Application state
class AppState:
def __init__(self):
self.collector: MessageCollector | None = None
self.collector_task: asyncio.Task | None = None
app_state = AppState()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage Discord collector lifecycle"""
if not settings.DISCORD_BOT_TOKEN:
logger.error("DISCORD_BOT_TOKEN not configured")
return
# Create and start the collector
app_state.collector = MessageCollector()
app_state.collector_task = asyncio.create_task(
app_state.collector.start(settings.DISCORD_BOT_TOKEN)
)
logger.info("Discord collector started")
yield
# Cleanup
if app_state.collector and not app_state.collector.is_closed():
await app_state.collector.close()
if app_state.collector_task:
app_state.collector_task.cancel()
try:
await app_state.collector_task
except asyncio.CancelledError:
pass
logger.info("Discord collector stopped")
# FastAPI app with lifespan management
app = FastAPI(title="Discord Collector API", version="1.0.0", lifespan=lifespan)
@app.post("/send_dm")
async def send_dm_endpoint(request: SendDMRequest):
"""Send a DM via the collector's Discord client"""
if not app_state.collector:
raise HTTPException(status_code=503, detail="Discord collector not running")
try:
success = await app_state.collector.send_dm(request.user, request.message)
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to send DM to {request.user}",
)
return {
"success": True,
"message": f"DM sent to {request.user}",
"user": request.user,
}
except Exception as e:
logger.error(f"Failed to send DM: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/send_channel")
async def send_channel_endpoint(request: SendChannelRequest):
"""Send a message to a channel via the collector's Discord client"""
if not app_state.collector:
raise HTTPException(status_code=503, detail="Discord collector not running")
try:
success = await app_state.collector.send_to_channel(
request.channel_name, request.message
)
if success:
return {
"success": True,
"message": f"Message sent to channel {request.channel_name}",
"channel": request.channel_name,
}
else:
raise HTTPException(
status_code=400,
detail=f"Failed to send message to channel {request.channel_name}",
)
except Exception as e:
logger.error(f"Failed to send channel message: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Check if the Discord collector is running and healthy"""
if not app_state.collector:
raise HTTPException(status_code=503, detail="Discord collector not running")
collector = app_state.collector
return {
"status": "healthy",
"connected": not collector.is_closed(),
"user": str(collector.user) if collector.user else None,
"guilds": len(collector.guilds) if collector.guilds else 0,
}
@app.post("/refresh_metadata")
async def refresh_metadata():
"""Refresh Discord server/channel/user metadata from Discord API"""
if not app_state.collector:
raise HTTPException(status_code=503, detail="Discord collector not running")
try:
result = await app_state.collector.refresh_metadata()
return {"success": True, "message": "Metadata refreshed successfully", **result}
except Exception as e:
logger.error(f"Failed to refresh metadata: {e}")
raise HTTPException(status_code=500, detail=str(e))
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
"""Run the Discord API server"""
uvicorn.run(app, host=host, port=port, log_level="debug")
if __name__ == "__main__":
# For testing the API server standalone
host = settings.DISCORD_COLLECTOR_SERVER_URL
port = settings.DISCORD_COLLECTOR_PORT
run_discord_api_server(host, port)

View File

@ -0,0 +1,398 @@
"""
Discord message collector.
Core message collection functionality - stores Discord messages to database.
"""
import logging
from datetime import datetime, timezone
import discord
from discord.ext import commands
from sqlalchemy.orm import Session, scoped_session
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models.sources import (
DiscordServer,
DiscordChannel,
DiscordUser,
)
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
logger = logging.getLogger(__name__)
# Pure functions for Discord entity creation/updates
def create_or_update_server(
session: Session | scoped_session, guild: discord.Guild
) -> DiscordServer:
"""Get or create DiscordServer record (pure DB operation)"""
server = session.query(DiscordServer).get(guild.id)
if not server:
server = DiscordServer(
id=guild.id,
name=guild.name,
description=guild.description,
member_count=guild.member_count,
)
session.add(server)
session.flush() # Get the ID
logger.info(f"Created server record for {guild.name} ({guild.id})")
else:
# Update metadata
server.name = guild.name
server.description = guild.description
server.member_count = guild.member_count
server.last_sync_at = datetime.now(timezone.utc)
return server
def determine_channel_metadata(channel) -> tuple[str, int | None, str]:
"""Pure function to determine channel type, server_id, and name"""
if isinstance(channel, discord.DMChannel):
return "dm", None, f"DM with {channel.recipient.name}"
elif isinstance(channel, discord.GroupChannel):
return "group_dm", None, channel.name or "Group DM"
elif isinstance(
channel, (discord.TextChannel, discord.VoiceChannel, discord.Thread)
):
return (
channel.__class__.__name__.lower().replace("channel", ""),
channel.guild.id,
channel.name,
)
else:
guild = getattr(channel, "guild", None)
server_id = guild.id if guild else None
name = getattr(channel, "name", f"Unknown-{channel.id}")
return "unknown", server_id, name
def create_or_update_channel(
session: Session | scoped_session, channel
) -> DiscordChannel:
"""Get or create DiscordChannel record (pure DB operation)"""
discord_channel = session.query(DiscordChannel).get(channel.id)
if not discord_channel:
channel_type, server_id, name = determine_channel_metadata(channel)
discord_channel = DiscordChannel(
id=channel.id,
server_id=server_id,
name=name,
channel_type=channel_type,
)
session.add(discord_channel)
session.flush()
logger.debug(f"Created channel: {name}")
elif hasattr(channel, "name"):
discord_channel.name = channel.name
return discord_channel
def create_or_update_user(
session: Session | scoped_session, user: discord.User | discord.Member
) -> DiscordUser:
"""Get or create DiscordUser record (pure DB operation)"""
discord_user = session.query(DiscordUser).get(user.id)
if not discord_user:
discord_user = DiscordUser(
id=user.id,
username=user.name,
display_name=user.display_name,
)
session.add(discord_user)
session.flush()
logger.debug(f"Created user: {user.name}")
else:
# Update user info in case it changed
discord_user.username = user.name
discord_user.display_name = user.display_name
return discord_user
def determine_message_metadata(
message: discord.Message,
) -> tuple[str, int | None, int | None]:
"""Pure function to determine message type, reply_to_id, and thread_id"""
message_type = "default"
reply_to_id = None
thread_id = None
if message.reference and message.reference.message_id:
message_type = "reply"
reply_to_id = message.reference.message_id
if hasattr(message.channel, "parent") and message.channel.parent:
thread_id = message.channel.id
return message_type, reply_to_id, thread_id
def should_track_message(
server: DiscordServer | None,
channel: DiscordChannel,
user: DiscordUser,
) -> bool:
"""Pure function to determine if we should track this message"""
if server and not server.track_messages: # type: ignore
return False
if not channel.track_messages:
return False
if channel.channel_type in ("dm", "group_dm"):
return bool(user.allow_dm_tracking)
# Default: track the message
return True
def should_collect_bot_message(message: discord.Message) -> bool:
"""Pure function to determine if we should collect bot messages"""
return not message.author.bot or settings.DISCORD_COLLECT_BOTS
def sync_guild_metadata(guild: discord.Guild) -> None:
"""Sync a single guild's metadata (functional approach)"""
with make_session() as session:
create_or_update_server(session, guild)
for channel in guild.channels:
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
create_or_update_channel(session, channel)
session.commit()
class MessageCollector(commands.Bot):
"""Discord bot that collects and stores messages (thin event handler)"""
def __init__(self):
intents = discord.Intents.default()
intents.message_content = True
intents.guilds = True
intents.members = True
intents.dm_messages = True
super().__init__(
command_prefix="!memory_", # Prefix to avoid conflicts
intents=intents,
help_command=None, # Disable default help
)
async def on_ready(self):
"""Called when bot connects to Discord"""
logger.info(f"Discord collector connected as {self.user}")
logger.info(f"Connected to {len(self.guilds)} servers")
# Sync server and channel metadata
await self.sync_servers_and_channels()
logger.info("Discord message collector ready")
async def on_message(self, message: discord.Message):
"""Queue incoming message for database storage"""
try:
if should_collect_bot_message(message):
# Ensure Discord entities exist in database first
with make_session() as session:
create_or_update_user(session, message.author)
create_or_update_channel(session, message.channel)
if message.guild:
create_or_update_server(session, message.guild)
session.commit()
# Queue the message for processing
add_discord_message.delay(
message_id=message.id,
channel_id=message.channel.id,
author_id=message.author.id,
server_id=message.guild.id if message.guild else None,
content=message.content or "",
sent_at=message.created_at.isoformat(),
message_reference_id=message.reference.message_id
if message.reference
else None,
)
except Exception as e:
logger.error(f"Error queuing message {message.id}: {e}")
async def on_message_edit(self, before: discord.Message, after: discord.Message):
"""Queue message edit for database update"""
try:
edit_time = after.edited_at or datetime.now(timezone.utc)
edit_discord_message.delay(
message_id=after.id,
content=after.content,
edited_at=edit_time.isoformat(),
)
except Exception as e:
logger.error(f"Error queuing message edit {after.id}: {e}")
async def sync_servers_and_channels(self):
"""Sync server and channel metadata on startup"""
for guild in self.guilds:
sync_guild_metadata(guild)
logger.info(f"Synced {len(self.guilds)} servers and their channels")
async def refresh_metadata(self) -> dict[str, int]:
"""Refresh server and channel metadata from Discord and update database"""
print("🔄 Refreshing Discord metadata...")
servers_updated = 0
channels_updated = 0
users_updated = 0
with make_session() as session:
# Refresh all servers
for guild in self.guilds:
create_or_update_server(session, guild)
servers_updated += 1
# Refresh all channels in this server
for channel in guild.channels:
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
create_or_update_channel(session, channel)
channels_updated += 1
# Refresh all members in this server (if members intent is enabled)
if self.intents.members:
for member in guild.members:
create_or_update_user(session, member)
users_updated += 1
session.commit()
result = {
"servers_updated": servers_updated,
"channels_updated": channels_updated,
"users_updated": users_updated,
}
print(f"✅ Metadata refresh complete: {result}")
logger.info(f"Metadata refresh complete: {result}")
return result
async def get_user(self, user_identifier: int | str) -> discord.User | None:
"""Get a Discord user by ID or username"""
if isinstance(user_identifier, int):
# Direct user ID lookup
if user := super().get_user(user_identifier):
return user
try:
return await self.fetch_user(user_identifier)
except discord.NotFound:
return None
else:
# Username lookup - search through all guilds
for guild in self.guilds:
for member in guild.members:
if (
member.name == user_identifier
or member.display_name == user_identifier
or f"{member.name}#{member.discriminator}" == user_identifier
):
return member
return None
async def get_channel_by_name(
self, channel_name: str
) -> discord.TextChannel | None:
"""Get a Discord channel by name (does not create if missing)"""
# Search all guilds for the channel
for guild in self.guilds:
for ch in guild.channels:
if isinstance(ch, discord.TextChannel) and ch.name == channel_name:
return ch
return None
async def create_channel(
self, channel_name: str, guild_id: int | None = None
) -> discord.TextChannel | None:
"""Create a Discord channel in the specified guild (or first guild if none specified)"""
target_guild = None
if guild_id:
target_guild = self.get_guild(guild_id)
elif self.guilds:
target_guild = self.guilds[0] # Default to first guild
if not target_guild:
logger.error(f"No guild available to create channel {channel_name}")
return None
try:
channel = await target_guild.create_text_channel(channel_name)
logger.info(f"Created channel {channel_name} in {target_guild.name}")
return channel
except Exception as e:
logger.error(
f"Failed to create channel {channel_name} in {target_guild.name}: {e}"
)
return None
async def send_dm(self, user_identifier: int | str, message: str) -> bool:
"""Send a DM using this collector's Discord client"""
try:
user = await self.get_user(user_identifier)
if not user:
logger.error(f"User {user_identifier} not found")
return False
await user.send(message)
logger.info(f"Sent DM to {user_identifier}")
return True
except Exception as e:
logger.error(f"Failed to send DM to {user_identifier}: {e}")
return False
async def send_to_channel(self, channel_name: str, message: str) -> bool:
"""Send a message to a channel by name across all guilds"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False
try:
channel = await self.get_channel_by_name(channel_name)
if not channel:
logger.error(f"Channel {channel_name} not found")
return False
await channel.send(message)
logger.info(f"Sent message to channel {channel_name}")
return True
except Exception as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}")
return False
async def run_collector():
"""Run the Discord message collector"""
if not settings.DISCORD_BOT_TOKEN:
logger.error("DISCORD_BOT_TOKEN not configured")
return
collector = MessageCollector()
try:
await collector.start(settings.DISCORD_BOT_TOKEN)
except Exception as e:
logger.error(f"Discord collector failed: {e}")
raise
if __name__ == "__main__":
import asyncio
asyncio.run(run_collector())

View File

@ -6,6 +6,7 @@ from memory.workers.tasks import (
email,
comic,
blogs,
discord,
ebook,
forums,
maintenance,
@ -20,6 +21,7 @@ __all__ = [
"comic",
"blogs",
"ebook",
"discord",
"forums",
"maintenance",
"notes",

View File

@ -10,9 +10,9 @@ from collections import defaultdict
import hashlib
import traceback
import logging
from typing import Any, Callable, Iterable, Sequence, cast
from typing import Any, Callable, Sequence, cast
from memory.common import embedding, qdrant, settings
from memory.common import embedding, qdrant
from memory.common.db.models import SourceItem, Chunk
from memory.common.discord import notify_task_failure
@ -38,19 +38,12 @@ def check_content_exists(
Returns:
Existing SourceItem if found, None otherwise
"""
query = session.query(model_class)
for key, value in kwargs.items():
if not hasattr(model_class, key):
continue
if hasattr(model_class, key):
query = query.filter(getattr(model_class, key) == value)
existing = (
session.query(model_class)
.filter(getattr(model_class, key) == value)
.first()
)
if existing:
return existing
return None
return query.first()
def create_content_hash(content: str, *additional_data: str) -> bytes:
@ -286,6 +279,6 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
traceback_str=traceback_str,
)
return {"status": "error", "error": str(e)}
return {"status": "error", "error": str(e), "traceback": traceback_str}
return wrapper

View File

@ -0,0 +1,130 @@
"""
Celery tasks for Discord message processing.
"""
import hashlib
import logging
from datetime import datetime
from typing import Any
from memory.common.celery_app import app
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordMessage, DiscordUser
from memory.workers.tasks.content_processing import (
safe_task_execution,
check_content_exists,
create_task_result,
process_content_item,
)
from memory.common.celery_app import ADD_DISCORD_MESSAGE, EDIT_DISCORD_MESSAGE
from memory.common import settings
from sqlalchemy.orm import Session, scoped_session
logger = logging.getLogger(__name__)
def get_prev(
session: Session | scoped_session, channel_id: int, sent_at: datetime
) -> list[str]:
prev = (
session.query(DiscordUser.username, DiscordMessage.content)
.join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id)
.filter(
DiscordMessage.channel_id == channel_id,
DiscordMessage.sent_at < sent_at,
)
.order_by(DiscordMessage.sent_at.desc())
.limit(settings.DISCORD_CONTEXT_WINDOW)
.all()
)
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
@app.task(name=ADD_DISCORD_MESSAGE)
@safe_task_execution
def add_discord_message(
message_id: int,
channel_id: int,
author_id: int,
content: str,
sent_at: str,
server_id: int | None = None,
message_reference_id: int | None = None,
) -> dict[str, Any]:
"""
Add a Discord message to the database.
This task is queued by the Discord collector when messages are received.
"""
logger.info(f"Adding Discord message {message_id}: {content}")
# Include message_id in hash to ensure uniqueness across duplicate content
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
with make_session() as session:
discord_message = DiscordMessage(
modality="text",
sha256=content_hash,
content=content,
channel_id=channel_id,
sent_at=sent_at_dt,
server_id=server_id,
discord_user_id=author_id,
message_id=message_id,
message_type="reply" if message_reference_id else "default",
reply_to_message_id=message_reference_id,
)
existing_msg = check_content_exists(
session, DiscordMessage, message_id=message_id, sha256=content_hash
)
if existing_msg:
logger.info(f"Discord message already exists: {existing_msg.message_id}")
return create_task_result(
existing_msg, "already_exists", message_id=message_id
)
if channel_id:
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
result = process_content_item(discord_message, session)
logger.info(
f"Discord message ID after process_content_item: {discord_message.id}"
)
logger.info(f"Process result: {result}")
return result
@app.task(name=EDIT_DISCORD_MESSAGE)
@safe_task_execution
def edit_discord_message(
message_id: int, content: str, edited_at: str
) -> dict[str, Any]:
"""
Edit a Discord message in the database.
This task is queued by the Discord collector when messages are edited.
"""
logger.info(f"Editing Discord message {message_id}: {content}")
with make_session() as session:
existing_msg = check_content_exists(
session, DiscordMessage, message_id=message_id
)
if not existing_msg:
return {
"status": "error",
"error": "Message not found",
"message_id": message_id,
}
existing_msg.content = content # type: ignore
if existing_msg.channel_id:
existing_msg.messages_before = get_prev(
session, existing_msg.channel_id, existing_msg.sent_at
)
existing_msg.edited_at = datetime.fromisoformat(
edited_at.replace("Z", "+00:00")
)
return process_content_item(existing_msg, session)