mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
use db bots
This commit is contained in:
parent
9639fa3dd7
commit
814090dccb
@ -20,12 +20,12 @@ def get_api_url() -> str:
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def send_dm(user_identifier: str, message: str) -> bool:
|
||||
def send_dm(bot_id: int, user_identifier: str, message: str) -> bool:
|
||||
"""Send a DM via the Discord collector API"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_dm",
|
||||
json={"user": user_identifier, "message": message},
|
||||
json={"bot_id": bot_id, "user": user_identifier, "message": message},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@ -37,12 +37,16 @@ def send_dm(user_identifier: str, message: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def send_to_channel(channel_name: str, message: str) -> bool:
|
||||
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
||||
"""Send a DM via the Discord collector API"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_channel",
|
||||
json={"channel_name": channel_name, "message": message},
|
||||
json={
|
||||
"bot_id": bot_id,
|
||||
"channel_name": channel_name,
|
||||
"message": message,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@ -55,12 +59,16 @@ def send_to_channel(channel_name: str, message: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def broadcast_message(channel_name: str, message: str) -> bool:
|
||||
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
||||
"""Send a message to a channel via the Discord collector API"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_channel",
|
||||
json={"channel_name": channel_name, "message": message},
|
||||
json={
|
||||
"bot_id": bot_id,
|
||||
"channel_name": channel_name,
|
||||
"message": message,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@ -72,19 +80,22 @@ def broadcast_message(channel_name: str, message: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_collector_healthy() -> bool:
|
||||
def is_collector_healthy(bot_id: int) -> bool:
|
||||
"""Check if the Discord collector is running and healthy"""
|
||||
try:
|
||||
response = requests.get(f"{get_api_url()}/health", timeout=5)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("status") == "healthy"
|
||||
bot_status = result.get(str(bot_id))
|
||||
if not isinstance(bot_status, dict):
|
||||
return False
|
||||
return bool(bot_status.get("connected"))
|
||||
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
def refresh_discord_metadata() -> dict[str, int] | None:
|
||||
def refresh_discord_metadata() -> dict[str, Any] | None:
|
||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||
try:
|
||||
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
|
||||
@ -96,24 +107,24 @@ def refresh_discord_metadata() -> dict[str, int] | None:
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def send_error_message(message: str) -> bool:
|
||||
def send_error_message(bot_id: int, message: str) -> bool:
|
||||
"""Send an error message to the error channel"""
|
||||
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
|
||||
return broadcast_message(bot_id, settings.DISCORD_ERROR_CHANNEL, message)
|
||||
|
||||
|
||||
def send_activity_message(message: str) -> bool:
|
||||
def send_activity_message(bot_id: int, message: str) -> bool:
|
||||
"""Send an activity message to the activity channel"""
|
||||
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
|
||||
return broadcast_message(bot_id, settings.DISCORD_ACTIVITY_CHANNEL, message)
|
||||
|
||||
|
||||
def send_discovery_message(message: str) -> bool:
|
||||
def send_discovery_message(bot_id: int, message: str) -> bool:
|
||||
"""Send a discovery message to the discovery channel"""
|
||||
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
|
||||
return broadcast_message(bot_id, settings.DISCORD_DISCOVERY_CHANNEL, message)
|
||||
|
||||
|
||||
def send_chat_message(message: str) -> bool:
|
||||
def send_chat_message(bot_id: int, message: str) -> bool:
|
||||
"""Send a chat message to the chat channel"""
|
||||
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
|
||||
return broadcast_message(bot_id, settings.DISCORD_CHAT_CHANNEL, message)
|
||||
|
||||
|
||||
def notify_task_failure(
|
||||
@ -122,6 +133,7 @@ def notify_task_failure(
|
||||
task_args: tuple = (),
|
||||
task_kwargs: dict[str, Any] | None = None,
|
||||
traceback_str: str | None = None,
|
||||
bot_id: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send a task failure notification to Discord.
|
||||
@ -137,6 +149,15 @@ def notify_task_failure(
|
||||
logger.debug("Discord notifications disabled")
|
||||
return
|
||||
|
||||
if bot_id is None:
|
||||
bot_id = settings.DISCORD_BOT_ID
|
||||
|
||||
if not bot_id:
|
||||
logger.debug(
|
||||
"No Discord bot ID provided for task failure notification; skipping"
|
||||
)
|
||||
return
|
||||
|
||||
message = f"🚨 **Task Failed: {task_name}**\n\n"
|
||||
message += f"**Error:** {error_message[:500]}\n"
|
||||
|
||||
@ -150,7 +171,7 @@ def notify_task_failure(
|
||||
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
|
||||
|
||||
try:
|
||||
send_error_message(message)
|
||||
send_error_message(bot_id, message)
|
||||
logger.info(f"Discord error notification sent for task: {task_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Discord notification: {e}")
|
||||
|
||||
@ -7,17 +7,18 @@ providing HTTP endpoints for sending Discord messages.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
import traceback
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import cast
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
from memory.common import settings
|
||||
from memory.discord.collector import MessageCollector
|
||||
from memory.common.db.models.users import BotUser
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.users import DiscordBotUser
|
||||
from memory.discord.collector import MessageCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -41,37 +42,25 @@ class Collector:
|
||||
bot_token: str
|
||||
bot_name: str
|
||||
|
||||
def __init__(self, collector: MessageCollector, bot: BotUser):
|
||||
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
||||
self.collector = collector
|
||||
self.collector_task = asyncio.create_task(collector.start(bot.api_key))
|
||||
self.bot_id = bot.id
|
||||
self.bot_token = bot.api_key
|
||||
self.bot_name = bot.name
|
||||
|
||||
|
||||
# Application state
|
||||
class AppState:
|
||||
def __init__(self):
|
||||
self.collector: MessageCollector | None = None
|
||||
self.collector_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
app_state = AppState()
|
||||
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
||||
self.bot_id = cast(int, bot.id)
|
||||
self.bot_token = str(bot.api_key)
|
||||
self.bot_name = str(bot.name)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage Discord collector lifecycle"""
|
||||
if not settings.DISCORD_BOT_TOKEN:
|
||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||
return
|
||||
|
||||
def make_collector(bot: BotUser):
|
||||
def make_collector(bot: DiscordBotUser):
|
||||
collector = MessageCollector()
|
||||
return Collector(collector=collector, bot=bot)
|
||||
|
||||
with make_session() as session:
|
||||
app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()}
|
||||
bots = session.query(DiscordBotUser).all()
|
||||
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
||||
|
||||
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
||||
|
||||
@ -155,9 +144,8 @@ async def health_check():
|
||||
if not app.bots:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
|
||||
collector = app_state.collector
|
||||
return {
|
||||
collector.bot_name: {
|
||||
bot.bot_name: {
|
||||
"status": "healthy",
|
||||
"connected": not bot.collector.is_closed(),
|
||||
"user": str(bot.collector.user) if bot.collector.user else None,
|
||||
|
||||
@ -138,6 +138,18 @@ def should_process(message: DiscordMessage) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_bot_id(discord_message: DiscordMessage) -> int | None:
|
||||
recipient = discord_message.recipient_user
|
||||
if not recipient:
|
||||
return None
|
||||
|
||||
system_user = recipient.system_user
|
||||
if not system_user:
|
||||
return None
|
||||
|
||||
return getattr(system_user, "id", None)
|
||||
|
||||
|
||||
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||
@safe_task_execution
|
||||
def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
@ -161,11 +173,26 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
||||
|
||||
if not response:
|
||||
pass
|
||||
elif discord_message.channel.server:
|
||||
discord.send_to_channel(discord_message.channel.name, response)
|
||||
return {
|
||||
"status": "processed",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
bot_id = _resolve_bot_id(discord_message)
|
||||
if not bot_id:
|
||||
logger.warning(
|
||||
"No associated Discord bot user for message %s; skipping send",
|
||||
message_id,
|
||||
)
|
||||
return {
|
||||
"status": "processed",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
if discord_message.channel.server:
|
||||
discord.send_to_channel(bot_id, discord_message.channel.name, response)
|
||||
else:
|
||||
discord.send_dm(discord_message.from_user.username, response)
|
||||
discord.send_dm(bot_id, discord_message.from_user.username, response)
|
||||
|
||||
return {
|
||||
"status": "processed",
|
||||
|
||||
@ -37,12 +37,22 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
||||
if len(message) > 1900: # Leave some buffer
|
||||
message = message[:1900] + "\n\n... (response truncated)"
|
||||
|
||||
bot_id_value = scheduled_call.user_id
|
||||
if bot_id_value is None:
|
||||
logger.warning(
|
||||
"Scheduled call %s has no associated bot user; skipping Discord send",
|
||||
scheduled_call.id,
|
||||
)
|
||||
return
|
||||
|
||||
bot_id = cast(int, bot_id_value)
|
||||
|
||||
if discord_user := scheduled_call.discord_user:
|
||||
logger.info(f"Sending DM to {discord_user.username}: {message}")
|
||||
discord.send_dm(discord_user.username, message)
|
||||
discord.send_dm(bot_id, discord_user.username, message)
|
||||
elif discord_channel := scheduled_call.discord_channel:
|
||||
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
|
||||
discord.broadcast_message(discord_channel.name, message)
|
||||
discord.broadcast_message(bot_id, discord_channel.name, message)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
||||
|
||||
@ -4,6 +4,8 @@ import requests
|
||||
|
||||
from memory.common import discord
|
||||
|
||||
BOT_ID = 42
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_url():
|
||||
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user": "user123", "message": "Hello!"},
|
||||
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||
"""Test DM sending when request raises exception"""
|
||||
mock_post.side_effect = requests.RequestException("Network error")
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_channel",
|
||||
json={"channel_name": "general", "message": "Announcement!"},
|
||||
json={
|
||||
"bot_id": BOT_ID,
|
||||
"channel_name": "general",
|
||||
"message": "Announcement!",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast with exception"""
|
||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
"""Test health check when collector is healthy"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "healthy"}
|
||||
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||
"""Test health check when collector returns unhealthy status"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "unhealthy"}
|
||||
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||
"""Test health check when request fails"""
|
||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
|
||||
"""Test sending error message to error channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_error_message("Something broke")
|
||||
result = discord.send_error_message(BOT_ID, "Something broke")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
|
||||
"""Test sending activity message to activity channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_activity_message("User logged in")
|
||||
result = discord.send_activity_message(BOT_ID, "User logged in")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||
mock_broadcast.assert_called_once_with(
|
||||
BOT_ID, "activity", "User logged in"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
|
||||
"""Test sending discovery message to discovery channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_discovery_message("Found interesting pattern")
|
||||
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||
mock_broadcast.assert_called_once_with(
|
||||
BOT_ID, "discoveries", "Found interesting pattern"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
|
||||
"""Test sending chat message to chat channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_chat_message("Hello from bot")
|
||||
result = discord.send_chat_message(BOT_ID, "Hello from bot")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_basic(mock_send_error):
|
||||
"""Test basic task failure notification"""
|
||||
discord.notify_task_failure("test_task", "Something went wrong")
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Something went wrong", bot_id=BOT_ID
|
||||
)
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
message = mock_send_error.call_args[0][0]
|
||||
assert mock_send_error.call_args[0][0] == BOT_ID
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "🚨 **Task Failed: test_task**" in message
|
||||
assert "**Error:** Something went wrong" in message
|
||||
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
|
||||
"Error occurred",
|
||||
task_args=("arg1", 42),
|
||||
task_kwargs={"key": "value", "number": 123},
|
||||
bot_id=BOT_ID,
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "**Args:** `('arg1', 42)" in message
|
||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
|
||||
"""Test task failure notification with traceback"""
|
||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||
|
||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "**Traceback:**" in message
|
||||
assert "Exception: test" in message
|
||||
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||
"""Test that long error messages are truncated"""
|
||||
long_error = "x" * 600
|
||||
|
||||
discord.notify_task_failure("test_task", long_error)
|
||||
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||
assert "**Error:** " + long_error[:500] in message
|
||||
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||
"""Test that long tracebacks are truncated"""
|
||||
long_traceback = "x" * 1000
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Traceback should show last 800 chars
|
||||
assert long_traceback[-800:] in message
|
||||
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||
"""Test that long task arguments are truncated"""
|
||||
long_args = ("x" * 300,)
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Args should be truncated to 200 chars
|
||||
assert (
|
||||
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
"""Test that long task kwargs are truncated"""
|
||||
long_kwargs = {"key": "x" * 300}
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Kwargs should be truncated to 200 chars
|
||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_notify_task_failure_disabled(mock_send_error):
|
||||
"""Test that notifications are not sent when disabled"""
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||
|
||||
mock_send_error.assert_not_called()
|
||||
|
||||
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||
mock_send_error.side_effect = Exception("Failed to send")
|
||||
|
||||
# Should not raise
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
|
||||
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
|
||||
):
|
||||
"""Test that convenience functions use the correct channel settings"""
|
||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||
function(message)
|
||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||
function(BOT_ID, message)
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||
result = discord.send_dm("user123", message_with_special_chars)
|
||||
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||
json_payload = call_args[1]["json"]
|
||||
assert json_payload["message"] == message_with_special_chars
|
||||
assert json_payload["bot_id"] == BOT_ID
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
long_message = "A" * 2000
|
||||
result = discord.broadcast_message("general", long_message)
|
||||
result = discord.broadcast_message(BOT_ID, "general", long_message)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == long_message
|
||||
json_payload = call_args[1]["json"]
|
||||
assert json_payload["message"] == long_message
|
||||
assert json_payload["bot_id"] == BOT_ID
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -4,6 +4,8 @@ import requests
|
||||
|
||||
from memory.common import discord
|
||||
|
||||
BOT_ID = 42
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_url():
|
||||
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user": "user123", "message": "Hello!"},
|
||||
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||
"""Test DM sending when request raises exception"""
|
||||
mock_post.side_effect = requests.RequestException("Network error")
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_channel",
|
||||
json={"channel_name": "general", "message": "Announcement!"},
|
||||
json={
|
||||
"bot_id": BOT_ID,
|
||||
"channel_name": "general",
|
||||
"message": "Announcement!",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast with exception"""
|
||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
"""Test health check when collector is healthy"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "healthy"}
|
||||
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||
"""Test health check when collector returns unhealthy status"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "unhealthy"}
|
||||
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||
"""Test health check when request fails"""
|
||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
|
||||
"""Test sending error message to error channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_error_message("Something broke")
|
||||
result = discord.send_error_message(BOT_ID, "Something broke")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
|
||||
"""Test sending activity message to activity channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_activity_message("User logged in")
|
||||
result = discord.send_activity_message(BOT_ID, "User logged in")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||
mock_broadcast.assert_called_once_with(
|
||||
BOT_ID, "activity", "User logged in"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
|
||||
"""Test sending discovery message to discovery channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_discovery_message("Found interesting pattern")
|
||||
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||
mock_broadcast.assert_called_once_with(
|
||||
BOT_ID, "discoveries", "Found interesting pattern"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
|
||||
"""Test sending chat message to chat channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_chat_message("Hello from bot")
|
||||
result = discord.send_chat_message(BOT_ID, "Hello from bot")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_basic(mock_send_error):
|
||||
"""Test basic task failure notification"""
|
||||
discord.notify_task_failure("test_task", "Something went wrong")
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Something went wrong", bot_id=BOT_ID
|
||||
)
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
message = mock_send_error.call_args[0][0]
|
||||
assert mock_send_error.call_args[0][0] == BOT_ID
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "🚨 **Task Failed: test_task**" in message
|
||||
assert "**Error:** Something went wrong" in message
|
||||
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
|
||||
"Error occurred",
|
||||
task_args=("arg1", 42),
|
||||
task_kwargs={"key": "value", "number": 123},
|
||||
bot_id=BOT_ID,
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "**Args:** `('arg1', 42)" in message
|
||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
|
||||
"""Test task failure notification with traceback"""
|
||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||
|
||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "**Traceback:**" in message
|
||||
assert "Exception: test" in message
|
||||
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||
"""Test that long error messages are truncated"""
|
||||
long_error = "x" * 600
|
||||
|
||||
discord.notify_task_failure("test_task", long_error)
|
||||
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||
assert "**Error:** " + long_error[:500] in message
|
||||
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||
"""Test that long tracebacks are truncated"""
|
||||
long_traceback = "x" * 1000
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Traceback should show last 800 chars
|
||||
assert long_traceback[-800:] in message
|
||||
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||
"""Test that long task arguments are truncated"""
|
||||
long_args = ("x" * 300,)
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Args should be truncated to 200 chars
|
||||
assert (
|
||||
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
"""Test that long task kwargs are truncated"""
|
||||
long_kwargs = {"key": "x" * 300}
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Kwargs should be truncated to 200 chars
|
||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_notify_task_failure_disabled(mock_send_error):
|
||||
"""Test that notifications are not sent when disabled"""
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||
|
||||
mock_send_error.assert_not_called()
|
||||
|
||||
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||
mock_send_error.side_effect = Exception("Failed to send")
|
||||
|
||||
# Should not raise
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
|
||||
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
|
||||
):
|
||||
"""Test that convenience functions use the correct channel settings"""
|
||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||
function(message)
|
||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||
function(BOT_ID, message)
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||
result = discord.send_dm("user123", message_with_special_chars)
|
||||
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||
json_payload = call_args[1]["json"]
|
||||
assert json_payload["message"] == message_with_special_chars
|
||||
assert json_payload["bot_id"] == BOT_ID
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
long_message = "A" * 2000
|
||||
result = discord.broadcast_message("general", long_message)
|
||||
result = discord.broadcast_message(BOT_ID, "general", long_message)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == long_message
|
||||
json_payload = call_args[1]["json"]
|
||||
assert json_payload["message"] == long_message
|
||||
assert json_payload["bot_id"] == BOT_ID
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -3,6 +3,7 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from memory.common.db.models import (
|
||||
DiscordBotUser,
|
||||
DiscordMessage,
|
||||
DiscordUser,
|
||||
DiscordServer,
|
||||
@ -12,12 +13,25 @@ from memory.workers.tasks import discord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_user(db_session):
|
||||
def discord_bot_user(db_session):
|
||||
bot = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
name="Test Bot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
db_session.add(bot)
|
||||
db_session.commit()
|
||||
return bot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_user(db_session, discord_bot_user):
|
||||
"""Create a Discord user for testing."""
|
||||
user = DiscordUser(
|
||||
id=123456789,
|
||||
username="testuser",
|
||||
ignore_messages=False,
|
||||
system_user_id=discord_bot_user.id,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@ -3,17 +3,23 @@ from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
import uuid
|
||||
|
||||
from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer
|
||||
from memory.common.db.models import (
|
||||
ScheduledLLMCall,
|
||||
DiscordBotUser,
|
||||
DiscordUser,
|
||||
DiscordChannel,
|
||||
DiscordServer,
|
||||
)
|
||||
from memory.workers.tasks import scheduled_calls
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user(db_session):
|
||||
"""Create a sample user for testing."""
|
||||
user = HumanUser.create_with_password(
|
||||
name="testuser",
|
||||
email="test@example.com",
|
||||
password="password",
|
||||
user = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
name="testbot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
@ -124,6 +130,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
||||
|
||||
mock_send_dm.assert_called_once_with(
|
||||
pending_scheduled_call.user_id,
|
||||
"testuser", # username, not ID
|
||||
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
||||
)
|
||||
@ -137,6 +144,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
||||
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
||||
|
||||
mock_broadcast.assert_called_once_with(
|
||||
completed_scheduled_call.user_id,
|
||||
"test-channel", # channel name, not ID
|
||||
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
||||
)
|
||||
@ -151,7 +159,8 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
|
||||
|
||||
# Verify the message was truncated
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
message = args[1]
|
||||
assert args[0] == pending_scheduled_call.user_id
|
||||
message = args[2]
|
||||
assert len(message) <= 1950 # Should be truncated
|
||||
assert message.endswith("... (response truncated)")
|
||||
|
||||
@ -164,7 +173,8 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
||||
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
message = args[1]
|
||||
assert args[0] == pending_scheduled_call.user_id
|
||||
message = args[2]
|
||||
assert not message.endswith("... (response truncated)")
|
||||
assert "This is a normal length response." in message
|
||||
|
||||
@ -574,6 +584,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
mock_discord_user.username = "testuser"
|
||||
|
||||
mock_call = Mock()
|
||||
mock_call.user_id = 987
|
||||
mock_call.topic = topic
|
||||
mock_call.model = model
|
||||
mock_call.discord_user = mock_discord_user
|
||||
@ -583,7 +594,8 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
|
||||
# Get the actual message that was sent
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
actual_message = args[1]
|
||||
assert args[0] == mock_call.user_id
|
||||
actual_message = args[2]
|
||||
|
||||
# Verify all expected parts are in the message
|
||||
for expected_part in expected_in_message:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user