mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-23 15:16:35 +02:00
add Discord ingester
This commit is contained in:
parent
f454aa9afa
commit
e086b4a3a6
@ -9,7 +9,6 @@ Create Date: 2025-10-12 10:12:57.421009
|
|||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
@ -20,10 +19,8 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Rename prompt column to message in scheduled_llm_calls table
|
|
||||||
op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message")
|
op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message")
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# Rename message column back to prompt in scheduled_llm_calls table
|
|
||||||
op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt")
|
op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt")
|
||||||
|
165
db/migrations/versions/20251012_222827_add_discord_models.py
Normal file
165
db/migrations/versions/20251012_222827_add_discord_models.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
"""add_discord_models
|
||||||
|
|
||||||
|
Revision ID: a8c8e8b17179
|
||||||
|
Revises: c86079073c1d
|
||||||
|
Create Date: 2025-10-12 22:28:27.856164
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "a8c8e8b17179"
|
||||||
|
down_revision: Union[str, None] = "c86079073c1d"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"discord_servers",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column("member_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"track_messages", sa.Boolean(), server_default="true", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_servers_active_idx",
|
||||||
|
"discord_servers",
|
||||||
|
["track_messages", "last_sync_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"discord_channels",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("server_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
|
sa.Column("channel_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column("track_messages", sa.Boolean(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["server_id"],
|
||||||
|
["discord_servers.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_channels_server_idx", "discord_channels", ["server_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"discord_users",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("username", sa.Text(), nullable=False),
|
||||||
|
sa.Column("display_name", sa.Text(), nullable=True),
|
||||||
|
sa.Column("system_user_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"allow_dm_tracking", sa.Boolean(), server_default="true", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["system_user_id"],
|
||||||
|
["users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_users_system_user_idx",
|
||||||
|
"discord_users",
|
||||||
|
["system_user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"discord_message",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("server_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("channel_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("discord_user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("message_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("message_type", sa.Text(), server_default="default", nullable=True),
|
||||||
|
sa.Column("reply_to_message_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("thread_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("edited_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["channel_id"],
|
||||||
|
["discord_channels.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["discord_user_id"],
|
||||||
|
["discord_users.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["server_id"],
|
||||||
|
["discord_servers.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_message_discord_id_idx", "discord_message", ["message_id"], unique=True
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_message_server_channel_idx",
|
||||||
|
"discord_message",
|
||||||
|
["server_id", "channel_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("discord_message_user_idx", table_name="discord_message")
|
||||||
|
op.drop_index("discord_message_server_channel_idx", table_name="discord_message")
|
||||||
|
op.drop_index("discord_message_discord_id_idx", table_name="discord_message")
|
||||||
|
op.drop_table("discord_message")
|
||||||
|
op.drop_index("discord_users_system_user_idx", table_name="discord_users")
|
||||||
|
op.drop_table("discord_users")
|
||||||
|
op.drop_index("discord_channels_server_idx", table_name="discord_channels")
|
||||||
|
op.drop_table("discord_channels")
|
||||||
|
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
|
||||||
|
op.drop_table("discord_servers")
|
@ -174,7 +174,7 @@ services:
|
|||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes,scheduler"
|
QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
|
||||||
|
|
||||||
ingest-hub:
|
ingest-hub:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
|
@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
|
|||||||
git config --global user.name "${GIT_USER_NAME}"
|
git config --global user.name "${GIT_USER_NAME}"
|
||||||
|
|
||||||
# Default queues to process
|
# Default queues to process
|
||||||
ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance"
|
ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
|
||||||
ENV PYTHONPATH="/app"
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
ENTRYPOINT ["./entry.sh"]
|
ENTRYPOINT ["./entry.sh"]
|
@ -5,3 +5,4 @@ python-multipart==0.0.9
|
|||||||
sqladmin==0.20.1
|
sqladmin==0.20.1
|
||||||
mcp==1.10.0
|
mcp==1.10.0
|
||||||
bm25s[full]==0.2.13
|
bm25s[full]==0.2.13
|
||||||
|
discord.py==2.3.2
|
@ -12,6 +12,9 @@ MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
|
|||||||
NOTES_ROOT = "memory.workers.tasks.notes"
|
NOTES_ROOT = "memory.workers.tasks.notes"
|
||||||
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
||||||
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
||||||
|
DISCORD_ROOT = "memory.workers.tasks.discord"
|
||||||
|
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||||
|
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||||
|
|
||||||
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
||||||
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
||||||
@ -72,17 +75,18 @@ app.conf.update(
|
|||||||
task_reject_on_worker_lost=True,
|
task_reject_on_worker_lost=True,
|
||||||
worker_prefetch_multiplier=1,
|
worker_prefetch_multiplier=1,
|
||||||
task_routes={
|
task_routes={
|
||||||
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
|
|
||||||
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
|
|
||||||
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
|
|
||||||
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
||||||
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
||||||
|
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
|
||||||
|
f"{DISCORD_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
|
||||||
|
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
|
||||||
f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
|
f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
|
||||||
f"{MAINTENANCE_ROOT}.*": {
|
f"{MAINTENANCE_ROOT}.*": {
|
||||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
|
"queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
|
||||||
},
|
},
|
||||||
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||||
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||||
|
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
|
||||||
f"{SCHEDULED_CALLS_ROOT}.*": {
|
f"{SCHEDULED_CALLS_ROOT}.*": {
|
||||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
||||||
},
|
},
|
||||||
|
@ -11,6 +11,7 @@ from memory.common.db.models.source_items import (
|
|||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
AgentObservation,
|
AgentObservation,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
DiscordMessage,
|
||||||
BlogPost,
|
BlogPost,
|
||||||
Comic,
|
Comic,
|
||||||
BookSection,
|
BookSection,
|
||||||
@ -40,6 +41,9 @@ from memory.common.db.models.sources import (
|
|||||||
Book,
|
Book,
|
||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
|
DiscordServer,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordUser,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.users import (
|
from memory.common.db.models.users import (
|
||||||
User,
|
User,
|
||||||
@ -74,6 +78,7 @@ __all__ = [
|
|||||||
"EmailAttachment",
|
"EmailAttachment",
|
||||||
"AgentObservation",
|
"AgentObservation",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
|
"DiscordMessage",
|
||||||
"BlogPost",
|
"BlogPost",
|
||||||
"Comic",
|
"Comic",
|
||||||
"BookSection",
|
"BookSection",
|
||||||
@ -93,6 +98,9 @@ __all__ = [
|
|||||||
"Book",
|
"Book",
|
||||||
"ArticleFeed",
|
"ArticleFeed",
|
||||||
"EmailAccount",
|
"EmailAccount",
|
||||||
|
"DiscordServer",
|
||||||
|
"DiscordChannel",
|
||||||
|
"DiscordUser",
|
||||||
# Users
|
# Users
|
||||||
"User",
|
"User",
|
||||||
"UserSession",
|
"UserSession",
|
||||||
|
@ -70,7 +70,7 @@ class ScheduledLLMCall(Base):
|
|||||||
"created_at": print_datetime(cast(datetime, self.created_at)),
|
"created_at": print_datetime(cast(datetime, self.created_at)),
|
||||||
"executed_at": print_datetime(cast(datetime, self.executed_at)),
|
"executed_at": print_datetime(cast(datetime, self.executed_at)),
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": self.message,
|
"message": self.message,
|
||||||
"system_prompt": self.system_prompt,
|
"system_prompt": self.system_prompt,
|
||||||
"allowed_tools": self.allowed_tools,
|
"allowed_tools": self.allowed_tools,
|
||||||
"discord_channel": self.discord_channel,
|
"discord_channel": self.discord_channel,
|
||||||
|
@ -262,7 +262,7 @@ class ChatMessage(SourceItem):
|
|||||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
platform = Column(Text)
|
platform = Column(Text)
|
||||||
channel_id = Column(Text)
|
channel_id = Column(Text) # Keep as Text for cross-platform compatibility
|
||||||
author = Column(Text)
|
author = Column(Text)
|
||||||
sent_at = Column(DateTime(timezone=True))
|
sent_at = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
@ -274,6 +274,64 @@ class ChatMessage(SourceItem):
|
|||||||
__table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),)
|
__table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordMessage(SourceItem):
|
||||||
|
"""Discord-specific chat message with rich metadata"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_message"
|
||||||
|
|
||||||
|
id = Column(
|
||||||
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
||||||
|
channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False)
|
||||||
|
discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
|
||||||
|
message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID
|
||||||
|
|
||||||
|
# Discord-specific metadata
|
||||||
|
message_type = Column(
|
||||||
|
Text, server_default="default"
|
||||||
|
) # "default", "reply", "thread_starter"
|
||||||
|
reply_to_message_id = Column(
|
||||||
|
BigInteger, nullable=True
|
||||||
|
) # Discord message snowflake ID if replying
|
||||||
|
thread_id = Column(
|
||||||
|
BigInteger, nullable=True
|
||||||
|
) # Discord thread snowflake ID if in thread
|
||||||
|
edited_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
|
||||||
|
server = relationship("DiscordServer", foreign_keys=[server_id])
|
||||||
|
discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def title(self) -> str:
|
||||||
|
return f"{self.discord_user.username}: {self.content}"
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "discord_message",
|
||||||
|
}
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("discord_message_discord_id_idx", "message_id", unique=True),
|
||||||
|
Index(
|
||||||
|
"discord_message_server_channel_idx",
|
||||||
|
"server_id",
|
||||||
|
"channel_id",
|
||||||
|
),
|
||||||
|
Index("discord_message_user_idx", "discord_user_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
|
content = cast(str | None, self.content)
|
||||||
|
if not content:
|
||||||
|
return []
|
||||||
|
prev = getattr(self, "messages_before", [])
|
||||||
|
content = "\n\n".join(prev) + "\n\n" + self.title
|
||||||
|
return extract.extract_text(content)
|
||||||
|
|
||||||
|
|
||||||
class GitCommit(SourceItem):
|
class GitCommit(SourceItem):
|
||||||
__tablename__ = "git_commit"
|
__tablename__ = "git_commit"
|
||||||
|
|
||||||
|
@ -10,12 +10,14 @@ from sqlalchemy import (
|
|||||||
Boolean,
|
Boolean,
|
||||||
Column,
|
Column,
|
||||||
DateTime,
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
Index,
|
Index,
|
||||||
Integer,
|
Integer,
|
||||||
Text,
|
Text,
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from memory.common.db.models.base import Base
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
@ -123,3 +125,69 @@ class EmailAccount(Base):
|
|||||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordServer(Base):
|
||||||
|
"""Discord server configuration and metadata"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_servers"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
description = Column(Text)
|
||||||
|
member_count = Column(Integer)
|
||||||
|
|
||||||
|
# Collection settings
|
||||||
|
track_messages = Column(Boolean, nullable=False, server_default="true")
|
||||||
|
last_sync_at = Column(DateTime(timezone=True))
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
channels = relationship(
|
||||||
|
"DiscordChannel", back_populates="server", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordChannel(Base):
|
||||||
|
"""Discord channel metadata and configuration"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_channels"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
|
||||||
|
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
|
||||||
|
|
||||||
|
# Collection settings (null = inherit from server)
|
||||||
|
track_messages = Column(Boolean, nullable=True)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
server = relationship("DiscordServer", back_populates="channels")
|
||||||
|
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordUser(Base):
|
||||||
|
"""Discord user metadata and preferences"""
|
||||||
|
|
||||||
|
__tablename__ = "discord_users"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
|
||||||
|
username = Column(Text, nullable=False)
|
||||||
|
display_name = Column(Text)
|
||||||
|
|
||||||
|
# Link to system user if registered
|
||||||
|
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||||
|
|
||||||
|
# Basic DM settings
|
||||||
|
allow_dm_tracking = Column(Boolean, nullable=False, server_default="true")
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
system_user = relationship("User", back_populates="discord_users")
|
||||||
|
|
||||||
|
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||||
|
@ -50,6 +50,7 @@ class User(Base):
|
|||||||
oauth_states = relationship(
|
oauth_states = relationship(
|
||||||
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
"OAuthState", back_populates="user", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
discord_users = relationship("DiscordUser", back_populates="system_user")
|
||||||
|
|
||||||
def serialize(self) -> dict:
|
def serialize(self) -> dict:
|
||||||
return {
|
return {
|
||||||
|
@ -1,221 +1,101 @@
|
|||||||
|
"""
|
||||||
|
Discord integration.
|
||||||
|
|
||||||
|
Simple HTTP client that communicates with the Discord collector's API server.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ERROR_CHANNEL = "memory-errors"
|
|
||||||
ACTIVITY_CHANNEL = "memory-activity"
|
def get_api_url() -> str:
|
||||||
DISCOVERY_CHANNEL = "memory-discoveries"
|
"""Get the Discord API server URL"""
|
||||||
CHAT_CHANNEL = "memory-chat"
|
host = settings.DISCORD_COLLECTOR_SERVER_URL
|
||||||
|
port = settings.DISCORD_COLLECTOR_PORT
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
class DiscordServer(requests.Session):
|
def send_dm(user_identifier: str, message: str) -> bool:
|
||||||
def __init__(self, server_id: str, server_name: str, *args, **kwargs):
|
"""Send a DM via the Discord collector API"""
|
||||||
self.server_id = server_id
|
|
||||||
self.server_name = server_name
|
|
||||||
self.channels = {}
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.setup_channels()
|
|
||||||
self.members = self.fetch_all_members()
|
|
||||||
|
|
||||||
def setup_channels(self):
|
|
||||||
resp = self.get(self.channels_url)
|
|
||||||
resp.raise_for_status()
|
|
||||||
channels = {channel["name"]: channel["id"] for channel in resp.json()}
|
|
||||||
|
|
||||||
if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)):
|
|
||||||
error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL)
|
|
||||||
self.channels[ERROR_CHANNEL] = error_channel
|
|
||||||
|
|
||||||
if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)):
|
|
||||||
activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL)
|
|
||||||
self.channels[ACTIVITY_CHANNEL] = activity_channel
|
|
||||||
|
|
||||||
if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)):
|
|
||||||
discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL)
|
|
||||||
self.channels[DISCOVERY_CHANNEL] = discovery_channel
|
|
||||||
|
|
||||||
if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)):
|
|
||||||
chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL)
|
|
||||||
self.channels[CHAT_CHANNEL] = chat_channel
|
|
||||||
|
|
||||||
@property
|
|
||||||
def error_channel(self) -> str:
|
|
||||||
return self.channels[ERROR_CHANNEL]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def activity_channel(self) -> str:
|
|
||||||
return self.channels[ACTIVITY_CHANNEL]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def discovery_channel(self) -> str:
|
|
||||||
return self.channels[DISCOVERY_CHANNEL]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def chat_channel(self) -> str:
|
|
||||||
return self.channels[CHAT_CHANNEL]
|
|
||||||
|
|
||||||
def channel_id(self, channel_name: str) -> str:
|
|
||||||
if not (channel_id := self.channels.get(channel_name)):
|
|
||||||
raise ValueError(f"Channel {channel_name} not found")
|
|
||||||
return channel_id
|
|
||||||
|
|
||||||
def send_message(self, channel_id: str, content: str):
|
|
||||||
payload: dict[str, Any] = {"content": content}
|
|
||||||
mentions = re.findall(r"@(\S*)", content)
|
|
||||||
users = {u: i for u, i in self.members.items() if u in mentions}
|
|
||||||
if users:
|
|
||||||
for u, i in users.items():
|
|
||||||
payload["content"] = payload["content"].replace(f"@{u}", f"<@{i}>")
|
|
||||||
payload["allowed_mentions"] = {
|
|
||||||
"parse": [],
|
|
||||||
"users": list(users.values()),
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.post(
|
|
||||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
|
||||||
json=payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
|
|
||||||
resp = self.post(
|
|
||||||
self.channels_url, json={"name": channel_name, "type": channel_type}
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.json()["id"]
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return (
|
|
||||||
f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def request(self, method: str, url: str, **kwargs):
|
|
||||||
headers = kwargs.get("headers", {})
|
|
||||||
headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}"
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
kwargs["headers"] = headers
|
|
||||||
return super().request(method, url, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def channels_url(self) -> str:
|
|
||||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def members_url(self) -> str:
|
|
||||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/members"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dm_create_url(self) -> str:
|
|
||||||
return "https://discord.com/api/v10/users/@me/channels"
|
|
||||||
|
|
||||||
def list_members(
|
|
||||||
self, limit: int = 1000, after: str | None = None
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""List up to `limit` members in this guild, starting after a user ID.
|
|
||||||
|
|
||||||
Requires the bot to have the Server Members Intent enabled in the Discord developer portal.
|
|
||||||
"""
|
|
||||||
params: dict[str, Any] = {"limit": limit}
|
|
||||||
if after:
|
|
||||||
params["after"] = after
|
|
||||||
resp = self.get(self.members_url, params=params)
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.json()
|
|
||||||
|
|
||||||
def fetch_all_members(self, page_size: int = 1000) -> dict[str, str]:
|
|
||||||
"""Retrieve all members in the guild by paginating the members list.
|
|
||||||
|
|
||||||
Note: Large guilds may take multiple requests. Rate limits are respected by requests.Session automatically.
|
|
||||||
"""
|
|
||||||
members: dict[str, str] = {}
|
|
||||||
after: str | None = None
|
|
||||||
while batch := self.list_members(limit=page_size, after=after):
|
|
||||||
for member in batch:
|
|
||||||
user = member.get("user", {})
|
|
||||||
members[user.get("global_name") or user.get("username", "")] = user.get(
|
|
||||||
"id", ""
|
|
||||||
)
|
|
||||||
after = user.get("id", "")
|
|
||||||
return members
|
|
||||||
|
|
||||||
def create_dm_channel(self, user_id: str) -> str:
|
|
||||||
"""Create (or retrieve) a DM channel with the given user and return the channel ID.
|
|
||||||
|
|
||||||
The bot must share a guild with the user, and the user's privacy settings must allow DMs from server members.
|
|
||||||
"""
|
|
||||||
resp = self.post(self.dm_create_url, json={"recipient_id": user_id})
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
return data["id"]
|
|
||||||
|
|
||||||
def send_dm(self, user_id: str, content: str):
|
|
||||||
"""Send a direct message to a specific user by ID."""
|
|
||||||
channel_id = self.create_dm_channel(self.members.get(user_id) or user_id)
|
|
||||||
return self.post(
|
|
||||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
|
||||||
json={"content": content},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_bot_servers() -> list[dict[str, Any]]:
|
|
||||||
"""Get list of servers the bot is in."""
|
|
||||||
if not settings.DISCORD_BOT_TOKEN:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"}
|
response = requests.post(
|
||||||
response = requests.get(
|
f"{get_api_url()}/send_dm",
|
||||||
"https://discord.com/api/v10/users/@me/guilds", headers=headers
|
json={"user_identifier": user_identifier, "message": message},
|
||||||
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("success", False)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to send DM to {user_identifier}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_message(channel_name: str, message: str) -> bool:
|
||||||
|
"""Send a message to a channel via the Discord collector API"""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{get_api_url()}/send_channel",
|
||||||
|
json={"channel_name": channel_name, "message": message},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("success", False)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_collector_healthy() -> bool:
|
||||||
|
"""Check if the Discord collector is running and healthy"""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{get_api_url()}/health", timeout=5)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("status") == "healthy"
|
||||||
|
|
||||||
|
except requests.RequestException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_discord_metadata() -> dict[str, int] | None:
|
||||||
|
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||||
|
try:
|
||||||
|
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
except Exception as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Failed to get bot servers: {e}")
|
logger.error(f"Failed to refresh Discord metadata: {e}")
|
||||||
return []
|
return None
|
||||||
|
|
||||||
|
|
||||||
servers: dict[str, DiscordServer] = {}
|
# Convenience functions
|
||||||
|
def send_error_message(message: str) -> bool:
|
||||||
|
"""Send an error message to the error channel"""
|
||||||
|
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def load_servers():
|
def send_activity_message(message: str) -> bool:
|
||||||
for server in get_bot_servers():
|
"""Send an activity message to the activity channel"""
|
||||||
servers[server["id"]] = DiscordServer(server["id"], server["name"])
|
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_message(channel: str, message: str):
|
def send_discovery_message(message: str) -> bool:
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
"""Send a discovery message to the discovery channel"""
|
||||||
return
|
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
|
||||||
|
|
||||||
for server in servers.values():
|
|
||||||
server.send_message(server.channel_id(channel), message)
|
|
||||||
|
|
||||||
|
|
||||||
def send_error_message(message: str):
|
def send_chat_message(message: str) -> bool:
|
||||||
broadcast_message(ERROR_CHANNEL, message)
|
"""Send a chat message to the chat channel"""
|
||||||
|
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
|
||||||
|
|
||||||
def send_activity_message(message: str):
|
|
||||||
broadcast_message(ACTIVITY_CHANNEL, message)
|
|
||||||
|
|
||||||
|
|
||||||
def send_discovery_message(message: str):
|
|
||||||
broadcast_message(DISCOVERY_CHANNEL, message)
|
|
||||||
|
|
||||||
|
|
||||||
def send_chat_message(message: str):
|
|
||||||
broadcast_message(CHAT_CHANNEL, message)
|
|
||||||
|
|
||||||
|
|
||||||
def send_dm(user_id: str, message: str):
|
|
||||||
for server in servers.values():
|
|
||||||
if not server.members.get(user_id) and user_id not in server.members.values():
|
|
||||||
continue
|
|
||||||
|
|
||||||
server.send_dm(user_id, message)
|
|
||||||
|
|
||||||
|
|
||||||
def notify_task_failure(
|
def notify_task_failure(
|
||||||
@ -234,9 +114,6 @@ def notify_task_failure(
|
|||||||
task_args: Task arguments
|
task_args: Task arguments
|
||||||
task_kwargs: Task keyword arguments
|
task_kwargs: Task keyword arguments
|
||||||
traceback_str: Full traceback string
|
traceback_str: Full traceback string
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if notification sent successfully
|
|
||||||
"""
|
"""
|
||||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
logger.debug("Discord notifications disabled")
|
logger.debug("Discord notifications disabled")
|
||||||
|
@ -172,3 +172,11 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
|||||||
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
||||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Discord collector settings
|
||||||
|
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
|
||||||
|
DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True)
|
||||||
|
DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
|
||||||
|
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8001))
|
||||||
|
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1")
|
||||||
|
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
|
||||||
|
166
src/memory/workers/discord/api.py
Normal file
166
src/memory/workers/discord/api.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Discord API server.
|
||||||
|
|
||||||
|
FastAPI server that owns and manages a Discord collector instance,
|
||||||
|
providing HTTP endpoints for sending Discord messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.workers.discord.collector import MessageCollector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SendDMRequest(BaseModel):
|
||||||
|
user: str # Discord user ID or username
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class SendChannelRequest(BaseModel):
|
||||||
|
channel_name: str # Channel name (e.g., "memory-errors")
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
# Application state
|
||||||
|
class AppState:
|
||||||
|
def __init__(self):
|
||||||
|
self.collector: MessageCollector | None = None
|
||||||
|
self.collector_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
|
||||||
|
app_state = AppState()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Manage Discord collector lifecycle"""
|
||||||
|
if not settings.DISCORD_BOT_TOKEN:
|
||||||
|
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create and start the collector
|
||||||
|
app_state.collector = MessageCollector()
|
||||||
|
app_state.collector_task = asyncio.create_task(
|
||||||
|
app_state.collector.start(settings.DISCORD_BOT_TOKEN)
|
||||||
|
)
|
||||||
|
logger.info("Discord collector started")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if app_state.collector and not app_state.collector.is_closed():
|
||||||
|
await app_state.collector.close()
|
||||||
|
|
||||||
|
if app_state.collector_task:
|
||||||
|
app_state.collector_task.cancel()
|
||||||
|
try:
|
||||||
|
await app_state.collector_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("Discord collector stopped")
|
||||||
|
|
||||||
|
|
||||||
|
# FastAPI app with lifespan management
|
||||||
|
app = FastAPI(title="Discord Collector API", version="1.0.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/send_dm")
|
||||||
|
async def send_dm_endpoint(request: SendDMRequest):
|
||||||
|
"""Send a DM via the collector's Discord client"""
|
||||||
|
if not app_state.collector:
|
||||||
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await app_state.collector.send_dm(request.user, request.message)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to send DM to {request.user}",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"DM sent to {request.user}",
|
||||||
|
"user": request.user,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send DM: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/send_channel")
|
||||||
|
async def send_channel_endpoint(request: SendChannelRequest):
|
||||||
|
"""Send a message to a channel via the collector's Discord client"""
|
||||||
|
if not app_state.collector:
|
||||||
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await app_state.collector.send_to_channel(
|
||||||
|
request.channel_name, request.message
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Message sent to channel {request.channel_name}",
|
||||||
|
"channel": request.channel_name,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to send message to channel {request.channel_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send channel message: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Check if the Discord collector is running and healthy"""
|
||||||
|
if not app_state.collector:
|
||||||
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
|
collector = app_state.collector
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"connected": not collector.is_closed(),
|
||||||
|
"user": str(collector.user) if collector.user else None,
|
||||||
|
"guilds": len(collector.guilds) if collector.guilds else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/refresh_metadata")
|
||||||
|
async def refresh_metadata():
|
||||||
|
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||||
|
if not app_state.collector:
|
||||||
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await app_state.collector.refresh_metadata()
|
||||||
|
return {"success": True, "message": "Metadata refreshed successfully", **result}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to refresh metadata: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
|
||||||
|
"""Run the Discord API server"""
|
||||||
|
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# For testing the API server standalone
|
||||||
|
host = settings.DISCORD_COLLECTOR_SERVER_URL
|
||||||
|
port = settings.DISCORD_COLLECTOR_PORT
|
||||||
|
run_discord_api_server(host, port)
|
398
src/memory/workers/discord/collector.py
Normal file
398
src/memory/workers/discord/collector.py
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
"""
|
||||||
|
Discord message collector.
|
||||||
|
|
||||||
|
Core message collection functionality - stores Discord messages to database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models.sources import (
|
||||||
|
DiscordServer,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordUser,
|
||||||
|
)
|
||||||
|
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Pure functions for Discord entity creation/updates
|
||||||
|
def create_or_update_server(
|
||||||
|
session: Session | scoped_session, guild: discord.Guild
|
||||||
|
) -> DiscordServer:
|
||||||
|
"""Get or create DiscordServer record (pure DB operation)"""
|
||||||
|
server = session.query(DiscordServer).get(guild.id)
|
||||||
|
|
||||||
|
if not server:
|
||||||
|
server = DiscordServer(
|
||||||
|
id=guild.id,
|
||||||
|
name=guild.name,
|
||||||
|
description=guild.description,
|
||||||
|
member_count=guild.member_count,
|
||||||
|
)
|
||||||
|
session.add(server)
|
||||||
|
session.flush() # Get the ID
|
||||||
|
logger.info(f"Created server record for {guild.name} ({guild.id})")
|
||||||
|
else:
|
||||||
|
# Update metadata
|
||||||
|
server.name = guild.name
|
||||||
|
server.description = guild.description
|
||||||
|
server.member_count = guild.member_count
|
||||||
|
server.last_sync_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
def determine_channel_metadata(channel) -> tuple[str, int | None, str]:
|
||||||
|
"""Pure function to determine channel type, server_id, and name"""
|
||||||
|
if isinstance(channel, discord.DMChannel):
|
||||||
|
return "dm", None, f"DM with {channel.recipient.name}"
|
||||||
|
elif isinstance(channel, discord.GroupChannel):
|
||||||
|
return "group_dm", None, channel.name or "Group DM"
|
||||||
|
elif isinstance(
|
||||||
|
channel, (discord.TextChannel, discord.VoiceChannel, discord.Thread)
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
channel.__class__.__name__.lower().replace("channel", ""),
|
||||||
|
channel.guild.id,
|
||||||
|
channel.name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
guild = getattr(channel, "guild", None)
|
||||||
|
server_id = guild.id if guild else None
|
||||||
|
name = getattr(channel, "name", f"Unknown-{channel.id}")
|
||||||
|
return "unknown", server_id, name
|
||||||
|
|
||||||
|
|
||||||
|
def create_or_update_channel(
|
||||||
|
session: Session | scoped_session, channel
|
||||||
|
) -> DiscordChannel:
|
||||||
|
"""Get or create DiscordChannel record (pure DB operation)"""
|
||||||
|
discord_channel = session.query(DiscordChannel).get(channel.id)
|
||||||
|
|
||||||
|
if not discord_channel:
|
||||||
|
channel_type, server_id, name = determine_channel_metadata(channel)
|
||||||
|
discord_channel = DiscordChannel(
|
||||||
|
id=channel.id,
|
||||||
|
server_id=server_id,
|
||||||
|
name=name,
|
||||||
|
channel_type=channel_type,
|
||||||
|
)
|
||||||
|
session.add(discord_channel)
|
||||||
|
session.flush()
|
||||||
|
logger.debug(f"Created channel: {name}")
|
||||||
|
elif hasattr(channel, "name"):
|
||||||
|
discord_channel.name = channel.name
|
||||||
|
|
||||||
|
return discord_channel
|
||||||
|
|
||||||
|
|
||||||
|
def create_or_update_user(
|
||||||
|
session: Session | scoped_session, user: discord.User | discord.Member
|
||||||
|
) -> DiscordUser:
|
||||||
|
"""Get or create DiscordUser record (pure DB operation)"""
|
||||||
|
discord_user = session.query(DiscordUser).get(user.id)
|
||||||
|
|
||||||
|
if not discord_user:
|
||||||
|
discord_user = DiscordUser(
|
||||||
|
id=user.id,
|
||||||
|
username=user.name,
|
||||||
|
display_name=user.display_name,
|
||||||
|
)
|
||||||
|
session.add(discord_user)
|
||||||
|
session.flush()
|
||||||
|
logger.debug(f"Created user: {user.name}")
|
||||||
|
else:
|
||||||
|
# Update user info in case it changed
|
||||||
|
discord_user.username = user.name
|
||||||
|
discord_user.display_name = user.display_name
|
||||||
|
|
||||||
|
return discord_user
|
||||||
|
|
||||||
|
|
||||||
|
def determine_message_metadata(
|
||||||
|
message: discord.Message,
|
||||||
|
) -> tuple[str, int | None, int | None]:
|
||||||
|
"""Pure function to determine message type, reply_to_id, and thread_id"""
|
||||||
|
message_type = "default"
|
||||||
|
reply_to_id = None
|
||||||
|
thread_id = None
|
||||||
|
|
||||||
|
if message.reference and message.reference.message_id:
|
||||||
|
message_type = "reply"
|
||||||
|
reply_to_id = message.reference.message_id
|
||||||
|
|
||||||
|
if hasattr(message.channel, "parent") and message.channel.parent:
|
||||||
|
thread_id = message.channel.id
|
||||||
|
|
||||||
|
return message_type, reply_to_id, thread_id
|
||||||
|
|
||||||
|
|
||||||
|
def should_track_message(
|
||||||
|
server: DiscordServer | None,
|
||||||
|
channel: DiscordChannel,
|
||||||
|
user: DiscordUser,
|
||||||
|
) -> bool:
|
||||||
|
"""Pure function to determine if we should track this message"""
|
||||||
|
if server and not server.track_messages: # type: ignore
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not channel.track_messages:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if channel.channel_type in ("dm", "group_dm"):
|
||||||
|
return bool(user.allow_dm_tracking)
|
||||||
|
|
||||||
|
# Default: track the message
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def should_collect_bot_message(message: discord.Message) -> bool:
|
||||||
|
"""Pure function to determine if we should collect bot messages"""
|
||||||
|
return not message.author.bot or settings.DISCORD_COLLECT_BOTS
|
||||||
|
|
||||||
|
|
||||||
|
def sync_guild_metadata(guild: discord.Guild) -> None:
|
||||||
|
"""Sync a single guild's metadata (functional approach)"""
|
||||||
|
with make_session() as session:
|
||||||
|
create_or_update_server(session, guild)
|
||||||
|
|
||||||
|
for channel in guild.channels:
|
||||||
|
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
||||||
|
create_or_update_channel(session, channel)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class MessageCollector(commands.Bot):
|
||||||
|
"""Discord bot that collects and stores messages (thin event handler)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.message_content = True
|
||||||
|
intents.guilds = True
|
||||||
|
intents.members = True
|
||||||
|
intents.dm_messages = True
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
command_prefix="!memory_", # Prefix to avoid conflicts
|
||||||
|
intents=intents,
|
||||||
|
help_command=None, # Disable default help
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_ready(self):
|
||||||
|
"""Called when bot connects to Discord"""
|
||||||
|
logger.info(f"Discord collector connected as {self.user}")
|
||||||
|
logger.info(f"Connected to {len(self.guilds)} servers")
|
||||||
|
|
||||||
|
# Sync server and channel metadata
|
||||||
|
await self.sync_servers_and_channels()
|
||||||
|
|
||||||
|
logger.info("Discord message collector ready")
|
||||||
|
|
||||||
|
async def on_message(self, message: discord.Message):
|
||||||
|
"""Queue incoming message for database storage"""
|
||||||
|
try:
|
||||||
|
if should_collect_bot_message(message):
|
||||||
|
# Ensure Discord entities exist in database first
|
||||||
|
with make_session() as session:
|
||||||
|
create_or_update_user(session, message.author)
|
||||||
|
create_or_update_channel(session, message.channel)
|
||||||
|
if message.guild:
|
||||||
|
create_or_update_server(session, message.guild)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Queue the message for processing
|
||||||
|
add_discord_message.delay(
|
||||||
|
message_id=message.id,
|
||||||
|
channel_id=message.channel.id,
|
||||||
|
author_id=message.author.id,
|
||||||
|
server_id=message.guild.id if message.guild else None,
|
||||||
|
content=message.content or "",
|
||||||
|
sent_at=message.created_at.isoformat(),
|
||||||
|
message_reference_id=message.reference.message_id
|
||||||
|
if message.reference
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error queuing message {message.id}: {e}")
|
||||||
|
|
||||||
|
async def on_message_edit(self, before: discord.Message, after: discord.Message):
|
||||||
|
"""Queue message edit for database update"""
|
||||||
|
try:
|
||||||
|
edit_time = after.edited_at or datetime.now(timezone.utc)
|
||||||
|
edit_discord_message.delay(
|
||||||
|
message_id=after.id,
|
||||||
|
content=after.content,
|
||||||
|
edited_at=edit_time.isoformat(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error queuing message edit {after.id}: {e}")
|
||||||
|
|
||||||
|
async def sync_servers_and_channels(self):
|
||||||
|
"""Sync server and channel metadata on startup"""
|
||||||
|
for guild in self.guilds:
|
||||||
|
sync_guild_metadata(guild)
|
||||||
|
|
||||||
|
logger.info(f"Synced {len(self.guilds)} servers and their channels")
|
||||||
|
|
||||||
|
async def refresh_metadata(self) -> dict[str, int]:
|
||||||
|
"""Refresh server and channel metadata from Discord and update database"""
|
||||||
|
print("🔄 Refreshing Discord metadata...")
|
||||||
|
|
||||||
|
servers_updated = 0
|
||||||
|
channels_updated = 0
|
||||||
|
users_updated = 0
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Refresh all servers
|
||||||
|
for guild in self.guilds:
|
||||||
|
create_or_update_server(session, guild)
|
||||||
|
servers_updated += 1
|
||||||
|
|
||||||
|
# Refresh all channels in this server
|
||||||
|
for channel in guild.channels:
|
||||||
|
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
|
||||||
|
create_or_update_channel(session, channel)
|
||||||
|
channels_updated += 1
|
||||||
|
|
||||||
|
# Refresh all members in this server (if members intent is enabled)
|
||||||
|
if self.intents.members:
|
||||||
|
for member in guild.members:
|
||||||
|
create_or_update_user(session, member)
|
||||||
|
users_updated += 1
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"servers_updated": servers_updated,
|
||||||
|
"channels_updated": channels_updated,
|
||||||
|
"users_updated": users_updated,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"✅ Metadata refresh complete: {result}")
|
||||||
|
logger.info(f"Metadata refresh complete: {result}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_user(self, user_identifier: int | str) -> discord.User | None:
|
||||||
|
"""Get a Discord user by ID or username"""
|
||||||
|
if isinstance(user_identifier, int):
|
||||||
|
# Direct user ID lookup
|
||||||
|
if user := super().get_user(user_identifier):
|
||||||
|
return user
|
||||||
|
try:
|
||||||
|
return await self.fetch_user(user_identifier)
|
||||||
|
except discord.NotFound:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# Username lookup - search through all guilds
|
||||||
|
for guild in self.guilds:
|
||||||
|
for member in guild.members:
|
||||||
|
if (
|
||||||
|
member.name == user_identifier
|
||||||
|
or member.display_name == user_identifier
|
||||||
|
or f"{member.name}#{member.discriminator}" == user_identifier
|
||||||
|
):
|
||||||
|
return member
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_channel_by_name(
|
||||||
|
self, channel_name: str
|
||||||
|
) -> discord.TextChannel | None:
|
||||||
|
"""Get a Discord channel by name (does not create if missing)"""
|
||||||
|
# Search all guilds for the channel
|
||||||
|
for guild in self.guilds:
|
||||||
|
for ch in guild.channels:
|
||||||
|
if isinstance(ch, discord.TextChannel) and ch.name == channel_name:
|
||||||
|
return ch
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create_channel(
|
||||||
|
self, channel_name: str, guild_id: int | None = None
|
||||||
|
) -> discord.TextChannel | None:
|
||||||
|
"""Create a Discord channel in the specified guild (or first guild if none specified)"""
|
||||||
|
target_guild = None
|
||||||
|
|
||||||
|
if guild_id:
|
||||||
|
target_guild = self.get_guild(guild_id)
|
||||||
|
elif self.guilds:
|
||||||
|
target_guild = self.guilds[0] # Default to first guild
|
||||||
|
|
||||||
|
if not target_guild:
|
||||||
|
logger.error(f"No guild available to create channel {channel_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
channel = await target_guild.create_text_channel(channel_name)
|
||||||
|
logger.info(f"Created channel {channel_name} in {target_guild.name}")
|
||||||
|
return channel
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create channel {channel_name} in {target_guild.name}: {e}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send_dm(self, user_identifier: int | str, message: str) -> bool:
|
||||||
|
"""Send a DM using this collector's Discord client"""
|
||||||
|
try:
|
||||||
|
user = await self.get_user(user_identifier)
|
||||||
|
if not user:
|
||||||
|
logger.error(f"User {user_identifier} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await user.send(message)
|
||||||
|
logger.info(f"Sent DM to {user_identifier}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send DM to {user_identifier}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send_to_channel(self, channel_name: str, message: str) -> bool:
|
||||||
|
"""Send a message to a channel by name across all guilds"""
|
||||||
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
channel = await self.get_channel_by_name(channel_name)
|
||||||
|
if not channel:
|
||||||
|
logger.error(f"Channel {channel_name} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await channel.send(message)
|
||||||
|
logger.info(f"Sent message to channel {channel_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def run_collector():
|
||||||
|
"""Run the Discord message collector"""
|
||||||
|
if not settings.DISCORD_BOT_TOKEN:
|
||||||
|
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||||
|
return
|
||||||
|
|
||||||
|
collector = MessageCollector()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await collector.start(settings.DISCORD_BOT_TOKEN)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Discord collector failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run_collector())
|
@ -6,6 +6,7 @@ from memory.workers.tasks import (
|
|||||||
email,
|
email,
|
||||||
comic,
|
comic,
|
||||||
blogs,
|
blogs,
|
||||||
|
discord,
|
||||||
ebook,
|
ebook,
|
||||||
forums,
|
forums,
|
||||||
maintenance,
|
maintenance,
|
||||||
@ -20,6 +21,7 @@ __all__ = [
|
|||||||
"comic",
|
"comic",
|
||||||
"blogs",
|
"blogs",
|
||||||
"ebook",
|
"ebook",
|
||||||
|
"discord",
|
||||||
"forums",
|
"forums",
|
||||||
"maintenance",
|
"maintenance",
|
||||||
"notes",
|
"notes",
|
||||||
|
@ -10,9 +10,9 @@ from collections import defaultdict
|
|||||||
import hashlib
|
import hashlib
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Iterable, Sequence, cast
|
from typing import Any, Callable, Sequence, cast
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, settings
|
from memory.common import embedding, qdrant
|
||||||
from memory.common.db.models import SourceItem, Chunk
|
from memory.common.db.models import SourceItem, Chunk
|
||||||
from memory.common.discord import notify_task_failure
|
from memory.common.discord import notify_task_failure
|
||||||
|
|
||||||
@ -38,19 +38,12 @@ def check_content_exists(
|
|||||||
Returns:
|
Returns:
|
||||||
Existing SourceItem if found, None otherwise
|
Existing SourceItem if found, None otherwise
|
||||||
"""
|
"""
|
||||||
|
query = session.query(model_class)
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if not hasattr(model_class, key):
|
if hasattr(model_class, key):
|
||||||
continue
|
query = query.filter(getattr(model_class, key) == value)
|
||||||
|
|
||||||
existing = (
|
return query.first()
|
||||||
session.query(model_class)
|
|
||||||
.filter(getattr(model_class, key) == value)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
return existing
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_content_hash(content: str, *additional_data: str) -> bytes:
|
def create_content_hash(content: str, *additional_data: str) -> bytes:
|
||||||
@ -286,6 +279,6 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
|||||||
traceback_str=traceback_str,
|
traceback_str=traceback_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"status": "error", "error": str(e)}
|
return {"status": "error", "error": str(e), "traceback": traceback_str}
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
130
src/memory/workers/tasks/discord.py
Normal file
130
src/memory/workers/tasks/discord.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
Celery tasks for Discord message processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from memory.common.celery_app import app
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import DiscordMessage, DiscordUser
|
||||||
|
from memory.workers.tasks.content_processing import (
|
||||||
|
safe_task_execution,
|
||||||
|
check_content_exists,
|
||||||
|
create_task_result,
|
||||||
|
process_content_item,
|
||||||
|
)
|
||||||
|
from memory.common.celery_app import ADD_DISCORD_MESSAGE, EDIT_DISCORD_MESSAGE
|
||||||
|
from memory.common import settings
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prev(
|
||||||
|
session: Session | scoped_session, channel_id: int, sent_at: datetime
|
||||||
|
) -> list[str]:
|
||||||
|
prev = (
|
||||||
|
session.query(DiscordUser.username, DiscordMessage.content)
|
||||||
|
.join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id)
|
||||||
|
.filter(
|
||||||
|
DiscordMessage.channel_id == channel_id,
|
||||||
|
DiscordMessage.sent_at < sent_at,
|
||||||
|
)
|
||||||
|
.order_by(DiscordMessage.sent_at.desc())
|
||||||
|
.limit(settings.DISCORD_CONTEXT_WINDOW)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=ADD_DISCORD_MESSAGE)
|
||||||
|
@safe_task_execution
|
||||||
|
def add_discord_message(
|
||||||
|
message_id: int,
|
||||||
|
channel_id: int,
|
||||||
|
author_id: int,
|
||||||
|
content: str,
|
||||||
|
sent_at: str,
|
||||||
|
server_id: int | None = None,
|
||||||
|
message_reference_id: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Add a Discord message to the database.
|
||||||
|
|
||||||
|
This task is queued by the Discord collector when messages are received.
|
||||||
|
"""
|
||||||
|
logger.info(f"Adding Discord message {message_id}: {content}")
|
||||||
|
# Include message_id in hash to ensure uniqueness across duplicate content
|
||||||
|
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
|
||||||
|
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
discord_message = DiscordMessage(
|
||||||
|
modality="text",
|
||||||
|
sha256=content_hash,
|
||||||
|
content=content,
|
||||||
|
channel_id=channel_id,
|
||||||
|
sent_at=sent_at_dt,
|
||||||
|
server_id=server_id,
|
||||||
|
discord_user_id=author_id,
|
||||||
|
message_id=message_id,
|
||||||
|
message_type="reply" if message_reference_id else "default",
|
||||||
|
reply_to_message_id=message_reference_id,
|
||||||
|
)
|
||||||
|
existing_msg = check_content_exists(
|
||||||
|
session, DiscordMessage, message_id=message_id, sha256=content_hash
|
||||||
|
)
|
||||||
|
if existing_msg:
|
||||||
|
logger.info(f"Discord message already exists: {existing_msg.message_id}")
|
||||||
|
return create_task_result(
|
||||||
|
existing_msg, "already_exists", message_id=message_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel_id:
|
||||||
|
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
|
||||||
|
|
||||||
|
result = process_content_item(discord_message, session)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Discord message ID after process_content_item: {discord_message.id}"
|
||||||
|
)
|
||||||
|
logger.info(f"Process result: {result}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=EDIT_DISCORD_MESSAGE)
|
||||||
|
@safe_task_execution
|
||||||
|
def edit_discord_message(
|
||||||
|
message_id: int, content: str, edited_at: str
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Edit a Discord message in the database.
|
||||||
|
|
||||||
|
This task is queued by the Discord collector when messages are edited.
|
||||||
|
"""
|
||||||
|
logger.info(f"Editing Discord message {message_id}: {content}")
|
||||||
|
with make_session() as session:
|
||||||
|
existing_msg = check_content_exists(
|
||||||
|
session, DiscordMessage, message_id=message_id
|
||||||
|
)
|
||||||
|
if not existing_msg:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": "Message not found",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
existing_msg.content = content # type: ignore
|
||||||
|
if existing_msg.channel_id:
|
||||||
|
existing_msg.messages_before = get_prev(
|
||||||
|
session, existing_msg.channel_id, existing_msg.sent_at
|
||||||
|
)
|
||||||
|
existing_msg.edited_at = datetime.fromisoformat(
|
||||||
|
edited_at.replace("Z", "+00:00")
|
||||||
|
)
|
||||||
|
|
||||||
|
return process_content_item(existing_msg, session)
|
Loading…
x
Reference in New Issue
Block a user