mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
use db bots
This commit is contained in:
parent
9639fa3dd7
commit
814090dccb
@ -20,12 +20,12 @@ def get_api_url() -> str:
|
|||||||
return f"http://{host}:{port}"
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
def send_dm(user_identifier: str, message: str) -> bool:
|
def send_dm(bot_id: int, user_identifier: str, message: str) -> bool:
|
||||||
"""Send a DM via the Discord collector API"""
|
"""Send a DM via the Discord collector API"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_dm",
|
f"{get_api_url()}/send_dm",
|
||||||
json={"user": user_identifier, "message": message},
|
json={"bot_id": bot_id, "user": user_identifier, "message": message},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -37,12 +37,16 @@ def send_dm(user_identifier: str, message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def send_to_channel(channel_name: str, message: str) -> bool:
|
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
||||||
"""Send a DM via the Discord collector API"""
|
"""Send a DM via the Discord collector API"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_channel",
|
f"{get_api_url()}/send_channel",
|
||||||
json={"channel_name": channel_name, "message": message},
|
json={
|
||||||
|
"bot_id": bot_id,
|
||||||
|
"channel_name": channel_name,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -55,12 +59,16 @@ def send_to_channel(channel_name: str, message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def broadcast_message(channel_name: str, message: str) -> bool:
|
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
||||||
"""Send a message to a channel via the Discord collector API"""
|
"""Send a message to a channel via the Discord collector API"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{get_api_url()}/send_channel",
|
f"{get_api_url()}/send_channel",
|
||||||
json={"channel_name": channel_name, "message": message},
|
json={
|
||||||
|
"bot_id": bot_id,
|
||||||
|
"channel_name": channel_name,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -72,19 +80,22 @@ def broadcast_message(channel_name: str, message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_collector_healthy() -> bool:
|
def is_collector_healthy(bot_id: int) -> bool:
|
||||||
"""Check if the Discord collector is running and healthy"""
|
"""Check if the Discord collector is running and healthy"""
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{get_api_url()}/health", timeout=5)
|
response = requests.get(f"{get_api_url()}/health", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
return result.get("status") == "healthy"
|
bot_status = result.get(str(bot_id))
|
||||||
|
if not isinstance(bot_status, dict):
|
||||||
|
return False
|
||||||
|
return bool(bot_status.get("connected"))
|
||||||
|
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def refresh_discord_metadata() -> dict[str, int] | None:
|
def refresh_discord_metadata() -> dict[str, Any] | None:
|
||||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
|
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
|
||||||
@ -96,24 +107,24 @@ def refresh_discord_metadata() -> dict[str, int] | None:
|
|||||||
|
|
||||||
|
|
||||||
# Convenience functions
|
# Convenience functions
|
||||||
def send_error_message(message: str) -> bool:
|
def send_error_message(bot_id: int, message: str) -> bool:
|
||||||
"""Send an error message to the error channel"""
|
"""Send an error message to the error channel"""
|
||||||
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
|
return broadcast_message(bot_id, settings.DISCORD_ERROR_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def send_activity_message(message: str) -> bool:
|
def send_activity_message(bot_id: int, message: str) -> bool:
|
||||||
"""Send an activity message to the activity channel"""
|
"""Send an activity message to the activity channel"""
|
||||||
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
|
return broadcast_message(bot_id, settings.DISCORD_ACTIVITY_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def send_discovery_message(message: str) -> bool:
|
def send_discovery_message(bot_id: int, message: str) -> bool:
|
||||||
"""Send a discovery message to the discovery channel"""
|
"""Send a discovery message to the discovery channel"""
|
||||||
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
|
return broadcast_message(bot_id, settings.DISCORD_DISCOVERY_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def send_chat_message(message: str) -> bool:
|
def send_chat_message(bot_id: int, message: str) -> bool:
|
||||||
"""Send a chat message to the chat channel"""
|
"""Send a chat message to the chat channel"""
|
||||||
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
|
return broadcast_message(bot_id, settings.DISCORD_CHAT_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
def notify_task_failure(
|
def notify_task_failure(
|
||||||
@ -122,6 +133,7 @@ def notify_task_failure(
|
|||||||
task_args: tuple = (),
|
task_args: tuple = (),
|
||||||
task_kwargs: dict[str, Any] | None = None,
|
task_kwargs: dict[str, Any] | None = None,
|
||||||
traceback_str: str | None = None,
|
traceback_str: str | None = None,
|
||||||
|
bot_id: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Send a task failure notification to Discord.
|
Send a task failure notification to Discord.
|
||||||
@ -137,6 +149,15 @@ def notify_task_failure(
|
|||||||
logger.debug("Discord notifications disabled")
|
logger.debug("Discord notifications disabled")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if bot_id is None:
|
||||||
|
bot_id = settings.DISCORD_BOT_ID
|
||||||
|
|
||||||
|
if not bot_id:
|
||||||
|
logger.debug(
|
||||||
|
"No Discord bot ID provided for task failure notification; skipping"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
message = f"🚨 **Task Failed: {task_name}**\n\n"
|
message = f"🚨 **Task Failed: {task_name}**\n\n"
|
||||||
message += f"**Error:** {error_message[:500]}\n"
|
message += f"**Error:** {error_message[:500]}\n"
|
||||||
|
|
||||||
@ -150,7 +171,7 @@ def notify_task_failure(
|
|||||||
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
|
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
send_error_message(message)
|
send_error_message(bot_id, message)
|
||||||
logger.info(f"Discord error notification sent for task: {task_name}")
|
logger.info(f"Discord error notification sent for task: {task_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send Discord notification: {e}")
|
logger.error(f"Failed to send Discord notification: {e}")
|
||||||
|
|||||||
@ -7,17 +7,18 @@ providing HTTP endpoints for sending Discord messages.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import traceback
|
import traceback
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.discord.collector import MessageCollector
|
|
||||||
from memory.common.db.models.users import BotUser
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models.users import DiscordBotUser
|
||||||
|
from memory.discord.collector import MessageCollector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -41,37 +42,25 @@ class Collector:
|
|||||||
bot_token: str
|
bot_token: str
|
||||||
bot_name: str
|
bot_name: str
|
||||||
|
|
||||||
def __init__(self, collector: MessageCollector, bot: BotUser):
|
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
||||||
self.collector = collector
|
self.collector = collector
|
||||||
self.collector_task = asyncio.create_task(collector.start(bot.api_key))
|
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
||||||
self.bot_id = bot.id
|
self.bot_id = cast(int, bot.id)
|
||||||
self.bot_token = bot.api_key
|
self.bot_token = str(bot.api_key)
|
||||||
self.bot_name = bot.name
|
self.bot_name = str(bot.name)
|
||||||
|
|
||||||
|
|
||||||
# Application state
|
|
||||||
class AppState:
|
|
||||||
def __init__(self):
|
|
||||||
self.collector: MessageCollector | None = None
|
|
||||||
self.collector_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
|
|
||||||
app_state = AppState()
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Manage Discord collector lifecycle"""
|
"""Manage Discord collector lifecycle"""
|
||||||
if not settings.DISCORD_BOT_TOKEN:
|
|
||||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
|
||||||
return
|
|
||||||
|
|
||||||
def make_collector(bot: BotUser):
|
def make_collector(bot: DiscordBotUser):
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
return Collector(collector=collector, bot=bot)
|
return Collector(collector=collector, bot=bot)
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()}
|
bots = session.query(DiscordBotUser).all()
|
||||||
|
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
||||||
|
|
||||||
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
||||||
|
|
||||||
@ -155,9 +144,8 @@ async def health_check():
|
|||||||
if not app.bots:
|
if not app.bots:
|
||||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
collector = app_state.collector
|
|
||||||
return {
|
return {
|
||||||
collector.bot_name: {
|
bot.bot_name: {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"connected": not bot.collector.is_closed(),
|
"connected": not bot.collector.is_closed(),
|
||||||
"user": str(bot.collector.user) if bot.collector.user else None,
|
"user": str(bot.collector.user) if bot.collector.user else None,
|
||||||
|
|||||||
@ -138,6 +138,18 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_bot_id(discord_message: DiscordMessage) -> int | None:
|
||||||
|
recipient = discord_message.recipient_user
|
||||||
|
if not recipient:
|
||||||
|
return None
|
||||||
|
|
||||||
|
system_user = recipient.system_user
|
||||||
|
if not system_user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return getattr(system_user, "id", None)
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||||
@safe_task_execution
|
@safe_task_execution
|
||||||
def process_discord_message(message_id: int) -> dict[str, Any]:
|
def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||||
@ -161,11 +173,26 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
pass
|
return {
|
||||||
elif discord_message.channel.server:
|
"status": "processed",
|
||||||
discord.send_to_channel(discord_message.channel.name, response)
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
bot_id = _resolve_bot_id(discord_message)
|
||||||
|
if not bot_id:
|
||||||
|
logger.warning(
|
||||||
|
"No associated Discord bot user for message %s; skipping send",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "processed",
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if discord_message.channel.server:
|
||||||
|
discord.send_to_channel(bot_id, discord_message.channel.name, response)
|
||||||
else:
|
else:
|
||||||
discord.send_dm(discord_message.from_user.username, response)
|
discord.send_dm(bot_id, discord_message.from_user.username, response)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "processed",
|
"status": "processed",
|
||||||
|
|||||||
@ -37,12 +37,22 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
|||||||
if len(message) > 1900: # Leave some buffer
|
if len(message) > 1900: # Leave some buffer
|
||||||
message = message[:1900] + "\n\n... (response truncated)"
|
message = message[:1900] + "\n\n... (response truncated)"
|
||||||
|
|
||||||
|
bot_id_value = scheduled_call.user_id
|
||||||
|
if bot_id_value is None:
|
||||||
|
logger.warning(
|
||||||
|
"Scheduled call %s has no associated bot user; skipping Discord send",
|
||||||
|
scheduled_call.id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
bot_id = cast(int, bot_id_value)
|
||||||
|
|
||||||
if discord_user := scheduled_call.discord_user:
|
if discord_user := scheduled_call.discord_user:
|
||||||
logger.info(f"Sending DM to {discord_user.username}: {message}")
|
logger.info(f"Sending DM to {discord_user.username}: {message}")
|
||||||
discord.send_dm(discord_user.username, message)
|
discord.send_dm(bot_id, discord_user.username, message)
|
||||||
elif discord_channel := scheduled_call.discord_channel:
|
elif discord_channel := scheduled_call.discord_channel:
|
||||||
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
|
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
|
||||||
discord.broadcast_message(discord_channel.name, message)
|
discord.broadcast_message(bot_id, discord_channel.name, message)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import requests
|
|||||||
|
|
||||||
from memory.common import discord
|
from memory.common import discord
|
||||||
|
|
||||||
|
BOT_ID = 42
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_api_url():
|
def mock_api_url():
|
||||||
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
"http://localhost:8000/send_dm",
|
"http://localhost:8000/send_dm",
|
||||||
json={"user": "user123", "message": "Hello!"},
|
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
|
|||||||
"""Test DM sending when request raises exception"""
|
"""Test DM sending when request raises exception"""
|
||||||
mock_post.side_effect = requests.RequestException("Network error")
|
mock_post.side_effect = requests.RequestException("Network error")
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
"http://localhost:8000/send_channel",
|
"http://localhost:8000/send_channel",
|
||||||
json={"channel_name": "general", "message": "Announcement!"},
|
json={
|
||||||
|
"bot_id": BOT_ID,
|
||||||
|
"channel_name": "general",
|
||||||
|
"message": "Announcement!",
|
||||||
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
|||||||
"""Test channel message broadcast with exception"""
|
"""Test channel message broadcast with exception"""
|
||||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
|||||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||||
"""Test health check when collector is healthy"""
|
"""Test health check when collector is healthy"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"status": "healthy"}
|
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||||
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
|
|||||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||||
"""Test health check when collector returns unhealthy status"""
|
"""Test health check when collector returns unhealthy status"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"status": "unhealthy"}
|
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
|||||||
"""Test health check when request fails"""
|
"""Test health check when request fails"""
|
||||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
|
|||||||
"""Test sending error message to error channel"""
|
"""Test sending error message to error channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_error_message("Something broke")
|
result = discord.send_error_message(BOT_ID, "Something broke")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
|
|||||||
"""Test sending activity message to activity channel"""
|
"""Test sending activity message to activity channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_activity_message("User logged in")
|
result = discord.send_activity_message(BOT_ID, "User logged in")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
mock_broadcast.assert_called_once_with(
|
||||||
|
BOT_ID, "activity", "User logged in"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
|
|||||||
"""Test sending discovery message to discovery channel"""
|
"""Test sending discovery message to discovery channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_discovery_message("Found interesting pattern")
|
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
mock_broadcast.assert_called_once_with(
|
||||||
|
BOT_ID, "discoveries", "Found interesting pattern"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
|
|||||||
"""Test sending chat message to chat channel"""
|
"""Test sending chat message to chat channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_chat_message("Hello from bot")
|
result = discord.send_chat_message(BOT_ID, "Hello from bot")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
def test_notify_task_failure_basic(mock_send_error):
|
def test_notify_task_failure_basic(mock_send_error):
|
||||||
"""Test basic task failure notification"""
|
"""Test basic task failure notification"""
|
||||||
discord.notify_task_failure("test_task", "Something went wrong")
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Something went wrong", bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
message = mock_send_error.call_args[0][0]
|
assert mock_send_error.call_args[0][0] == BOT_ID
|
||||||
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "🚨 **Task Failed: test_task**" in message
|
assert "🚨 **Task Failed: test_task**" in message
|
||||||
assert "**Error:** Something went wrong" in message
|
assert "**Error:** Something went wrong" in message
|
||||||
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
|
|||||||
"Error occurred",
|
"Error occurred",
|
||||||
task_args=("arg1", 42),
|
task_args=("arg1", 42),
|
||||||
task_kwargs={"key": "value", "number": 123},
|
task_kwargs={"key": "value", "number": 123},
|
||||||
|
bot_id=BOT_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "**Args:** `('arg1', 42)" in message
|
assert "**Args:** `('arg1', 42)" in message
|
||||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||||
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
|
|||||||
"""Test task failure notification with traceback"""
|
"""Test task failure notification with traceback"""
|
||||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "**Traceback:**" in message
|
assert "**Traceback:**" in message
|
||||||
assert "Exception: test" in message
|
assert "Exception: test" in message
|
||||||
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
|
|||||||
"""Test that long error messages are truncated"""
|
"""Test that long error messages are truncated"""
|
||||||
long_error = "x" * 600
|
long_error = "x" * 600
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", long_error)
|
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||||
assert "**Error:** " + long_error[:500] in message
|
assert "**Error:** " + long_error[:500] in message
|
||||||
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
|||||||
"""Test that long tracebacks are truncated"""
|
"""Test that long tracebacks are truncated"""
|
||||||
long_traceback = "x" * 1000
|
long_traceback = "x" * 1000
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Traceback should show last 800 chars
|
# Traceback should show last 800 chars
|
||||||
assert long_traceback[-800:] in message
|
assert long_traceback[-800:] in message
|
||||||
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
|
|||||||
"""Test that long task arguments are truncated"""
|
"""Test that long task arguments are truncated"""
|
||||||
long_args = ("x" * 300,)
|
long_args = ("x" * 300,)
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Args should be truncated to 200 chars
|
# Args should be truncated to 200 chars
|
||||||
assert (
|
assert (
|
||||||
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
|||||||
"""Test that long task kwargs are truncated"""
|
"""Test that long task kwargs are truncated"""
|
||||||
long_kwargs = {"key": "x" * 300}
|
long_kwargs = {"key": "x" * 300}
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Kwargs should be truncated to 200 chars
|
# Kwargs should be truncated to 200 chars
|
||||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||||
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
|||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
def test_notify_task_failure_disabled(mock_send_error):
|
def test_notify_task_failure_disabled(mock_send_error):
|
||||||
"""Test that notifications are not sent when disabled"""
|
"""Test that notifications are not sent when disabled"""
|
||||||
discord.notify_task_failure("test_task", "Error occurred")
|
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||||
|
|
||||||
mock_send_error.assert_not_called()
|
mock_send_error.assert_not_called()
|
||||||
|
|
||||||
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
|
|||||||
mock_send_error.side_effect = Exception("Failed to send")
|
mock_send_error.side_effect = Exception("Failed to send")
|
||||||
|
|
||||||
# Should not raise
|
# Should not raise
|
||||||
discord.notify_task_failure("test_task", "Error occurred")
|
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
|
|
||||||
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
|
|||||||
):
|
):
|
||||||
"""Test that convenience functions use the correct channel settings"""
|
"""Test that convenience functions use the correct channel settings"""
|
||||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||||
function(message)
|
function(BOT_ID, message)
|
||||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.post")
|
@patch("requests.post")
|
||||||
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
|||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||||
result = discord.send_dm("user123", message_with_special_chars)
|
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
call_args = mock_post.call_args
|
call_args = mock_post.call_args
|
||||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
json_payload = call_args[1]["json"]
|
||||||
|
assert json_payload["message"] == message_with_special_chars
|
||||||
|
assert json_payload["bot_id"] == BOT_ID
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.post")
|
@patch("requests.post")
|
||||||
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
|||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
long_message = "A" * 2000
|
long_message = "A" * 2000
|
||||||
result = discord.broadcast_message("general", long_message)
|
result = discord.broadcast_message(BOT_ID, "general", long_message)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
call_args = mock_post.call_args
|
call_args = mock_post.call_args
|
||||||
assert call_args[1]["json"]["message"] == long_message
|
json_payload = call_args[1]["json"]
|
||||||
|
assert json_payload["message"] == long_message
|
||||||
|
assert json_payload["bot_id"] == BOT_ID
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("requests.get")
|
||||||
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import requests
|
|||||||
|
|
||||||
from memory.common import discord
|
from memory.common import discord
|
||||||
|
|
||||||
|
BOT_ID = 42
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_api_url():
|
def mock_api_url():
|
||||||
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
"http://localhost:8000/send_dm",
|
"http://localhost:8000/send_dm",
|
||||||
json={"user": "user123", "message": "Hello!"},
|
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
|
|||||||
"""Test DM sending when request raises exception"""
|
"""Test DM sending when request raises exception"""
|
||||||
mock_post.side_effect = requests.RequestException("Network error")
|
mock_post.side_effect = requests.RequestException("Network error")
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.send_dm("user123", "Hello!")
|
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
"http://localhost:8000/send_channel",
|
"http://localhost:8000/send_channel",
|
||||||
json={"channel_name": "general", "message": "Announcement!"},
|
json={
|
||||||
|
"bot_id": BOT_ID,
|
||||||
|
"channel_name": "general",
|
||||||
|
"message": "Announcement!",
|
||||||
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
|||||||
"""Test channel message broadcast with exception"""
|
"""Test channel message broadcast with exception"""
|
||||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||||
|
|
||||||
result = discord.broadcast_message("general", "Announcement!")
|
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
|||||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||||
"""Test health check when collector is healthy"""
|
"""Test health check when collector is healthy"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"status": "healthy"}
|
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||||
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
|
|||||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||||
"""Test health check when collector returns unhealthy status"""
|
"""Test health check when collector returns unhealthy status"""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.json.return_value = {"status": "unhealthy"}
|
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
|||||||
"""Test health check when request fails"""
|
"""Test health check when request fails"""
|
||||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
|
|||||||
"""Test sending error message to error channel"""
|
"""Test sending error message to error channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_error_message("Something broke")
|
result = discord.send_error_message(BOT_ID, "Something broke")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
|
|||||||
"""Test sending activity message to activity channel"""
|
"""Test sending activity message to activity channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_activity_message("User logged in")
|
result = discord.send_activity_message(BOT_ID, "User logged in")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
mock_broadcast.assert_called_once_with(
|
||||||
|
BOT_ID, "activity", "User logged in"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
|
|||||||
"""Test sending discovery message to discovery channel"""
|
"""Test sending discovery message to discovery channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_discovery_message("Found interesting pattern")
|
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
mock_broadcast.assert_called_once_with(
|
||||||
|
BOT_ID, "discoveries", "Found interesting pattern"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.broadcast_message")
|
@patch("memory.common.discord.broadcast_message")
|
||||||
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
|
|||||||
"""Test sending chat message to chat channel"""
|
"""Test sending chat message to chat channel"""
|
||||||
mock_broadcast.return_value = True
|
mock_broadcast.return_value = True
|
||||||
|
|
||||||
result = discord.send_chat_message("Hello from bot")
|
result = discord.send_chat_message(BOT_ID, "Hello from bot")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.discord.send_error_message")
|
@patch("memory.common.discord.send_error_message")
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
def test_notify_task_failure_basic(mock_send_error):
|
def test_notify_task_failure_basic(mock_send_error):
|
||||||
"""Test basic task failure notification"""
|
"""Test basic task failure notification"""
|
||||||
discord.notify_task_failure("test_task", "Something went wrong")
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Something went wrong", bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
message = mock_send_error.call_args[0][0]
|
assert mock_send_error.call_args[0][0] == BOT_ID
|
||||||
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "🚨 **Task Failed: test_task**" in message
|
assert "🚨 **Task Failed: test_task**" in message
|
||||||
assert "**Error:** Something went wrong" in message
|
assert "**Error:** Something went wrong" in message
|
||||||
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
|
|||||||
"Error occurred",
|
"Error occurred",
|
||||||
task_args=("arg1", 42),
|
task_args=("arg1", 42),
|
||||||
task_kwargs={"key": "value", "number": 123},
|
task_kwargs={"key": "value", "number": 123},
|
||||||
|
bot_id=BOT_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "**Args:** `('arg1', 42)" in message
|
assert "**Args:** `('arg1', 42)" in message
|
||||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||||
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
|
|||||||
"""Test task failure notification with traceback"""
|
"""Test task failure notification with traceback"""
|
||||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
assert "**Traceback:**" in message
|
assert "**Traceback:**" in message
|
||||||
assert "Exception: test" in message
|
assert "Exception: test" in message
|
||||||
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
|
|||||||
"""Test that long error messages are truncated"""
|
"""Test that long error messages are truncated"""
|
||||||
long_error = "x" * 600
|
long_error = "x" * 600
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", long_error)
|
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||||
assert "**Error:** " + long_error[:500] in message
|
assert "**Error:** " + long_error[:500] in message
|
||||||
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
|||||||
"""Test that long tracebacks are truncated"""
|
"""Test that long tracebacks are truncated"""
|
||||||
long_traceback = "x" * 1000
|
long_traceback = "x" * 1000
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Traceback should show last 800 chars
|
# Traceback should show last 800 chars
|
||||||
assert long_traceback[-800:] in message
|
assert long_traceback[-800:] in message
|
||||||
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
|
|||||||
"""Test that long task arguments are truncated"""
|
"""Test that long task arguments are truncated"""
|
||||||
long_args = ("x" * 300,)
|
long_args = ("x" * 300,)
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Args should be truncated to 200 chars
|
# Args should be truncated to 200 chars
|
||||||
assert (
|
assert (
|
||||||
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
|||||||
"""Test that long task kwargs are truncated"""
|
"""Test that long task kwargs are truncated"""
|
||||||
long_kwargs = {"key": "x" * 300}
|
long_kwargs = {"key": "x" * 300}
|
||||||
|
|
||||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
discord.notify_task_failure(
|
||||||
|
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
message = mock_send_error.call_args[0][0]
|
message = mock_send_error.call_args[0][1]
|
||||||
|
|
||||||
# Kwargs should be truncated to 200 chars
|
# Kwargs should be truncated to 200 chars
|
||||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||||
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
|||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
def test_notify_task_failure_disabled(mock_send_error):
|
def test_notify_task_failure_disabled(mock_send_error):
|
||||||
"""Test that notifications are not sent when disabled"""
|
"""Test that notifications are not sent when disabled"""
|
||||||
discord.notify_task_failure("test_task", "Error occurred")
|
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||||
|
|
||||||
mock_send_error.assert_not_called()
|
mock_send_error.assert_not_called()
|
||||||
|
|
||||||
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
|
|||||||
mock_send_error.side_effect = Exception("Failed to send")
|
mock_send_error.side_effect = Exception("Failed to send")
|
||||||
|
|
||||||
# Should not raise
|
# Should not raise
|
||||||
discord.notify_task_failure("test_task", "Error occurred")
|
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||||
|
|
||||||
mock_send_error.assert_called_once()
|
mock_send_error.assert_called_once()
|
||||||
|
|
||||||
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
|
|||||||
):
|
):
|
||||||
"""Test that convenience functions use the correct channel settings"""
|
"""Test that convenience functions use the correct channel settings"""
|
||||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||||
function(message)
|
function(BOT_ID, message)
|
||||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.post")
|
@patch("requests.post")
|
||||||
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
|||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||||
result = discord.send_dm("user123", message_with_special_chars)
|
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
call_args = mock_post.call_args
|
call_args = mock_post.call_args
|
||||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
json_payload = call_args[1]["json"]
|
||||||
|
assert json_payload["message"] == message_with_special_chars
|
||||||
|
assert json_payload["bot_id"] == BOT_ID
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.post")
|
@patch("requests.post")
|
||||||
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
|||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
long_message = "A" * 2000
|
long_message = "A" * 2000
|
||||||
result = discord.broadcast_message("general", long_message)
|
result = discord.broadcast_message(BOT_ID, "general", long_message)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
call_args = mock_post.call_args
|
call_args = mock_post.call_args
|
||||||
assert call_args[1]["json"]["message"] == long_message
|
json_payload = call_args[1]["json"]
|
||||||
|
assert json_payload["message"] == long_message
|
||||||
|
assert json_payload["bot_id"] == BOT_ID
|
||||||
|
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("requests.get")
|
||||||
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
result = discord.is_collector_healthy()
|
result = discord.is_collector_healthy(BOT_ID)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from datetime import datetime, timezone
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
|
DiscordBotUser,
|
||||||
DiscordMessage,
|
DiscordMessage,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
DiscordServer,
|
DiscordServer,
|
||||||
@ -12,12 +13,25 @@ from memory.workers.tasks import discord
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_discord_user(db_session):
|
def discord_bot_user(db_session):
|
||||||
|
bot = DiscordBotUser.create_with_api_key(
|
||||||
|
discord_users=[],
|
||||||
|
name="Test Bot",
|
||||||
|
email="bot@example.com",
|
||||||
|
)
|
||||||
|
db_session.add(bot)
|
||||||
|
db_session.commit()
|
||||||
|
return bot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_discord_user(db_session, discord_bot_user):
|
||||||
"""Create a Discord user for testing."""
|
"""Create a Discord user for testing."""
|
||||||
user = DiscordUser(
|
user = DiscordUser(
|
||||||
id=123456789,
|
id=123456789,
|
||||||
username="testuser",
|
username="testuser",
|
||||||
ignore_messages=False,
|
ignore_messages=False,
|
||||||
|
system_user_id=discord_bot_user.id,
|
||||||
)
|
)
|
||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|||||||
@ -3,17 +3,23 @@ from datetime import datetime, timezone, timedelta
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer
|
from memory.common.db.models import (
|
||||||
|
ScheduledLLMCall,
|
||||||
|
DiscordBotUser,
|
||||||
|
DiscordUser,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordServer,
|
||||||
|
)
|
||||||
from memory.workers.tasks import scheduled_calls
|
from memory.workers.tasks import scheduled_calls
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_user(db_session):
|
def sample_user(db_session):
|
||||||
"""Create a sample user for testing."""
|
"""Create a sample user for testing."""
|
||||||
user = HumanUser.create_with_password(
|
user = DiscordBotUser.create_with_api_key(
|
||||||
name="testuser",
|
discord_users=[],
|
||||||
email="test@example.com",
|
name="testbot",
|
||||||
password="password",
|
email="bot@example.com",
|
||||||
)
|
)
|
||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
@ -124,6 +130,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
|||||||
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
||||||
|
|
||||||
mock_send_dm.assert_called_once_with(
|
mock_send_dm.assert_called_once_with(
|
||||||
|
pending_scheduled_call.user_id,
|
||||||
"testuser", # username, not ID
|
"testuser", # username, not ID
|
||||||
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
||||||
)
|
)
|
||||||
@ -137,6 +144,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
|||||||
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
||||||
|
|
||||||
mock_broadcast.assert_called_once_with(
|
mock_broadcast.assert_called_once_with(
|
||||||
|
completed_scheduled_call.user_id,
|
||||||
"test-channel", # channel name, not ID
|
"test-channel", # channel name, not ID
|
||||||
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
||||||
)
|
)
|
||||||
@ -151,7 +159,8 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
|
|||||||
|
|
||||||
# Verify the message was truncated
|
# Verify the message was truncated
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
message = args[1]
|
assert args[0] == pending_scheduled_call.user_id
|
||||||
|
message = args[2]
|
||||||
assert len(message) <= 1950 # Should be truncated
|
assert len(message) <= 1950 # Should be truncated
|
||||||
assert message.endswith("... (response truncated)")
|
assert message.endswith("... (response truncated)")
|
||||||
|
|
||||||
@ -164,7 +173,8 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
|
|||||||
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
||||||
|
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
message = args[1]
|
assert args[0] == pending_scheduled_call.user_id
|
||||||
|
message = args[2]
|
||||||
assert not message.endswith("... (response truncated)")
|
assert not message.endswith("... (response truncated)")
|
||||||
assert "This is a normal length response." in message
|
assert "This is a normal length response." in message
|
||||||
|
|
||||||
@ -574,6 +584,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
|||||||
mock_discord_user.username = "testuser"
|
mock_discord_user.username = "testuser"
|
||||||
|
|
||||||
mock_call = Mock()
|
mock_call = Mock()
|
||||||
|
mock_call.user_id = 987
|
||||||
mock_call.topic = topic
|
mock_call.topic = topic
|
||||||
mock_call.model = model
|
mock_call.model = model
|
||||||
mock_call.discord_user = mock_discord_user
|
mock_call.discord_user = mock_discord_user
|
||||||
@ -583,7 +594,8 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
|||||||
|
|
||||||
# Get the actual message that was sent
|
# Get the actual message that was sent
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
actual_message = args[1]
|
assert args[0] == mock_call.user_id
|
||||||
|
actual_message = args[2]
|
||||||
|
|
||||||
# Verify all expected parts are in the message
|
# Verify all expected parts are in the message
|
||||||
for expected_part in expected_in_message:
|
for expected_part in expected_in_message:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user