use db bots

This commit is contained in:
Daniel O'Connell 2025-11-01 18:52:37 +01:00
parent 9639fa3dd7
commit 814090dccb
8 changed files with 271 additions and 147 deletions

View File

@ -20,12 +20,12 @@ def get_api_url() -> str:
return f"http://{host}:{port}" return f"http://{host}:{port}"
def send_dm(user_identifier: str, message: str) -> bool: def send_dm(bot_id: int, user_identifier: str, message: str) -> bool:
"""Send a DM via the Discord collector API""" """Send a DM via the Discord collector API"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_dm", f"{get_api_url()}/send_dm",
json={"user": user_identifier, "message": message}, json={"bot_id": bot_id, "user": user_identifier, "message": message},
timeout=10, timeout=10,
) )
response.raise_for_status() response.raise_for_status()
@ -37,12 +37,16 @@ def send_dm(user_identifier: str, message: str) -> bool:
return False return False
def send_to_channel(channel_name: str, message: str) -> bool: def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
"""Send a DM via the Discord collector API""" """Send a DM via the Discord collector API"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_channel", f"{get_api_url()}/send_channel",
json={"channel_name": channel_name, "message": message}, json={
"bot_id": bot_id,
"channel_name": channel_name,
"message": message,
},
timeout=10, timeout=10,
) )
response.raise_for_status() response.raise_for_status()
@ -55,12 +59,16 @@ def send_to_channel(channel_name: str, message: str) -> bool:
return False return False
def broadcast_message(channel_name: str, message: str) -> bool: def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API""" """Send a message to a channel via the Discord collector API"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_channel", f"{get_api_url()}/send_channel",
json={"channel_name": channel_name, "message": message}, json={
"bot_id": bot_id,
"channel_name": channel_name,
"message": message,
},
timeout=10, timeout=10,
) )
response.raise_for_status() response.raise_for_status()
@ -72,19 +80,22 @@ def broadcast_message(channel_name: str, message: str) -> bool:
return False return False
def is_collector_healthy() -> bool: def is_collector_healthy(bot_id: int) -> bool:
"""Check if the Discord collector is running and healthy""" """Check if the Discord collector is running and healthy"""
try: try:
response = requests.get(f"{get_api_url()}/health", timeout=5) response = requests.get(f"{get_api_url()}/health", timeout=5)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
return result.get("status") == "healthy" bot_status = result.get(str(bot_id))
if not isinstance(bot_status, dict):
return False
return bool(bot_status.get("connected"))
except requests.RequestException: except requests.RequestException:
return False return False
def refresh_discord_metadata() -> dict[str, int] | None: def refresh_discord_metadata() -> dict[str, Any] | None:
"""Refresh Discord server/channel/user metadata from Discord API""" """Refresh Discord server/channel/user metadata from Discord API"""
try: try:
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30) response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
@ -96,24 +107,24 @@ def refresh_discord_metadata() -> dict[str, int] | None:
# Convenience functions # Convenience functions
def send_error_message(message: str) -> bool: def send_error_message(bot_id: int, message: str) -> bool:
"""Send an error message to the error channel""" """Send an error message to the error channel"""
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message) return broadcast_message(bot_id, settings.DISCORD_ERROR_CHANNEL, message)
def send_activity_message(message: str) -> bool: def send_activity_message(bot_id: int, message: str) -> bool:
"""Send an activity message to the activity channel""" """Send an activity message to the activity channel"""
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message) return broadcast_message(bot_id, settings.DISCORD_ACTIVITY_CHANNEL, message)
def send_discovery_message(message: str) -> bool: def send_discovery_message(bot_id: int, message: str) -> bool:
"""Send a discovery message to the discovery channel""" """Send a discovery message to the discovery channel"""
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message) return broadcast_message(bot_id, settings.DISCORD_DISCOVERY_CHANNEL, message)
def send_chat_message(message: str) -> bool: def send_chat_message(bot_id: int, message: str) -> bool:
"""Send a chat message to the chat channel""" """Send a chat message to the chat channel"""
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message) return broadcast_message(bot_id, settings.DISCORD_CHAT_CHANNEL, message)
def notify_task_failure( def notify_task_failure(
@ -122,6 +133,7 @@ def notify_task_failure(
task_args: tuple = (), task_args: tuple = (),
task_kwargs: dict[str, Any] | None = None, task_kwargs: dict[str, Any] | None = None,
traceback_str: str | None = None, traceback_str: str | None = None,
bot_id: int | None = None,
) -> None: ) -> None:
""" """
Send a task failure notification to Discord. Send a task failure notification to Discord.
@ -137,6 +149,15 @@ def notify_task_failure(
logger.debug("Discord notifications disabled") logger.debug("Discord notifications disabled")
return return
if bot_id is None:
bot_id = settings.DISCORD_BOT_ID
if not bot_id:
logger.debug(
"No Discord bot ID provided for task failure notification; skipping"
)
return
message = f"🚨 **Task Failed: {task_name}**\n\n" message = f"🚨 **Task Failed: {task_name}**\n\n"
message += f"**Error:** {error_message[:500]}\n" message += f"**Error:** {error_message[:500]}\n"
@ -150,7 +171,7 @@ def notify_task_failure(
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```" message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
try: try:
send_error_message(message) send_error_message(bot_id, message)
logger.info(f"Discord error notification sent for task: {task_name}") logger.info(f"Discord error notification sent for task: {task_name}")
except Exception as e: except Exception as e:
logger.error(f"Failed to send Discord notification: {e}") logger.error(f"Failed to send Discord notification: {e}")

View File

@ -7,17 +7,18 @@ providing HTTP endpoints for sending Discord messages.
import asyncio import asyncio
import logging import logging
from contextlib import asynccontextmanager
import traceback import traceback
from contextlib import asynccontextmanager
from typing import cast
import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import uvicorn
from memory.common import settings from memory.common import settings
from memory.discord.collector import MessageCollector
from memory.common.db.models.users import BotUser
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models.users import DiscordBotUser
from memory.discord.collector import MessageCollector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,37 +42,25 @@ class Collector:
bot_token: str bot_token: str
bot_name: str bot_name: str
def __init__(self, collector: MessageCollector, bot: BotUser): def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
self.collector = collector self.collector = collector
self.collector_task = asyncio.create_task(collector.start(bot.api_key)) self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
self.bot_id = bot.id self.bot_id = cast(int, bot.id)
self.bot_token = bot.api_key self.bot_token = str(bot.api_key)
self.bot_name = bot.name self.bot_name = str(bot.name)
# Application state
class AppState:
def __init__(self):
self.collector: MessageCollector | None = None
self.collector_task: asyncio.Task | None = None
app_state = AppState()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Manage Discord collector lifecycle""" """Manage Discord collector lifecycle"""
if not settings.DISCORD_BOT_TOKEN:
logger.error("DISCORD_BOT_TOKEN not configured")
return
def make_collector(bot: BotUser): def make_collector(bot: DiscordBotUser):
collector = MessageCollector() collector = MessageCollector()
return Collector(collector=collector, bot=bot) return Collector(collector=collector, bot=bot)
with make_session() as session: with make_session() as session:
app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()} bots = session.query(DiscordBotUser).all()
app.bots = {bot.id: make_collector(bot) for bot in bots}
logger.info(f"Discord collectors started for {len(app.bots)} bots") logger.info(f"Discord collectors started for {len(app.bots)} bots")
@ -155,9 +144,8 @@ async def health_check():
if not app.bots: if not app.bots:
raise HTTPException(status_code=503, detail="Discord collector not running") raise HTTPException(status_code=503, detail="Discord collector not running")
collector = app_state.collector
return { return {
collector.bot_name: { bot.bot_name: {
"status": "healthy", "status": "healthy",
"connected": not bot.collector.is_closed(), "connected": not bot.collector.is_closed(),
"user": str(bot.collector.user) if bot.collector.user else None, "user": str(bot.collector.user) if bot.collector.user else None,

View File

@ -138,6 +138,18 @@ def should_process(message: DiscordMessage) -> bool:
return False return False
def _resolve_bot_id(discord_message: DiscordMessage) -> int | None:
recipient = discord_message.recipient_user
if not recipient:
return None
system_user = recipient.system_user
if not system_user:
return None
return getattr(system_user, "id", None)
@app.task(name=PROCESS_DISCORD_MESSAGE) @app.task(name=PROCESS_DISCORD_MESSAGE)
@safe_task_execution @safe_task_execution
def process_discord_message(message_id: int) -> dict[str, Any]: def process_discord_message(message_id: int) -> dict[str, Any]:
@ -161,11 +173,26 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
response = call_llm(session, discord_message, settings.DISCORD_MODEL) response = call_llm(session, discord_message, settings.DISCORD_MODEL)
if not response: if not response:
pass return {
elif discord_message.channel.server: "status": "processed",
discord.send_to_channel(discord_message.channel.name, response) "message_id": message_id,
}
bot_id = _resolve_bot_id(discord_message)
if not bot_id:
logger.warning(
"No associated Discord bot user for message %s; skipping send",
message_id,
)
return {
"status": "processed",
"message_id": message_id,
}
if discord_message.channel.server:
discord.send_to_channel(bot_id, discord_message.channel.name, response)
else: else:
discord.send_dm(discord_message.from_user.username, response) discord.send_dm(bot_id, discord_message.from_user.username, response)
return { return {
"status": "processed", "status": "processed",

View File

@ -37,12 +37,22 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
if len(message) > 1900: # Leave some buffer if len(message) > 1900: # Leave some buffer
message = message[:1900] + "\n\n... (response truncated)" message = message[:1900] + "\n\n... (response truncated)"
bot_id_value = scheduled_call.user_id
if bot_id_value is None:
logger.warning(
"Scheduled call %s has no associated bot user; skipping Discord send",
scheduled_call.id,
)
return
bot_id = cast(int, bot_id_value)
if discord_user := scheduled_call.discord_user: if discord_user := scheduled_call.discord_user:
logger.info(f"Sending DM to {discord_user.username}: {message}") logger.info(f"Sending DM to {discord_user.username}: {message}")
discord.send_dm(discord_user.username, message) discord.send_dm(bot_id, discord_user.username, message)
elif discord_channel := scheduled_call.discord_channel: elif discord_channel := scheduled_call.discord_channel:
logger.info(f"Broadcasting message to {discord_channel.name}: {message}") logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
discord.broadcast_message(discord_channel.name, message) discord.broadcast_message(bot_id, discord_channel.name, message)
else: else:
logger.warning( logger.warning(
f"No Discord user or channel found for scheduled call {scheduled_call.id}" f"No Discord user or channel found for scheduled call {scheduled_call.id}"

View File

@ -4,6 +4,8 @@ import requests
from memory.common import discord from memory.common import discord
BOT_ID = 42
@pytest.fixture @pytest.fixture
def mock_api_url(): def mock_api_url():
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is True assert result is True
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"http://localhost:8000/send_dm", "http://localhost:8000/send_dm",
json={"user": "user123", "message": "Hello!"}, json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
timeout=10, timeout=10,
) )
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception""" """Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error") mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is True assert result is True
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"http://localhost:8000/send_channel", "http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"}, json={
"bot_id": BOT_ID,
"channel_name": "general",
"message": "Announcement!",
},
timeout=10, timeout=10,
) )
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False assert result is False
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception""" """Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout") mock_post.side_effect = requests.Timeout("Request timeout")
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False assert result is False
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
def test_is_collector_healthy_true(mock_get, mock_api_url): def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy""" """Test health check when collector is healthy"""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"status": "healthy"} mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is True assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5) mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
def test_is_collector_healthy_false_status(mock_get, mock_api_url): def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status""" """Test health check when collector returns unhealthy status"""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"} mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails""" """Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused") mock_get.side_effect = requests.ConnectionError("Connection refused")
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
"""Test sending error message to error channel""" """Test sending error message to error channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_error_message("Something broke") result = discord.send_error_message(BOT_ID, "Something broke")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke") mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
"""Test sending activity message to activity channel""" """Test sending activity message to activity channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in") result = discord.send_activity_message(BOT_ID, "User logged in")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in") mock_broadcast.assert_called_once_with(
BOT_ID, "activity", "User logged in"
)
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
"""Test sending discovery message to discovery channel""" """Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern") result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern") mock_broadcast.assert_called_once_with(
BOT_ID, "discoveries", "Found interesting pattern"
)
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
"""Test sending chat message to chat channel""" """Test sending chat message to chat channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot") result = discord.send_chat_message(BOT_ID, "Hello from bot")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot") mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
@patch("memory.common.discord.send_error_message") @patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error): def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification""" """Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong") discord.notify_task_failure(
"test_task", "Something went wrong", bot_id=BOT_ID
)
mock_send_error.assert_called_once() mock_send_error.assert_called_once()
message = mock_send_error.call_args[0][0] assert mock_send_error.call_args[0][0] == BOT_ID
message = mock_send_error.call_args[0][1]
assert "🚨 **Task Failed: test_task**" in message assert "🚨 **Task Failed: test_task**" in message
assert "**Error:** Something went wrong" in message assert "**Error:** Something went wrong" in message
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
"Error occurred", "Error occurred",
task_args=("arg1", 42), task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123}, task_kwargs={"key": "value", "number": 123},
bot_id=BOT_ID,
) )
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
assert "**Args:** `('arg1', 42)" in message assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
"""Test task failure notification with traceback""" """Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test" traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback) discord.notify_task_failure(
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
assert "**Traceback:**" in message assert "**Traceback:**" in message
assert "Exception: test" in message assert "Exception: test" in message
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
"""Test that long error messages are truncated""" """Test that long error messages are truncated"""
long_error = "x" * 600 long_error = "x" * 600
discord.notify_task_failure("test_task", long_error) discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Error should be truncated to 500 chars - check that the full 600 char string is not there # Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message assert "**Error:** " + long_error[:500] in message
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
"""Test that long tracebacks are truncated""" """Test that long tracebacks are truncated"""
long_traceback = "x" * 1000 long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback) discord.notify_task_failure(
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Traceback should show last 800 chars # Traceback should show last 800 chars
assert long_traceback[-800:] in message assert long_traceback[-800:] in message
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated""" """Test that long task arguments are truncated"""
long_args = ("x" * 300,) long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args) discord.notify_task_failure(
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Args should be truncated to 200 chars # Args should be truncated to 200 chars
assert ( assert (
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated""" """Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300} long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs) discord.notify_task_failure(
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Kwargs should be truncated to 200 chars # Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210 assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error): def test_notify_task_failure_disabled(mock_send_error):
"""Test that notifications are not sent when disabled""" """Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred") discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_not_called() mock_send_error.assert_not_called()
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
mock_send_error.side_effect = Exception("Failed to send") mock_send_error.side_effect = Exception("Failed to send")
# Should not raise # Should not raise
discord.notify_task_failure("test_task", "Error occurred") discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_called_once() mock_send_error.assert_called_once()
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
): ):
"""Test that convenience functions use the correct channel settings""" """Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"): with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message) function(BOT_ID, message)
mock_broadcast.assert_called_once_with("test-channel", message) mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
@patch("requests.post") @patch("requests.post")
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
mock_post.return_value = mock_response mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general" message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars) result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
assert result is True assert result is True
call_args = mock_post.call_args call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars json_payload = call_args[1]["json"]
assert json_payload["message"] == message_with_special_chars
assert json_payload["bot_id"] == BOT_ID
@patch("requests.post") @patch("requests.post")
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
mock_post.return_value = mock_response mock_post.return_value = mock_response
long_message = "A" * 2000 long_message = "A" * 2000
result = discord.broadcast_message("general", long_message) result = discord.broadcast_message(BOT_ID, "general", long_message)
assert result is True assert result is True
call_args = mock_post.call_args call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message json_payload = call_args[1]["json"]
assert json_payload["message"] == long_message
assert json_payload["bot_id"] == BOT_ID
@patch("requests.get") @patch("requests.get")
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False

View File

@ -4,6 +4,8 @@ import requests
from memory.common import discord from memory.common import discord
BOT_ID = 42
@pytest.fixture @pytest.fixture
def mock_api_url(): def mock_api_url():
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is True assert result is True
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"http://localhost:8000/send_dm", "http://localhost:8000/send_dm",
json={"user": "user123", "message": "Hello!"}, json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
timeout=10, timeout=10,
) )
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception""" """Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error") mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is True assert result is True
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"http://localhost:8000/send_channel", "http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"}, json={
"bot_id": BOT_ID,
"channel_name": "general",
"message": "Announcement!",
},
timeout=10, timeout=10,
) )
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False assert result is False
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception""" """Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout") mock_post.side_effect = requests.Timeout("Request timeout")
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False assert result is False
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
def test_is_collector_healthy_true(mock_get, mock_api_url): def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy""" """Test health check when collector is healthy"""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"status": "healthy"} mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is True assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5) mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
def test_is_collector_healthy_false_status(mock_get, mock_api_url): def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status""" """Test health check when collector returns unhealthy status"""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"} mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails""" """Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused") mock_get.side_effect = requests.ConnectionError("Connection refused")
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
"""Test sending error message to error channel""" """Test sending error message to error channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_error_message("Something broke") result = discord.send_error_message(BOT_ID, "Something broke")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke") mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
"""Test sending activity message to activity channel""" """Test sending activity message to activity channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in") result = discord.send_activity_message(BOT_ID, "User logged in")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in") mock_broadcast.assert_called_once_with(
BOT_ID, "activity", "User logged in"
)
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
"""Test sending discovery message to discovery channel""" """Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern") result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern") mock_broadcast.assert_called_once_with(
BOT_ID, "discoveries", "Found interesting pattern"
)
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
"""Test sending chat message to chat channel""" """Test sending chat message to chat channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot") result = discord.send_chat_message(BOT_ID, "Hello from bot")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot") mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
@patch("memory.common.discord.send_error_message") @patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error): def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification""" """Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong") discord.notify_task_failure(
"test_task", "Something went wrong", bot_id=BOT_ID
)
mock_send_error.assert_called_once() mock_send_error.assert_called_once()
message = mock_send_error.call_args[0][0] assert mock_send_error.call_args[0][0] == BOT_ID
message = mock_send_error.call_args[0][1]
assert "🚨 **Task Failed: test_task**" in message assert "🚨 **Task Failed: test_task**" in message
assert "**Error:** Something went wrong" in message assert "**Error:** Something went wrong" in message
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
"Error occurred", "Error occurred",
task_args=("arg1", 42), task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123}, task_kwargs={"key": "value", "number": 123},
bot_id=BOT_ID,
) )
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
assert "**Args:** `('arg1', 42)" in message assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
"""Test task failure notification with traceback""" """Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test" traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback) discord.notify_task_failure(
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
assert "**Traceback:**" in message assert "**Traceback:**" in message
assert "Exception: test" in message assert "Exception: test" in message
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
"""Test that long error messages are truncated""" """Test that long error messages are truncated"""
long_error = "x" * 600 long_error = "x" * 600
discord.notify_task_failure("test_task", long_error) discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Error should be truncated to 500 chars - check that the full 600 char string is not there # Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message assert "**Error:** " + long_error[:500] in message
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
"""Test that long tracebacks are truncated""" """Test that long tracebacks are truncated"""
long_traceback = "x" * 1000 long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback) discord.notify_task_failure(
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Traceback should show last 800 chars # Traceback should show last 800 chars
assert long_traceback[-800:] in message assert long_traceback[-800:] in message
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated""" """Test that long task arguments are truncated"""
long_args = ("x" * 300,) long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args) discord.notify_task_failure(
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Args should be truncated to 200 chars # Args should be truncated to 200 chars
assert ( assert (
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated""" """Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300} long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs) discord.notify_task_failure(
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Kwargs should be truncated to 200 chars # Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210 assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error): def test_notify_task_failure_disabled(mock_send_error):
"""Test that notifications are not sent when disabled""" """Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred") discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_not_called() mock_send_error.assert_not_called()
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
mock_send_error.side_effect = Exception("Failed to send") mock_send_error.side_effect = Exception("Failed to send")
# Should not raise # Should not raise
discord.notify_task_failure("test_task", "Error occurred") discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_called_once() mock_send_error.assert_called_once()
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
): ):
"""Test that convenience functions use the correct channel settings""" """Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"): with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message) function(BOT_ID, message)
mock_broadcast.assert_called_once_with("test-channel", message) mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
@patch("requests.post") @patch("requests.post")
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
mock_post.return_value = mock_response mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general" message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars) result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
assert result is True assert result is True
call_args = mock_post.call_args call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars json_payload = call_args[1]["json"]
assert json_payload["message"] == message_with_special_chars
assert json_payload["bot_id"] == BOT_ID
@patch("requests.post") @patch("requests.post")
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
mock_post.return_value = mock_response mock_post.return_value = mock_response
long_message = "A" * 2000 long_message = "A" * 2000
result = discord.broadcast_message("general", long_message) result = discord.broadcast_message(BOT_ID, "general", long_message)
assert result is True assert result is True
call_args = mock_post.call_args call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message json_payload = call_args[1]["json"]
assert json_payload["message"] == long_message
assert json_payload["bot_id"] == BOT_ID
@patch("requests.get") @patch("requests.get")
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False

View File

@ -3,6 +3,7 @@ from datetime import datetime, timezone
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from memory.common.db.models import ( from memory.common.db.models import (
DiscordBotUser,
DiscordMessage, DiscordMessage,
DiscordUser, DiscordUser,
DiscordServer, DiscordServer,
@ -12,12 +13,25 @@ from memory.workers.tasks import discord
@pytest.fixture @pytest.fixture
def mock_discord_user(db_session): def discord_bot_user(db_session):
bot = DiscordBotUser.create_with_api_key(
discord_users=[],
name="Test Bot",
email="bot@example.com",
)
db_session.add(bot)
db_session.commit()
return bot
@pytest.fixture
def mock_discord_user(db_session, discord_bot_user):
"""Create a Discord user for testing.""" """Create a Discord user for testing."""
user = DiscordUser( user = DiscordUser(
id=123456789, id=123456789,
username="testuser", username="testuser",
ignore_messages=False, ignore_messages=False,
system_user_id=discord_bot_user.id,
) )
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()

View File

@ -3,17 +3,23 @@ from datetime import datetime, timezone, timedelta
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import uuid import uuid
from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer from memory.common.db.models import (
ScheduledLLMCall,
DiscordBotUser,
DiscordUser,
DiscordChannel,
DiscordServer,
)
from memory.workers.tasks import scheduled_calls from memory.workers.tasks import scheduled_calls
@pytest.fixture @pytest.fixture
def sample_user(db_session): def sample_user(db_session):
"""Create a sample user for testing.""" """Create a sample user for testing."""
user = HumanUser.create_with_password( user = DiscordBotUser.create_with_api_key(
name="testuser", discord_users=[],
email="test@example.com", name="testbot",
password="password", email="bot@example.com",
) )
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
@ -124,6 +130,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
scheduled_calls._send_to_discord(pending_scheduled_call, response) scheduled_calls._send_to_discord(pending_scheduled_call, response)
mock_send_dm.assert_called_once_with( mock_send_dm.assert_called_once_with(
pending_scheduled_call.user_id,
"testuser", # username, not ID "testuser", # username, not ID
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.", "**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
) )
@ -137,6 +144,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
scheduled_calls._send_to_discord(completed_scheduled_call, response) scheduled_calls._send_to_discord(completed_scheduled_call, response)
mock_broadcast.assert_called_once_with( mock_broadcast.assert_called_once_with(
completed_scheduled_call.user_id,
"test-channel", # channel name, not ID "test-channel", # channel name, not ID
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.", "**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
) )
@ -151,7 +159,8 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
# Verify the message was truncated # Verify the message was truncated
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
message = args[1] assert args[0] == pending_scheduled_call.user_id
message = args[2]
assert len(message) <= 1950 # Should be truncated assert len(message) <= 1950 # Should be truncated
assert message.endswith("... (response truncated)") assert message.endswith("... (response truncated)")
@ -164,7 +173,8 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response) scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
message = args[1] assert args[0] == pending_scheduled_call.user_id
message = args[2]
assert not message.endswith("... (response truncated)") assert not message.endswith("... (response truncated)")
assert "This is a normal length response." in message assert "This is a normal length response." in message
@ -574,6 +584,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
mock_discord_user.username = "testuser" mock_discord_user.username = "testuser"
mock_call = Mock() mock_call = Mock()
mock_call.user_id = 987
mock_call.topic = topic mock_call.topic = topic
mock_call.model = model mock_call.model = model
mock_call.discord_user = mock_discord_user mock_call.discord_user = mock_discord_user
@ -583,7 +594,8 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
# Get the actual message that was sent # Get the actual message that was sent
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
actual_message = args[1] assert args[0] == mock_call.user_id
actual_message = args[2]
# Verify all expected parts are in the message # Verify all expected parts are in the message
for expected_part in expected_in_message: for expected_part in expected_in_message: