From 814090dccb9dcbf565692dcfa24bffb5b126765e Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sat, 1 Nov 2025 18:52:37 +0100 Subject: [PATCH] use db bots --- src/memory/common/discord.py | 57 ++++++--- src/memory/discord/api.py | 40 +++--- src/memory/workers/tasks/discord.py | 35 +++++- src/memory/workers/tasks/scheduled_calls.py | 14 ++- tests/memory/common/test_discord.py | 114 +++++++++++------- .../memory/common/test_discord_integration.py | 114 +++++++++++------- .../workers/tasks/test_discord_tasks.py | 16 ++- .../workers/tasks/test_scheduled_calls.py | 28 +++-- 8 files changed, 271 insertions(+), 147 deletions(-) diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py index 2a1a026..e850059 100644 --- a/src/memory/common/discord.py +++ b/src/memory/common/discord.py @@ -20,12 +20,12 @@ def get_api_url() -> str: return f"http://{host}:{port}" -def send_dm(user_identifier: str, message: str) -> bool: +def send_dm(bot_id: int, user_identifier: str, message: str) -> bool: """Send a DM via the Discord collector API""" try: response = requests.post( f"{get_api_url()}/send_dm", - json={"user": user_identifier, "message": message}, + json={"bot_id": bot_id, "user": user_identifier, "message": message}, timeout=10, ) response.raise_for_status() @@ -37,12 +37,16 @@ def send_dm(user_identifier: str, message: str) -> bool: return False -def send_to_channel(channel_name: str, message: str) -> bool: +def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool: """Send a DM via the Discord collector API""" try: response = requests.post( f"{get_api_url()}/send_channel", - json={"channel_name": channel_name, "message": message}, + json={ + "bot_id": bot_id, + "channel_name": channel_name, + "message": message, + }, timeout=10, ) response.raise_for_status() @@ -55,12 +59,16 @@ def send_to_channel(channel_name: str, message: str) -> bool: return False -def broadcast_message(channel_name: str, message: str) -> bool: +def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool: """Send a message to a channel via the Discord collector API""" try: response = requests.post( f"{get_api_url()}/send_channel", - json={"channel_name": channel_name, "message": message}, + json={ + "bot_id": bot_id, + "channel_name": channel_name, + "message": message, + }, timeout=10, ) response.raise_for_status() @@ -72,19 +80,22 @@ def broadcast_message(channel_name: str, message: str) -> bool: return False -def is_collector_healthy() -> bool: +def is_collector_healthy(bot_id: int) -> bool: """Check if the Discord collector is running and healthy""" try: response = requests.get(f"{get_api_url()}/health", timeout=5) response.raise_for_status() result = response.json() - return result.get("status") == "healthy" + bot_status = result.get(str(bot_id)) + if not isinstance(bot_status, dict): + return False + return bool(bot_status.get("connected")) except requests.RequestException: return False -def refresh_discord_metadata() -> dict[str, int] | None: +def refresh_discord_metadata() -> dict[str, Any] | None: """Refresh Discord server/channel/user metadata from Discord API""" try: response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30) @@ -96,24 +107,24 @@ def refresh_discord_metadata() -> dict[str, int] | None: # Convenience functions -def send_error_message(message: str) -> bool: +def send_error_message(bot_id: int, message: str) -> bool: """Send an error message to the error channel""" - return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message) + return broadcast_message(bot_id, settings.DISCORD_ERROR_CHANNEL, message) -def send_activity_message(message: str) -> bool: +def send_activity_message(bot_id: int, message: str) -> bool: """Send an activity message to the activity channel""" - return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message) + return broadcast_message(bot_id, settings.DISCORD_ACTIVITY_CHANNEL, message) -def send_discovery_message(message: str) -> bool: +def send_discovery_message(bot_id: int, message: str) -> bool: """Send a discovery message to the discovery channel""" - return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message) + return broadcast_message(bot_id, settings.DISCORD_DISCOVERY_CHANNEL, message) -def send_chat_message(message: str) -> bool: +def send_chat_message(bot_id: int, message: str) -> bool: """Send a chat message to the chat channel""" - return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message) + return broadcast_message(bot_id, settings.DISCORD_CHAT_CHANNEL, message) def notify_task_failure( @@ -122,6 +133,7 @@ def notify_task_failure( task_args: tuple = (), task_kwargs: dict[str, Any] | None = None, traceback_str: str | None = None, + bot_id: int | None = None, ) -> None: """ Send a task failure notification to Discord. @@ -137,6 +149,15 @@ def notify_task_failure( logger.debug("Discord notifications disabled") return + if bot_id is None: + bot_id = settings.DISCORD_BOT_ID + + if not bot_id: + logger.debug( + "No Discord bot ID provided for task failure notification; skipping" + ) + return + message = f"🚨 **Task Failed: {task_name}**\n\n" message += f"**Error:** {error_message[:500]}\n" @@ -150,7 +171,7 @@ def notify_task_failure( message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```" try: - send_error_message(message) + send_error_message(bot_id, message) logger.info(f"Discord error notification sent for task: {task_name}") except Exception as e: logger.error(f"Failed to send Discord notification: {e}") diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py index fabb8d2..a5e21b6 100644 --- a/src/memory/discord/api.py +++ b/src/memory/discord/api.py @@ -7,17 +7,18 @@ providing HTTP endpoints for sending Discord messages. import asyncio import logging -from contextlib import asynccontextmanager import traceback +from contextlib import asynccontextmanager +from typing import cast +import uvicorn from fastapi import FastAPI, HTTPException from pydantic import BaseModel -import uvicorn from memory.common import settings -from memory.discord.collector import MessageCollector -from memory.common.db.models.users import BotUser from memory.common.db.connection import make_session +from memory.common.db.models.users import DiscordBotUser +from memory.discord.collector import MessageCollector logger = logging.getLogger(__name__) @@ -41,37 +42,25 @@ class Collector: bot_token: str bot_name: str - def __init__(self, collector: MessageCollector, bot: BotUser): + def __init__(self, collector: MessageCollector, bot: DiscordBotUser): self.collector = collector - self.collector_task = asyncio.create_task(collector.start(bot.api_key)) - self.bot_id = bot.id - self.bot_token = bot.api_key - self.bot_name = bot.name - - -# Application state -class AppState: - def __init__(self): - self.collector: MessageCollector | None = None - self.collector_task: asyncio.Task | None = None - - -app_state = AppState() + self.collector_task = asyncio.create_task(collector.start(str(bot.api_key))) + self.bot_id = cast(int, bot.id) + self.bot_token = str(bot.api_key) + self.bot_name = str(bot.name) @asynccontextmanager async def lifespan(app: FastAPI): """Manage Discord collector lifecycle""" - if not settings.DISCORD_BOT_TOKEN: - logger.error("DISCORD_BOT_TOKEN not configured") - return - def make_collector(bot: BotUser): + def make_collector(bot: DiscordBotUser): collector = MessageCollector() return Collector(collector=collector, bot=bot) with make_session() as session: - app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()} + bots = session.query(DiscordBotUser).all() + app.bots = {bot.id: make_collector(bot) for bot in bots} logger.info(f"Discord collectors started for {len(app.bots)} bots") @@ -155,9 +144,8 @@ async def health_check(): if not app.bots: raise HTTPException(status_code=503, detail="Discord collector not running") - collector = app_state.collector return { - collector.bot_name: { + bot.bot_name: { "status": "healthy", "connected": not bot.collector.is_closed(), "user": str(bot.collector.user) if bot.collector.user else None, diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index 102fbd7..c1cdda0 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -138,6 +138,18 @@ def should_process(message: DiscordMessage) -> bool: return False +def _resolve_bot_id(discord_message: DiscordMessage) -> int | None: + recipient = discord_message.recipient_user + if not recipient: + return None + + system_user = recipient.system_user + if not system_user: + return None + + return getattr(system_user, "id", None) + + @app.task(name=PROCESS_DISCORD_MESSAGE) @safe_task_execution def process_discord_message(message_id: int) -> dict[str, Any]: @@ -161,11 +173,26 @@ def process_discord_message(message_id: int) -> dict[str, Any]: response = call_llm(session, discord_message, settings.DISCORD_MODEL) if not response: - pass - elif discord_message.channel.server: - discord.send_to_channel(discord_message.channel.name, response) + return { + "status": "processed", + "message_id": message_id, + } + + bot_id = _resolve_bot_id(discord_message) + if not bot_id: + logger.warning( + "No associated Discord bot user for message %s; skipping send", + message_id, + ) + return { + "status": "processed", + "message_id": message_id, + } + + if discord_message.channel.server: + discord.send_to_channel(bot_id, discord_message.channel.name, response) else: - discord.send_dm(discord_message.from_user.username, response) + discord.send_dm(bot_id, discord_message.from_user.username, response) return { "status": "processed", diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py index 3ca6cad..76746b9 100644 --- a/src/memory/workers/tasks/scheduled_calls.py +++ b/src/memory/workers/tasks/scheduled_calls.py @@ -37,12 +37,22 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str): if len(message) > 1900: # Leave some buffer message = message[:1900] + "\n\n... (response truncated)" + bot_id_value = scheduled_call.user_id + if bot_id_value is None: + logger.warning( + "Scheduled call %s has no associated bot user; skipping Discord send", + scheduled_call.id, + ) + return + + bot_id = cast(int, bot_id_value) + if discord_user := scheduled_call.discord_user: logger.info(f"Sending DM to {discord_user.username}: {message}") - discord.send_dm(discord_user.username, message) + discord.send_dm(bot_id, discord_user.username, message) elif discord_channel := scheduled_call.discord_channel: logger.info(f"Broadcasting message to {discord_channel.name}: {message}") - discord.broadcast_message(discord_channel.name, message) + discord.broadcast_message(bot_id, discord_channel.name, message) else: logger.warning( f"No Discord user or channel found for scheduled call {scheduled_call.id}" diff --git a/tests/memory/common/test_discord.py b/tests/memory/common/test_discord.py index 274ba8e..ce4cf3f 100644 --- a/tests/memory/common/test_discord.py +++ b/tests/memory/common/test_discord.py @@ -4,6 +4,8 @@ import requests from memory.common import discord +BOT_ID = 42 + @pytest.fixture def mock_api_url(): @@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url): mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response - result = discord.send_dm("user123", "Hello!") + result = discord.send_dm(BOT_ID, "user123", "Hello!") assert result is True mock_post.assert_called_once_with( "http://localhost:8000/send_dm", - json={"user": "user123", "message": "Hello!"}, + json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"}, timeout=10, ) @@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url): mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response - result = discord.send_dm("user123", "Hello!") + result = discord.send_dm(BOT_ID, "user123", "Hello!") assert result is False @@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url): """Test DM sending when request raises exception""" mock_post.side_effect = requests.RequestException("Network error") - result = discord.send_dm("user123", "Hello!") + result = discord.send_dm(BOT_ID, "user123", "Hello!") assert result is False @@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url): mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") mock_post.return_value = mock_response - result = discord.send_dm("user123", "Hello!") + result = discord.send_dm(BOT_ID, "user123", "Hello!") assert result is False @@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url): mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response - result = discord.broadcast_message("general", "Announcement!") + result = discord.broadcast_message(BOT_ID, "general", "Announcement!") assert result is True mock_post.assert_called_once_with( "http://localhost:8000/send_channel", - json={"channel_name": "general", "message": "Announcement!"}, + json={ + "bot_id": BOT_ID, + "channel_name": "general", + "message": "Announcement!", + }, timeout=10, ) @@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url): mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response - result = discord.broadcast_message("general", "Announcement!") + result = discord.broadcast_message(BOT_ID, "general", "Announcement!") assert result is False @@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url): """Test channel message broadcast with exception""" mock_post.side_effect = requests.Timeout("Request timeout") - result = discord.broadcast_message("general", "Announcement!") + result = discord.broadcast_message(BOT_ID, "general", "Announcement!") assert result is False @@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url): def test_is_collector_healthy_true(mock_get, mock_api_url): """Test health check when collector is healthy""" mock_response = Mock() - mock_response.json.return_value = {"status": "healthy"} + mock_response.json.return_value = {str(BOT_ID): {"connected": True}} mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - result = discord.is_collector_healthy() + result = discord.is_collector_healthy(BOT_ID) assert result is True mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5) @@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url): def test_is_collector_healthy_false_status(mock_get, mock_api_url): """Test health check when collector returns unhealthy status""" mock_response = Mock() - mock_response.json.return_value = {"status": "unhealthy"} + mock_response.json.return_value = {str(BOT_ID): {"connected": False}} mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - result = discord.is_collector_healthy() + result = discord.is_collector_healthy(BOT_ID) assert result is False @@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url): """Test health check when request fails""" mock_get.side_effect = requests.ConnectionError("Connection refused") - result = discord.is_collector_healthy() + result = discord.is_collector_healthy(BOT_ID) assert result is False @@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast): """Test sending error message to error channel""" mock_broadcast.return_value = True - result = discord.send_error_message("Something broke") + result = discord.send_error_message(BOT_ID, "Something broke") assert result is True - mock_broadcast.assert_called_once_with("errors", "Something broke") + mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke") @patch("memory.common.discord.broadcast_message") @@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast): """Test sending activity message to activity channel""" mock_broadcast.return_value = True - result = discord.send_activity_message("User logged in") + result = discord.send_activity_message(BOT_ID, "User logged in") assert result is True - mock_broadcast.assert_called_once_with("activity", "User logged in") + mock_broadcast.assert_called_once_with( + BOT_ID, "activity", "User logged in" + ) @patch("memory.common.discord.broadcast_message") @@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast): """Test sending discovery message to discovery channel""" mock_broadcast.return_value = True - result = discord.send_discovery_message("Found interesting pattern") + result = discord.send_discovery_message(BOT_ID, "Found interesting pattern") assert result is True - mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern") + mock_broadcast.assert_called_once_with( + BOT_ID, "discoveries", "Found interesting pattern" + ) @patch("memory.common.discord.broadcast_message") @@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast): """Test sending chat message to chat channel""" mock_broadcast.return_value = True - result = discord.send_chat_message("Hello from bot") + result = discord.send_chat_message(BOT_ID, "Hello from bot") assert result is True - mock_broadcast.assert_called_once_with("chat", "Hello from bot") + mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot") @patch("memory.common.discord.send_error_message") @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) def test_notify_task_failure_basic(mock_send_error): """Test basic task failure notification""" - discord.notify_task_failure("test_task", "Something went wrong") + discord.notify_task_failure( + "test_task", "Something went wrong", bot_id=BOT_ID + ) mock_send_error.assert_called_once() - message = mock_send_error.call_args[0][0] + assert mock_send_error.call_args[0][0] == BOT_ID + message = mock_send_error.call_args[0][1] assert "🚨 **Task Failed: test_task**" in message assert "**Error:** Something went wrong" in message @@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error): "Error occurred", task_args=("arg1", 42), task_kwargs={"key": "value", "number": 123}, + bot_id=BOT_ID, ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] assert "**Args:** `('arg1', 42)" in message assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message @@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error): """Test task failure notification with traceback""" traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test" - discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback) + discord.notify_task_failure( + "test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID + ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] assert "**Traceback:**" in message assert "Exception: test" in message @@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error): """Test that long error messages are truncated""" long_error = "x" * 600 - discord.notify_task_failure("test_task", long_error) + discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] # Error should be truncated to 500 chars - check that the full 600 char string is not there assert "**Error:** " + long_error[:500] in message @@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error): """Test that long tracebacks are truncated""" long_traceback = "x" * 1000 - discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback) + discord.notify_task_failure( + "test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID + ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] # Traceback should show last 800 chars assert long_traceback[-800:] in message @@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error): """Test that long task arguments are truncated""" long_args = ("x" * 300,) - discord.notify_task_failure("test_task", "Error", task_args=long_args) + discord.notify_task_failure( + "test_task", "Error", task_args=long_args, bot_id=BOT_ID + ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] # Args should be truncated to 200 chars assert ( @@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error): """Test that long task kwargs are truncated""" long_kwargs = {"key": "x" * 300} - discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs) + discord.notify_task_failure( + "test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID + ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] # Kwargs should be truncated to 200 chars assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210 @@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error): @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) def test_notify_task_failure_disabled(mock_send_error): """Test that notifications are not sent when disabled""" - discord.notify_task_failure("test_task", "Error occurred") + discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID) mock_send_error.assert_not_called() @@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error): mock_send_error.side_effect = Exception("Failed to send") # Should not raise - discord.notify_task_failure("test_task", "Error occurred") + discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID) mock_send_error.assert_called_once() @@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels( ): """Test that convenience functions use the correct channel settings""" with patch(f"memory.common.settings.{channel_setting}", "test-channel"): - function(message) - mock_broadcast.assert_called_once_with("test-channel", message) + function(BOT_ID, message) + mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message) @patch("requests.post") @@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url): mock_post.return_value = mock_response message_with_special_chars = "Hello! 🎉 <@123> #general" - result = discord.send_dm("user123", message_with_special_chars) + result = discord.send_dm(BOT_ID, "user123", message_with_special_chars) assert result is True call_args = mock_post.call_args - assert call_args[1]["json"]["message"] == message_with_special_chars + json_payload = call_args[1]["json"] + assert json_payload["message"] == message_with_special_chars + assert json_payload["bot_id"] == BOT_ID @patch("requests.post") @@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url): mock_post.return_value = mock_response long_message = "A" * 2000 - result = discord.broadcast_message("general", long_message) + result = discord.broadcast_message(BOT_ID, "general", long_message) assert result is True call_args = mock_post.call_args - assert call_args[1]["json"]["message"] == long_message + json_payload = call_args[1]["json"] + assert json_payload["message"] == long_message + assert json_payload["bot_id"] == BOT_ID @patch("requests.get") @@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url): mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - result = discord.is_collector_healthy() + result = discord.is_collector_healthy(BOT_ID) assert result is False diff --git a/tests/memory/common/test_discord_integration.py b/tests/memory/common/test_discord_integration.py index 274ba8e..ce4cf3f 100644 --- a/tests/memory/common/test_discord_integration.py +++ b/tests/memory/common/test_discord_integration.py @@ -4,6 +4,8 @@ import requests from memory.common import discord +BOT_ID = 42 + @pytest.fixture def mock_api_url(): @@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url): mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response - result = discord.send_dm("user123", "Hello!") + result = discord.send_dm(BOT_ID, "user123", "Hello!") assert result is True mock_post.assert_called_once_with( "http://localhost:8000/send_dm", - json={"user": "user123", "message": "Hello!"}, + json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"}, timeout=10, ) @@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url): mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response - result = discord.send_dm("user123", "Hello!") + result = discord.send_dm(BOT_ID, "user123", "Hello!") assert result is False @@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url): """Test DM sending when request raises exception""" mock_post.side_effect = requests.RequestException("Network error") - result = discord.send_dm("user123", "Hello!") + result = discord.send_dm(BOT_ID, "user123", "Hello!") assert result is False @@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url): mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") mock_post.return_value = mock_response - result = discord.send_dm("user123", "Hello!") + result = discord.send_dm(BOT_ID, "user123", "Hello!") assert result is False @@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url): mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response - result = discord.broadcast_message("general", "Announcement!") + result = discord.broadcast_message(BOT_ID, "general", "Announcement!") assert result is True mock_post.assert_called_once_with( "http://localhost:8000/send_channel", - json={"channel_name": "general", "message": "Announcement!"}, + json={ + "bot_id": BOT_ID, + "channel_name": "general", + "message": "Announcement!", + }, timeout=10, ) @@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url): mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response - result = discord.broadcast_message("general", "Announcement!") + result = discord.broadcast_message(BOT_ID, "general", "Announcement!") assert result is False @@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url): """Test channel message broadcast with exception""" mock_post.side_effect = requests.Timeout("Request timeout") - result = discord.broadcast_message("general", "Announcement!") + result = discord.broadcast_message(BOT_ID, "general", "Announcement!") assert result is False @@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url): def test_is_collector_healthy_true(mock_get, mock_api_url): """Test health check when collector is healthy""" mock_response = Mock() - mock_response.json.return_value = {"status": "healthy"} + mock_response.json.return_value = {str(BOT_ID): {"connected": True}} mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - result = discord.is_collector_healthy() + result = discord.is_collector_healthy(BOT_ID) assert result is True mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5) @@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url): def test_is_collector_healthy_false_status(mock_get, mock_api_url): """Test health check when collector returns unhealthy status""" mock_response = Mock() - mock_response.json.return_value = {"status": "unhealthy"} + mock_response.json.return_value = {str(BOT_ID): {"connected": False}} mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - result = discord.is_collector_healthy() + result = discord.is_collector_healthy(BOT_ID) assert result is False @@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url): """Test health check when request fails""" mock_get.side_effect = requests.ConnectionError("Connection refused") - result = discord.is_collector_healthy() + result = discord.is_collector_healthy(BOT_ID) assert result is False @@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast): """Test sending error message to error channel""" mock_broadcast.return_value = True - result = discord.send_error_message("Something broke") + result = discord.send_error_message(BOT_ID, "Something broke") assert result is True - mock_broadcast.assert_called_once_with("errors", "Something broke") + mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke") @patch("memory.common.discord.broadcast_message") @@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast): """Test sending activity message to activity channel""" mock_broadcast.return_value = True - result = discord.send_activity_message("User logged in") + result = discord.send_activity_message(BOT_ID, "User logged in") assert result is True - mock_broadcast.assert_called_once_with("activity", "User logged in") + mock_broadcast.assert_called_once_with( + BOT_ID, "activity", "User logged in" + ) @patch("memory.common.discord.broadcast_message") @@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast): """Test sending discovery message to discovery channel""" mock_broadcast.return_value = True - result = discord.send_discovery_message("Found interesting pattern") + result = discord.send_discovery_message(BOT_ID, "Found interesting pattern") assert result is True - mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern") + mock_broadcast.assert_called_once_with( + BOT_ID, "discoveries", "Found interesting pattern" + ) @patch("memory.common.discord.broadcast_message") @@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast): """Test sending chat message to chat channel""" mock_broadcast.return_value = True - result = discord.send_chat_message("Hello from bot") + result = discord.send_chat_message(BOT_ID, "Hello from bot") assert result is True - mock_broadcast.assert_called_once_with("chat", "Hello from bot") + mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot") @patch("memory.common.discord.send_error_message") @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) def test_notify_task_failure_basic(mock_send_error): """Test basic task failure notification""" - discord.notify_task_failure("test_task", "Something went wrong") + discord.notify_task_failure( + "test_task", "Something went wrong", bot_id=BOT_ID + ) mock_send_error.assert_called_once() - message = mock_send_error.call_args[0][0] + assert mock_send_error.call_args[0][0] == BOT_ID + message = mock_send_error.call_args[0][1] assert "🚨 **Task Failed: test_task**" in message assert "**Error:** Something went wrong" in message @@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error): "Error occurred", task_args=("arg1", 42), task_kwargs={"key": "value", "number": 123}, + bot_id=BOT_ID, ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] assert "**Args:** `('arg1', 42)" in message assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message @@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error): """Test task failure notification with traceback""" traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test" - discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback) + discord.notify_task_failure( + "test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID + ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] assert "**Traceback:**" in message assert "Exception: test" in message @@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error): """Test that long error messages are truncated""" long_error = "x" * 600 - discord.notify_task_failure("test_task", long_error) + discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] # Error should be truncated to 500 chars - check that the full 600 char string is not there assert "**Error:** " + long_error[:500] in message @@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error): """Test that long tracebacks are truncated""" long_traceback = "x" * 1000 - discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback) + discord.notify_task_failure( + "test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID + ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] # Traceback should show last 800 chars assert long_traceback[-800:] in message @@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error): """Test that long task arguments are truncated""" long_args = ("x" * 300,) - discord.notify_task_failure("test_task", "Error", task_args=long_args) + discord.notify_task_failure( + "test_task", "Error", task_args=long_args, bot_id=BOT_ID + ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] # Args should be truncated to 200 chars assert ( @@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error): """Test that long task kwargs are truncated""" long_kwargs = {"key": "x" * 300} - discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs) + discord.notify_task_failure( + "test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID + ) - message = mock_send_error.call_args[0][0] + message = mock_send_error.call_args[0][1] # Kwargs should be truncated to 200 chars assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210 @@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error): @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) def test_notify_task_failure_disabled(mock_send_error): """Test that notifications are not sent when disabled""" - discord.notify_task_failure("test_task", "Error occurred") + discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID) mock_send_error.assert_not_called() @@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error): mock_send_error.side_effect = Exception("Failed to send") # Should not raise - discord.notify_task_failure("test_task", "Error occurred") + discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID) mock_send_error.assert_called_once() @@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels( ): """Test that convenience functions use the correct channel settings""" with patch(f"memory.common.settings.{channel_setting}", "test-channel"): - function(message) - mock_broadcast.assert_called_once_with("test-channel", message) + function(BOT_ID, message) + mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message) @patch("requests.post") @@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url): mock_post.return_value = mock_response message_with_special_chars = "Hello! 🎉 <@123> #general" - result = discord.send_dm("user123", message_with_special_chars) + result = discord.send_dm(BOT_ID, "user123", message_with_special_chars) assert result is True call_args = mock_post.call_args - assert call_args[1]["json"]["message"] == message_with_special_chars + json_payload = call_args[1]["json"] + assert json_payload["message"] == message_with_special_chars + assert json_payload["bot_id"] == BOT_ID @patch("requests.post") @@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url): mock_post.return_value = mock_response long_message = "A" * 2000 - result = discord.broadcast_message("general", long_message) + result = discord.broadcast_message(BOT_ID, "general", long_message) assert result is True call_args = mock_post.call_args - assert call_args[1]["json"]["message"] == long_message + json_payload = call_args[1]["json"] + assert json_payload["message"] == long_message + assert json_payload["bot_id"] == BOT_ID @patch("requests.get") @@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url): mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - result = discord.is_collector_healthy() + result = discord.is_collector_healthy(BOT_ID) assert result is False diff --git a/tests/memory/workers/tasks/test_discord_tasks.py b/tests/memory/workers/tasks/test_discord_tasks.py index 61e985d..8decc7d 100644 --- a/tests/memory/workers/tasks/test_discord_tasks.py +++ b/tests/memory/workers/tasks/test_discord_tasks.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from unittest.mock import Mock, patch from memory.common.db.models import ( + DiscordBotUser, DiscordMessage, DiscordUser, DiscordServer, @@ -12,12 +13,25 @@ from memory.workers.tasks import discord @pytest.fixture -def mock_discord_user(db_session): +def discord_bot_user(db_session): + bot = DiscordBotUser.create_with_api_key( + discord_users=[], + name="Test Bot", + email="bot@example.com", + ) + db_session.add(bot) + db_session.commit() + return bot + + +@pytest.fixture +def mock_discord_user(db_session, discord_bot_user): """Create a Discord user for testing.""" user = DiscordUser( id=123456789, username="testuser", ignore_messages=False, + system_user_id=discord_bot_user.id, ) db_session.add(user) db_session.commit() diff --git a/tests/memory/workers/tasks/test_scheduled_calls.py b/tests/memory/workers/tasks/test_scheduled_calls.py index decf6a6..5b49a8a 100644 --- a/tests/memory/workers/tasks/test_scheduled_calls.py +++ b/tests/memory/workers/tasks/test_scheduled_calls.py @@ -3,17 +3,23 @@ from datetime import datetime, timezone, timedelta from unittest.mock import Mock, patch import uuid -from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer +from memory.common.db.models import ( + ScheduledLLMCall, + DiscordBotUser, + DiscordUser, + DiscordChannel, + DiscordServer, +) from memory.workers.tasks import scheduled_calls @pytest.fixture def sample_user(db_session): """Create a sample user for testing.""" - user = HumanUser.create_with_password( - name="testuser", - email="test@example.com", - password="password", + user = DiscordBotUser.create_with_api_key( + discord_users=[], + name="testbot", + email="bot@example.com", ) db_session.add(user) db_session.commit() @@ -124,6 +130,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call): scheduled_calls._send_to_discord(pending_scheduled_call, response) mock_send_dm.assert_called_once_with( + pending_scheduled_call.user_id, "testuser", # username, not ID "**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.", ) @@ -137,6 +144,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call): scheduled_calls._send_to_discord(completed_scheduled_call, response) mock_broadcast.assert_called_once_with( + completed_scheduled_call.user_id, "test-channel", # channel name, not ID "**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.", ) @@ -151,7 +159,8 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled # Verify the message was truncated args, kwargs = mock_send_dm.call_args - message = args[1] + assert args[0] == pending_scheduled_call.user_id + message = args[2] assert len(message) <= 1950 # Should be truncated assert message.endswith("... (response truncated)") @@ -164,7 +173,8 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c scheduled_calls._send_to_discord(pending_scheduled_call, normal_response) args, kwargs = mock_send_dm.call_args - message = args[1] + assert args[0] == pending_scheduled_call.user_id + message = args[2] assert not message.endswith("... (response truncated)") assert "This is a normal length response." in message @@ -574,6 +584,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me mock_discord_user.username = "testuser" mock_call = Mock() + mock_call.user_id = 987 mock_call.topic = topic mock_call.model = model mock_call.discord_user = mock_discord_user @@ -583,7 +594,8 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me # Get the actual message that was sent args, kwargs = mock_send_dm.call_args - actual_message = args[1] + assert args[0] == mock_call.user_id + actual_message = args[2] # Verify all expected parts are in the message for expected_part in expected_in_message: