diff --git a/db/migrations/versions/20251012_101257_rename_prompt_to_message_in_scheduled_.py b/db/migrations/versions/20251012_101257_rename_prompt_to_message_in_scheduled_.py index a11ae23..f83626f 100644 --- a/db/migrations/versions/20251012_101257_rename_prompt_to_message_in_scheduled_.py +++ b/db/migrations/versions/20251012_101257_rename_prompt_to_message_in_scheduled_.py @@ -9,7 +9,6 @@ Create Date: 2025-10-12 10:12:57.421009 from typing import Sequence, Union from alembic import op -import sqlalchemy as sa # revision identifiers, used by Alembic. @@ -20,10 +19,8 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: - # Rename prompt column to message in scheduled_llm_calls table op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message") def downgrade() -> None: - # Rename message column back to prompt in scheduled_llm_calls table op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt") diff --git a/db/migrations/versions/20251012_222827_add_discord_models.py b/db/migrations/versions/20251012_222827_add_discord_models.py new file mode 100644 index 0000000..91d661b --- /dev/null +++ b/db/migrations/versions/20251012_222827_add_discord_models.py @@ -0,0 +1,165 @@ +"""add_discord_models + +Revision ID: a8c8e8b17179 +Revises: c86079073c1d +Create Date: 2025-10-12 22:28:27.856164 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "a8c8e8b17179" +down_revision: Union[str, None] = "c86079073c1d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "discord_servers", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("member_count", sa.Integer(), nullable=True), + sa.Column( + "track_messages", sa.Boolean(), server_default="true", nullable=False + ), + sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "discord_servers_active_idx", + "discord_servers", + ["track_messages", "last_sync_at"], + unique=False, + ) + op.create_table( + "discord_channels", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("server_id", sa.BigInteger(), nullable=True), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("channel_type", sa.Text(), nullable=False), + sa.Column("track_messages", sa.Boolean(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["server_id"], + ["discord_servers.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "discord_channels_server_idx", "discord_channels", ["server_id"], unique=False + ) + op.create_table( + "discord_users", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("username", sa.Text(), nullable=False), + sa.Column("display_name", sa.Text(), nullable=True), + sa.Column("system_user_id", sa.Integer(), nullable=True), + sa.Column( + "allow_dm_tracking", sa.Boolean(), server_default="true", nullable=False + ), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["system_user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "discord_users_system_user_idx", + "discord_users", + ["system_user_id"], + unique=False, + ) + op.create_table( + "discord_message", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("sent_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("server_id", sa.BigInteger(), nullable=True), + sa.Column("channel_id", sa.BigInteger(), nullable=False), + sa.Column("discord_user_id", sa.BigInteger(), nullable=False), + sa.Column("message_id", sa.BigInteger(), nullable=False), + sa.Column("message_type", sa.Text(), server_default="default", nullable=True), + sa.Column("reply_to_message_id", sa.BigInteger(), nullable=True), + sa.Column("thread_id", sa.BigInteger(), nullable=True), + sa.Column("edited_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["channel_id"], + ["discord_channels.id"], + ), + sa.ForeignKeyConstraint( + ["discord_user_id"], + ["discord_users.id"], + ), + sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["server_id"], + ["discord_servers.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "discord_message_discord_id_idx", "discord_message", ["message_id"], unique=True + ) + op.create_index( + "discord_message_server_channel_idx", + "discord_message", + ["server_id", "channel_id"], + unique=False, + ) + op.create_index( + "discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False + ) + + +def downgrade() -> None: + op.drop_index("discord_message_user_idx", table_name="discord_message") + op.drop_index("discord_message_server_channel_idx", table_name="discord_message") + op.drop_index("discord_message_discord_id_idx", table_name="discord_message") + op.drop_table("discord_message") + op.drop_index("discord_users_system_user_idx", table_name="discord_users") + op.drop_table("discord_users") + op.drop_index("discord_channels_server_idx", table_name="discord_channels") + op.drop_table("discord_channels") + op.drop_index("discord_servers_active_idx", table_name="discord_servers") + op.drop_table("discord_servers") diff --git a/docker-compose.yaml b/docker-compose.yaml index 2817b22..18af07e 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -174,7 +174,7 @@ services: <<: *worker-base environment: <<: *worker-env - QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes,scheduler" + QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler" ingest-hub: <<: *worker-base diff --git a/docker/workers/Dockerfile b/docker/workers/Dockerfile index 1ff8ad5..c616fe6 100644 --- a/docker/workers/Dockerfile +++ b/docker/workers/Dockerfile @@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \ git config --global user.name "${GIT_USER_NAME}" # Default queues to process -ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance" +ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance" ENV PYTHONPATH="/app" ENTRYPOINT ["./entry.sh"] \ No newline at end of file diff --git a/requirements/requirements-api.txt b/requirements/requirements-api.txt index 9b01a9b..e96bea8 100644 --- a/requirements/requirements-api.txt +++ b/requirements/requirements-api.txt @@ -4,4 +4,5 @@ python-jose==3.3.0 python-multipart==0.0.9 sqladmin==0.20.1 mcp==1.10.0 -bm25s[full]==0.2.13 \ No newline at end of file +bm25s[full]==0.2.13 +discord.py==2.3.2 \ No newline at end of file diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index 12516f7..75a3861 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -12,6 +12,9 @@ MAINTENANCE_ROOT = "memory.workers.tasks.maintenance" NOTES_ROOT = "memory.workers.tasks.notes" OBSERVATIONS_ROOT = "memory.workers.tasks.observations" SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls" +DISCORD_ROOT = "memory.workers.tasks.discord" +ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message" +EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message" SYNC_NOTES = f"{NOTES_ROOT}.sync_notes" SYNC_NOTE = f"{NOTES_ROOT}.sync_note" @@ -72,17 +75,18 @@ app.conf.update( task_reject_on_worker_lost=True, worker_prefetch_multiplier=1, task_routes={ - f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"}, - f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"}, - f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"}, f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"}, f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"}, + f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"}, + f"{DISCORD_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"}, + f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"}, f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"}, f"{MAINTENANCE_ROOT}.*": { "queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance" }, f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"}, f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"}, + f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"}, f"{SCHEDULED_CALLS_ROOT}.*": { "queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler" }, diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py index 9afc2b2..7b1f778 100644 --- a/src/memory/common/db/models/__init__.py +++ b/src/memory/common/db/models/__init__.py @@ -11,6 +11,7 @@ from memory.common.db.models.source_items import ( EmailAttachment, AgentObservation, ChatMessage, + DiscordMessage, BlogPost, Comic, BookSection, @@ -40,6 +41,9 @@ from memory.common.db.models.sources import ( Book, ArticleFeed, EmailAccount, + DiscordServer, + DiscordChannel, + DiscordUser, ) from memory.common.db.models.users import ( User, @@ -74,6 +78,7 @@ __all__ = [ "EmailAttachment", "AgentObservation", "ChatMessage", + "DiscordMessage", "BlogPost", "Comic", "BookSection", @@ -93,6 +98,9 @@ __all__ = [ "Book", "ArticleFeed", "EmailAccount", + "DiscordServer", + "DiscordChannel", + "DiscordUser", # Users "User", "UserSession", diff --git a/src/memory/common/db/models/scheduled_calls.py b/src/memory/common/db/models/scheduled_calls.py index 097fdd9..bbd23d1 100644 --- a/src/memory/common/db/models/scheduled_calls.py +++ b/src/memory/common/db/models/scheduled_calls.py @@ -70,7 +70,7 @@ class ScheduledLLMCall(Base): "created_at": print_datetime(cast(datetime, self.created_at)), "executed_at": print_datetime(cast(datetime, self.executed_at)), "model": self.model, - "prompt": self.message, + "message": self.message, "system_prompt": self.system_prompt, "allowed_tools": self.allowed_tools, "discord_channel": self.discord_channel, diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index 98ee028..0b4ca96 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -262,7 +262,7 @@ class ChatMessage(SourceItem): BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True ) platform = Column(Text) - channel_id = Column(Text) + channel_id = Column(Text) # Keep as Text for cross-platform compatibility author = Column(Text) sent_at = Column(DateTime(timezone=True)) @@ -274,6 +274,64 @@ class ChatMessage(SourceItem): __table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),) +class DiscordMessage(SourceItem): + """Discord-specific chat message with rich metadata""" + + __tablename__ = "discord_message" + + id = Column( + BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True + ) + + sent_at = Column(DateTime(timezone=True), nullable=False) + server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True) + channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False) + discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False) + message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID + + # Discord-specific metadata + message_type = Column( + Text, server_default="default" + ) # "default", "reply", "thread_starter" + reply_to_message_id = Column( + BigInteger, nullable=True + ) # Discord message snowflake ID if replying + thread_id = Column( + BigInteger, nullable=True + ) # Discord thread snowflake ID if in thread + edited_at = Column(DateTime(timezone=True), nullable=True) + + channel = relationship("DiscordChannel", foreign_keys=[channel_id]) + server = relationship("DiscordServer", foreign_keys=[server_id]) + discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id]) + + @property + def title(self) -> str: + return f"{self.discord_user.username}: {self.content}" + + __mapper_args__ = { + "polymorphic_identity": "discord_message", + } + + __table_args__ = ( + Index("discord_message_discord_id_idx", "message_id", unique=True), + Index( + "discord_message_server_channel_idx", + "server_id", + "channel_id", + ), + Index("discord_message_user_idx", "discord_user_id"), + ) + + def _chunk_contents(self) -> Sequence[extract.DataChunk]: + content = cast(str | None, self.content) + if not content: + return [] + prev = getattr(self, "messages_before", []) + content = "\n\n".join(prev) + "\n\n" + self.title + return extract.extract_text(content) + + class GitCommit(SourceItem): __tablename__ = "git_commit" diff --git a/src/memory/common/db/models/sources.py b/src/memory/common/db/models/sources.py index 40171d2..9d018e5 100644 --- a/src/memory/common/db/models/sources.py +++ b/src/memory/common/db/models/sources.py @@ -10,12 +10,14 @@ from sqlalchemy import ( Boolean, Column, DateTime, + ForeignKey, Index, Integer, Text, func, ) from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import relationship from memory.common.db.models.base import Base @@ -123,3 +125,69 @@ class EmailAccount(Base): Index("email_accounts_active_idx", "active", "last_sync_at"), Index("email_accounts_tags_idx", "tags", postgresql_using="gin"), ) + + +class DiscordServer(Base): + """Discord server configuration and metadata""" + + __tablename__ = "discord_servers" + + id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID + name = Column(Text, nullable=False) + description = Column(Text) + member_count = Column(Integer) + + # Collection settings + track_messages = Column(Boolean, nullable=False, server_default="true") + last_sync_at = Column(DateTime(timezone=True)) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + channels = relationship( + "DiscordChannel", back_populates="server", cascade="all, delete-orphan" + ) + + __table_args__ = ( + Index("discord_servers_active_idx", "track_messages", "last_sync_at"), + ) + + +class DiscordChannel(Base): + """Discord channel metadata and configuration""" + + __tablename__ = "discord_channels" + + id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID + server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True) + name = Column(Text, nullable=False) + channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm" + + # Collection settings (null = inherit from server) + track_messages = Column(Boolean, nullable=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + server = relationship("DiscordServer", back_populates="channels") + __table_args__ = (Index("discord_channels_server_idx", "server_id"),) + + +class DiscordUser(Base): + """Discord user metadata and preferences""" + + __tablename__ = "discord_users" + + id = Column(BigInteger, primary_key=True) # Discord user snowflake ID + username = Column(Text, nullable=False) + display_name = Column(Text) + + # Link to system user if registered + system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True) + + # Basic DM settings + allow_dm_tracking = Column(Boolean, nullable=False, server_default="true") + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + system_user = relationship("User", back_populates="discord_users") + + __table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),) diff --git a/src/memory/common/db/models/users.py b/src/memory/common/db/models/users.py index e83d0c9..8228cee 100644 --- a/src/memory/common/db/models/users.py +++ b/src/memory/common/db/models/users.py @@ -50,6 +50,7 @@ class User(Base): oauth_states = relationship( "OAuthState", back_populates="user", cascade="all, delete-orphan" ) + discord_users = relationship("DiscordUser", back_populates="system_user") def serialize(self) -> dict: return { diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py index e041c73..92482bc 100644 --- a/src/memory/common/discord.py +++ b/src/memory/common/discord.py @@ -1,221 +1,101 @@ +""" +Discord integration. + +Simple HTTP client that communicates with the Discord collector's API server. +""" + import logging import requests -import re from typing import Any from memory.common import settings logger = logging.getLogger(__name__) -ERROR_CHANNEL = "memory-errors" -ACTIVITY_CHANNEL = "memory-activity" -DISCOVERY_CHANNEL = "memory-discoveries" -CHAT_CHANNEL = "memory-chat" + +def get_api_url() -> str: + """Get the Discord API server URL""" + host = settings.DISCORD_COLLECTOR_SERVER_URL + port = settings.DISCORD_COLLECTOR_PORT + return f"http://{host}:{port}" -class DiscordServer(requests.Session): - def __init__(self, server_id: str, server_name: str, *args, **kwargs): - self.server_id = server_id - self.server_name = server_name - self.channels = {} - super().__init__(*args, **kwargs) - self.setup_channels() - self.members = self.fetch_all_members() - - def setup_channels(self): - resp = self.get(self.channels_url) - resp.raise_for_status() - channels = {channel["name"]: channel["id"] for channel in resp.json()} - - if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)): - error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL) - self.channels[ERROR_CHANNEL] = error_channel - - if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)): - activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL) - self.channels[ACTIVITY_CHANNEL] = activity_channel - - if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)): - discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL) - self.channels[DISCOVERY_CHANNEL] = discovery_channel - - if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)): - chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL) - self.channels[CHAT_CHANNEL] = chat_channel - - @property - def error_channel(self) -> str: - return self.channels[ERROR_CHANNEL] - - @property - def activity_channel(self) -> str: - return self.channels[ACTIVITY_CHANNEL] - - @property - def discovery_channel(self) -> str: - return self.channels[DISCOVERY_CHANNEL] - - @property - def chat_channel(self) -> str: - return self.channels[CHAT_CHANNEL] - - def channel_id(self, channel_name: str) -> str: - if not (channel_id := self.channels.get(channel_name)): - raise ValueError(f"Channel {channel_name} not found") - return channel_id - - def send_message(self, channel_id: str, content: str): - payload: dict[str, Any] = {"content": content} - mentions = re.findall(r"@(\S*)", content) - users = {u: i for u, i in self.members.items() if u in mentions} - if users: - for u, i in users.items(): - payload["content"] = payload["content"].replace(f"@{u}", f"<@{i}>") - payload["allowed_mentions"] = { - "parse": [], - "users": list(users.values()), - } - - return self.post( - f"https://discord.com/api/v10/channels/{channel_id}/messages", - json=payload, - ) - - def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None: - resp = self.post( - self.channels_url, json={"name": channel_name, "type": channel_type} - ) - resp.raise_for_status() - return resp.json()["id"] - - def __str__(self): - return ( - f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})" - ) - - def request(self, method: str, url: str, **kwargs): - headers = kwargs.get("headers", {}) - headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}" - headers["Content-Type"] = "application/json" - kwargs["headers"] = headers - return super().request(method, url, **kwargs) - - @property - def channels_url(self) -> str: - return f"https://discord.com/api/v10/guilds/{self.server_id}/channels" - - @property - def members_url(self) -> str: - return f"https://discord.com/api/v10/guilds/{self.server_id}/members" - - @property - def dm_create_url(self) -> str: - return "https://discord.com/api/v10/users/@me/channels" - - def list_members( - self, limit: int = 1000, after: str | None = None - ) -> list[dict[str, Any]]: - """List up to `limit` members in this guild, starting after a user ID. - - Requires the bot to have the Server Members Intent enabled in the Discord developer portal. - """ - params: dict[str, Any] = {"limit": limit} - if after: - params["after"] = after - resp = self.get(self.members_url, params=params) - resp.raise_for_status() - return resp.json() - - def fetch_all_members(self, page_size: int = 1000) -> dict[str, str]: - """Retrieve all members in the guild by paginating the members list. - - Note: Large guilds may take multiple requests. Rate limits are respected by requests.Session automatically. - """ - members: dict[str, str] = {} - after: str | None = None - while batch := self.list_members(limit=page_size, after=after): - for member in batch: - user = member.get("user", {}) - members[user.get("global_name") or user.get("username", "")] = user.get( - "id", "" - ) - after = user.get("id", "") - return members - - def create_dm_channel(self, user_id: str) -> str: - """Create (or retrieve) a DM channel with the given user and return the channel ID. - - The bot must share a guild with the user, and the user's privacy settings must allow DMs from server members. - """ - resp = self.post(self.dm_create_url, json={"recipient_id": user_id}) - resp.raise_for_status() - data = resp.json() - return data["id"] - - def send_dm(self, user_id: str, content: str): - """Send a direct message to a specific user by ID.""" - channel_id = self.create_dm_channel(self.members.get(user_id) or user_id) - return self.post( - f"https://discord.com/api/v10/channels/{channel_id}/messages", - json={"content": content}, - ) - - -def get_bot_servers() -> list[dict[str, Any]]: - """Get list of servers the bot is in.""" - if not settings.DISCORD_BOT_TOKEN: - return [] - +def send_dm(user_identifier: str, message: str) -> bool: + """Send a DM via the Discord collector API""" try: - headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"} - response = requests.get( - "https://discord.com/api/v10/users/@me/guilds", headers=headers + response = requests.post( + f"{get_api_url()}/send_dm", + json={"user_identifier": user_identifier, "message": message}, + timeout=10, ) response.raise_for_status() + result = response.json() + return result.get("success", False) + + except requests.RequestException as e: + logger.error(f"Failed to send DM to {user_identifier}: {e}") + return False + + +def broadcast_message(channel_name: str, message: str) -> bool: + """Send a message to a channel via the Discord collector API""" + try: + response = requests.post( + f"{get_api_url()}/send_channel", + json={"channel_name": channel_name, "message": message}, + timeout=10, + ) + response.raise_for_status() + result = response.json() + return result.get("success", False) + + except requests.RequestException as e: + logger.error(f"Failed to send message to channel {channel_name}: {e}") + return False + + +def is_collector_healthy() -> bool: + """Check if the Discord collector is running and healthy""" + try: + response = requests.get(f"{get_api_url()}/health", timeout=5) + response.raise_for_status() + result = response.json() + return result.get("status") == "healthy" + + except requests.RequestException: + return False + + +def refresh_discord_metadata() -> dict[str, int] | None: + """Refresh Discord server/channel/user metadata from Discord API""" + try: + response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30) + response.raise_for_status() return response.json() - except Exception as e: - logger.error(f"Failed to get bot servers: {e}") - return [] + except requests.RequestException as e: + logger.error(f"Failed to refresh Discord metadata: {e}") + return None -servers: dict[str, DiscordServer] = {} +# Convenience functions +def send_error_message(message: str) -> bool: + """Send an error message to the error channel""" + return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message) -def load_servers(): - for server in get_bot_servers(): - servers[server["id"]] = DiscordServer(server["id"], server["name"]) +def send_activity_message(message: str) -> bool: + """Send an activity message to the activity channel""" + return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message) -def broadcast_message(channel: str, message: str): - if not settings.DISCORD_NOTIFICATIONS_ENABLED: - return - - for server in servers.values(): - server.send_message(server.channel_id(channel), message) +def send_discovery_message(message: str) -> bool: + """Send a discovery message to the discovery channel""" + return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message) -def send_error_message(message: str): - broadcast_message(ERROR_CHANNEL, message) - - -def send_activity_message(message: str): - broadcast_message(ACTIVITY_CHANNEL, message) - - -def send_discovery_message(message: str): - broadcast_message(DISCOVERY_CHANNEL, message) - - -def send_chat_message(message: str): - broadcast_message(CHAT_CHANNEL, message) - - -def send_dm(user_id: str, message: str): - for server in servers.values(): - if not server.members.get(user_id) and user_id not in server.members.values(): - continue - - server.send_dm(user_id, message) +def send_chat_message(message: str) -> bool: + """Send a chat message to the chat channel""" + return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message) def notify_task_failure( @@ -234,9 +114,6 @@ def notify_task_failure( task_args: Task arguments task_kwargs: Task keyword arguments traceback_str: Full traceback string - - Returns: - True if notification sent successfully """ if not settings.DISCORD_NOTIFICATIONS_ENABLED: logger.debug("Discord notifications disabled") diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 7fcf47c..86d2490 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -172,3 +172,11 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat") DISCORD_NOTIFICATIONS_ENABLED = bool( boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN ) + +# Discord collector settings +DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True) +DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True) +DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True) +DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8001)) +DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1") +DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10)) diff --git a/src/memory/workers/discord/api.py b/src/memory/workers/discord/api.py new file mode 100644 index 0000000..346ebb7 --- /dev/null +++ b/src/memory/workers/discord/api.py @@ -0,0 +1,166 @@ +""" +Discord API server. + +FastAPI server that owns and manages a Discord collector instance, +providing HTTP endpoints for sending Discord messages. +""" + +import asyncio +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import uvicorn + +from memory.common import settings +from memory.workers.discord.collector import MessageCollector + +logger = logging.getLogger(__name__) + + +class SendDMRequest(BaseModel): + user: str # Discord user ID or username + message: str + + +class SendChannelRequest(BaseModel): + channel_name: str # Channel name (e.g., "memory-errors") + message: str + + +# Application state +class AppState: + def __init__(self): + self.collector: MessageCollector | None = None + self.collector_task: asyncio.Task | None = None + + +app_state = AppState() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage Discord collector lifecycle""" + if not settings.DISCORD_BOT_TOKEN: + logger.error("DISCORD_BOT_TOKEN not configured") + return + + # Create and start the collector + app_state.collector = MessageCollector() + app_state.collector_task = asyncio.create_task( + app_state.collector.start(settings.DISCORD_BOT_TOKEN) + ) + logger.info("Discord collector started") + + yield + + # Cleanup + if app_state.collector and not app_state.collector.is_closed(): + await app_state.collector.close() + + if app_state.collector_task: + app_state.collector_task.cancel() + try: + await app_state.collector_task + except asyncio.CancelledError: + pass + + logger.info("Discord collector stopped") + + +# FastAPI app with lifespan management +app = FastAPI(title="Discord Collector API", version="1.0.0", lifespan=lifespan) + + +@app.post("/send_dm") +async def send_dm_endpoint(request: SendDMRequest): + """Send a DM via the collector's Discord client""" + if not app_state.collector: + raise HTTPException(status_code=503, detail="Discord collector not running") + + try: + success = await app_state.collector.send_dm(request.user, request.message) + + if not success: + raise HTTPException( + status_code=400, + detail=f"Failed to send DM to {request.user}", + ) + return { + "success": True, + "message": f"DM sent to {request.user}", + "user": request.user, + } + except Exception as e: + logger.error(f"Failed to send DM: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/send_channel") +async def send_channel_endpoint(request: SendChannelRequest): + """Send a message to a channel via the collector's Discord client""" + if not app_state.collector: + raise HTTPException(status_code=503, detail="Discord collector not running") + + try: + success = await app_state.collector.send_to_channel( + request.channel_name, request.message + ) + + if success: + return { + "success": True, + "message": f"Message sent to channel {request.channel_name}", + "channel": request.channel_name, + } + else: + raise HTTPException( + status_code=400, + detail=f"Failed to send message to channel {request.channel_name}", + ) + + except Exception as e: + logger.error(f"Failed to send channel message: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health") +async def health_check(): + """Check if the Discord collector is running and healthy""" + if not app_state.collector: + raise HTTPException(status_code=503, detail="Discord collector not running") + + collector = app_state.collector + return { + "status": "healthy", + "connected": not collector.is_closed(), + "user": str(collector.user) if collector.user else None, + "guilds": len(collector.guilds) if collector.guilds else 0, + } + + +@app.post("/refresh_metadata") +async def refresh_metadata(): + """Refresh Discord server/channel/user metadata from Discord API""" + if not app_state.collector: + raise HTTPException(status_code=503, detail="Discord collector not running") + + try: + result = await app_state.collector.refresh_metadata() + return {"success": True, "message": "Metadata refreshed successfully", **result} + except Exception as e: + logger.error(f"Failed to refresh metadata: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001): + """Run the Discord API server""" + uvicorn.run(app, host=host, port=port, log_level="debug") + + +if __name__ == "__main__": + # For testing the API server standalone + host = settings.DISCORD_COLLECTOR_SERVER_URL + port = settings.DISCORD_COLLECTOR_PORT + run_discord_api_server(host, port) diff --git a/src/memory/workers/discord/collector.py b/src/memory/workers/discord/collector.py new file mode 100644 index 0000000..bd92425 --- /dev/null +++ b/src/memory/workers/discord/collector.py @@ -0,0 +1,398 @@ +""" +Discord message collector. + +Core message collection functionality - stores Discord messages to database. +""" + +import logging +from datetime import datetime, timezone + +import discord +from discord.ext import commands +from sqlalchemy.orm import Session, scoped_session + +from memory.common import settings +from memory.common.db.connection import make_session +from memory.common.db.models.sources import ( + DiscordServer, + DiscordChannel, + DiscordUser, +) +from memory.workers.tasks.discord import add_discord_message, edit_discord_message + +logger = logging.getLogger(__name__) + + +# Pure functions for Discord entity creation/updates +def create_or_update_server( + session: Session | scoped_session, guild: discord.Guild +) -> DiscordServer: + """Get or create DiscordServer record (pure DB operation)""" + server = session.query(DiscordServer).get(guild.id) + + if not server: + server = DiscordServer( + id=guild.id, + name=guild.name, + description=guild.description, + member_count=guild.member_count, + ) + session.add(server) + session.flush() # Get the ID + logger.info(f"Created server record for {guild.name} ({guild.id})") + else: + # Update metadata + server.name = guild.name + server.description = guild.description + server.member_count = guild.member_count + server.last_sync_at = datetime.now(timezone.utc) + + return server + + +def determine_channel_metadata(channel) -> tuple[str, int | None, str]: + """Pure function to determine channel type, server_id, and name""" + if isinstance(channel, discord.DMChannel): + return "dm", None, f"DM with {channel.recipient.name}" + elif isinstance(channel, discord.GroupChannel): + return "group_dm", None, channel.name or "Group DM" + elif isinstance( + channel, (discord.TextChannel, discord.VoiceChannel, discord.Thread) + ): + return ( + channel.__class__.__name__.lower().replace("channel", ""), + channel.guild.id, + channel.name, + ) + else: + guild = getattr(channel, "guild", None) + server_id = guild.id if guild else None + name = getattr(channel, "name", f"Unknown-{channel.id}") + return "unknown", server_id, name + + +def create_or_update_channel( + session: Session | scoped_session, channel +) -> DiscordChannel: + """Get or create DiscordChannel record (pure DB operation)""" + discord_channel = session.query(DiscordChannel).get(channel.id) + + if not discord_channel: + channel_type, server_id, name = determine_channel_metadata(channel) + discord_channel = DiscordChannel( + id=channel.id, + server_id=server_id, + name=name, + channel_type=channel_type, + ) + session.add(discord_channel) + session.flush() + logger.debug(f"Created channel: {name}") + elif hasattr(channel, "name"): + discord_channel.name = channel.name + + return discord_channel + + +def create_or_update_user( + session: Session | scoped_session, user: discord.User | discord.Member +) -> DiscordUser: + """Get or create DiscordUser record (pure DB operation)""" + discord_user = session.query(DiscordUser).get(user.id) + + if not discord_user: + discord_user = DiscordUser( + id=user.id, + username=user.name, + display_name=user.display_name, + ) + session.add(discord_user) + session.flush() + logger.debug(f"Created user: {user.name}") + else: + # Update user info in case it changed + discord_user.username = user.name + discord_user.display_name = user.display_name + + return discord_user + + +def determine_message_metadata( + message: discord.Message, +) -> tuple[str, int | None, int | None]: + """Pure function to determine message type, reply_to_id, and thread_id""" + message_type = "default" + reply_to_id = None + thread_id = None + + if message.reference and message.reference.message_id: + message_type = "reply" + reply_to_id = message.reference.message_id + + if hasattr(message.channel, "parent") and message.channel.parent: + thread_id = message.channel.id + + return message_type, reply_to_id, thread_id + + +def should_track_message( + server: DiscordServer | None, + channel: DiscordChannel, + user: DiscordUser, +) -> bool: + """Pure function to determine if we should track this message""" + if server and not server.track_messages: # type: ignore + return False + + if not channel.track_messages: + return False + + if channel.channel_type in ("dm", "group_dm"): + return bool(user.allow_dm_tracking) + + # Default: track the message + return True + + +def should_collect_bot_message(message: discord.Message) -> bool: + """Pure function to determine if we should collect bot messages""" + return not message.author.bot or settings.DISCORD_COLLECT_BOTS + + +def sync_guild_metadata(guild: discord.Guild) -> None: + """Sync a single guild's metadata (functional approach)""" + with make_session() as session: + create_or_update_server(session, guild) + + for channel in guild.channels: + if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)): + create_or_update_channel(session, channel) + + session.commit() + + +class MessageCollector(commands.Bot): + """Discord bot that collects and stores messages (thin event handler)""" + + def __init__(self): + intents = discord.Intents.default() + intents.message_content = True + intents.guilds = True + intents.members = True + intents.dm_messages = True + + super().__init__( + command_prefix="!memory_", # Prefix to avoid conflicts + intents=intents, + help_command=None, # Disable default help + ) + + async def on_ready(self): + """Called when bot connects to Discord""" + logger.info(f"Discord collector connected as {self.user}") + logger.info(f"Connected to {len(self.guilds)} servers") + + # Sync server and channel metadata + await self.sync_servers_and_channels() + + logger.info("Discord message collector ready") + + async def on_message(self, message: discord.Message): + """Queue incoming message for database storage""" + try: + if should_collect_bot_message(message): + # Ensure Discord entities exist in database first + with make_session() as session: + create_or_update_user(session, message.author) + create_or_update_channel(session, message.channel) + if message.guild: + create_or_update_server(session, message.guild) + + session.commit() + + # Queue the message for processing + add_discord_message.delay( + message_id=message.id, + channel_id=message.channel.id, + author_id=message.author.id, + server_id=message.guild.id if message.guild else None, + content=message.content or "", + sent_at=message.created_at.isoformat(), + message_reference_id=message.reference.message_id + if message.reference + else None, + ) + except Exception as e: + logger.error(f"Error queuing message {message.id}: {e}") + + async def on_message_edit(self, before: discord.Message, after: discord.Message): + """Queue message edit for database update""" + try: + edit_time = after.edited_at or datetime.now(timezone.utc) + edit_discord_message.delay( + message_id=after.id, + content=after.content, + edited_at=edit_time.isoformat(), + ) + except Exception as e: + logger.error(f"Error queuing message edit {after.id}: {e}") + + async def sync_servers_and_channels(self): + """Sync server and channel metadata on startup""" + for guild in self.guilds: + sync_guild_metadata(guild) + + logger.info(f"Synced {len(self.guilds)} servers and their channels") + + async def refresh_metadata(self) -> dict[str, int]: + """Refresh server and channel metadata from Discord and update database""" + print("🔄 Refreshing Discord metadata...") + + servers_updated = 0 + channels_updated = 0 + users_updated = 0 + + with make_session() as session: + # Refresh all servers + for guild in self.guilds: + create_or_update_server(session, guild) + servers_updated += 1 + + # Refresh all channels in this server + for channel in guild.channels: + if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)): + create_or_update_channel(session, channel) + channels_updated += 1 + + # Refresh all members in this server (if members intent is enabled) + if self.intents.members: + for member in guild.members: + create_or_update_user(session, member) + users_updated += 1 + + session.commit() + + result = { + "servers_updated": servers_updated, + "channels_updated": channels_updated, + "users_updated": users_updated, + } + + print(f"✅ Metadata refresh complete: {result}") + logger.info(f"Metadata refresh complete: {result}") + + return result + + async def get_user(self, user_identifier: int | str) -> discord.User | None: + """Get a Discord user by ID or username""" + if isinstance(user_identifier, int): + # Direct user ID lookup + if user := super().get_user(user_identifier): + return user + try: + return await self.fetch_user(user_identifier) + except discord.NotFound: + return None + else: + # Username lookup - search through all guilds + for guild in self.guilds: + for member in guild.members: + if ( + member.name == user_identifier + or member.display_name == user_identifier + or f"{member.name}#{member.discriminator}" == user_identifier + ): + return member + return None + + async def get_channel_by_name( + self, channel_name: str + ) -> discord.TextChannel | None: + """Get a Discord channel by name (does not create if missing)""" + # Search all guilds for the channel + for guild in self.guilds: + for ch in guild.channels: + if isinstance(ch, discord.TextChannel) and ch.name == channel_name: + return ch + return None + + async def create_channel( + self, channel_name: str, guild_id: int | None = None + ) -> discord.TextChannel | None: + """Create a Discord channel in the specified guild (or first guild if none specified)""" + target_guild = None + + if guild_id: + target_guild = self.get_guild(guild_id) + elif self.guilds: + target_guild = self.guilds[0] # Default to first guild + + if not target_guild: + logger.error(f"No guild available to create channel {channel_name}") + return None + + try: + channel = await target_guild.create_text_channel(channel_name) + logger.info(f"Created channel {channel_name} in {target_guild.name}") + return channel + except Exception as e: + logger.error( + f"Failed to create channel {channel_name} in {target_guild.name}: {e}" + ) + return None + + async def send_dm(self, user_identifier: int | str, message: str) -> bool: + """Send a DM using this collector's Discord client""" + try: + user = await self.get_user(user_identifier) + if not user: + logger.error(f"User {user_identifier} not found") + return False + + await user.send(message) + logger.info(f"Sent DM to {user_identifier}") + return True + + except Exception as e: + logger.error(f"Failed to send DM to {user_identifier}: {e}") + return False + + async def send_to_channel(self, channel_name: str, message: str) -> bool: + """Send a message to a channel by name across all guilds""" + if not settings.DISCORD_NOTIFICATIONS_ENABLED: + return False + + try: + channel = await self.get_channel_by_name(channel_name) + if not channel: + logger.error(f"Channel {channel_name} not found") + return False + + await channel.send(message) + logger.info(f"Sent message to channel {channel_name}") + return True + + except Exception as e: + logger.error(f"Failed to send message to channel {channel_name}: {e}") + return False + + +async def run_collector(): + """Run the Discord message collector""" + if not settings.DISCORD_BOT_TOKEN: + logger.error("DISCORD_BOT_TOKEN not configured") + return + + collector = MessageCollector() + + try: + await collector.start(settings.DISCORD_BOT_TOKEN) + except Exception as e: + logger.error(f"Discord collector failed: {e}") + raise + + +if __name__ == "__main__": + import asyncio + + asyncio.run(run_collector()) diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py index 40c88e6..80cc6e9 100644 --- a/src/memory/workers/tasks/__init__.py +++ b/src/memory/workers/tasks/__init__.py @@ -6,6 +6,7 @@ from memory.workers.tasks import ( email, comic, blogs, + discord, ebook, forums, maintenance, @@ -20,6 +21,7 @@ __all__ = [ "comic", "blogs", "ebook", + "discord", "forums", "maintenance", "notes", diff --git a/src/memory/workers/tasks/content_processing.py b/src/memory/workers/tasks/content_processing.py index e781983..34dbb41 100644 --- a/src/memory/workers/tasks/content_processing.py +++ b/src/memory/workers/tasks/content_processing.py @@ -10,9 +10,9 @@ from collections import defaultdict import hashlib import traceback import logging -from typing import Any, Callable, Iterable, Sequence, cast +from typing import Any, Callable, Sequence, cast -from memory.common import embedding, qdrant, settings +from memory.common import embedding, qdrant from memory.common.db.models import SourceItem, Chunk from memory.common.discord import notify_task_failure @@ -38,19 +38,12 @@ def check_content_exists( Returns: Existing SourceItem if found, None otherwise """ + query = session.query(model_class) for key, value in kwargs.items(): - if not hasattr(model_class, key): - continue + if hasattr(model_class, key): + query = query.filter(getattr(model_class, key) == value) - existing = ( - session.query(model_class) - .filter(getattr(model_class, key) == value) - .first() - ) - if existing: - return existing - - return None + return query.first() def create_content_hash(content: str, *additional_data: str) -> bytes: @@ -286,6 +279,6 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]: traceback_str=traceback_str, ) - return {"status": "error", "error": str(e)} + return {"status": "error", "error": str(e), "traceback": traceback_str} return wrapper diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py new file mode 100644 index 0000000..764d4f2 --- /dev/null +++ b/src/memory/workers/tasks/discord.py @@ -0,0 +1,130 @@ +""" +Celery tasks for Discord message processing. +""" + +import hashlib +import logging +from datetime import datetime +from typing import Any + +from memory.common.celery_app import app +from memory.common.db.connection import make_session +from memory.common.db.models import DiscordMessage, DiscordUser +from memory.workers.tasks.content_processing import ( + safe_task_execution, + check_content_exists, + create_task_result, + process_content_item, +) +from memory.common.celery_app import ADD_DISCORD_MESSAGE, EDIT_DISCORD_MESSAGE +from memory.common import settings +from sqlalchemy.orm import Session, scoped_session + +logger = logging.getLogger(__name__) + + +def get_prev( + session: Session | scoped_session, channel_id: int, sent_at: datetime +) -> list[str]: + prev = ( + session.query(DiscordUser.username, DiscordMessage.content) + .join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id) + .filter( + DiscordMessage.channel_id == channel_id, + DiscordMessage.sent_at < sent_at, + ) + .order_by(DiscordMessage.sent_at.desc()) + .limit(settings.DISCORD_CONTEXT_WINDOW) + .all() + ) + return [f"{msg.username}: {msg.content}" for msg in prev[::-1]] + + +@app.task(name=ADD_DISCORD_MESSAGE) +@safe_task_execution +def add_discord_message( + message_id: int, + channel_id: int, + author_id: int, + content: str, + sent_at: str, + server_id: int | None = None, + message_reference_id: int | None = None, +) -> dict[str, Any]: + """ + Add a Discord message to the database. + + This task is queued by the Discord collector when messages are received. + """ + logger.info(f"Adding Discord message {message_id}: {content}") + # Include message_id in hash to ensure uniqueness across duplicate content + content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest() + sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00")) + + with make_session() as session: + discord_message = DiscordMessage( + modality="text", + sha256=content_hash, + content=content, + channel_id=channel_id, + sent_at=sent_at_dt, + server_id=server_id, + discord_user_id=author_id, + message_id=message_id, + message_type="reply" if message_reference_id else "default", + reply_to_message_id=message_reference_id, + ) + existing_msg = check_content_exists( + session, DiscordMessage, message_id=message_id, sha256=content_hash + ) + if existing_msg: + logger.info(f"Discord message already exists: {existing_msg.message_id}") + return create_task_result( + existing_msg, "already_exists", message_id=message_id + ) + + if channel_id: + discord_message.messages_before = get_prev(session, channel_id, sent_at_dt) + + result = process_content_item(discord_message, session) + + logger.info( + f"Discord message ID after process_content_item: {discord_message.id}" + ) + logger.info(f"Process result: {result}") + + return result + + +@app.task(name=EDIT_DISCORD_MESSAGE) +@safe_task_execution +def edit_discord_message( + message_id: int, content: str, edited_at: str +) -> dict[str, Any]: + """ + Edit a Discord message in the database. + + This task is queued by the Discord collector when messages are edited. + """ + logger.info(f"Editing Discord message {message_id}: {content}") + with make_session() as session: + existing_msg = check_content_exists( + session, DiscordMessage, message_id=message_id + ) + if not existing_msg: + return { + "status": "error", + "error": "Message not found", + "message_id": message_id, + } + + existing_msg.content = content # type: ignore + if existing_msg.channel_id: + existing_msg.messages_before = get_prev( + session, existing_msg.channel_id, existing_msg.sent_at + ) + existing_msg.edited_at = datetime.fromisoformat( + edited_at.replace("Z", "+00:00") + ) + + return process_content_item(existing_msg, session)