Compare commits

...

3 Commits

Author SHA1 Message Date
Daniel O'Connell
99d3843f47 move to general LLM providers 2025-10-13 03:23:20 +02:00
Daniel O'Connell
08d17c28dd run discord collector 2025-10-12 23:43:44 +02:00
Daniel O'Connell
e086b4a3a6 add Discord ingester 2025-10-12 23:13:30 +02:00
39 changed files with 4966 additions and 610 deletions

View File

@ -9,7 +9,6 @@ Create Date: 2025-10-12 10:12:57.421009
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
@ -20,10 +19,8 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Rename prompt column to message in scheduled_llm_calls table
op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message")
def downgrade() -> None:
# Rename message column back to prompt in scheduled_llm_calls table
op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt")

View File

@ -0,0 +1,176 @@
"""add_discord_models
Revision ID: a8c8e8b17179
Revises: c86079073c1d
Create Date: 2025-10-12 22:28:27.856164
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "a8c8e8b17179"
down_revision: Union[str, None] = "c86079073c1d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"discord_servers",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("member_count", sa.Integer(), nullable=True),
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"discord_servers_active_idx",
"discord_servers",
["track_messages", "last_sync_at"],
unique=False,
)
op.create_table(
"discord_channels",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("server_id", sa.BigInteger(), nullable=True),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("channel_type", sa.Text(), nullable=False),
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(
["server_id"],
["discord_servers.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"discord_channels_server_idx", "discord_channels", ["server_id"], unique=False
)
op.create_table(
"discord_users",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("username", sa.Text(), nullable=False),
sa.Column("display_name", sa.Text(), nullable=True),
sa.Column("system_user_id", sa.Integer(), nullable=True),
sa.Column("track_messages", sa.Boolean(), server_default="true", nullable=True),
sa.Column(
"ignore_messages", sa.Boolean(), server_default="false", nullable=True
),
sa.Column("allowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("disallowed_tools", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(
["system_user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"discord_users_system_user_idx",
"discord_users",
["system_user_id"],
unique=False,
)
op.create_table(
"discord_message",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("server_id", sa.BigInteger(), nullable=True),
sa.Column("channel_id", sa.BigInteger(), nullable=False),
sa.Column("discord_user_id", sa.BigInteger(), nullable=False),
sa.Column("message_id", sa.BigInteger(), nullable=False),
sa.Column("message_type", sa.Text(), server_default="default", nullable=True),
sa.Column("reply_to_message_id", sa.BigInteger(), nullable=True),
sa.Column("thread_id", sa.BigInteger(), nullable=True),
sa.Column("edited_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["channel_id"],
["discord_channels.id"],
),
sa.ForeignKeyConstraint(
["discord_user_id"],
["discord_users.id"],
),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["server_id"],
["discord_servers.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"discord_message_discord_id_idx", "discord_message", ["message_id"], unique=True
)
op.create_index(
"discord_message_server_channel_idx",
"discord_message",
["server_id", "channel_id"],
unique=False,
)
op.create_index(
"discord_message_user_idx", "discord_message", ["discord_user_id"], unique=False
)
def downgrade() -> None:
op.drop_index("discord_message_user_idx", table_name="discord_message")
op.drop_index("discord_message_server_channel_idx", table_name="discord_message")
op.drop_index("discord_message_discord_id_idx", table_name="discord_message")
op.drop_table("discord_message")
op.drop_index("discord_users_system_user_idx", table_name="discord_users")
op.drop_table("discord_users")
op.drop_index("discord_channels_server_idx", table_name="discord_channels")
op.drop_table("discord_channels")
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
op.drop_table("discord_servers")

View File

@ -174,7 +174,7 @@ services:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes,scheduler"
QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
ingest-hub:
<<: *worker-base
@ -183,6 +183,10 @@ services:
dockerfile: docker/ingest_hub/Dockerfile
environment:
<<: *worker-env
DISCORD_API_PORT: 8000
DISCORD_BOT_TOKEN: ${DISCORD_BOT_TOKEN}
DISCORD_NOTIFICATIONS_ENABLED: true
DISCORD_COLLECTOR_ENABLED: true
volumes:
- ./memory_files:/app/memory_files:rw
tmpfs:

View File

@ -11,10 +11,10 @@ RUN apt-get update && apt-get install -y \
COPY requirements ./requirements/
COPY setup.py ./
RUN mkdir src
RUN pip install -e ".[common]"
RUN pip install -e ".[ingesters]"
COPY src/ ./src/
RUN pip install -e ".[common]"
RUN pip install -e ".[ingesters]"
# Create and copy entrypoint script
COPY docker/workers/entry.sh ./entry.sh

View File

@ -14,3 +14,12 @@ stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
autorestart=true
startsecs=10
[program:discord-api]
command=uvicorn memory.discord.api:app --host 0.0.0.0 --port %(ENV_DISCORD_API_PORT)s
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
autorestart=true
startsecs=10

View File

@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
git config --global user.name "${GIT_USER_NAME}"
# Default queues to process
ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance"
ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
ENV PYTHONPATH="/app"
ENTRYPOINT ["./entry.sh"]

View File

@ -5,7 +5,7 @@ alembic==1.13.1
dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
anthropic==0.18.1
anthropic==0.69.0
openai==1.25.0
# Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0

View File

@ -0,0 +1,3 @@
discord.py==2.3.2
uvicorn==0.29.0
fastapi==0.112.2

View File

@ -17,6 +17,7 @@ common_requires = read_requirements("requirements-common.txt")
parsers_requires = read_requirements("requirements-parsers.txt")
api_requires = read_requirements("requirements-api.txt")
dev_requires = read_requirements("requirements-dev.txt")
ingesters_requires = read_requirements("requirements-ingesters.txt")
setup(
name="memory",
@ -28,6 +29,11 @@ setup(
"api": api_requires + common_requires + parsers_requires,
"common": common_requires + parsers_requires,
"dev": dev_requires,
"all": api_requires + common_requires + dev_requires + parsers_requires,
"ingesters": common_requires + parsers_requires + ingesters_requires,
"all": api_requires
+ common_requires
+ dev_requires
+ parsers_requires
+ ingesters_requires,
},
)

View File

@ -40,7 +40,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
try:
response = await asyncio.to_thread(
llms.call,
llms.summarize,
prompt,
settings.RANKER_MODEL,
images=images,

View File

@ -12,6 +12,10 @@ MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
NOTES_ROOT = "memory.workers.tasks.notes"
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
DISCORD_ROOT = "memory.workers.tasks.discord"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
@ -72,17 +76,18 @@ app.conf.update(
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1,
task_routes={
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
f"{COMIC_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-comic"},
f"{DISCORD_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
f"{EMAIL_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-email"},
f"{FORUMS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-forums"},
f"{MAINTENANCE_ROOT}.*": {
"queue": f"{settings.CELERY_QUEUE_PREFIX}-maintenance"
},
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
f"{PHOTO_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-photo-embed"},
f"{SCHEDULED_CALLS_ROOT}.*": {
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
},

View File

@ -11,6 +11,7 @@ from memory.common.db.models.source_items import (
EmailAttachment,
AgentObservation,
ChatMessage,
DiscordMessage,
BlogPost,
Comic,
BookSection,
@ -40,6 +41,9 @@ from memory.common.db.models.sources import (
Book,
ArticleFeed,
EmailAccount,
DiscordServer,
DiscordChannel,
DiscordUser,
)
from memory.common.db.models.users import (
User,
@ -74,6 +78,7 @@ __all__ = [
"EmailAttachment",
"AgentObservation",
"ChatMessage",
"DiscordMessage",
"BlogPost",
"Comic",
"BookSection",
@ -93,6 +98,9 @@ __all__ = [
"Book",
"ArticleFeed",
"EmailAccount",
"DiscordServer",
"DiscordChannel",
"DiscordUser",
# Users
"User",
"UserSession",

View File

@ -70,7 +70,7 @@ class ScheduledLLMCall(Base):
"created_at": print_datetime(cast(datetime, self.created_at)),
"executed_at": print_datetime(cast(datetime, self.executed_at)),
"model": self.model,
"prompt": self.message,
"message": self.message,
"system_prompt": self.system_prompt,
"allowed_tools": self.allowed_tools,
"discord_channel": self.discord_channel,

View File

@ -262,7 +262,7 @@ class ChatMessage(SourceItem):
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
platform = Column(Text)
channel_id = Column(Text)
channel_id = Column(Text) # Keep as Text for cross-platform compatibility
author = Column(Text)
sent_at = Column(DateTime(timezone=True))
@ -274,6 +274,64 @@ class ChatMessage(SourceItem):
__table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),)
class DiscordMessage(SourceItem):
"""Discord-specific chat message with rich metadata"""
__tablename__ = "discord_message"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
sent_at = Column(DateTime(timezone=True), nullable=False)
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
channel_id = Column(BigInteger, ForeignKey("discord_channels.id"), nullable=False)
discord_user_id = Column(BigInteger, ForeignKey("discord_users.id"), nullable=False)
message_id = Column(BigInteger, nullable=False) # Discord message snowflake ID
# Discord-specific metadata
message_type = Column(
Text, server_default="default"
) # "default", "reply", "thread_starter"
reply_to_message_id = Column(
BigInteger, nullable=True
) # Discord message snowflake ID if replying
thread_id = Column(
BigInteger, nullable=True
) # Discord thread snowflake ID if in thread
edited_at = Column(DateTime(timezone=True), nullable=True)
channel = relationship("DiscordChannel", foreign_keys=[channel_id])
server = relationship("DiscordServer", foreign_keys=[server_id])
discord_user = relationship("DiscordUser", foreign_keys=[discord_user_id])
@property
def title(self) -> str:
return f"{self.discord_user.username}: {self.content}"
__mapper_args__ = {
"polymorphic_identity": "discord_message",
}
__table_args__ = (
Index("discord_message_discord_id_idx", "message_id", unique=True),
Index(
"discord_message_server_channel_idx",
"server_id",
"channel_id",
),
Index("discord_message_user_idx", "discord_user_id"),
)
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
content = cast(str | None, self.content)
if not content:
return []
prev = getattr(self, "messages_before", [])
content = "\n\n".join(prev) + "\n\n" + self.title
return extract.extract_text(content)
class GitCommit(SourceItem):
__tablename__ = "git_commit"

View File

@ -10,12 +10,14 @@ from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Text,
func,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
@ -123,3 +125,74 @@ class EmailAccount(Base):
Index("email_accounts_active_idx", "active", "last_sync_at"),
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
)
class MessageProcessor:
track_messages = Column(Boolean, nullable=False, server_default="true")
ignore_messages = Column(Boolean, nullable=True, default=False)
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
class DiscordServer(Base, MessageProcessor):
"""Discord server configuration and metadata"""
__tablename__ = "discord_servers"
id = Column(BigInteger, primary_key=True) # Discord guild snowflake ID
name = Column(Text, nullable=False)
description = Column(Text)
member_count = Column(Integer)
# Collection settings
last_sync_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
channels = relationship(
"DiscordChannel", back_populates="server", cascade="all, delete-orphan"
)
__table_args__ = (
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
)
class DiscordChannel(Base, MessageProcessor):
"""Discord channel metadata and configuration"""
__tablename__ = "discord_channels"
id = Column(BigInteger, primary_key=True) # Discord channel snowflake ID
server_id = Column(BigInteger, ForeignKey("discord_servers.id"), nullable=True)
name = Column(Text, nullable=False)
channel_type = Column(Text, nullable=False) # "text", "voice", "dm", "group_dm"
# Collection settings (null = inherit from server)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
server = relationship("DiscordServer", back_populates="channels")
__table_args__ = (Index("discord_channels_server_idx", "server_id"),)
class DiscordUser(Base, MessageProcessor):
"""Discord user metadata and preferences"""
__tablename__ = "discord_users"
id = Column(BigInteger, primary_key=True) # Discord user snowflake ID
username = Column(Text, nullable=False)
display_name = Column(Text)
# Link to system user if registered
system_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
# Basic DM settings
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
system_user = relationship("User", back_populates="discord_users")
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)

View File

@ -50,6 +50,7 @@ class User(Base):
oauth_states = relationship(
"OAuthState", back_populates="user", cascade="all, delete-orphan"
)
discord_users = relationship("DiscordUser", back_populates="system_user")
def serialize(self) -> dict:
return {

View File

@ -1,221 +1,101 @@
"""
Discord integration.
Simple HTTP client that communicates with the Discord collector's API server.
"""
import logging
import requests
import re
from typing import Any
from memory.common import settings
logger = logging.getLogger(__name__)
ERROR_CHANNEL = "memory-errors"
ACTIVITY_CHANNEL = "memory-activity"
DISCOVERY_CHANNEL = "memory-discoveries"
CHAT_CHANNEL = "memory-chat"
def get_api_url() -> str:
"""Get the Discord API server URL"""
host = settings.DISCORD_COLLECTOR_SERVER_URL
port = settings.DISCORD_COLLECTOR_PORT
return f"http://{host}:{port}"
class DiscordServer(requests.Session):
def __init__(self, server_id: str, server_name: str, *args, **kwargs):
self.server_id = server_id
self.server_name = server_name
self.channels = {}
super().__init__(*args, **kwargs)
self.setup_channels()
self.members = self.fetch_all_members()
def setup_channels(self):
resp = self.get(self.channels_url)
resp.raise_for_status()
channels = {channel["name"]: channel["id"] for channel in resp.json()}
if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)):
error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL)
self.channels[ERROR_CHANNEL] = error_channel
if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)):
activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL)
self.channels[ACTIVITY_CHANNEL] = activity_channel
if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)):
discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL)
self.channels[DISCOVERY_CHANNEL] = discovery_channel
if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)):
chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL)
self.channels[CHAT_CHANNEL] = chat_channel
@property
def error_channel(self) -> str:
return self.channels[ERROR_CHANNEL]
@property
def activity_channel(self) -> str:
return self.channels[ACTIVITY_CHANNEL]
@property
def discovery_channel(self) -> str:
return self.channels[DISCOVERY_CHANNEL]
@property
def chat_channel(self) -> str:
return self.channels[CHAT_CHANNEL]
def channel_id(self, channel_name: str) -> str:
if not (channel_id := self.channels.get(channel_name)):
raise ValueError(f"Channel {channel_name} not found")
return channel_id
def send_message(self, channel_id: str, content: str):
payload: dict[str, Any] = {"content": content}
mentions = re.findall(r"@(\S*)", content)
users = {u: i for u, i in self.members.items() if u in mentions}
if users:
for u, i in users.items():
payload["content"] = payload["content"].replace(f"@{u}", f"<@{i}>")
payload["allowed_mentions"] = {
"parse": [],
"users": list(users.values()),
}
return self.post(
f"https://discord.com/api/v10/channels/{channel_id}/messages",
json=payload,
)
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
resp = self.post(
self.channels_url, json={"name": channel_name, "type": channel_type}
)
resp.raise_for_status()
return resp.json()["id"]
def __str__(self):
return (
f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})"
)
def request(self, method: str, url: str, **kwargs):
headers = kwargs.get("headers", {})
headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}"
headers["Content-Type"] = "application/json"
kwargs["headers"] = headers
return super().request(method, url, **kwargs)
@property
def channels_url(self) -> str:
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
@property
def members_url(self) -> str:
return f"https://discord.com/api/v10/guilds/{self.server_id}/members"
@property
def dm_create_url(self) -> str:
return "https://discord.com/api/v10/users/@me/channels"
def list_members(
self, limit: int = 1000, after: str | None = None
) -> list[dict[str, Any]]:
"""List up to `limit` members in this guild, starting after a user ID.
Requires the bot to have the Server Members Intent enabled in the Discord developer portal.
"""
params: dict[str, Any] = {"limit": limit}
if after:
params["after"] = after
resp = self.get(self.members_url, params=params)
resp.raise_for_status()
return resp.json()
def fetch_all_members(self, page_size: int = 1000) -> dict[str, str]:
"""Retrieve all members in the guild by paginating the members list.
Note: Large guilds may take multiple requests. Rate limits are respected by requests.Session automatically.
"""
members: dict[str, str] = {}
after: str | None = None
while batch := self.list_members(limit=page_size, after=after):
for member in batch:
user = member.get("user", {})
members[user.get("global_name") or user.get("username", "")] = user.get(
"id", ""
)
after = user.get("id", "")
return members
def create_dm_channel(self, user_id: str) -> str:
"""Create (or retrieve) a DM channel with the given user and return the channel ID.
The bot must share a guild with the user, and the user's privacy settings must allow DMs from server members.
"""
resp = self.post(self.dm_create_url, json={"recipient_id": user_id})
resp.raise_for_status()
data = resp.json()
return data["id"]
def send_dm(self, user_id: str, content: str):
"""Send a direct message to a specific user by ID."""
channel_id = self.create_dm_channel(self.members.get(user_id) or user_id)
return self.post(
f"https://discord.com/api/v10/channels/{channel_id}/messages",
json={"content": content},
)
def get_bot_servers() -> list[dict[str, Any]]:
"""Get list of servers the bot is in."""
if not settings.DISCORD_BOT_TOKEN:
return []
def send_dm(user_identifier: str, message: str) -> bool:
"""Send a DM via the Discord collector API"""
try:
headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"}
response = requests.get(
"https://discord.com/api/v10/users/@me/guilds", headers=headers
response = requests.post(
f"{get_api_url()}/send_dm",
json={"user_identifier": user_identifier, "message": message},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to send DM to {user_identifier}: {e}")
return False
def broadcast_message(channel_name: str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/send_channel",
json={"channel_name": channel_name, "message": message},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}")
return False
def is_collector_healthy() -> bool:
"""Check if the Discord collector is running and healthy"""
try:
response = requests.get(f"{get_api_url()}/health", timeout=5)
response.raise_for_status()
result = response.json()
return result.get("status") == "healthy"
except requests.RequestException:
return False
def refresh_discord_metadata() -> dict[str, int] | None:
"""Refresh Discord server/channel/user metadata from Discord API"""
try:
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Failed to get bot servers: {e}")
return []
except requests.RequestException as e:
logger.error(f"Failed to refresh Discord metadata: {e}")
return None
servers: dict[str, DiscordServer] = {}
# Convenience functions
def send_error_message(message: str) -> bool:
"""Send an error message to the error channel"""
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
def load_servers():
for server in get_bot_servers():
servers[server["id"]] = DiscordServer(server["id"], server["name"])
def send_activity_message(message: str) -> bool:
"""Send an activity message to the activity channel"""
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
def broadcast_message(channel: str, message: str):
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return
for server in servers.values():
server.send_message(server.channel_id(channel), message)
def send_discovery_message(message: str) -> bool:
"""Send a discovery message to the discovery channel"""
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
def send_error_message(message: str):
broadcast_message(ERROR_CHANNEL, message)
def send_activity_message(message: str):
broadcast_message(ACTIVITY_CHANNEL, message)
def send_discovery_message(message: str):
broadcast_message(DISCOVERY_CHANNEL, message)
def send_chat_message(message: str):
broadcast_message(CHAT_CHANNEL, message)
def send_dm(user_id: str, message: str):
for server in servers.values():
if not server.members.get(user_id) and user_id not in server.members.values():
continue
server.send_dm(user_id, message)
def send_chat_message(message: str) -> bool:
"""Send a chat message to the chat channel"""
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
def notify_task_failure(
@ -234,9 +114,6 @@ def notify_task_failure(
task_args: Task arguments
task_kwargs: Task keyword arguments
traceback_str: Full traceback string
Returns:
True if notification sent successfully
"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
logger.debug("Discord notifications disabled")

View File

@ -1,122 +0,0 @@
import logging
import base64
import io
from typing import Any
from PIL import Image
from memory.common import settings, tokens
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """
You are a helpful assistant that creates concise summaries and identifies key topics.
"""
def encode_image(image: Image.Image) -> str:
"""Encode PIL Image to base64 string."""
buffer = io.BytesIO()
# Convert to RGB if necessary (for RGBA, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def call_openai(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
"""Call OpenAI API for summarization."""
import openai
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
try:
user_content: Any = [{"type": "text", "text": prompt}]
if images:
for image in images:
encoded_image = encode_image(image)
user_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
}
)
response = client.chat.completions.create(
model=model.split("/")[1],
messages=[
{
"role": "system",
"content": system_prompt,
},
{"role": "user", "content": user_content},
],
temperature=0.3,
max_tokens=2048,
)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
def call_anthropic(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
"""Call Anthropic API for summarization."""
import anthropic
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
try:
# Prepare the message content
content: Any = [{"type": "text", "text": prompt}]
if images:
# Add images if provided
for image in images:
encoded_image = encode_image(image)
content.append(
{ # type: ignore
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": encoded_image,
},
}
)
response = client.messages.create(
model=model.split("/")[1],
messages=[{"role": "user", "content": content}], # type: ignore
system=system_prompt,
temperature=0.3,
max_tokens=2048,
)
return response.content[0].text
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
def call(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
if model.startswith("anthropic"):
return call_anthropic(prompt, model, images, system_prompt)
return call_openai(prompt, model, images, system_prompt)
def truncate(content: str, target_tokens: int) -> str:
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content

View File

@ -0,0 +1,79 @@
"""LLM provider module for unified LLM access."""
# Legacy imports for backwards compatibility
import logging
from PIL import Image
# New provider system
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
LLMSettings,
Message,
MessageContent,
MessageRole,
StreamEvent,
TextContent,
ThinkingContent,
ToolDefinition,
ToolResultContent,
ToolUseContent,
create_provider,
)
from memory.common import tokens
__all__ = [
"BaseLLMProvider",
"Message",
"MessageRole",
"MessageContent",
"TextContent",
"ImageContent",
"ToolUseContent",
"ToolResultContent",
"ThinkingContent",
"ToolDefinition",
"StreamEvent",
"LLMSettings",
"create_provider",
]
logger = logging.getLogger(__name__)
def summarize(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = "",
) -> str:
provider = create_provider(model=model)
try:
# Build message content
content: list[MessageContent] = [TextContent(text=prompt)]
for image in images:
content.append(ImageContent(image=image))
messages = [Message(role=MessageRole.USER, content=content)]
settings_obj = LLMSettings(temperature=0.3, max_tokens=2048)
res = provider.run_with_tools(
messages=messages,
system_prompt=system_prompt
or "You are a helpful assistant that creates concise summaries and identifies key topics.",
settings=settings_obj,
tools={},
)
return res.response or ""
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
def truncate(content: str, target_tokens: int) -> str:
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content

View File

@ -0,0 +1,451 @@
"""Anthropic LLM provider implementation."""
import json
import logging
from typing import Any, AsyncIterator, Iterator, Optional
import anthropic
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
LLMSettings,
Message,
MessageRole,
StreamEvent,
ToolDefinition,
ToolUseContent,
ThinkingContent,
TextContent,
)
logger = logging.getLogger(__name__)
class AnthropicProvider(BaseLLMProvider):
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
# Models that support extended thinking
THINKING_MODELS = {
"claude-opus-4",
"claude-opus-4-1",
"claude-sonnet-4-0",
"claude-sonnet-3-7",
"claude-sonnet-4-5",
}
def __init__(self, api_key: str, model: str, enable_thinking: bool = False):
"""
Initialize the Anthropic provider.
Args:
api_key: Anthropic API key
model: Model identifier
enable_thinking: Enable extended thinking for supported models
"""
super().__init__(api_key, model)
self.enable_thinking = enable_thinking
self._async_client: Optional[anthropic.AsyncAnthropic] = None
def _initialize_client(self) -> anthropic.Anthropic:
"""Initialize the Anthropic client."""
return anthropic.Anthropic(api_key=self.api_key)
@property
def async_client(self) -> anthropic.AsyncAnthropic:
"""Lazy-load the async client."""
if self._async_client is None:
self._async_client = anthropic.AsyncAnthropic(api_key=self.api_key)
return self._async_client
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
"""Convert ImageContent to Anthropic's base64 source format."""
encoded_image = self.encode_image(content.image)
return {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": encoded_image,
},
}
def _convert_message(self, message: Message) -> dict[str, Any]:
converted = message.to_dict()
if converted["role"] == MessageRole.ASSISTANT and isinstance(
converted["content"], list
):
content = sorted(
converted["content"], key=lambda x: x["type"] != "thinking"
)
return converted | {"content": content}
return converted
def _should_include_message(self, message: Message) -> bool:
"""Filter out system messages (handled separately in Anthropic)."""
return message.role != MessageRole.SYSTEM
def _supports_thinking(self) -> bool:
"""Check if the current model supports extended thinking."""
model_lower = self.model.lower()
return any(supported in model_lower for supported in self.THINKING_MODELS)
def _build_request_kwargs(
self,
messages: list[Message],
system_prompt: str | None,
tools: list[ToolDefinition] | None,
settings: LLMSettings,
) -> dict[str, Any]:
"""Build common request kwargs for API calls."""
anthropic_messages = self._convert_messages(messages)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": anthropic_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
}
# Only include top_p if explicitly set
if settings.top_p is not None:
kwargs["top_p"] = settings.top_p
if system_prompt:
kwargs["system"] = system_prompt
if settings.stop_sequences:
kwargs["stop_sequences"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
# Enable extended thinking if requested and model supports it
if self.enable_thinking and self._supports_thinking():
thinking_budget = min(10000, settings.max_tokens - 1024)
if thinking_budget >= 1024:
kwargs["thinking"] = {
"type": "enabled",
"budget_tokens": thinking_budget,
}
# When thinking is enabled: temperature must be 1, can't use top_p
kwargs["temperature"] = 1.0
kwargs.pop("top_p", None)
return kwargs
def _handle_stream_event(
self, event: Any, current_tool_use: dict[str, Any] | None
) -> tuple[StreamEvent | None, dict[str, Any] | None]:
"""
Handle a streaming event and return StreamEvent and updated tool state.
Returns:
Tuple of (StreamEvent or None, updated current_tool_use or None)
"""
event_type = getattr(event, "type", None)
# Handle error events
if event_type == "error":
error = getattr(event, "error", None)
error_msg = str(error) if error else "Unknown error"
return StreamEvent(type="error", data=error_msg), current_tool_use
if event_type == "content_block_start":
block = getattr(event, "content_block", None)
if not block:
return None, current_tool_use
block_type = getattr(block, "type", None)
# Handle various tool types (tool_use, mcp_tool_use, server_tool_use)
if block_type in ("tool_use", "mcp_tool_use", "server_tool_use"):
# In content_block_start, input may already be present (empty dict)
block_input = getattr(block, "input", None)
current_tool_use = {
"id": getattr(block, "id", ""),
"name": getattr(block, "name", ""),
"input": block_input if block_input is not None else "",
"server_name": getattr(block, "server_name", None),
"is_server_call": block_type != "tool_use",
}
# Handle tool result blocks
elif hasattr(block, "tool_use_id"):
tool_result = {
"id": getattr(block, "tool_use_id", ""),
"result": getattr(block, "content", ""),
}
return StreamEvent(
type="tool_result", data=tool_result
), current_tool_use
# For non-tool blocks (text, thinking), we don't need to track state
return None, current_tool_use
elif event_type == "content_block_delta":
delta = getattr(event, "delta", None)
if not delta:
return None, current_tool_use
delta_type = getattr(delta, "type", None)
if delta_type == "text_delta":
text = getattr(delta, "text", "")
return StreamEvent(type="text", data=text), current_tool_use
elif delta_type == "thinking_delta":
thinking = getattr(delta, "thinking", "")
return StreamEvent(type="thinking", data=thinking), current_tool_use
elif delta_type == "signature_delta":
# Handle thinking signature for extended thinking
signature = getattr(delta, "signature", "")
return StreamEvent(
type="thinking", signature=signature
), current_tool_use
elif delta_type == "input_json_delta":
if current_tool_use is None:
# Edge case: received input_json_delta without tool_use start
logger.warning("Received input_json_delta without tool_use context")
return None, None
# Only accumulate if input is still a string (being built up)
if isinstance(current_tool_use.get("input"), str):
partial_json = getattr(delta, "partial_json", "")
current_tool_use["input"] += partial_json
# else: input was already set as a dict in content_block_start
return None, current_tool_use
elif event_type == "content_block_stop":
if current_tool_use:
# Use the parsed input from the content block if available
# This handles empty inputs {} more reliably than parsing
content_block = getattr(event, "content_block", None)
if content_block and hasattr(content_block, "input"):
current_tool_use["input"] = content_block.input
else:
# Fallback: parse accumulated JSON string
input_str = current_tool_use.get("input", "")
if isinstance(input_str, str):
# Need to parse the accumulated string
if not input_str or input_str.isspace():
# Empty or whitespace-only input
current_tool_use["input"] = {}
else:
try:
current_tool_use["input"] = json.loads(input_str)
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse tool input '{input_str}': {e}"
)
current_tool_use["input"] = {}
# else: input is already parsed
tool_data = {
"id": current_tool_use.get("id", ""),
"name": current_tool_use.get("name", ""),
"input": current_tool_use.get("input", {}),
}
# Include server info if present
if current_tool_use.get("server_name"):
tool_data["server_name"] = current_tool_use["server_name"]
if current_tool_use.get("is_server_call"):
tool_data["is_server_call"] = current_tool_use["is_server_call"]
return StreamEvent(type="tool_use", data=tool_data), None
elif event_type == "message_delta":
delta = getattr(event, "delta", None)
if delta:
stop_reason = getattr(delta, "stop_reason", None)
if stop_reason == "max_tokens":
return StreamEvent(
type="error", data="Max tokens reached"
), current_tool_use
# Handle token usage information
usage = getattr(event, "usage", None)
if usage:
usage_data = {
"input_tokens": getattr(usage, "input_tokens", 0),
"output_tokens": getattr(usage, "output_tokens", 0),
"cache_creation_input_tokens": getattr(
usage, "cache_creation_input_tokens", None
),
"cache_read_input_tokens": getattr(
usage, "cache_read_input_tokens", None
),
}
# Could emit this as a separate event type if needed
logger.debug(f"Token usage: {usage_data}")
return None, current_tool_use
elif event_type == "message_stop":
# Final event - clean up any pending state
if current_tool_use:
logger.warning(
f"Message ended with incomplete tool use: {current_tool_use}"
)
return StreamEvent(type="done"), None
# Unknown event type - log but don't fail
if event_type and event_type not in (
"message_start",
"message_delta",
"content_block_start",
"content_block_delta",
"content_block_stop",
"message_stop",
):
logger.debug(f"Unknown event type: {event_type}")
return None, current_tool_use
def generate(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
try:
response = self.client.messages.create(**kwargs)
return "".join(
block.text for block in response.content if block.type == "text"
)
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
def stream(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""Generate a streaming response."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
try:
with self.client.messages.stream(**kwargs) as stream:
current_tool_use: dict[str, Any] | None = None
for event in stream:
stream_event, current_tool_use = self._handle_stream_event(
event, current_tool_use
)
if stream_event:
yield stream_event
except Exception as e:
logger.error(f"Anthropic streaming error: {e}")
yield StreamEvent(type="error", data=str(e))
async def agenerate(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
try:
response = await self.async_client.messages.create(**kwargs)
return "".join(
block.text for block in response.content if block.type == "text"
)
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
async def astream(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
kwargs = self._build_request_kwargs(messages, system_prompt, tools, settings)
try:
async with self.async_client.messages.stream(**kwargs) as stream:
current_tool_use: dict[str, Any] | None = None
async for event in stream:
stream_event, current_tool_use = self._handle_stream_event(
event, current_tool_use
)
if stream_event:
yield stream_event
except Exception as e:
logger.error(f"Anthropic streaming error: {e}")
yield StreamEvent(type="error", data=str(e))
def stream_with_tools(
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
) -> Iterator[StreamEvent]:
if max_iterations <= 0:
return
response = TextContent(text="")
thinking = ThinkingContent(thinking="", signature="")
for event in self.stream(
messages=messages,
system_prompt=system_prompt,
tools=list(tools.values()),
settings=settings,
):
if event.type == "text":
response.text += event.data
yield event
elif event.type == "thinking" and event.signature:
thinking.signature = event.signature
elif event.type == "thinking":
thinking.thinking += event.data
yield event
elif event.type == "tool_use":
yield event
tool_result = self.execute_tool(event.data, tools)
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
messages.append(
Message.assistant(
response,
thinking,
ToolUseContent(
id=event.data["id"],
name=event.data["name"],
input=event.data["input"],
),
)
)
messages.append(Message.user(tool_result=tool_result))
yield from self.stream_with_tools(
messages, tools, settings, system_prompt, max_iterations - 1
)
elif event.type == "tool_result":
yield event
elif event.type == "error":
logger.error(f"LLM error: {event.data}")
raise RuntimeError(f"LLM error: {event.data}")

View File

@ -0,0 +1,561 @@
"""Base classes and types for LLM providers."""
import base64
import io
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, AsyncIterator, Iterator, Literal, Optional, Union
from PIL import Image
from memory.common import settings
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
logger = logging.getLogger(__name__)
class MessageRole(str, Enum):
"""Message roles for chat history."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
class TextContent:
"""Text content in a message."""
type: Literal["text"] = "text"
text: str = ""
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
return {"type": "text", "text": self.text}
@dataclass
class ImageContent:
"""Image content in a message."""
type: Literal["image"] = "image"
image: Image.Image = None # type: ignore
detail: Optional[str] = None # For OpenAI: "low", "high", "auto"
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
# Note: Image will be encoded by provider-specific implementation
return {"type": "image", "image": self.image}
@dataclass
class ToolUseContent:
"""Tool use request from the assistant."""
type: Literal["tool_use"] = "tool_use"
id: str = ""
name: str = ""
input: dict[str, Any] = None # type: ignore
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
return {
"type": "tool_use",
"id": self.id,
"name": self.name,
"input": self.input,
}
@dataclass
class ToolResultContent:
"""Tool result from tool execution."""
type: Literal["tool_result"] = "tool_result"
tool_use_id: str = ""
content: str = ""
is_error: bool = False
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
return {
"type": "tool_result",
"tool_use_id": self.tool_use_id,
"content": self.content,
"is_error": self.is_error,
}
@dataclass
class ThinkingContent:
"""Thinking/reasoning content from the assistant (extended thinking)."""
type: Literal["thinking"] = "thinking"
thinking: str = ""
signature: str | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format."""
return {
"type": "thinking",
"thinking": self.thinking,
"signature": self.signature,
}
MessageContent = Union[
TextContent, ImageContent, ToolUseContent, ToolResultContent, ThinkingContent
]
@dataclass
class Turn:
"""A turn in the conversation."""
response: str | None
thinking: str | None
tool_calls: dict[str, ToolResult] | None
@dataclass
class Message:
"""A message in the conversation history."""
role: MessageRole
content: Union[str, list[MessageContent]]
def to_dict(self) -> dict[str, Any]:
"""Convert message to dictionary format."""
if isinstance(self.content, str):
return {"role": self.role.value, "content": self.content}
content_list = [item.to_dict() for item in self.content]
return {"role": self.role.value, "content": content_list}
@staticmethod
def assistant(
text: TextContent | None = None,
thinking: ThinkingContent | None = None,
tool_use: ToolUseContent | None = None,
) -> "Message":
parts = []
if text:
parts.append(text)
if thinking:
parts.append(thinking)
if tool_use:
parts.append(tool_use)
return Message(role=MessageRole.ASSISTANT, content=parts)
@staticmethod
def user(
text: str | None = None, tool_result: ToolResultContent | None = None
) -> "Message":
parts = []
if text:
parts.append(TextContent(text=text))
if tool_result:
parts.append(tool_result)
return Message(role=MessageRole.USER, content=parts)
@dataclass
class StreamEvent:
"""An event from the streaming response."""
type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"]
data: Any = None
signature: str | None = None
@dataclass
class LLMSettings:
"""Settings for LLM API calls."""
temperature: float = 0.7
max_tokens: int = 2048
# Don't set by default - some models don't allow both temp and top_p
top_p: float | None = None
stop_sequences: list[str] | None = None
stream: bool = False
class BaseLLMProvider(ABC):
"""Base class for LLM providers."""
def __init__(self, api_key: str, model: str):
"""
Initialize the LLM provider.
Args:
api_key: API key for the provider
model: Model identifier
"""
self.api_key = api_key
self.model = model
self._client: Any = None
@abstractmethod
def _initialize_client(self) -> Any:
"""Initialize the provider-specific client."""
pass
@property
def client(self) -> Any:
"""Lazy-load the client."""
if self._client is None:
self._client = self._initialize_client()
return self._client
def execute_tool(
self,
tool_call: ToolCall,
tool_handlers: dict[str, ToolDefinition],
) -> ToolResultContent:
"""
Execute a tool call.
Args:
tool_call: Tool call
tool_handlers: Dict mapping tool names to handler functions
Returns:
ToolResultContent with result or error
"""
name = tool_call.get("name")
tool_use_id = tool_call.get("id")
input = tool_call.get("input")
if not name:
return ToolResultContent(
tool_use_id=tool_use_id,
content="Tool name missing",
is_error=True,
)
if not (tool := tool_handlers.get(name)):
return ToolResultContent(
tool_use_id=tool_use_id,
content=f"Tool '{name}' not found",
is_error=True,
)
try:
return ToolResultContent(
tool_use_id=tool_use_id,
content=tool(input),
is_error=False,
)
except Exception as e:
logger.error(f"Tool '{name}' failed: {e}", exc_info=True)
return ToolResultContent(
tool_use_id=tool_use_id,
content=str(e),
is_error=True,
)
def encode_image(self, image: Image.Image) -> str:
"""
Encode PIL Image to base64 string.
Args:
image: PIL Image to encode
Returns:
Base64 encoded string
"""
buffer = io.BytesIO()
# Convert to RGB if necessary (for RGBA, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def _convert_text_content(self, content: TextContent) -> dict[str, Any]:
"""Convert TextContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_image_content(self, content: ImageContent) -> dict[str, Any]:
"""Convert ImageContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_tool_use_content(self, content: ToolUseContent) -> dict[str, Any]:
"""Convert ToolUseContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_tool_result_content(
self, content: ToolResultContent
) -> dict[str, Any]:
"""Convert ToolResultContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_thinking_content(self, content: ThinkingContent) -> dict[str, Any]:
"""Convert ThinkingContent to provider format. Override for custom format."""
return content.to_dict()
def _convert_message_content(self, content: MessageContent) -> dict[str, Any]:
"""
Convert a MessageContent item to provider format.
Dispatches to type-specific converters that can be overridden.
"""
if isinstance(content, TextContent):
return self._convert_text_content(content)
elif isinstance(content, ImageContent):
return self._convert_image_content(content)
elif isinstance(content, ToolUseContent):
return self._convert_tool_use_content(content)
elif isinstance(content, ToolResultContent):
return self._convert_tool_result_content(content)
elif isinstance(content, ThinkingContent):
return self._convert_thinking_content(content)
else:
raise ValueError(f"Unknown content type: {type(content)}")
def _convert_message(self, message: Message) -> dict[str, Any]:
"""
Convert a Message to provider format.
Can be overridden for provider-specific handling (e.g., filtering system messages).
"""
return message.to_dict()
def _should_include_message(self, message: Message) -> bool:
"""
Determine if a message should be included in the request.
Override to filter messages (e.g., Anthropic filters SYSTEM messages).
Args:
message: Message to check
Returns:
True if message should be included
"""
return True
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
"""
Convert a list of messages to provider format.
Uses _should_include_message for filtering and _convert_message for conversion.
"""
return [
self._convert_message(msg)
for msg in messages
if self._should_include_message(msg)
]
def _convert_tool(self, tool: ToolDefinition) -> dict[str, Any]:
"""
Convert a single ToolDefinition to provider format.
Default format matches Anthropic. Override for other providers (e.g., OpenAI uses functions).
"""
return {
"name": tool.name,
"description": tool.description,
"input_schema": tool.input_schema,
}
def _convert_tools(
self, tools: list[ToolDefinition] | None
) -> Optional[list[dict[str, Any]]]:
"""Convert tool definitions to provider format."""
if not tools:
return None
return [self._convert_tool(tool) for tool in tools]
@abstractmethod
def generate(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""
Generate a non-streaming response.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools the LLM can use
settings: Optional settings for the generation
Returns:
Generated text response
"""
pass
@abstractmethod
def stream(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> Iterator[StreamEvent]:
"""
Generate a streaming response.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools the LLM can use
settings: Optional settings for the generation
Yields:
StreamEvent objects containing text chunks, tool uses, or errors
"""
pass
@abstractmethod
async def agenerate(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> str:
"""
Generate a non-streaming response asynchronously.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools the LLM can use
settings: Optional settings for the generation
Returns:
Generated text response
"""
pass
@abstractmethod
async def astream(
self,
messages: list[Message],
system_prompt: str | None = None,
tools: list[ToolDefinition] | None = None,
settings: LLMSettings | None = None,
) -> AsyncIterator[StreamEvent]:
"""
Generate a streaming response asynchronously.
Args:
messages: Conversation history
system_prompt: Optional system prompt
tools: Optional list of tools the LLM can use
settings: Optional settings for the generation
Yields:
StreamEvent objects containing text chunks, tool uses, or errors
"""
pass
@abstractmethod
def stream_with_tools(
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
) -> Iterator[StreamEvent]:
pass
def run_with_tools(
self,
messages: list[Message],
tools: dict[str, ToolDefinition],
settings: LLMSettings | None = None,
system_prompt: str | None = None,
max_iterations: int = 10,
) -> Turn:
thinking, response, tool_calls = "", "", {}
for event in self.stream_with_tools(
messages=messages,
tools=tools,
settings=settings,
system_prompt=system_prompt,
max_iterations=max_iterations,
):
if event.type == "thinking":
thinking += event.data
elif event.type == "tool_use":
tool_calls[event.data["id"]] = {
"name": event.data["name"],
"input": event.data["input"],
"output": "",
}
elif event.type == "text":
response += event.data
elif event.type == "tool_result":
current = tool_calls.get(event.data["tool_use_id"]) or {}
tool_calls[event.data["tool_use_id"]] = {
"name": event.data.get("name") or current.get("name"),
"input": event.data.get("input") or current.get("input"),
"output": event.data.get("content"),
}
return Turn(
thinking=thinking or None,
response=response or None,
tool_calls=tool_calls or None,
)
def create_provider(
model: str | None = None,
api_key: str | None = None,
enable_thinking: bool = False,
) -> BaseLLMProvider:
"""
Create an LLM provider based on the model name.
Args:
model: Model identifier (e.g., "claude-3-opus-20240229", "gpt-4").
If not provided, uses SUMMARIZER_MODEL from settings.
api_key: Optional API key. If not provided, uses keys from settings.
enable_thinking: Enable extended thinking for supported models (Claude Opus 4+, Sonnet 4+, Sonnet 3.7)
Returns:
An initialized LLM provider
Raises:
ValueError: If the provider cannot be determined from the model name
"""
# Use default model from settings if not provided
if model is None:
model = settings.SUMMARIZER_MODEL
provider, model = model.split("/", 1)
if provider == "anthropic":
# Anthropic models
if api_key is None:
api_key = settings.ANTHROPIC_API_KEY
if not api_key:
raise ValueError(
"ANTHROPIC_API_KEY not found in settings. "
"Please set it in your .env file."
)
from memory.common.llms.anthropic_provider import AnthropicProvider
return AnthropicProvider(
api_key=api_key, model=model, enable_thinking=enable_thinking
)
# Could add OpenAI support here in the future
# elif "gpt" in model_lower or model.startswith("openai"):
# ...
else:
raise ValueError(
f"Unknown provider for model: {model}. "
f"Supported providers: Anthropic (claude-*)"
)

View File

@ -0,0 +1,388 @@
"""OpenAI LLM provider implementation."""
import logging
from typing import Any, AsyncIterator, Iterator, Optional
import openai
from memory.common.llms.base import (
BaseLLMProvider,
ImageContent,
LLMSettings,
Message,
MessageContent,
MessageRole,
StreamEvent,
TextContent,
ThinkingContent,
ToolDefinition,
ToolResultContent,
ToolUseContent,
)
logger = logging.getLogger(__name__)
class OpenAIProvider(BaseLLMProvider):
"""OpenAI LLM provider with streaming and tool support."""
def _initialize_client(self) -> openai.OpenAI:
"""Initialize the OpenAI client."""
return openai.OpenAI(api_key=self.api_key)
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
"""
Convert our Message format to OpenAI format.
Args:
messages: List of messages in our format
Returns:
List of messages in OpenAI format
"""
openai_messages = []
for msg in messages:
if isinstance(msg.content, str):
openai_messages.append({"role": msg.role.value, "content": msg.content})
else:
# Handle multi-part content
content_parts = []
for item in msg.content:
if isinstance(item, TextContent):
content_parts.append({"type": "text", "text": item.text})
elif isinstance(item, ImageContent):
encoded_image = self.encode_image(item.image)
image_part: dict[str, Any] = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}"
},
}
if item.detail:
image_part["image_url"]["detail"] = item.detail
content_parts.append(image_part)
elif isinstance(item, ToolUseContent):
# OpenAI doesn't have tool_use in content, it's a separate field
# We'll handle this by adding a tool_calls field to the message
pass
elif isinstance(item, ToolResultContent):
# OpenAI handles tool results as separate "tool" role messages
openai_messages.append(
{
"role": "tool",
"tool_call_id": item.tool_use_id,
"content": item.content,
}
)
continue
elif isinstance(item, ThinkingContent):
# OpenAI doesn't have native thinking support in most models
# We can add it as text with a special marker
content_parts.append(
{
"type": "text",
"text": f"[Thinking: {item.thinking}]",
}
)
# Check if this message has tool calls
tool_calls = [
item for item in msg.content if isinstance(item, ToolUseContent)
]
message_dict: dict[str, Any] = {"role": msg.role.value}
if content_parts:
message_dict["content"] = content_parts
if tool_calls:
message_dict["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": str(tc.input)},
}
for tc in tool_calls
]
openai_messages.append(message_dict)
return openai_messages
def _convert_tools(
self, tools: Optional[list[ToolDefinition]]
) -> Optional[list[dict[str, Any]]]:
"""
Convert our tool definitions to OpenAI format.
Args:
tools: List of tool definitions
Returns:
List of tools in OpenAI format
"""
if not tools:
return None
return [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
},
}
for tool in tools
]
def generate(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
) -> str:
"""Generate a non-streaming response."""
settings = settings or LLMSettings()
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try:
response = self.client.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
def stream(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
) -> Iterator[StreamEvent]:
"""Generate a streaming response."""
settings = settings or LLMSettings()
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
"stream": True,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try:
stream = self.client.chat.completions.create(**kwargs)
current_tool_call: Optional[dict[str, Any]] = None
for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
yield StreamEvent(type="text", data=delta.content)
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one
yield StreamEvent(
type="tool_use", data=current_tool_call
)
current_tool_call = {
"id": tool_call.id,
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += (
tool_call.function.arguments
)
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
yield StreamEvent(type="tool_use", data=current_tool_call)
current_tool_call = None
yield StreamEvent(type="done")
except Exception as e:
logger.error(f"OpenAI streaming error: {e}")
yield StreamEvent(type="error", data=str(e))
async def agenerate(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
) -> str:
"""Generate a non-streaming response asynchronously."""
settings = settings or LLMSettings()
# Use async client
async_client = openai.AsyncOpenAI(api_key=self.api_key)
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try:
response = await async_client.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
async def astream(
self,
messages: list[Message],
system_prompt: Optional[str] = None,
tools: Optional[list[ToolDefinition]] = None,
settings: Optional[LLMSettings] = None,
) -> AsyncIterator[StreamEvent]:
"""Generate a streaming response asynchronously."""
settings = settings or LLMSettings()
# Use async client
async_client = openai.AsyncOpenAI(api_key=self.api_key)
openai_messages = self._convert_messages(messages)
# Add system prompt as first message if provided
if system_prompt:
openai_messages.insert(
0, {"role": "system", "content": system_prompt}
)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": openai_messages,
"temperature": settings.temperature,
"max_tokens": settings.max_tokens,
"top_p": settings.top_p,
"stream": True,
}
if settings.stop_sequences:
kwargs["stop"] = settings.stop_sequences
if tools:
kwargs["tools"] = self._convert_tools(tools)
kwargs["tool_choice"] = "auto"
try:
stream = await async_client.chat.completions.create(**kwargs)
current_tool_call: Optional[dict[str, Any]] = None
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
# Handle text content
if delta.content:
yield StreamEvent(type="text", data=delta.content)
# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.id:
# New tool call starting
if current_tool_call:
# Yield the previous one
yield StreamEvent(
type="tool_use", data=current_tool_call
)
current_tool_call = {
"id": tool_call.id,
"name": tool_call.function.name or "",
"arguments": tool_call.function.arguments or "",
}
elif current_tool_call and tool_call.function.arguments:
# Continue building the current tool call
current_tool_call["arguments"] += (
tool_call.function.arguments
)
# Check if stream is finished
if chunk.choices[0].finish_reason:
if current_tool_call:
yield StreamEvent(type="tool_use", data=current_tool_call)
current_tool_call = None
yield StreamEvent(type="done")
except Exception as e:
logger.error(f"OpenAI streaming error: {e}")
yield StreamEvent(type="error", data=str(e))

View File

@ -0,0 +1,36 @@
from dataclasses import dataclass
from typing import Any, Callable, TypedDict
ToolInput = str | dict[str, Any] | None
ToolHandler = Callable[[ToolInput], str]
class ToolCall(TypedDict):
"""A call to a tool."""
name: str
id: str
input: ToolInput
class ToolResult(TypedDict):
"""A result from a tool call."""
id: str
name: str
input: ToolInput
output: str
@dataclass
class ToolDefinition:
"""Definition of a tool that can be called by the LLM."""
name: str
description: str
input_schema: dict[str, Any] # JSON Schema for the tool's parameters
function: ToolHandler
def __call__(self, input: ToolInput) -> str:
return self.function(input)

View File

@ -0,0 +1,42 @@
"""Ping tool for testing LLM tool integration."""
from memory.common.llms.tools import ToolDefinition, ToolInput
def handle_ping_call(message: ToolInput = None) -> str:
"""
Handle a ping tool call.
Args:
message: Optional message to include in response
Returns:
Response string
"""
if message:
return f"pong: {message}"
return "pong"
def get_ping_tool() -> ToolDefinition:
"""
Get a ping tool definition for testing tool calls.
Returns a simple tool that takes no required parameters and can be used
to verify that tool calling is working correctly.
"""
return ToolDefinition(
name="ping",
description="A simple test tool that returns 'pong'. Use this to verify tool calling is working.",
input_schema={
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Optional message to echo back",
}
},
"required": [],
},
function=handle_ping_call,
)

View File

@ -0,0 +1,9 @@
def process_message(
msg: str,
history: list[str],
model: str | None = None,
system_prompt: str | None = None,
allowed_tools: list[str] | None = None,
disallowed_tools: list[str] | None = None,
) -> str:
return "asd"

View File

@ -172,3 +172,12 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
DISCORD_NOTIFICATIONS_ENABLED = bool(
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
)
DISCORD_PROCESS_MESSAGES = boolean_env("DISCORD_PROCESS_MESSAGES", True)
# Discord collector settings
DISCORD_COLLECTOR_ENABLED = boolean_env("DISCORD_COLLECTOR_ENABLED", True)
DISCORD_COLLECT_DMS = boolean_env("DISCORD_COLLECT_DMS", True)
DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8000))
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "127.0.0.1")
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))

View File

@ -105,7 +105,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
try:
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
response = llms.summarize(prompt, settings.SUMMARIZER_MODEL)
result = parse_response(response)
summary = result.get("summary", "")

166
src/memory/discord/api.py Normal file
View File

@ -0,0 +1,166 @@
"""
Discord API server.
FastAPI server that owns and manages a Discord collector instance,
providing HTTP endpoints for sending Discord messages.
"""
import asyncio
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from memory.common import settings
from memory.discord.collector import MessageCollector
logger = logging.getLogger(__name__)
class SendDMRequest(BaseModel):
user: str # Discord user ID or username
message: str
class SendChannelRequest(BaseModel):
channel_name: str # Channel name (e.g., "memory-errors")
message: str
# Application state
class AppState:
def __init__(self):
self.collector: MessageCollector | None = None
self.collector_task: asyncio.Task | None = None
app_state = AppState()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage Discord collector lifecycle"""
if not settings.DISCORD_BOT_TOKEN:
logger.error("DISCORD_BOT_TOKEN not configured")
return
# Create and start the collector
app_state.collector = MessageCollector()
app_state.collector_task = asyncio.create_task(
app_state.collector.start(settings.DISCORD_BOT_TOKEN)
)
logger.info("Discord collector started")
yield
# Cleanup
if app_state.collector and not app_state.collector.is_closed():
await app_state.collector.close()
if app_state.collector_task:
app_state.collector_task.cancel()
try:
await app_state.collector_task
except asyncio.CancelledError:
pass
logger.info("Discord collector stopped")
# FastAPI app with lifespan management
app = FastAPI(title="Discord Collector API", version="1.0.0", lifespan=lifespan)
@app.post("/send_dm")
async def send_dm_endpoint(request: SendDMRequest):
"""Send a DM via the collector's Discord client"""
if not app_state.collector:
raise HTTPException(status_code=503, detail="Discord collector not running")
try:
success = await app_state.collector.send_dm(request.user, request.message)
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to send DM to {request.user}",
)
return {
"success": True,
"message": f"DM sent to {request.user}",
"user": request.user,
}
except Exception as e:
logger.error(f"Failed to send DM: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/send_channel")
async def send_channel_endpoint(request: SendChannelRequest):
"""Send a message to a channel via the collector's Discord client"""
if not app_state.collector:
raise HTTPException(status_code=503, detail="Discord collector not running")
try:
success = await app_state.collector.send_to_channel(
request.channel_name, request.message
)
if success:
return {
"success": True,
"message": f"Message sent to channel {request.channel_name}",
"channel": request.channel_name,
}
else:
raise HTTPException(
status_code=400,
detail=f"Failed to send message to channel {request.channel_name}",
)
except Exception as e:
logger.error(f"Failed to send channel message: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Check if the Discord collector is running and healthy"""
if not app_state.collector:
raise HTTPException(status_code=503, detail="Discord collector not running")
collector = app_state.collector
return {
"status": "healthy",
"connected": not collector.is_closed(),
"user": str(collector.user) if collector.user else None,
"guilds": len(collector.guilds) if collector.guilds else 0,
}
@app.post("/refresh_metadata")
async def refresh_metadata():
"""Refresh Discord server/channel/user metadata from Discord API"""
if not app_state.collector:
raise HTTPException(status_code=503, detail="Discord collector not running")
try:
result = await app_state.collector.refresh_metadata()
return {"success": True, "message": "Metadata refreshed successfully", **result}
except Exception as e:
logger.error(f"Failed to refresh metadata: {e}")
raise HTTPException(status_code=500, detail=str(e))
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
"""Run the Discord API server"""
uvicorn.run(app, host=host, port=port, log_level="debug")
if __name__ == "__main__":
# For testing the API server standalone
host = settings.DISCORD_COLLECTOR_SERVER_URL
port = settings.DISCORD_COLLECTOR_PORT
run_discord_api_server(host, port)

View File

@ -0,0 +1,410 @@
"""
Discord message collector.
Core message collection functionality - stores Discord messages to database.
"""
import logging
from datetime import datetime, timezone
import discord
from discord.ext import commands
from sqlalchemy.orm import Session, scoped_session
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models.sources import (
DiscordServer,
DiscordChannel,
DiscordUser,
)
from memory.workers.tasks.discord import add_discord_message, edit_discord_message
logger = logging.getLogger(__name__)
# Pure functions for Discord entity creation/updates
def create_or_update_server(
session: Session | scoped_session, guild: discord.Guild | None
) -> DiscordServer | None:
"""Get or create DiscordServer record (pure DB operation)"""
if not guild:
return None
server = session.query(DiscordServer).get(guild.id)
if not server:
server = DiscordServer(
id=guild.id,
name=guild.name,
description=guild.description,
member_count=guild.member_count,
)
session.add(server)
session.flush() # Get the ID
logger.info(f"Created server record for {guild.name} ({guild.id})")
else:
# Update metadata
server.name = guild.name
server.description = guild.description
server.member_count = guild.member_count
server.last_sync_at = datetime.now(timezone.utc)
return server
def determine_channel_metadata(channel) -> tuple[str, int | None, str]:
"""Pure function to determine channel type, server_id, and name"""
if isinstance(channel, discord.DMChannel):
desc = (
f"DM with {channel.recipient.name}" if channel.recipient else "Unknown DM"
)
return ("dm", None, desc)
elif isinstance(channel, discord.GroupChannel):
return "group_dm", None, channel.name or "Group DM"
elif isinstance(
channel, (discord.TextChannel, discord.VoiceChannel, discord.Thread)
):
return (
channel.__class__.__name__.lower().replace("channel", ""),
channel.guild.id,
channel.name,
)
else:
guild = getattr(channel, "guild", None)
server_id = guild.id if guild else None
name = getattr(channel, "name", f"Unknown-{channel.id}")
return "unknown", server_id, name
def create_or_update_channel(
session: Session | scoped_session, channel
) -> DiscordChannel | None:
"""Get or create DiscordChannel record (pure DB operation)"""
if not channel:
return None
discord_channel = session.query(DiscordChannel).get(channel.id)
if not discord_channel:
channel_type, server_id, name = determine_channel_metadata(channel)
discord_channel = DiscordChannel(
id=channel.id,
server_id=server_id,
name=name,
channel_type=channel_type,
)
session.add(discord_channel)
session.flush()
logger.debug(f"Created channel: {name}")
elif hasattr(channel, "name"):
discord_channel.name = channel.name
return discord_channel
def create_or_update_user(
session: Session | scoped_session, user: discord.User | discord.Member
) -> DiscordUser:
"""Get or create DiscordUser record (pure DB operation)"""
if not user:
return None
discord_user = session.query(DiscordUser).get(user.id)
if not discord_user:
discord_user = DiscordUser(
id=user.id,
username=user.name,
display_name=user.display_name,
)
session.add(discord_user)
session.flush()
logger.debug(f"Created user: {user.name}")
else:
# Update user info in case it changed
discord_user.username = user.name
discord_user.display_name = user.display_name
return discord_user
def determine_message_metadata(
message: discord.Message,
) -> tuple[str, int | None, int | None]:
"""Pure function to determine message type, reply_to_id, and thread_id"""
message_type = "default"
reply_to_id = None
thread_id = None
if message.reference and message.reference.message_id:
message_type = "reply"
reply_to_id = message.reference.message_id
if hasattr(message.channel, "parent") and message.channel.parent:
thread_id = message.channel.id
return message_type, reply_to_id, thread_id
def should_track_message(
server: DiscordServer | None,
channel: DiscordChannel,
user: DiscordUser,
) -> bool:
"""Pure function to determine if we should track this message"""
if server and not server.track_messages: # type: ignore
return False
if not channel.track_messages:
return False
if channel.channel_type in ("dm", "group_dm"):
return bool(user.track_messages)
# Default: track the message
return True
def should_collect_bot_message(message: discord.Message) -> bool:
"""Pure function to determine if we should collect bot messages"""
return not message.author.bot or settings.DISCORD_COLLECT_BOTS
def sync_guild_metadata(guild: discord.Guild) -> None:
"""Sync a single guild's metadata (functional approach)"""
with make_session() as session:
create_or_update_server(session, guild)
for channel in guild.channels:
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
create_or_update_channel(session, channel)
session.commit()
class MessageCollector(commands.Bot):
"""Discord bot that collects and stores messages (thin event handler)"""
def __init__(self):
intents = discord.Intents.default()
intents.message_content = True
intents.guilds = True
intents.members = True
intents.dm_messages = True
super().__init__(
command_prefix="!memory_", # Prefix to avoid conflicts
intents=intents,
help_command=None, # Disable default help
)
async def on_ready(self):
"""Called when bot connects to Discord"""
logger.info(f"Discord collector connected as {self.user}")
logger.info(f"Connected to {len(self.guilds)} servers")
# Sync server and channel metadata
await self.sync_servers_and_channels()
logger.info("Discord message collector ready")
async def on_message(self, message: discord.Message):
"""Queue incoming message for database storage"""
try:
if should_collect_bot_message(message):
# Ensure Discord entities exist in database first
with make_session() as session:
create_or_update_user(session, message.author)
create_or_update_channel(session, message.channel)
if message.guild:
create_or_update_server(session, message.guild)
session.commit()
# Queue the message for processing
add_discord_message.delay(
message_id=message.id,
channel_id=message.channel.id,
author_id=message.author.id,
server_id=message.guild.id if message.guild else None,
content=message.content or "",
sent_at=message.created_at.isoformat(),
message_reference_id=message.reference.message_id
if message.reference
else None,
)
except Exception as e:
logger.error(f"Error queuing message {message.id}: {e}")
async def on_message_edit(self, before: discord.Message, after: discord.Message):
"""Queue message edit for database update"""
try:
edit_time = after.edited_at or datetime.now(timezone.utc)
edit_discord_message.delay(
message_id=after.id,
content=after.content,
edited_at=edit_time.isoformat(),
)
except Exception as e:
logger.error(f"Error queuing message edit {after.id}: {e}")
async def sync_servers_and_channels(self):
"""Sync server and channel metadata on startup"""
for guild in self.guilds:
sync_guild_metadata(guild)
logger.info(f"Synced {len(self.guilds)} servers and their channels")
async def refresh_metadata(self) -> dict[str, int]:
"""Refresh server and channel metadata from Discord and update database"""
print("🔄 Refreshing Discord metadata...")
servers_updated = 0
channels_updated = 0
users_updated = 0
with make_session() as session:
# Refresh all servers
for guild in self.guilds:
create_or_update_server(session, guild)
servers_updated += 1
# Refresh all channels in this server
for channel in guild.channels:
if isinstance(channel, (discord.TextChannel, discord.VoiceChannel)):
create_or_update_channel(session, channel)
channels_updated += 1
# Refresh all members in this server (if members intent is enabled)
if self.intents.members:
for member in guild.members:
create_or_update_user(session, member)
users_updated += 1
session.commit()
result = {
"servers_updated": servers_updated,
"channels_updated": channels_updated,
"users_updated": users_updated,
}
print(f"✅ Metadata refresh complete: {result}")
logger.info(f"Metadata refresh complete: {result}")
return result
async def get_user(self, user_identifier: int | str) -> discord.User | None:
"""Get a Discord user by ID or username"""
if isinstance(user_identifier, int):
# Direct user ID lookup
if user := super().get_user(user_identifier):
return user
try:
return await self.fetch_user(user_identifier)
except discord.NotFound:
return None
else:
# Username lookup - search through all guilds
for guild in self.guilds:
for member in guild.members:
if (
member.name == user_identifier
or member.display_name == user_identifier
or f"{member.name}#{member.discriminator}" == user_identifier
):
return member
return None
async def get_channel_by_name(
self, channel_name: str
) -> discord.TextChannel | None:
"""Get a Discord channel by name (does not create if missing)"""
# Search all guilds for the channel
for guild in self.guilds:
for ch in guild.channels:
if isinstance(ch, discord.TextChannel) and ch.name == channel_name:
return ch
return None
async def create_channel(
self, channel_name: str, guild_id: int | None = None
) -> discord.TextChannel | None:
"""Create a Discord channel in the specified guild (or first guild if none specified)"""
target_guild = None
if guild_id:
target_guild = self.get_guild(guild_id)
elif self.guilds:
target_guild = self.guilds[0] # Default to first guild
if not target_guild:
logger.error(f"No guild available to create channel {channel_name}")
return None
try:
channel = await target_guild.create_text_channel(channel_name)
logger.info(f"Created channel {channel_name} in {target_guild.name}")
return channel
except Exception as e:
logger.error(
f"Failed to create channel {channel_name} in {target_guild.name}: {e}"
)
return None
async def send_dm(self, user_identifier: int | str, message: str) -> bool:
"""Send a DM using this collector's Discord client"""
try:
user = await self.get_user(user_identifier)
if not user:
logger.error(f"User {user_identifier} not found")
return False
await user.send(message)
logger.info(f"Sent DM to {user_identifier}")
return True
except Exception as e:
logger.error(f"Failed to send DM to {user_identifier}: {e}")
return False
async def send_to_channel(self, channel_name: str, message: str) -> bool:
"""Send a message to a channel by name across all guilds"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False
try:
channel = await self.get_channel_by_name(channel_name)
if not channel:
logger.error(f"Channel {channel_name} not found")
return False
await channel.send(message)
logger.info(f"Sent message to channel {channel_name}")
return True
except Exception as e:
logger.error(f"Failed to send message to channel {channel_name}: {e}")
return False
async def run_collector():
"""Run the Discord message collector"""
if not settings.DISCORD_BOT_TOKEN:
logger.error("DISCORD_BOT_TOKEN not configured")
return
collector = MessageCollector()
try:
await collector.start(settings.DISCORD_BOT_TOKEN)
except Exception as e:
logger.error(f"Discord collector failed: {e}")
raise
if __name__ == "__main__":
import asyncio
asyncio.run(run_collector())

View File

@ -6,6 +6,7 @@ from memory.workers.tasks import (
email,
comic,
blogs,
discord,
ebook,
forums,
maintenance,
@ -20,6 +21,7 @@ __all__ = [
"comic",
"blogs",
"ebook",
"discord",
"forums",
"maintenance",
"notes",

View File

@ -115,6 +115,9 @@ def sync_article_feed(feed_id: int) -> dict:
try:
for feed_item in parser.parse_feed():
if not feed_item.url:
continue
articles_found += 1
existing = check_content_exists(session, BlogPost, url=feed_item.url)

View File

@ -10,9 +10,9 @@ from collections import defaultdict
import hashlib
import traceback
import logging
from typing import Any, Callable, Iterable, Sequence, cast
from typing import Any, Callable, Sequence, cast
from memory.common import embedding, qdrant, settings
from memory.common import embedding, qdrant
from memory.common.db.models import SourceItem, Chunk
from memory.common.discord import notify_task_failure
@ -38,19 +38,12 @@ def check_content_exists(
Returns:
Existing SourceItem if found, None otherwise
"""
query = session.query(model_class)
for key, value in kwargs.items():
if not hasattr(model_class, key):
continue
if hasattr(model_class, key):
query = query.filter(getattr(model_class, key) == value)
existing = (
session.query(model_class)
.filter(getattr(model_class, key) == value)
.first()
)
if existing:
return existing
return None
return query.first()
def create_content_hash(content: str, *additional_data: str) -> bytes:
@ -286,6 +279,6 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
traceback_str=traceback_str,
)
return {"status": "error", "error": str(e)}
return {"status": "error", "error": str(e), "traceback": traceback_str}
return wrapper

View File

@ -0,0 +1,166 @@
"""
Celery tasks for Discord message processing.
"""
import hashlib
import logging
from datetime import datetime
from typing import Any
from memory.common.celery_app import app
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordMessage, DiscordUser
from memory.workers.tasks.content_processing import (
safe_task_execution,
check_content_exists,
create_task_result,
process_content_item,
)
from memory.common.celery_app import (
ADD_DISCORD_MESSAGE,
EDIT_DISCORD_MESSAGE,
PROCESS_DISCORD_MESSAGE,
)
from memory.common import settings
from sqlalchemy.orm import Session, scoped_session
logger = logging.getLogger(__name__)
def get_prev(
session: Session | scoped_session, channel_id: int, sent_at: datetime
) -> list[str]:
prev = (
session.query(DiscordUser.username, DiscordMessage.content)
.join(DiscordUser, DiscordMessage.discord_user_id == DiscordUser.id)
.filter(
DiscordMessage.channel_id == channel_id,
DiscordMessage.sent_at < sent_at,
)
.order_by(DiscordMessage.sent_at.desc())
.limit(settings.DISCORD_CONTEXT_WINDOW)
.all()
)
return [f"{msg.username}: {msg.content}" for msg in prev[::-1]]
def should_process(message: DiscordMessage) -> bool:
return (
settings.DISCORD_PROCESS_MESSAGES
and settings.DISCORD_NOTIFICATIONS_ENABLED
and not (
(message.server and message.server.ignore_messages)
or (message.channel and message.channel.ignore_messages)
or (message.discord_user and message.discord_user.ignore_messages)
)
)
@app.task(name=PROCESS_DISCORD_MESSAGE)
@safe_task_execution
def process_discord_message(message_id: int) -> dict[str, Any]:
logger.info(f"Processing Discord message {message_id}")
with make_session() as session:
discord_message = session.query(DiscordMessage).get(message_id)
if not discord_message:
logger.info(f"Discord message not found: {message_id}")
return {
"status": "error",
"error": "Message not found",
"message_id": message_id,
}
print("Processing message", discord_message.id, discord_message.content)
return {
"status": "processed",
"message_id": message_id,
}
@app.task(name=ADD_DISCORD_MESSAGE)
@safe_task_execution
def add_discord_message(
message_id: int,
channel_id: int,
author_id: int,
content: str,
sent_at: str,
server_id: int | None = None,
message_reference_id: int | None = None,
) -> dict[str, Any]:
"""
Add a Discord message to the database.
This task is queued by the Discord collector when messages are received.
"""
logger.info(f"Adding Discord message {message_id}: {content}")
# Include message_id in hash to ensure uniqueness across duplicate content
content_hash = hashlib.sha256(f"{message_id}:{content}".encode()).digest()
sent_at_dt = datetime.fromisoformat(sent_at.replace("Z", "+00:00"))
with make_session() as session:
discord_message = DiscordMessage(
modality="text",
sha256=content_hash,
content=content,
channel_id=channel_id,
sent_at=sent_at_dt,
server_id=server_id,
discord_user_id=author_id,
message_id=message_id,
message_type="reply" if message_reference_id else "default",
reply_to_message_id=message_reference_id,
)
existing_msg = check_content_exists(
session, DiscordMessage, message_id=message_id, sha256=content_hash
)
if existing_msg:
logger.info(f"Discord message already exists: {existing_msg.message_id}")
return create_task_result(
existing_msg, "already_exists", message_id=message_id
)
if channel_id:
discord_message.messages_before = get_prev(session, channel_id, sent_at_dt)
result = process_content_item(discord_message, session)
if should_process(discord_message):
process_discord_message.delay(discord_message.id)
return result
@app.task(name=EDIT_DISCORD_MESSAGE)
@safe_task_execution
def edit_discord_message(
message_id: int, content: str, edited_at: str
) -> dict[str, Any]:
"""
Edit a Discord message in the database.
This task is queued by the Discord collector when messages are edited.
"""
logger.info(f"Editing Discord message {message_id}: {content}")
with make_session() as session:
existing_msg = check_content_exists(
session, DiscordMessage, message_id=message_id
)
if not existing_msg:
return {
"status": "error",
"error": "Message not found",
"message_id": message_id,
}
existing_msg.content = content # type: ignore
if existing_msg.channel_id:
existing_msg.messages_before = get_prev(
session, existing_msg.channel_id, existing_msg.sent_at
)
existing_msg.edited_at = datetime.fromisoformat(
edited_at.replace("Z", "+00:00")
)
return process_content_item(existing_msg, session)

View File

@ -88,11 +88,10 @@ def execute_scheduled_call(self, scheduled_call_id: str):
# Make the LLM call
if scheduled_call.model:
response = llms.call(
response = llms.summarize(
prompt=cast(str, scheduled_call.message),
model=cast(str, scheduled_call.model),
system_prompt=cast(str, scheduled_call.system_prompt)
or llms.SYSTEM_PROMPT,
system_prompt=cast(str, scheduled_call.system_prompt),
)
else:
response = cast(str, scheduled_call.message)

View File

@ -273,6 +273,27 @@ def mock_anthropic_client():
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
client = mock_client()
client.messages = Mock()
# Mock stream as a context manager
mock_stream = Mock()
mock_stream.__enter__ = Mock(
return_value=Mock(
__iter__=lambda self: iter(
[
Mock(
type="content_block_delta",
delta=Mock(
type="text_delta",
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
),
)
]
)
)
)
mock_stream.__exit__ = Mock(return_value=False)
client.messages.stream = Mock(return_value=mock_stream)
client.messages.create = Mock(
return_value=Mock(
content=[

View File

@ -2,318 +2,250 @@ import pytest
from unittest.mock import Mock, patch
import requests
from memory.common import discord, settings
from memory.common import discord
@pytest.fixture
def mock_session_request():
with patch("requests.Session.request") as mock:
yield mock
def mock_api_url():
"""Mock the API URL to avoid using actual settings"""
with patch(
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
):
yield
@pytest.fixture
def mock_get_channels_response():
return [
{"name": "memory-errors", "id": "error_channel_id"},
{"name": "memory-activity", "id": "activity_channel_id"},
{"name": "memory-discoveries", "id": "discovery_channel_id"},
{"name": "memory-chat", "id": "chat_channel_id"},
]
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
def test_get_api_url():
"""Test API URL construction"""
assert discord.get_api_url() == "http://testhost:9999"
def test_discord_server_init(mock_session_request, mock_get_channels_response):
# Mock the channels API call
@patch("requests.post")
def test_send_dm_success(mock_post, mock_api_url):
"""Test successful DM sending"""
mock_response = Mock()
mock_response.json.return_value = mock_get_channels_response
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
mock_post.return_value = mock_response
server = discord.DiscordServer("server123", "Test Server")
result = discord.send_dm("user123", "Hello!")
assert server.server_id == "server123"
assert server.server_name == "Test Server"
assert hasattr(server, "channels")
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors")
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity")
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat")
def test_setup_channels_existing(mock_session_request, mock_get_channels_response):
# Mock the channels API call
mock_response = Mock()
mock_response.json.return_value = mock_get_channels_response
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
server = discord.DiscordServer("server123", "Test Server")
assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id"
assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id"
assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id"
assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id"
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel")
def test_setup_channels_create_missing(mock_session_request):
# Mock get channels (empty) and create channel calls
get_response = Mock()
get_response.json.return_value = []
get_response.raise_for_status.return_value = None
create_response = Mock()
create_response.json.return_value = {"id": "new_channel_id"}
create_response.raise_for_status.return_value = None
mock_session_request.side_effect = [
get_response,
create_response,
create_response,
create_response,
create_response,
]
server = discord.DiscordServer("server123", "Test Server")
assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id"
def test_channel_properties():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.channels = {
discord.ERROR_CHANNEL: "error_id",
discord.ACTIVITY_CHANNEL: "activity_id",
discord.DISCOVERY_CHANNEL: "discovery_id",
discord.CHAT_CHANNEL: "chat_id",
}
assert server.error_channel == "error_id"
assert server.activity_channel == "activity_id"
assert server.discovery_channel == "discovery_id"
assert server.chat_channel == "chat_id"
def test_channel_id_exists():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.channels = {"test-channel": "channel123"}
assert server.channel_id("test-channel") == "channel123"
def test_channel_id_not_found():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.channels = {}
with pytest.raises(ValueError, match="Channel nonexistent not found"):
server.channel_id("nonexistent")
def test_send_message(mock_session_request):
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.send_message("channel123", "Hello World")
mock_session_request.assert_called_with(
"POST",
"https://discord.com/api/v10/channels/channel123/messages",
data=None,
json={"content": "Hello World"},
headers={
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
},
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_dm",
json={"user_identifier": "user123", "message": "Hello!"},
timeout=10,
)
def test_create_channel(mock_session_request):
@patch("requests.post")
def test_send_dm_api_failure(mock_post, mock_api_url):
"""Test DM sending when API returns failure"""
mock_response = Mock()
mock_response.json.return_value = {"id": "new_channel_id"}
mock_response.json.return_value = {"success": False}
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
mock_post.return_value = mock_response
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
result = discord.send_dm("user123", "Hello!")
channel_id = server.create_channel("new-channel")
assert channel_id == "new_channel_id"
mock_session_request.assert_called_with(
"POST",
"https://discord.com/api/v10/guilds/server123/channels",
data=None,
json={"name": "new-channel", "type": 0},
headers={
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
},
)
assert result is False
def test_create_channel_custom_type(mock_session_request):
@patch("requests.post")
def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_send_dm_http_error(mock_post, mock_api_url):
"""Test DM sending when HTTP error occurs"""
mock_response = Mock()
mock_response.json.return_value = {"id": "voice_channel_id"}
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_broadcast_message_success(mock_post, mock_api_url):
"""Test successful channel message broadcast"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
mock_post.return_value = mock_response
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
result = discord.broadcast_message("general", "Announcement!")
channel_id = server.create_channel("voice-channel", channel_type=2)
assert channel_id == "voice_channel_id"
mock_session_request.assert_called_with(
"POST",
"https://discord.com/api/v10/guilds/server123/channels",
data=None,
json={"name": "voice-channel", "type": 2},
headers={
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
},
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"},
timeout=10,
)
def test_str_representation():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
server.server_name = "Test Server"
@patch("requests.post")
def test_broadcast_message_failure(mock_post, mock_api_url):
"""Test channel message broadcast failure"""
mock_response = Mock()
mock_response.json.return_value = {"success": False}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)"
result = discord.broadcast_message("general", "Announcement!")
assert result is False
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123")
def test_request_adds_headers(mock_session_request):
server = discord.DiscordServer.__new__(discord.DiscordServer)
@patch("requests.post")
def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout")
server.request("GET", "https://example.com", headers={"Custom": "header"})
result = discord.broadcast_message("general", "Announcement!")
expected_headers = {
"Custom": "header",
"Authorization": "Bot test_token_123",
"Content-Type": "application/json",
}
mock_session_request.assert_called_once_with(
"GET", "https://example.com", headers=expected_headers
)
assert result is False
def test_channels_url():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
assert (
server.channels_url == "https://discord.com/api/v10/guilds/server123/channels"
)
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
@patch("requests.get")
def test_get_bot_servers_success(mock_get):
def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy"""
mock_response = Mock()
mock_response.json.return_value = [
{"id": "server1", "name": "Server 1"},
{"id": "server2", "name": "Server 2"},
]
mock_response.json.return_value = {"status": "healthy"}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
servers = discord.get_bot_servers()
result = discord.is_collector_healthy()
assert len(servers) == 2
assert servers[0] == {"id": "server1", "name": "Server 1"}
mock_get.assert_called_once_with(
"https://discord.com/api/v10/users/@me/guilds",
headers={"Authorization": "Bot test_token"},
)
assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
def test_get_bot_servers_no_token():
assert discord.get_bot_servers() == []
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
@patch("requests.get")
def test_get_bot_servers_exception(mock_get):
mock_get.side_effect = requests.RequestException("API Error")
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status"""
mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
servers = discord.get_bot_servers()
result = discord.is_collector_healthy()
assert servers == []
assert result is False
@patch("memory.common.discord.get_bot_servers")
@patch("memory.common.discord.DiscordServer")
def test_load_servers(mock_discord_server_class, mock_get_servers):
mock_get_servers.return_value = [
{"id": "server1", "name": "Server 1"},
{"id": "server2", "name": "Server 2"},
]
@patch("requests.get")
def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused")
discord.load_servers()
result = discord.is_collector_healthy()
assert mock_discord_server_class.call_count == 2
mock_discord_server_class.assert_any_call("server1", "Server 1")
mock_discord_server_class.assert_any_call("server2", "Server 2")
assert result is False
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_broadcast_message():
mock_server1 = Mock()
mock_server2 = Mock()
discord.servers = {"1": mock_server1, "2": mock_server2}
@patch("requests.post")
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
"""Test successful metadata refresh"""
mock_response = Mock()
mock_response.json.return_value = {
"servers": 5,
"channels": 20,
"users": 100,
}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
discord.broadcast_message("test-channel", "Hello")
result = discord.refresh_discord_metadata()
mock_server1.send_message.assert_called_once_with(
mock_server1.channel_id.return_value, "Hello"
)
mock_server2.send_message.assert_called_once_with(
mock_server2.channel_id.return_value, "Hello"
assert result == {"servers": 5, "channels": 20, "users": 100}
mock_post.assert_called_once_with(
"http://localhost:8000/refresh_metadata", timeout=30
)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_broadcast_message_disabled():
mock_server = Mock()
discord.servers = {"1": mock_server}
@patch("requests.post")
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
"""Test metadata refresh failure"""
mock_post.side_effect = requests.RequestException("Failed to connect")
discord.broadcast_message("test-channel", "Hello")
result = discord.refresh_discord_metadata()
mock_server.send_message.assert_not_called()
assert result is None
@patch("requests.post")
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
"""Test metadata refresh with HTTP error"""
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
mock_post.return_value = mock_response
result = discord.refresh_discord_metadata()
assert result is None
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
def test_send_error_message(mock_broadcast):
discord.send_error_message("Error occurred")
mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred")
"""Test sending error message to error channel"""
mock_broadcast.return_value = True
result = discord.send_error_message("Something broke")
assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
def test_send_activity_message(mock_broadcast):
discord.send_activity_message("Activity update")
mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update")
"""Test sending activity message to activity channel"""
mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in")
assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
def test_send_discovery_message(mock_broadcast):
discord.send_discovery_message("Discovery made")
mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made")
"""Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern")
assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
def test_send_chat_message(mock_broadcast):
discord.send_chat_message("Chat message")
mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message")
"""Test sending chat message to chat channel"""
mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot")
assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong")
mock_send_error.assert_called_once()
@ -323,69 +255,181 @@ def test_notify_task_failure_basic(mock_send_error):
assert "**Error:** Something went wrong" in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_args(mock_send_error):
"""Test task failure notification with arguments"""
discord.notify_task_failure(
"test_task",
"Error message",
task_args=("arg1", "arg2"),
task_kwargs={"key": "value"},
"Error occurred",
task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123},
)
message = mock_send_error.call_args[0][0]
assert "**Args:** `('arg1', 'arg2')`" in message
assert "**Kwargs:** `{'key': 'value'}`" in message
assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_traceback(mock_send_error):
traceback = "Traceback (most recent call last):\n File ...\nError: Something"
"""Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error message", traceback_str=traceback)
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
message = mock_send_error.call_args[0][0]
assert "**Traceback:**" in message
assert traceback in message
assert "Exception: test" in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_error(mock_send_error):
long_error = "x" * 600 # Longer than 500 char limit
"""Test that long error messages are truncated"""
long_error = "x" * 600
discord.notify_task_failure("test_task", long_error)
message = mock_send_error.call_args[0][0]
assert long_error[:500] in message
# Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message
# The full 600-char error should not be present
error_section = message.split("**Error:** ")[1].split("\n")[0]
assert len(error_section) == 500
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
long_traceback = "x" * 1000 # Longer than 800 char limit
"""Test that long tracebacks are truncated"""
long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
message = mock_send_error.call_args[0][0]
# Traceback should show last 800 chars
assert long_traceback[-800:] in message
# The full 1000-char traceback should not be present
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
assert len(traceback_section) == 800
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated"""
long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args)
message = mock_send_error.call_args[0][0]
# Args should be truncated to 200 chars
assert (
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
) # Some buffer for formatting
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
message = mock_send_error.call_args[0][0]
# Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error):
discord.notify_task_failure("test_task", "Error message")
"""Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred")
mock_send_error.assert_not_called()
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
def test_notify_task_failure_send_fails(mock_send_error):
mock_send_error.side_effect = Exception("Discord API error")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_send_error_exception(mock_send_error):
"""Test that exceptions in send_error_message don't propagate"""
mock_send_error.side_effect = Exception("Failed to send")
# Should not raise, just log the error
discord.notify_task_failure("test_task", "Error message")
# Should not raise
discord.notify_task_failure("test_task", "Error occurred")
mock_send_error.assert_called_once()
@pytest.mark.parametrize(
"function,channel_setting,message",
[
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
],
)
@patch("memory.common.discord.broadcast_message")
def test_convenience_functions_use_correct_channels(
mock_broadcast, function, channel_setting, message
):
"""Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message)
mock_broadcast.assert_called_once_with("test-channel", message)
@patch("requests.post")
def test_send_dm_with_special_characters(mock_post, mock_api_url):
"""Test sending DM with special characters"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars
@patch("requests.post")
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
"""Test broadcasting a long message"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
long_message = "A" * 2000
result = discord.broadcast_message("general", long_message)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message
@patch("requests.get")
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
"""Test health check when response doesn't have status key"""
mock_response = Mock()
mock_response.json.return_value = {}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
assert result is False

View File

@ -0,0 +1,435 @@
import pytest
from unittest.mock import Mock, patch
import requests
from memory.common import discord
@pytest.fixture
def mock_api_url():
"""Mock the API URL to avoid using actual settings"""
with patch(
"memory.common.discord.get_api_url", return_value="http://localhost:8000"
):
yield
@patch("memory.common.settings.DISCORD_COLLECTOR_SERVER_URL", "testhost")
@patch("memory.common.settings.DISCORD_COLLECTOR_PORT", 9999)
def test_get_api_url():
"""Test API URL construction"""
assert discord.get_api_url() == "http://testhost:9999"
@patch("requests.post")
def test_send_dm_success(mock_post, mock_api_url):
"""Test successful DM sending"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_dm",
json={"user_identifier": "user123", "message": "Hello!"},
timeout=10,
)
@patch("requests.post")
def test_send_dm_api_failure(mock_post, mock_api_url):
"""Test DM sending when API returns failure"""
mock_response = Mock()
mock_response.json.return_value = {"success": False}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_send_dm_http_error(mock_post, mock_api_url):
"""Test DM sending when HTTP error occurs"""
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
assert result is False
@patch("requests.post")
def test_broadcast_message_success(mock_post, mock_api_url):
"""Test successful channel message broadcast"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!")
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"},
timeout=10,
)
@patch("requests.post")
def test_broadcast_message_failure(mock_post, mock_api_url):
"""Test channel message broadcast failure"""
mock_response = Mock()
mock_response.json.return_value = {"success": False}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!")
assert result is False
@patch("requests.post")
def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout")
result = discord.broadcast_message("general", "Announcement!")
assert result is False
@patch("requests.get")
def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy"""
mock_response = Mock()
mock_response.json.return_value = {"status": "healthy"}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@patch("requests.get")
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status"""
mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
assert result is False
@patch("requests.get")
def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused")
result = discord.is_collector_healthy()
assert result is False
@patch("requests.post")
def test_refresh_discord_metadata_success(mock_post, mock_api_url):
"""Test successful metadata refresh"""
mock_response = Mock()
mock_response.json.return_value = {
"servers": 5,
"channels": 20,
"users": 100,
}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.refresh_discord_metadata()
assert result == {"servers": 5, "channels": 20, "users": 100}
mock_post.assert_called_once_with(
"http://localhost:8000/refresh_metadata", timeout=30
)
@patch("requests.post")
def test_refresh_discord_metadata_failure(mock_post, mock_api_url):
"""Test metadata refresh failure"""
mock_post.side_effect = requests.RequestException("Failed to connect")
result = discord.refresh_discord_metadata()
assert result is None
@patch("requests.post")
def test_refresh_discord_metadata_http_error(mock_post, mock_api_url):
"""Test metadata refresh with HTTP error"""
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
mock_post.return_value = mock_response
result = discord.refresh_discord_metadata()
assert result is None
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "errors")
def test_send_error_message(mock_broadcast):
"""Test sending error message to error channel"""
mock_broadcast.return_value = True
result = discord.send_error_message("Something broke")
assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "activity")
def test_send_activity_message(mock_broadcast):
"""Test sending activity message to activity channel"""
mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in")
assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "discoveries")
def test_send_discovery_message(mock_broadcast):
"""Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern")
assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
@patch("memory.common.discord.broadcast_message")
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "chat")
def test_send_chat_message(mock_broadcast):
"""Test sending chat message to chat channel"""
mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot")
assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong")
mock_send_error.assert_called_once()
message = mock_send_error.call_args[0][0]
assert "🚨 **Task Failed: test_task**" in message
assert "**Error:** Something went wrong" in message
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_args(mock_send_error):
"""Test task failure notification with arguments"""
discord.notify_task_failure(
"test_task",
"Error occurred",
task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123},
)
message = mock_send_error.call_args[0][0]
assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_with_traceback(mock_send_error):
"""Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
message = mock_send_error.call_args[0][0]
assert "**Traceback:**" in message
assert "Exception: test" in message
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_error(mock_send_error):
"""Test that long error messages are truncated"""
long_error = "x" * 600
discord.notify_task_failure("test_task", long_error)
message = mock_send_error.call_args[0][0]
# Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message
# The full 600-char error should not be present
error_section = message.split("**Error:** ")[1].split("\n")[0]
assert len(error_section) == 500
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
"""Test that long tracebacks are truncated"""
long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
message = mock_send_error.call_args[0][0]
# Traceback should show last 800 chars
assert long_traceback[-800:] in message
# The full 1000-char traceback should not be present
traceback_section = message.split("**Traceback:**\n```\n")[1].split("\n```")[0]
assert len(traceback_section) == 800
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated"""
long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args)
message = mock_send_error.call_args[0][0]
# Args should be truncated to 200 chars
assert (
len(message.split("**Args:**")[1].split("\n")[0]) <= 210
) # Some buffer for formatting
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
message = mock_send_error.call_args[0][0]
# Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error):
"""Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred")
mock_send_error.assert_not_called()
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_send_error_exception(mock_send_error):
"""Test that exceptions in send_error_message don't propagate"""
mock_send_error.side_effect = Exception("Failed to send")
# Should not raise
discord.notify_task_failure("test_task", "Error occurred")
mock_send_error.assert_called_once()
@pytest.mark.parametrize(
"function,channel_setting,message",
[
(discord.send_error_message, "DISCORD_ERROR_CHANNEL", "Error!"),
(discord.send_activity_message, "DISCORD_ACTIVITY_CHANNEL", "Activity!"),
(discord.send_discovery_message, "DISCORD_DISCOVERY_CHANNEL", "Discovery!"),
(discord.send_chat_message, "DISCORD_CHAT_CHANNEL", "Chat!"),
],
)
@patch("memory.common.discord.broadcast_message")
def test_convenience_functions_use_correct_channels(
mock_broadcast, function, channel_setting, message
):
"""Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message)
mock_broadcast.assert_called_once_with("test-channel", message)
@patch("requests.post")
def test_send_dm_with_special_characters(mock_post, mock_api_url):
"""Test sending DM with special characters"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars
@patch("requests.post")
def test_broadcast_message_with_long_message(mock_post, mock_api_url):
"""Test broadcasting a long message"""
mock_response = Mock()
mock_response.json.return_value = {"success": True}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
long_message = "A" * 2000
result = discord.broadcast_message("general", long_message)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message
@patch("requests.get")
def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
"""Test health check when response doesn't have status key"""
mock_response = Mock()
mock_response.json.return_value = {}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
assert result is False

View File

@ -0,0 +1,840 @@
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock, patch, AsyncMock, MagicMock
import discord
from memory.discord.collector import (
create_or_update_server,
determine_channel_metadata,
create_or_update_channel,
create_or_update_user,
determine_message_metadata,
should_track_message,
should_collect_bot_message,
sync_guild_metadata,
MessageCollector,
)
from memory.common.db.models.sources import (
DiscordServer,
DiscordChannel,
DiscordUser,
)
# Fixtures for Discord objects
@pytest.fixture
def mock_guild():
"""Mock Discord Guild object"""
guild = Mock(spec=discord.Guild)
guild.id = 123456789
guild.name = "Test Server"
guild.description = "A test server"
guild.member_count = 42
return guild
@pytest.fixture
def mock_text_channel():
"""Mock Discord TextChannel object"""
channel = Mock(spec=discord.TextChannel)
channel.id = 987654321
channel.name = "general"
guild = Mock()
guild.id = 123456789
channel.guild = guild
return channel
@pytest.fixture
def mock_dm_channel():
"""Mock Discord DMChannel object"""
channel = Mock(spec=discord.DMChannel)
channel.id = 111222333
recipient = Mock()
recipient.name = "TestUser"
channel.recipient = recipient
return channel
@pytest.fixture
def mock_user():
"""Mock Discord User object"""
user = Mock(spec=discord.User)
user.id = 444555666
user.name = "testuser"
user.display_name = "Test User"
user.bot = False
return user
@pytest.fixture
def mock_message(mock_text_channel, mock_user):
"""Mock Discord Message object"""
message = Mock(spec=discord.Message)
message.id = 777888999
message.channel = mock_text_channel
message.author = mock_user
message.guild = mock_text_channel.guild
message.content = "Test message"
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
message.reference = None
return message
# Tests for create_or_update_server
def test_create_or_update_server_creates_new(db_session, mock_guild):
"""Test creating a new server record"""
result = create_or_update_server(db_session, mock_guild)
assert result is not None
assert result.id == mock_guild.id
assert result.name == mock_guild.name
assert result.description == mock_guild.description
assert result.member_count == mock_guild.member_count
def test_create_or_update_server_updates_existing(db_session, mock_guild):
"""Test updating an existing server record"""
# Create initial server
server = DiscordServer(
id=mock_guild.id,
name="Old Name",
description="Old Description",
member_count=10,
)
db_session.add(server)
db_session.commit()
# Update with new data
mock_guild.name = "New Name"
mock_guild.description = "New Description"
mock_guild.member_count = 50
result = create_or_update_server(db_session, mock_guild)
assert result.name == "New Name"
assert result.description == "New Description"
assert result.member_count == 50
assert result.last_sync_at is not None
def test_create_or_update_server_none_guild(db_session):
"""Test with None guild"""
result = create_or_update_server(db_session, None)
assert result is None
# Tests for determine_channel_metadata
def test_determine_channel_metadata_dm():
"""Test metadata for DM channel"""
channel = Mock(spec=discord.DMChannel)
channel.recipient = Mock()
channel.recipient.name = "TestUser"
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "dm"
assert server_id is None
assert "DM with TestUser" in name
def test_determine_channel_metadata_dm_no_recipient():
"""Test metadata for DM channel without recipient"""
channel = Mock(spec=discord.DMChannel)
channel.recipient = None
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "dm"
assert name == "Unknown DM"
def test_determine_channel_metadata_group_dm():
"""Test metadata for group DM channel"""
channel = Mock(spec=discord.GroupChannel)
channel.name = "Group Chat"
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "group_dm"
assert server_id is None
assert name == "Group Chat"
def test_determine_channel_metadata_group_dm_no_name():
"""Test metadata for group DM without name"""
channel = Mock(spec=discord.GroupChannel)
channel.name = None
channel_type, server_id, name = determine_channel_metadata(channel)
assert name == "Group DM"
def test_determine_channel_metadata_text_channel():
"""Test metadata for text channel"""
channel = Mock(spec=discord.TextChannel)
channel.name = "general"
channel.guild = Mock()
channel.guild.id = 123
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "text"
assert server_id == 123
assert name == "general"
def test_determine_channel_metadata_voice_channel():
"""Test metadata for voice channel"""
channel = Mock(spec=discord.VoiceChannel)
channel.name = "voice-chat"
channel.guild = Mock()
channel.guild.id = 456
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "voice"
assert server_id == 456
assert name == "voice-chat"
def test_determine_channel_metadata_thread():
"""Test metadata for thread"""
channel = Mock(spec=discord.Thread)
channel.name = "thread-1"
channel.guild = Mock()
channel.guild.id = 789
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "thread"
assert server_id == 789
assert name == "thread-1"
def test_determine_channel_metadata_unknown():
"""Test metadata for unknown channel type"""
channel = Mock()
channel.id = 999
# Ensure the mock doesn't have a 'name' attribute
del channel.name
channel_type, server_id, name = determine_channel_metadata(channel)
assert channel_type == "unknown"
assert name == "Unknown-999"
# Tests for create_or_update_channel
def test_create_or_update_channel_creates_new(
db_session, mock_text_channel, mock_guild
):
"""Test creating a new channel record"""
# Create the server first to satisfy foreign key constraint
create_or_update_server(db_session, mock_guild)
result = create_or_update_channel(db_session, mock_text_channel)
assert result is not None
assert result.id == mock_text_channel.id
assert result.name == mock_text_channel.name
assert result.channel_type == "text"
def test_create_or_update_channel_updates_existing(db_session, mock_text_channel):
"""Test updating an existing channel record"""
# Create initial channel
channel = DiscordChannel(
id=mock_text_channel.id,
name="old-name",
channel_type="text",
)
db_session.add(channel)
db_session.commit()
# Update with new name
mock_text_channel.name = "new-name"
result = create_or_update_channel(db_session, mock_text_channel)
assert result.name == "new-name"
def test_create_or_update_channel_none_channel(db_session):
"""Test with None channel"""
result = create_or_update_channel(db_session, None)
assert result is None
# Tests for create_or_update_user
def test_create_or_update_user_creates_new(db_session, mock_user):
"""Test creating a new user record"""
result = create_or_update_user(db_session, mock_user)
assert result is not None
assert result.id == mock_user.id
assert result.username == mock_user.name
assert result.display_name == mock_user.display_name
def test_create_or_update_user_updates_existing(db_session, mock_user):
"""Test updating an existing user record"""
# Create initial user
user = DiscordUser(
id=mock_user.id,
username="oldname",
display_name="Old Display Name",
)
db_session.add(user)
db_session.commit()
# Update with new data
mock_user.name = "newname"
mock_user.display_name = "New Display Name"
result = create_or_update_user(db_session, mock_user)
assert result.username == "newname"
assert result.display_name == "New Display Name"
def test_create_or_update_user_none_user(db_session):
"""Test with None user"""
result = create_or_update_user(db_session, None)
assert result is None
# Tests for determine_message_metadata
def test_determine_message_metadata_default():
"""Test metadata for default message"""
message = Mock()
message.reference = None
message.channel = Mock()
# Ensure channel doesn't have parent attribute
del message.channel.parent
message_type, reply_to_id, thread_id = determine_message_metadata(message)
assert message_type == "default"
assert reply_to_id is None
assert thread_id is None
def test_determine_message_metadata_reply():
"""Test metadata for reply message"""
message = Mock()
message.reference = Mock()
message.reference.message_id = 123456
message.channel = Mock()
message_type, reply_to_id, thread_id = determine_message_metadata(message)
assert message_type == "reply"
assert reply_to_id == 123456
def test_determine_message_metadata_thread():
"""Test metadata for message in thread"""
message = Mock()
message.reference = None
message.channel = Mock()
message.channel.id = 999
message.channel.parent = Mock() # Has parent means it's a thread
message_type, reply_to_id, thread_id = determine_message_metadata(message)
assert thread_id == 999
# Tests for should_track_message
def test_should_track_message_server_disabled(db_session):
"""Test when server has tracking disabled"""
server = DiscordServer(id=1, name="Server", track_messages=False)
channel = DiscordChannel(id=2, name="Channel", channel_type="text")
user = DiscordUser(id=3, username="User")
result = should_track_message(server, channel, user)
assert result is False
def test_should_track_message_channel_disabled(db_session):
"""Test when channel has tracking disabled"""
server = DiscordServer(id=1, name="Server", track_messages=True)
channel = DiscordChannel(
id=2, name="Channel", channel_type="text", track_messages=False
)
user = DiscordUser(id=3, username="User")
result = should_track_message(server, channel, user)
assert result is False
def test_should_track_message_dm_allowed(db_session):
"""Test DM tracking when user allows it"""
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
user = DiscordUser(id=3, username="User", track_messages=True)
result = should_track_message(None, channel, user)
assert result is True
def test_should_track_message_dm_not_allowed(db_session):
"""Test DM tracking when user doesn't allow it"""
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
user = DiscordUser(id=3, username="User", track_messages=False)
result = should_track_message(None, channel, user)
assert result is False
def test_should_track_message_default_true(db_session):
"""Test default tracking behavior"""
server = DiscordServer(id=1, name="Server", track_messages=True)
channel = DiscordChannel(
id=2, name="Channel", channel_type="text", track_messages=True
)
user = DiscordUser(id=3, username="User")
result = should_track_message(server, channel, user)
assert result is True
# Tests for should_collect_bot_message
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", False)
def test_should_collect_bot_message_bot_not_allowed():
"""Test bot message collection when disabled"""
message = Mock()
message.author = Mock()
message.author.bot = True
result = should_collect_bot_message(message)
assert result is False
@patch("memory.common.settings.DISCORD_COLLECT_BOTS", True)
def test_should_collect_bot_message_bot_allowed():
"""Test bot message collection when enabled"""
message = Mock()
message.author = Mock()
message.author.bot = True
result = should_collect_bot_message(message)
assert result is True
def test_should_collect_bot_message_human():
"""Test human message collection"""
message = Mock()
message.author = Mock()
message.author.bot = False
result = should_collect_bot_message(message)
assert result is True
# Tests for sync_guild_metadata
@patch("memory.discord.collector.make_session")
def test_sync_guild_metadata(mock_make_session, mock_guild):
"""Test syncing guild metadata"""
mock_session = Mock()
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
mock_make_session.return_value.__exit__ = Mock(return_value=None)
# Mock session.query().get() to return None (new server)
mock_session.query.return_value.get.return_value = None
# Mock channels
text_channel = Mock(spec=discord.TextChannel)
text_channel.id = 1
text_channel.name = "general"
text_channel.guild = mock_guild
voice_channel = Mock(spec=discord.VoiceChannel)
voice_channel.id = 2
voice_channel.name = "voice"
voice_channel.guild = mock_guild
mock_guild.channels = [text_channel, voice_channel]
sync_guild_metadata(mock_guild)
# Verify session.commit was called
mock_session.commit.assert_called_once()
# Tests for MessageCollector class
def test_message_collector_init():
"""Test MessageCollector initialization"""
collector = MessageCollector()
assert collector.command_prefix == "!memory_"
assert collector.help_command is None
assert collector.intents.message_content is True
assert collector.intents.guilds is True
assert collector.intents.members is True
assert collector.intents.dm_messages is True
@pytest.mark.asyncio
async def test_on_ready():
"""Test on_ready event handler"""
collector = MessageCollector()
collector.user = Mock()
collector.user.name = "TestBot"
collector.guilds = [Mock(), Mock()]
collector.sync_servers_and_channels = AsyncMock()
await collector.on_ready()
collector.sync_servers_and_channels.assert_called_once()
@pytest.mark.asyncio
@patch("memory.discord.collector.make_session")
@patch("memory.discord.collector.add_discord_message")
async def test_on_message_success(mock_add_task, mock_make_session, mock_message):
"""Test successful message handling"""
mock_session = Mock()
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
mock_make_session.return_value.__exit__ = Mock(return_value=None)
mock_session.query.return_value.get.return_value = None # New entities
collector = MessageCollector()
await collector.on_message(mock_message)
# Verify task was queued
mock_add_task.delay.assert_called_once()
call_kwargs = mock_add_task.delay.call_args[1]
assert call_kwargs["message_id"] == mock_message.id
assert call_kwargs["channel_id"] == mock_message.channel.id
assert call_kwargs["author_id"] == mock_message.author.id
assert call_kwargs["content"] == mock_message.content
@pytest.mark.asyncio
@patch("memory.discord.collector.make_session")
async def test_on_message_bot_message_filtered(mock_make_session, mock_message):
"""Test bot message filtering"""
mock_message.author.bot = True
with patch(
"memory.discord.collector.should_collect_bot_message", return_value=False
):
collector = MessageCollector()
await collector.on_message(mock_message)
# Should not create session or queue task
mock_make_session.assert_not_called()
@pytest.mark.asyncio
@patch("memory.discord.collector.make_session")
async def test_on_message_error_handling(mock_make_session, mock_message):
"""Test error handling in on_message"""
mock_make_session.side_effect = Exception("Database error")
collector = MessageCollector()
# Should not raise
await collector.on_message(mock_message)
@pytest.mark.asyncio
@patch("memory.discord.collector.edit_discord_message")
async def test_on_message_edit(mock_edit_task):
"""Test message edit handler"""
before = Mock()
after = Mock()
after.id = 123
after.content = "Edited content"
after.edited_at = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
collector = MessageCollector()
await collector.on_message_edit(before, after)
mock_edit_task.delay.assert_called_once()
call_kwargs = mock_edit_task.delay.call_args[1]
assert call_kwargs["message_id"] == 123
assert call_kwargs["content"] == "Edited content"
@pytest.mark.asyncio
async def test_on_message_edit_error_handling():
"""Test error handling in on_message_edit"""
before = Mock()
after = Mock()
after.id = 123
after.content = "Edited"
after.edited_at = None # Will trigger datetime.now
with patch("memory.discord.collector.edit_discord_message") as mock_edit:
mock_edit.delay.side_effect = Exception("Task error")
collector = MessageCollector()
# Should not raise
await collector.on_message_edit(before, after)
@pytest.mark.asyncio
async def test_sync_servers_and_channels():
"""Test syncing servers and channels"""
guild1 = Mock()
guild2 = Mock()
collector = MessageCollector()
collector.guilds = [guild1, guild2]
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
await collector.sync_servers_and_channels()
assert mock_sync.call_count == 2
mock_sync.assert_any_call(guild1)
mock_sync.assert_any_call(guild2)
@pytest.mark.asyncio
@patch("memory.discord.collector.make_session")
async def test_refresh_metadata(mock_make_session):
"""Test metadata refresh"""
mock_session = Mock()
mock_make_session.return_value.__enter__ = Mock(return_value=mock_session)
mock_make_session.return_value.__exit__ = Mock(return_value=None)
mock_session.query.return_value.get.return_value = None
guild = Mock()
guild.id = 123
guild.name = "Test"
guild.channels = []
guild.members = []
collector = MessageCollector()
collector.guilds = [guild]
collector.intents = Mock()
collector.intents.members = False
result = await collector.refresh_metadata()
assert result["servers_updated"] == 1
assert result["channels_updated"] == 0
assert result["users_updated"] == 0
@pytest.mark.asyncio
async def test_get_user_by_id():
"""Test getting user by ID"""
user = Mock()
user.id = 123
collector = MessageCollector()
collector.get_user = Mock(return_value=user)
result = await collector.get_user(123)
assert result == user
@pytest.mark.asyncio
async def test_get_user_by_username():
"""Test getting user by username"""
member = Mock()
member.name = "testuser"
member.display_name = "Test User"
member.discriminator = "1234"
guild = Mock()
guild.members = [member]
collector = MessageCollector()
collector.guilds = [guild]
result = await collector.get_user("testuser")
assert result == member
@pytest.mark.asyncio
async def test_get_user_not_found():
"""Test getting non-existent user"""
collector = MessageCollector()
collector.guilds = []
with patch.object(collector, "get_user", return_value=None):
with patch.object(
collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock())
):
result = await collector.get_user(999)
assert result is None
@pytest.mark.asyncio
async def test_get_channel_by_name():
"""Test getting channel by name"""
channel = Mock(spec=discord.TextChannel)
channel.name = "general"
guild = Mock()
guild.channels = [channel]
collector = MessageCollector()
collector.guilds = [guild]
result = await collector.get_channel_by_name("general")
assert result == channel
@pytest.mark.asyncio
async def test_get_channel_by_name_not_found():
"""Test getting non-existent channel"""
guild = Mock()
guild.channels = []
collector = MessageCollector()
collector.guilds = [guild]
result = await collector.get_channel_by_name("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_create_channel():
"""Test creating a channel"""
guild = Mock()
guild.name = "Test Server"
new_channel = Mock()
guild.create_text_channel = AsyncMock(return_value=new_channel)
collector = MessageCollector()
collector.get_guild = Mock(return_value=guild)
result = await collector.create_channel("new-channel", guild_id=123)
assert result == new_channel
guild.create_text_channel.assert_called_once_with("new-channel")
@pytest.mark.asyncio
async def test_create_channel_no_guild():
"""Test creating channel when no guild available"""
collector = MessageCollector()
collector.get_guild = Mock(return_value=None)
collector.guilds = []
result = await collector.create_channel("new-channel")
assert result is None
@pytest.mark.asyncio
async def test_send_dm_success():
"""Test sending DM successfully"""
user = Mock()
user.send = AsyncMock()
collector = MessageCollector()
collector.get_user = AsyncMock(return_value=user)
result = await collector.send_dm(123, "Hello!")
assert result is True
user.send.assert_called_once_with("Hello!")
@pytest.mark.asyncio
async def test_send_dm_user_not_found():
"""Test sending DM when user not found"""
collector = MessageCollector()
collector.get_user = AsyncMock(return_value=None)
result = await collector.send_dm(123, "Hello!")
assert result is False
@pytest.mark.asyncio
async def test_send_dm_exception():
"""Test sending DM with exception"""
user = Mock()
user.send = AsyncMock(side_effect=Exception("Send failed"))
collector = MessageCollector()
collector.get_user = AsyncMock(return_value=user)
result = await collector.send_dm(123, "Hello!")
assert result is False
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
async def test_send_to_channel_success():
"""Test sending to channel successfully"""
channel = Mock()
channel.send = AsyncMock()
collector = MessageCollector()
collector.get_channel_by_name = AsyncMock(return_value=channel)
result = await collector.send_to_channel("general", "Announcement!")
assert result is True
channel.send.assert_called_once_with("Announcement!")
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
async def test_send_to_channel_notifications_disabled():
"""Test sending to channel when notifications disabled"""
collector = MessageCollector()
result = await collector.send_to_channel("general", "Announcement!")
assert result is False
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
async def test_send_to_channel_not_found():
"""Test sending to non-existent channel"""
collector = MessageCollector()
collector.get_channel_by_name = AsyncMock(return_value=None)
result = await collector.send_to_channel("nonexistent", "Message")
assert result is False
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
async def test_run_collector():
"""Test running the collector"""
from memory.discord.collector import run_collector
with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
mock_collector = Mock()
mock_collector.start = AsyncMock()
mock_collector_class.return_value = mock_collector
await run_collector()
mock_collector.start.assert_called_once_with("test_token")
@pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
async def test_run_collector_no_token():
"""Test running collector without token"""
from memory.discord.collector import run_collector
# Should return early without raising
await run_collector()

View File

@ -0,0 +1,607 @@
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock, patch
from memory.common.db.models import (
DiscordMessage,
DiscordUser,
DiscordServer,
DiscordChannel,
)
from memory.workers.tasks import discord
@pytest.fixture
def mock_discord_user(db_session):
"""Create a Discord user for testing."""
user = DiscordUser(
id=123456789,
username="testuser",
ignore_messages=False,
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def mock_discord_server(db_session):
"""Create a Discord server for testing."""
server = DiscordServer(
id=987654321,
name="Test Server",
ignore_messages=False,
)
db_session.add(server)
db_session.commit()
return server
@pytest.fixture
def mock_discord_channel(db_session, mock_discord_server):
"""Create a Discord channel for testing."""
channel = DiscordChannel(
id=111222333,
name="test-channel",
channel_type="text",
server_id=mock_discord_server.id,
ignore_messages=False,
)
db_session.add(channel)
db_session.commit()
return channel
@pytest.fixture
def sample_message_data(mock_discord_user, mock_discord_channel):
"""Sample message data for testing."""
return {
"message_id": 999888777,
"channel_id": mock_discord_channel.id,
"author_id": mock_discord_user.id,
"content": "This is a test Discord message with enough content to be processed.",
"sent_at": "2024-01-01T12:00:00Z",
"server_id": None,
"message_reference_id": None,
}
def test_get_prev_returns_previous_messages(
db_session, mock_discord_user, mock_discord_channel
):
"""Test that get_prev returns previous messages in order."""
# Create previous messages
msg1 = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
content="First message",
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash1" + bytes(26),
)
msg2 = DiscordMessage(
message_id=2,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
content="Second message",
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash2" + bytes(26),
)
msg3 = DiscordMessage(
message_id=3,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
content="Third message",
sent_at=datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash3" + bytes(26),
)
db_session.add_all([msg1, msg2, msg3])
db_session.commit()
# Get previous messages before 10:15
result = discord.get_prev(
db_session,
mock_discord_channel.id,
datetime(2024, 1, 1, 10, 15, 0, tzinfo=timezone.utc),
)
assert len(result) == 3
assert result[0] == "testuser: First message"
assert result[1] == "testuser: Second message"
assert result[2] == "testuser: Third message"
def test_get_prev_limits_context_window(
db_session, mock_discord_user, mock_discord_channel
):
"""Test that get_prev respects DISCORD_CONTEXT_WINDOW setting."""
# Create 15 messages (more than the default context window of 10)
for i in range(15):
msg = DiscordMessage(
message_id=i,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
content=f"Message {i}",
sent_at=datetime(2024, 1, 1, 10, i, 0, tzinfo=timezone.utc),
modality="text",
sha256=f"hash{i}".encode() + bytes(27),
)
db_session.add(msg)
db_session.commit()
result = discord.get_prev(
db_session,
mock_discord_channel.id,
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
)
# Should only return last 10 messages
assert len(result) == 10
assert result[0] == "testuser: Message 5" # Oldest in window
assert result[-1] == "testuser: Message 14" # Most recent
def test_get_prev_empty_channel(db_session, mock_discord_channel):
"""Test get_prev with no previous messages."""
result = discord.get_prev(
db_session,
mock_discord_channel.id,
datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
)
assert result == []
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_should_process_normal_message(
db_session, mock_discord_user, mock_discord_server, mock_discord_channel
):
"""Test should_process returns True for normal messages."""
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
modality="text",
sha256=b"hash" + bytes(27),
)
db_session.add(message)
db_session.commit()
db_session.refresh(message)
assert discord.should_process(message) is True
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
def test_should_process_disabled():
"""Test should_process returns False when processing is disabled."""
message = Mock()
assert discord.should_process(message) is False
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_should_process_notifications_disabled():
"""Test should_process returns False when notifications are disabled."""
message = Mock()
assert discord.should_process(message) is False
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_should_process_server_ignored(
db_session, mock_discord_user, mock_discord_channel
):
"""Test should_process returns False when server has ignore_messages=True."""
server = DiscordServer(
id=123,
name="Ignored Server",
ignore_messages=True,
)
db_session.add(server)
db_session.commit()
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
discord_user_id=mock_discord_user.id,
server_id=server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
modality="text",
sha256=b"hash" + bytes(27),
)
db_session.add(message)
db_session.commit()
db_session.refresh(message)
assert discord.should_process(message) is False
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_should_process_channel_ignored(
db_session, mock_discord_user, mock_discord_server
):
"""Test should_process returns False when channel has ignore_messages=True."""
channel = DiscordChannel(
id=456,
name="ignored-channel",
channel_type="text",
server_id=mock_discord_server.id,
ignore_messages=True,
)
db_session.add(channel)
db_session.commit()
message = DiscordMessage(
message_id=1,
channel_id=channel.id,
discord_user_id=mock_discord_user.id,
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
modality="text",
sha256=b"hash" + bytes(27),
)
db_session.add(message)
db_session.commit()
db_session.refresh(message)
assert discord.should_process(message) is False
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_should_process_user_ignored(
db_session, mock_discord_server, mock_discord_channel
):
"""Test should_process returns False when user has ignore_messages=True."""
user = DiscordUser(
id=789,
username="ignoreduser",
ignore_messages=True,
)
db_session.add(user)
db_session.commit()
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
discord_user_id=user.id,
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
modality="text",
sha256=b"hash" + bytes(27),
)
db_session.add(message)
db_session.commit()
db_session.refresh(message)
assert discord.should_process(message) is False
def test_add_discord_message_success(db_session, sample_message_data, qdrant):
"""Test successful Discord message addition."""
result = discord.add_discord_message(**sample_message_data)
assert result["status"] == "processed"
assert "discordmessage_id" in result
# Verify the message was created in the database
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message is not None
assert message.content == sample_message_data["content"]
assert message.message_type == "default"
assert message.reply_to_message_id is None
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
"""Test adding a Discord message that is a reply."""
sample_message_data["message_reference_id"] = 111222333
discord.add_discord_message(**sample_message_data)
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message.message_type == "reply"
assert message.reply_to_message_id == 111222333
def test_add_discord_message_already_exists(db_session, sample_message_data, qdrant):
"""Test adding a message that already exists."""
# Add the message once
discord.add_discord_message(**sample_message_data)
# Try to add it again
result = discord.add_discord_message(**sample_message_data)
assert result["status"] == "already_exists"
assert result["message_id"] == sample_message_data["message_id"]
# Verify no duplicate was created
messages = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.all()
)
assert len(messages) == 1
def test_add_discord_message_with_context(
db_session, sample_message_data, mock_discord_user, qdrant
):
"""Test that message is added successfully when previous messages exist."""
# Add a previous message
prev_msg = DiscordMessage(
message_id=111111,
channel_id=sample_message_data["channel_id"],
discord_user_id=mock_discord_user.id,
content="Previous message",
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"prev" + bytes(28),
)
db_session.add(prev_msg)
db_session.commit()
result = discord.add_discord_message(**sample_message_data)
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message is not None
assert result["status"] == "processed"
@patch("memory.workers.tasks.discord.process_discord_message")
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_add_discord_message_triggers_processing(
mock_process,
db_session,
sample_message_data,
mock_discord_server,
mock_discord_channel,
qdrant,
):
"""Test that add_discord_message triggers process_discord_message when conditions are met."""
mock_process.delay = Mock()
sample_message_data["server_id"] = mock_discord_server.id
discord.add_discord_message(**sample_message_data)
# Verify process_discord_message.delay was called
mock_process.delay.assert_called_once()
@patch("memory.workers.tasks.discord.process_discord_message")
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
def test_add_discord_message_no_processing_when_disabled(
mock_process, db_session, sample_message_data, qdrant
):
"""Test that process_discord_message is not called when processing is disabled."""
mock_process.delay = Mock()
discord.add_discord_message(**sample_message_data)
mock_process.delay.assert_not_called()
def test_edit_discord_message_success(db_session, sample_message_data, qdrant):
"""Test successful Discord message edit."""
# First add the message
discord.add_discord_message(**sample_message_data)
# Edit it
new_content = (
"This is the edited content with enough text to be meaningful and processed."
)
edited_at = "2024-01-01T13:00:00Z"
result = discord.edit_discord_message(
sample_message_data["message_id"],
new_content,
edited_at,
)
assert result["status"] == "processed"
# Verify the message was updated
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message.content == new_content
assert message.edited_at is not None
def test_edit_discord_message_not_found(db_session):
"""Test editing a message that doesn't exist."""
result = discord.edit_discord_message(
999999,
"New content",
"2024-01-01T13:00:00Z",
)
assert result["status"] == "error"
assert result["error"] == "Message not found"
assert result["message_id"] == 999999
def test_edit_discord_message_updates_context(
db_session, sample_message_data, mock_discord_user, qdrant
):
"""Test that editing a message works correctly."""
# Add previous message and the message to be edited
prev_msg = DiscordMessage(
message_id=111111,
channel_id=sample_message_data["channel_id"],
discord_user_id=mock_discord_user.id,
content="Previous message",
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"prev" + bytes(28),
)
db_session.add(prev_msg)
db_session.commit()
discord.add_discord_message(**sample_message_data)
# Edit the message
result = discord.edit_discord_message(
sample_message_data["message_id"],
"Edited content that should have context updated properly.",
"2024-01-01T13:00:00Z",
)
# Verify message was updated
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert (
message.content == "Edited content that should have context updated properly."
)
assert result["status"] == "processed"
def test_process_discord_message_success(db_session, sample_message_data, qdrant):
"""Test processing a Discord message."""
# Add a message first
add_result = discord.add_discord_message(**sample_message_data)
message_id = add_result["discordmessage_id"]
# Process it
result = discord.process_discord_message(message_id)
assert result["status"] == "processed"
assert result["message_id"] == message_id
def test_process_discord_message_not_found(db_session):
"""Test processing a message that doesn't exist."""
result = discord.process_discord_message(999999)
assert result["status"] == "error"
assert result["error"] == "Message not found"
assert result["message_id"] == 999999
@pytest.mark.parametrize(
"sent_at_str,expected_hour",
[
("2024-01-01T12:00:00Z", 12),
("2024-01-01T00:00:00+00:00", 0),
("2024-01-01T23:59:59Z", 23),
],
)
def test_add_discord_message_datetime_parsing(
db_session, sample_message_data, sent_at_str, expected_hour, qdrant
):
"""Test that various datetime formats are parsed correctly."""
sample_message_data["sent_at"] = sent_at_str
discord.add_discord_message(**sample_message_data)
message = (
db_session.query(DiscordMessage)
.filter_by(message_id=sample_message_data["message_id"])
.first()
)
assert message.sent_at.hour == expected_hour
def test_add_discord_message_unique_hash(db_session, sample_message_data, qdrant):
"""Test that message hash includes message_id for uniqueness."""
# Add first message
discord.add_discord_message(**sample_message_data)
# Try to add another message with same content but different message_id
sample_message_data["message_id"] = 888777666
result = discord.add_discord_message(**sample_message_data)
# Should succeed because hash includes message_id
assert result["status"] == "processed"
# Verify both messages exist
messages = (
db_session.query(DiscordMessage)
.filter_by(content=sample_message_data["content"])
.all()
)
assert len(messages) == 2
def test_get_prev_only_returns_messages_from_same_channel(
db_session, mock_discord_user, mock_discord_server
):
"""Test that get_prev only returns messages from the specified channel."""
# Create two channels
channel1 = DiscordChannel(
id=111,
name="channel-1",
channel_type="text",
server_id=mock_discord_server.id,
)
channel2 = DiscordChannel(
id=222,
name="channel-2",
channel_type="text",
server_id=mock_discord_server.id,
)
db_session.add_all([channel1, channel2])
db_session.commit()
# Add messages to both channels
msg1 = DiscordMessage(
message_id=1,
channel_id=channel1.id,
discord_user_id=mock_discord_user.id,
content="Message in channel 1",
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash1" + bytes(26),
)
msg2 = DiscordMessage(
message_id=2,
channel_id=channel2.id,
discord_user_id=mock_discord_user.id,
content="Message in channel 2",
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
modality="text",
sha256=b"hash2" + bytes(26),
)
db_session.add_all([msg1, msg2])
db_session.commit()
# Get previous messages for channel 1
result = discord.get_prev(
db_session,
channel1.id, # type: ignore
datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
)
# Should only return message from channel 1
assert len(result) == 1
assert "Message in channel 1" in result[0]
assert "Message in channel 2" not in result[0]