Add deployer

This commit is contained in:
Daniel O'Connell 2025-12-24 16:25:53 +01:00
parent 47180e1e71
commit d3d71edf1d
11 changed files with 342 additions and 170 deletions

View File

@ -766,7 +766,7 @@ EXPECTED_OBSERVATION_RESULTS = {
), ),
( (
0.409, 0.409,
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI", "Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
), ),
], ],
}, },
@ -835,11 +835,11 @@ EXPECTED_OBSERVATION_RESULTS = {
"semantic": [ "semantic": [
(0.489, "I find backend logic more interesting than UI work"), (0.489, "I find backend logic more interesting than UI work"),
(0.462, "The user prefers working on backend systems over frontend UI"), (0.462, "The user prefers working on backend systems over frontend UI"),
(0.455, "The user said pure functions are yucky"),
( (
0.455, 0.455,
"The user believes functional programming leads to better code quality", "The user believes functional programming leads to better code quality",
), ),
(0.455, "The user said pure functions are yucky"),
], ],
"temporal": [ "temporal": [
( (

View File

@ -137,10 +137,9 @@ class TestBuildPrompt:
): ):
prompt = _build_prompt() prompt = _build_prompt()
assert "lesswrong" in prompt.lower() assert "Remove meta-language" in prompt
assert "comic" in prompt.lower()
assert "Remove" in prompt
assert "Return ONLY valid JSON" in prompt assert "Return ONLY valid JSON" in prompt
assert "recalled_content" in prompt
class TestAnalyzeQuery: class TestAnalyzeQuery:

View File

@ -83,9 +83,10 @@ def test_logout_handles_missing_session(mock_get_user_session):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock) @patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session") @patch("memory.api.auth.make_session")
async def test_oauth_callback_discord_success(mock_make_session, mock_complete): async def test_oauth_callback_discord_success(mock_make_session, mock_complete, mock_mcp_tools):
mock_session = MagicMock() mock_session = MagicMock()
@contextmanager @contextmanager
@ -95,9 +96,12 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
mock_make_session.return_value = session_cm() mock_make_session.return_value = session_cm()
mcp_server = MagicMock() mcp_server = MagicMock()
mcp_server.mcp_server_url = "https://example.com"
mcp_server.access_token = "token123"
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (200, "Authorized") mock_complete.return_value = (200, "Authorized")
mock_mcp_tools.return_value = [{"name": "test_tool"}]
request = make_request("code=abc123&state=state456") request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request) response = await auth.oauth_callback_discord(request)
@ -107,14 +111,15 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
assert "Authorization Successful" in body assert "Authorization Successful" in body
assert "Authorized" in body assert "Authorized" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456") mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
mock_session.commit.assert_called_once() assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock) @patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session") @patch("memory.api.auth.make_session")
async def test_oauth_callback_discord_handles_failures( async def test_oauth_callback_discord_handles_failures(
mock_make_session, mock_complete mock_make_session, mock_complete, mock_mcp_tools
): ):
mock_session = MagicMock() mock_session = MagicMock()
@ -125,9 +130,12 @@ async def test_oauth_callback_discord_handles_failures(
mock_make_session.return_value = session_cm() mock_make_session.return_value = session_cm()
mcp_server = MagicMock() mcp_server = MagicMock()
mcp_server.mcp_server_url = "https://example.com"
mcp_server.access_token = "token123"
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (500, "Failure") mock_complete.return_value = (500, "Failure")
mock_mcp_tools.return_value = []
request = make_request("code=abc123&state=state456") request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request) response = await auth.oauth_callback_discord(request)
@ -137,7 +145,7 @@ async def test_oauth_callback_discord_handles_failures(
assert "Authorization Failed" in body assert "Authorization Failed" in body
assert "Failure" in body assert "Failure" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456") mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
mock_session.commit.assert_called_once() assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -18,6 +18,7 @@ from memory.common.db.models import (
DiscordUser, DiscordUser,
DiscordMessage, DiscordMessage,
BotUser, BotUser,
DiscordBotUser,
HumanUser, HumanUser,
ScheduledLLMCall, ScheduledLLMCall,
) )
@ -67,9 +68,10 @@ def sample_discord_user(db_session):
@pytest.fixture @pytest.fixture
def sample_bot_user(db_session): def sample_bot_user(db_session, sample_discord_user):
"""Create a sample bot user for testing.""" """Create a sample bot user for testing."""
bot = BotUser.create_with_api_key( bot = DiscordBotUser.create_with_api_key(
discord_users=[sample_discord_user],
name="Test Bot", name="Test Bot",
email="testbot@example.com", email="testbot@example.com",
) )
@ -209,9 +211,9 @@ def test_schedule_message_with_user(
future_time = datetime.now(timezone.utc) + timedelta(hours=1) future_time = datetime.now(timezone.utc) + timedelta(hours=1)
result = schedule_message( result = schedule_message(
user_id=sample_human_user.id, bot_id=sample_human_user.id,
user=sample_discord_user.id, recipient_id=sample_discord_user.id,
channel=None, channel_id=None,
model="test-model", model="test-model",
message="Test message", message="Test message",
date_time=future_time, date_time=future_time,
@ -240,9 +242,9 @@ def test_schedule_message_with_channel(
future_time = datetime.now(timezone.utc) + timedelta(hours=1) future_time = datetime.now(timezone.utc) + timedelta(hours=1)
result = schedule_message( result = schedule_message(
user_id=sample_human_user.id, bot_id=sample_human_user.id,
user=None, recipient_id=None,
channel=sample_discord_channel.id, channel_id=sample_discord_channel.id,
model="test-model", model="test-model",
message="Test message", message="Test message",
date_time=future_time, date_time=future_time,
@ -265,12 +267,12 @@ def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user):
"""Test creating a message scheduler tool for a user.""" """Test creating a message scheduler tool for a user."""
tool = make_message_scheduler( tool = make_message_scheduler(
bot=sample_bot_user, bot=sample_bot_user,
user=sample_discord_user.id, user_id=sample_discord_user.id,
channel=None, channel_id=None,
model="test-model", model="test-model",
) )
assert tool.name == "schedule_message" assert tool.name == "schedule_discord_message"
assert "from your chat with this user" in tool.description assert "from your chat with this user" in tool.description
assert tool.input_schema["type"] == "object" assert tool.input_schema["type"] == "object"
assert "message" in tool.input_schema["properties"] assert "message" in tool.input_schema["properties"]
@ -282,12 +284,12 @@ def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_cha
"""Test creating a message scheduler tool for a channel.""" """Test creating a message scheduler tool for a channel."""
tool = make_message_scheduler( tool = make_message_scheduler(
bot=sample_bot_user, bot=sample_bot_user,
user=None, user_id=None,
channel=sample_discord_channel.id, channel_id=sample_discord_channel.id,
model="test-model", model="test-model",
) )
assert tool.name == "schedule_message" assert tool.name == "schedule_discord_message"
assert "in this channel" in tool.description assert "in this channel" in tool.description
assert callable(tool.function) assert callable(tool.function)
@ -297,8 +299,8 @@ def test_make_message_scheduler_without_user_or_channel(sample_bot_user):
with pytest.raises(ValueError, match="Either user or channel must be provided"): with pytest.raises(ValueError, match="Either user or channel must be provided"):
make_message_scheduler( make_message_scheduler(
bot=sample_bot_user, bot=sample_bot_user,
user=None, user_id=None,
channel=None, channel_id=None,
model="test-model", model="test-model",
) )
@ -310,8 +312,8 @@ def test_message_scheduler_handler_success(
"""Test message scheduler handler with valid input.""" """Test message scheduler handler with valid input."""
tool = make_message_scheduler( tool = make_message_scheduler(
bot=sample_bot_user, bot=sample_bot_user,
user=sample_discord_user.id, user_id=sample_discord_user.id,
channel=None, channel_id=None,
model="test-model", model="test-model",
) )
@ -330,8 +332,8 @@ def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord
"""Test message scheduler handler with non-dict input.""" """Test message scheduler handler with non-dict input."""
tool = make_message_scheduler( tool = make_message_scheduler(
bot=sample_bot_user, bot=sample_bot_user,
user=sample_discord_user.id, user_id=sample_discord_user.id,
channel=None, channel_id=None,
model="test-model", model="test-model",
) )
@ -345,8 +347,8 @@ def test_message_scheduler_handler_invalid_datetime(
"""Test message scheduler handler with invalid datetime.""" """Test message scheduler handler with invalid datetime."""
tool = make_message_scheduler( tool = make_message_scheduler(
bot=sample_bot_user, bot=sample_bot_user,
user=sample_discord_user.id, user_id=sample_discord_user.id,
channel=None, channel_id=None,
model="test-model", model="test-model",
) )
@ -365,8 +367,8 @@ def test_message_scheduler_handler_missing_datetime(
"""Test message scheduler handler with missing datetime.""" """Test message scheduler handler with missing datetime."""
tool = make_message_scheduler( tool = make_message_scheduler(
bot=sample_bot_user, bot=sample_bot_user,
user=sample_discord_user.id, user_id=sample_discord_user.id,
channel=None, channel_id=None,
model="test-model", model="test-model",
) )
@ -375,9 +377,9 @@ def test_message_scheduler_handler_missing_datetime(
# Tests for make_prev_messages_tool # Tests for make_prev_messages_tool
def test_make_prev_messages_tool_with_user(sample_discord_user): def test_make_prev_messages_tool_with_user(sample_bot_user, sample_discord_user):
"""Test creating a previous messages tool for a user.""" """Test creating a previous messages tool for a user."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
assert tool.name == "previous_messages" assert tool.name == "previous_messages"
assert "from your chat with this user" in tool.description assert "from your chat with this user" in tool.description
@ -387,26 +389,26 @@ def test_make_prev_messages_tool_with_user(sample_discord_user):
assert callable(tool.function) assert callable(tool.function)
def test_make_prev_messages_tool_with_channel(sample_discord_channel): def test_make_prev_messages_tool_with_channel(sample_bot_user, sample_discord_channel):
"""Test creating a previous messages tool for a channel.""" """Test creating a previous messages tool for a channel."""
tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id) tool = make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=sample_discord_channel.id)
assert tool.name == "previous_messages" assert tool.name == "previous_messages"
assert "in this channel" in tool.description assert "in this channel" in tool.description
assert callable(tool.function) assert callable(tool.function)
def test_make_prev_messages_tool_without_user_or_channel(): def test_make_prev_messages_tool_without_user_or_channel(sample_bot_user):
"""Test that creating a tool without user or channel raises error.""" """Test that creating a tool without user or channel raises error."""
with pytest.raises(ValueError, match="Either user or channel must be provided"): with pytest.raises(ValueError, match="Either user or channel must be provided"):
make_prev_messages_tool(user=None, channel=None) make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=None)
def test_prev_messages_handler_success( def test_prev_messages_handler_success(
db_session, sample_discord_user, sample_discord_channel db_session, sample_bot_user, sample_discord_user, sample_discord_channel
): ):
"""Test previous messages handler with valid input.""" """Test previous messages handler with valid input."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
# Create some actual messages in the database # Create some actual messages in the database
msg1 = DiscordMessage( msg1 = DiscordMessage(
@ -440,9 +442,9 @@ def test_prev_messages_handler_success(
assert "Message 1" in result or "Message 2" in result assert "Message 1" in result or "Message 2" in result
def test_prev_messages_handler_with_defaults(db_session, sample_discord_user): def test_prev_messages_handler_with_defaults(db_session, sample_bot_user, sample_discord_user):
"""Test previous messages handler with default values.""" """Test previous messages handler with default values."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
result = tool.function({}) result = tool.function({})
@ -450,35 +452,35 @@ def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
assert isinstance(result, str) assert isinstance(result, str)
def test_prev_messages_handler_invalid_input(sample_discord_user): def test_prev_messages_handler_invalid_input(sample_bot_user, sample_discord_user):
"""Test previous messages handler with non-dict input.""" """Test previous messages handler with non-dict input."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Input must be a dictionary"): with pytest.raises(ValueError, match="Input must be a dictionary"):
tool.function("not a dict") tool.function("not a dict")
def test_prev_messages_handler_invalid_max_messages(sample_discord_user): def test_prev_messages_handler_invalid_max_messages(sample_bot_user, sample_discord_user):
"""Test previous messages handler with invalid max_messages (negative value).""" """Test previous messages handler with invalid max_messages (negative value)."""
# Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting, # Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
# so we test with -1 which actually triggers the validation # so we test with -1 which actually triggers the validation
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Max messages must be greater than 0"): with pytest.raises(ValueError, match="Max messages must be greater than 0"):
tool.function({"max_messages": -1}) tool.function({"max_messages": -1})
def test_prev_messages_handler_invalid_offset(sample_discord_user): def test_prev_messages_handler_invalid_offset(sample_bot_user, sample_discord_user):
"""Test previous messages handler with invalid offset.""" """Test previous messages handler with invalid offset."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"): with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
tool.function({"offset": -1}) tool.function({"offset": -1})
def test_prev_messages_handler_non_integer_values(sample_discord_user): def test_prev_messages_handler_non_integer_values(sample_bot_user, sample_discord_user):
"""Test previous messages handler with non-integer values.""" """Test previous messages handler with non-integer values."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Max messages and offset must be integers"): with pytest.raises(ValueError, match="Max messages and offset must be integers"):
tool.function({"max_messages": "not an int"}) tool.function({"max_messages": "not an int"})
@ -496,10 +498,10 @@ def test_make_discord_tools_with_user_and_channel(
model="test-model", model="test-model",
) )
# Should have: schedule_message, previous_messages, update_channel_summary, # Should have: schedule_discord_message, previous_messages, update_channel_summary,
# update_user_summary, update_server_summary, add_reaction # update_user_summary, update_server_summary, add_reaction
assert len(tools) == 6 assert len(tools) == 6
assert "schedule_message" in tools assert "schedule_discord_message" in tools
assert "previous_messages" in tools assert "previous_messages" in tools
assert "update_channel_summary" in tools assert "update_channel_summary" in tools
assert "update_user_summary" in tools assert "update_user_summary" in tools
@ -516,10 +518,10 @@ def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user)
model="test-model", model="test-model",
) )
# Should have: schedule_message, previous_messages, update_user_summary # Should have: schedule_discord_message, previous_messages, update_user_summary
# Note: Without channel, there's no channel summary tool # Note: Without channel, there's no channel summary tool
assert len(tools) >= 2 # At least schedule and previous messages assert len(tools) >= 2 # At least schedule and previous messages
assert "schedule_message" in tools assert "schedule_discord_message" in tools
assert "previous_messages" in tools assert "previous_messages" in tools
assert "update_user_summary" in tools assert "update_user_summary" in tools
@ -533,10 +535,10 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
model="test-model", model="test-model",
) )
# Should have: schedule_message, previous_messages, update_channel_summary, # Should have: schedule_discord_message, previous_messages, update_channel_summary,
# update_server_summary, add_reaction (no user summary without author) # update_server_summary, add_reaction (no user summary without author)
assert len(tools) == 5 assert len(tools) == 5
assert "schedule_message" in tools assert "schedule_discord_message" in tools
assert "previous_messages" in tools assert "previous_messages" in tools
assert "update_channel_summary" in tools assert "update_channel_summary" in tools
assert "update_server_summary" in tools assert "update_server_summary" in tools

View File

@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
"http://localhost:8000/send_channel", "http://localhost:8000/send_channel",
json={ json={
"bot_id": BOT_ID, "bot_id": BOT_ID,
"channel_name": "general", "channel": "general",
"message": "Announcement!", "message": "Announcement!",
}, },
timeout=10, timeout=10,

View File

@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
"http://localhost:8000/send_channel", "http://localhost:8000/send_channel",
json={ json={
"bot_id": BOT_ID, "bot_id": BOT_ID,
"channel_name": "general", "channel": "general",
"message": "Announcement!", "message": "Announcement!",
}, },
timeout=10, timeout=10,

View File

@ -1,6 +1,5 @@
"""Tests for Discord MCP server management.""" """Tests for Discord MCP server management."""
import json
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import aiohttp import aiohttp
@ -8,8 +7,8 @@ import discord
import pytest import pytest
from memory.common.db.models import MCPServer, MCPServerAssignment from memory.common.db.models import MCPServer, MCPServerAssignment
from memory.common.mcp import mcp_call
from memory.discord.mcp import ( from memory.discord.mcp import (
call_mcp_server,
find_mcp_server, find_mcp_server,
handle_mcp_add, handle_mcp_add,
handle_mcp_connect, handle_mcp_connect,
@ -142,7 +141,7 @@ async def test_call_mcp_server_success():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx): with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = [] results = []
async for data in call_mcp_server( async for data in mcp_call(
"https://mcp.example.com", "test_token", "tools/list", {} "https://mcp.example.com", "test_token", "tools/list", {}
): ):
results.append(data) results.append(data)
@ -172,7 +171,7 @@ async def test_call_mcp_server_error():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx): with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
with pytest.raises(ValueError, match="Failed to call MCP server"): with pytest.raises(ValueError, match="Failed to call MCP server"):
async for _ in call_mcp_server( async for _ in mcp_call(
"https://mcp.example.com", "test_token", "tools/list" "https://mcp.example.com", "test_token", "tools/list"
): ):
pass pass
@ -203,7 +202,7 @@ async def test_call_mcp_server_invalid_json():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx): with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = [] results = []
async for data in call_mcp_server( async for data in mcp_call(
"https://mcp.example.com", "test_token", "tools/list" "https://mcp.example.com", "test_token", "tools/list"
): ):
results.append(data) results.append(data)

View File

@ -19,6 +19,7 @@ from memory.common.db.models import (
DiscordChannel, DiscordChannel,
DiscordServer, DiscordServer,
DiscordMessage, DiscordMessage,
DiscordBotUser,
HumanUser, HumanUser,
ScheduledLLMCall, ScheduledLLMCall,
) )
@ -34,6 +35,19 @@ def sample_discord_user(db_session):
return user return user
@pytest.fixture
def sample_bot_user(db_session, sample_discord_user):
"""Create a sample Discord bot user."""
bot = DiscordBotUser.create_with_api_key(
discord_users=[sample_discord_user],
name="Test Bot",
email="testbot@example.com",
)
db_session.add(bot)
db_session.commit()
return bot
@pytest.fixture @pytest.fixture
def sample_discord_channel(db_session): def sample_discord_channel(db_session):
"""Create a sample Discord channel.""" """Create a sample Discord channel."""
@ -290,13 +304,13 @@ def test_upsert_scheduled_message_cancels_earlier_call(
# Test previous_messages # Test previous_messages
def test_previous_messages_empty(db_session): def test_previous_messages_empty(db_session, sample_bot_user):
"""Test getting previous messages when none exist.""" """Test getting previous messages when none exist."""
result = previous_messages(db_session, user_id=123, channel_id=456) result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=123, channel_id=456)
assert result == [] assert result == []
def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel): def test_previous_messages_filters_by_user(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test filtering messages by recipient user.""" """Test filtering messages by recipient user."""
# Create some messages # Create some messages
msg1 = DiscordMessage( msg1 = DiscordMessage(
@ -322,14 +336,14 @@ def test_previous_messages_filters_by_user(db_session, sample_discord_user, samp
db_session.add_all([msg1, msg2]) db_session.add_all([msg1, msg2])
db_session.commit() db_session.commit()
result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None) result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None)
assert len(result) == 2 assert len(result) == 2
# Should be in chronological order (oldest first) # Should be in chronological order (oldest first)
assert result[0].message_id == 1 assert result[0].message_id == 1
assert result[1].message_id == 2 assert result[1].message_id == 2
def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel): def test_previous_messages_limits_results(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test limiting the number of previous messages.""" """Test limiting the number of previous messages."""
# Create 15 messages # Create 15 messages
for i in range(15): for i in range(15):
@ -347,7 +361,7 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
db_session.commit() db_session.commit()
result = previous_messages( result = previous_messages(
db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5 db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None, max_messages=5
) )
assert len(result) == 5 assert len(result) == 5
@ -355,10 +369,10 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
# Test comm_channel_prompt # Test comm_channel_prompt
def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel): def test_comm_channel_prompt_basic(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test generating a basic communication channel prompt.""" """Test generating a basic communication channel prompt."""
result = comm_channel_prompt( result = comm_channel_prompt(
db_session, user=sample_discord_user, channel=sample_discord_channel db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
) )
assert "You are a bot communicating on Discord" in result assert "You are a bot communicating on Discord" in result
@ -366,31 +380,31 @@ def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_disco
assert len(result) > 0 assert len(result) > 0
def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel): def test_comm_channel_prompt_includes_server_context(db_session, sample_bot_user, sample_discord_channel):
"""Test that prompt includes server context when available.""" """Test that prompt includes server context when available."""
server = sample_discord_channel.server server = sample_discord_channel.server
server.summary = "Gaming community server" server.summary = "Gaming community server"
db_session.commit() db_session.commit()
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel) result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
assert "server_context" in result.lower() assert "server_context" in result.lower()
assert "Gaming community server" in result assert "Gaming community server" in result
def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel): def test_comm_channel_prompt_includes_channel_context(db_session, sample_bot_user, sample_discord_channel):
"""Test that prompt includes channel context.""" """Test that prompt includes channel context."""
sample_discord_channel.summary = "General discussion channel" sample_discord_channel.summary = "General discussion channel"
db_session.commit() db_session.commit()
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel) result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
assert "channel_context" in result.lower() assert "channel_context" in result.lower()
assert "General discussion channel" in result assert "General discussion channel" in result
def test_comm_channel_prompt_includes_user_notes( def test_comm_channel_prompt_includes_user_notes(
db_session, sample_discord_user, sample_discord_channel db_session, sample_bot_user, sample_discord_user, sample_discord_channel
): ):
"""Test that prompt includes user notes from previous messages.""" """Test that prompt includes user notes from previous messages."""
sample_discord_user.summary = "Helpful community member" sample_discord_user.summary = "Helpful community member"
@ -411,7 +425,7 @@ def test_comm_channel_prompt_includes_user_notes(
db_session.commit() db_session.commit()
result = comm_channel_prompt( result = comm_channel_prompt(
db_session, user=sample_discord_user, channel=sample_discord_channel db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
) )
assert "user_notes" in result.lower() assert "user_notes" in result.lower()
@ -442,12 +456,16 @@ def test_call_llm_includes_web_search_and_mcp_servers(
web_tool_instance = MagicMock(name="web_tool") web_tool_instance = MagicMock(name="web_tool")
mock_web_search.return_value = web_tool_instance mock_web_search.return_value = web_tool_instance
bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt") bot_user = SimpleNamespace(
system_user=SimpleNamespace(discord_id=999888777),
system_prompt="bot prompt"
)
from_user = SimpleNamespace(id=123) from_user = SimpleNamespace(id=123)
mcp_model = SimpleNamespace( mcp_model = SimpleNamespace(
name="Server", name="Server",
mcp_server_url="https://mcp.example.com", mcp_server_url="https://mcp.example.com",
access_token="token123", access_token="token123",
disabled_tools=[],
) )
result = call_llm( result = call_llm(
@ -502,7 +520,10 @@ def test_call_llm_filters_disallowed_tools(
mock_web_search.return_value = MagicMock(name="web_tool") mock_web_search.return_value = MagicMock(name="web_tool")
bot_user = SimpleNamespace(system_user="system-user", system_prompt=None) bot_user = SimpleNamespace(
system_user=SimpleNamespace(discord_id=999888777),
system_prompt=None
)
from_user = SimpleNamespace(id=1) from_user = SimpleNamespace(id=1)
call_llm( call_llm(

View File

@ -14,12 +14,25 @@ from memory.workers.tasks import discord
@pytest.fixture @pytest.fixture
def discord_bot_user(db_session): def discord_bot_user(db_session):
# Create a discord user for the bot first
bot_discord_user = DiscordUser(
id=999999999,
username="testbot",
)
db_session.add(bot_discord_user)
db_session.flush()
bot = DiscordBotUser.create_with_api_key( bot = DiscordBotUser.create_with_api_key(
discord_users=[], discord_users=[bot_discord_user],
name="Test Bot", name="Test Bot",
email="bot@example.com", email="bot@example.com",
) )
db_session.add(bot) db_session.add(bot)
db_session.flush()
# Link the discord user to the system user
bot_discord_user.system_user_id = bot.id
db_session.commit() db_session.commit()
return bot return bot
@ -176,26 +189,29 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel):
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True) @patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.workers.tasks.discord.create_provider") @patch("memory.workers.tasks.discord.call_llm")
@patch("memory.workers.tasks.discord.discord.trigger_typing_channel")
def test_should_process_normal_message( def test_should_process_normal_message(
mock_create_provider, mock_trigger_typing,
mock_call_llm,
db_session, db_session,
mock_discord_user, mock_discord_user,
mock_discord_server, mock_discord_server,
mock_discord_channel, mock_discord_channel,
discord_bot_user,
): ):
"""Test should_process returns True for normal messages.""" """Test should_process returns True for normal messages."""
# Mock the LLM provider to return "yes" # Create a separate recipient user (the bot)
mock_provider = Mock() bot_discord_user = discord_bot_user.discord_users[0]
mock_provider.generate.return_value = "<response>yes</response>"
mock_provider.as_messages.return_value = [] # Mock call_llm to return a high number (100 = always process)
mock_create_provider.return_value = mock_provider mock_call_llm.return_value = "<response><number>100</number><reason>Test</reason></response>"
message = DiscordMessage( message = DiscordMessage(
message_id=1, message_id=1,
channel_id=mock_discord_channel.id, channel_id=mock_discord_channel.id,
from_id=mock_discord_user.id, from_id=mock_discord_user.id,
recipient_id=mock_discord_user.id, recipient_id=bot_discord_user.id, # Bot is recipient, not the from_user
server_id=mock_discord_server.id, server_id=mock_discord_server.id,
content="Test", content="Test",
sent_at=datetime.now(timezone.utc), sent_at=datetime.now(timezone.utc),
@ -207,6 +223,8 @@ def test_should_process_normal_message(
db_session.refresh(message) db_session.refresh(message)
assert discord.should_process(message) is True assert discord.should_process(message) is True
mock_call_llm.assert_called_once()
mock_trigger_typing.assert_called_once()
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False) @patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
@ -344,6 +362,7 @@ def test_add_discord_message_success(db_session, sample_message_data, qdrant):
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant): def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
"""Test adding a Discord message that is a reply.""" """Test adding a Discord message that is a reply."""
sample_message_data["message_reference_id"] = 111222333 sample_message_data["message_reference_id"] = 111222333
sample_message_data["message_type"] = "reply" # Explicitly set message_type
discord.add_discord_message(**sample_message_data) discord.add_discord_message(**sample_message_data)
@ -523,8 +542,17 @@ def test_edit_discord_message_updates_context(
assert result["status"] == "processed" assert result["status"] == "processed"
def test_process_discord_message_success(db_session, sample_message_data, qdrant): @patch("memory.workers.tasks.discord.send_discord_response")
@patch("memory.workers.tasks.discord.call_llm")
def test_process_discord_message_success(
mock_call_llm, mock_send_response, db_session, sample_message_data, qdrant
):
"""Test processing a Discord message.""" """Test processing a Discord message."""
# Mock LLM to return a response
mock_call_llm.return_value = "Test response from bot"
# Mock Discord API to succeed
mock_send_response.return_value = True
# Add a message first # Add a message first
add_result = discord.add_discord_message(**sample_message_data) add_result = discord.add_discord_message(**sample_message_data)
message_id = add_result["discordmessage_id"] message_id = add_result["discordmessage_id"]
@ -534,6 +562,8 @@ def test_process_discord_message_success(db_session, sample_message_data, qdrant
assert result["status"] == "processed" assert result["status"] == "processed"
assert result["message_id"] == message_id assert result["message_id"] == message_id
mock_call_llm.assert_called_once()
mock_send_response.assert_called_once()
def test_process_discord_message_not_found(db_session): def test_process_discord_message_not_found(db_session):

View File

@ -16,8 +16,16 @@ from memory.workers.tasks import scheduled_calls
@pytest.fixture @pytest.fixture
def sample_user(db_session): def sample_user(db_session):
"""Create a sample user for testing.""" """Create a sample user for testing."""
# Create a discord user for the bot
bot_discord_user = DiscordUser(
id=999999999,
username="testbot",
)
db_session.add(bot_discord_user)
db_session.flush()
user = DiscordBotUser.create_with_api_key( user = DiscordBotUser.create_with_api_key(
discord_users=[], discord_users=[bot_discord_user],
name="testbot", name="testbot",
email="bot@example.com", email="bot@example.com",
) )
@ -122,65 +130,64 @@ def future_scheduled_call(db_session, sample_user, sample_discord_user):
return call return call
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") @patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call): def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
"""Test sending to Discord user.""" """Test sending to Discord user."""
response = "This is a test response." response = "This is a test response."
scheduled_calls.send_to_discord(pending_scheduled_call, response) scheduled_calls.send_to_discord(999999999, pending_scheduled_call, response)
mock_send_dm.assert_called_once_with( mock_send_dm.assert_called_once_with(
pending_scheduled_call.user_id, 999999999, # bot_id
"testuser", # username, not ID "testuser", # username, not ID
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.", response,
) )
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message") @patch("memory.discord.messages.discord.send_to_channel")
def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call): def test_send_to_discord_channel(mock_send_to_channel, completed_scheduled_call):
"""Test sending to Discord channel.""" """Test sending to Discord channel."""
response = "This is a channel response." response = "This is a channel response."
scheduled_calls.send_to_discord(completed_scheduled_call, response) scheduled_calls.send_to_discord(999999999, completed_scheduled_call, response)
mock_broadcast.assert_called_once_with( mock_send_to_channel.assert_called_once_with(
completed_scheduled_call.user_id, 999999999, # bot_id
"test-channel", # channel name, not ID completed_scheduled_call.discord_channel.id, # channel ID, not name
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.", response,
) )
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") @patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call): def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
"""Test message truncation for long responses.""" """Test message truncation for long responses."""
long_response = "A" * 2500 # Very long response long_response = "A" * 2500 # Very long response
scheduled_calls.send_to_discord(pending_scheduled_call, long_response) scheduled_calls.send_to_discord(999999999, pending_scheduled_call, long_response)
# Verify the message was truncated # With the new implementation, send_discord_response sends the full response
# No truncation happens in _send_to_discord
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
assert args[0] == pending_scheduled_call.user_id assert args[0] == 999999999 # bot_id
message = args[2] message = args[2]
assert len(message) <= 1950 # Should be truncated assert message == long_response
assert message.endswith("... (response truncated)")
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") @patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call): def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
"""Test that normal length messages are not truncated.""" """Test that normal length messages are not truncated."""
normal_response = "This is a normal length response." normal_response = "This is a normal length response."
scheduled_calls.send_to_discord(pending_scheduled_call, normal_response) scheduled_calls.send_to_discord(999999999, pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
assert args[0] == pending_scheduled_call.user_id assert args[0] == 999999999 # bot_id
message = args[2] message = args[2]
assert not message.endswith("... (response truncated)") assert message == normal_response
assert "This is a normal length response." in message
@patch("memory.workers.tasks.scheduled_calls._send_to_discord") @patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize") @patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_success( def test_execute_scheduled_call_success(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
): ):
@ -189,12 +196,8 @@ def test_execute_scheduled_call_success(
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id) result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
# Verify LLM was called with correct parameters # Verify LLM was called
mock_llm_call.assert_called_once_with( mock_llm_call.assert_called_once()
prompt="What is the weather like today?",
model="anthropic/claude-3-5-sonnet-20241022",
system_prompt="You are a helpful assistant.",
)
# Verify result # Verify result
assert result["success"] is True assert result["success"] is True
@ -218,7 +221,7 @@ def test_execute_scheduled_call_not_found(db_session):
assert result == {"error": "Scheduled call not found"} assert result == {"error": "Scheduled call not found"}
@patch("memory.workers.tasks.scheduled_calls.llms.summarize") @patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_not_pending( def test_execute_scheduled_call_not_pending(
mock_llm_call, completed_scheduled_call, db_session mock_llm_call, completed_scheduled_call, db_session
): ):
@ -229,8 +232,8 @@ def test_execute_scheduled_call_not_pending(
mock_llm_call.assert_not_called() mock_llm_call.assert_not_called()
@patch("memory.workers.tasks.scheduled_calls._send_to_discord") @patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize") @patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_with_default_system_prompt( def test_execute_scheduled_call_with_default_system_prompt(
mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
): ):
@ -254,16 +257,12 @@ def test_execute_scheduled_call_with_default_system_prompt(
scheduled_calls.execute_scheduled_call(call.id) scheduled_calls.execute_scheduled_call(call.id)
# Verify default system prompt was used # Verify LLM was called
mock_llm_call.assert_called_once_with( mock_llm_call.assert_called_once()
prompt="Test prompt",
model="anthropic/claude-3-5-sonnet-20241022",
system_prompt=None, # The code uses system_prompt as-is, not a default
)
@patch("memory.workers.tasks.scheduled_calls._send_to_discord") @patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize") @patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_discord_error( def test_execute_scheduled_call_discord_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
): ):
@ -286,26 +285,27 @@ def test_execute_scheduled_call_discord_error(
assert pending_scheduled_call.data["discord_error"] == "Discord API error" assert pending_scheduled_call.data["discord_error"] == "Discord API error"
@patch("memory.workers.tasks.scheduled_calls._send_to_discord") @patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize") @patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_llm_error( def test_execute_scheduled_call_llm_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
): ):
"""Test execution when LLM call fails.""" """Test execution when LLM call fails."""
mock_llm_call.side_effect = Exception("LLM API error") mock_llm_call.side_effect = Exception("LLM API error")
# The safe_task_execution decorator should catch this # The execute_scheduled_call function catches the exception and returns an error response
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id) result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
assert result["status"] == "error" assert result["success"] is False
assert "LLM API error" in result["error"] assert "error" in result
assert "LLM call failed" in result["error"]
# Discord should not be called # Discord should not be called
mock_send_discord.assert_not_called() mock_send_discord.assert_not_called()
@patch("memory.workers.tasks.scheduled_calls._send_to_discord") @patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize") @patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_long_response_truncation( def test_execute_scheduled_call_long_response_truncation(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
): ):
@ -477,8 +477,8 @@ def test_run_scheduled_calls_timezone_handling(
mock_execute_delay.delay.assert_called_once_with(due_call.id) mock_execute_delay.delay.assert_called_once_with(due_call.id)
@patch("memory.workers.tasks.scheduled_calls._send_to_discord") @patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize") @patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_status_transition_pending_to_executing_to_completed( def test_status_transition_pending_to_executing_to_completed(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
): ):
@ -502,14 +502,14 @@ def test_status_transition_pending_to_executing_to_completed(
"has_discord_user,has_discord_channel,expected_method", "has_discord_user,has_discord_channel,expected_method",
[ [
(True, False, "send_dm"), (True, False, "send_dm"),
(False, True, "broadcast_message"), (False, True, "send_to_channel"),
(True, True, "send_dm"), # User takes precedence (True, True, "send_to_channel"), # Channel takes precedence in the implementation
], ],
) )
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") @patch("memory.discord.messages.discord.send_dm")
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message") @patch("memory.discord.messages.discord.send_to_channel")
def test_discord_destination_priority( def test_discord_destination_priority(
mock_broadcast, mock_send_to_channel,
mock_send_dm, mock_send_dm,
has_discord_user, has_discord_user,
has_discord_channel, has_discord_channel,
@ -535,50 +535,39 @@ def test_discord_destination_priority(
db_session.commit() db_session.commit()
response = "Test response" response = "Test response"
scheduled_calls.send_to_discord(call, response) scheduled_calls.send_to_discord(999999999, call, response)
if expected_method == "send_dm": if expected_method == "send_dm":
mock_send_dm.assert_called_once() mock_send_dm.assert_called_once()
mock_broadcast.assert_not_called() mock_send_to_channel.assert_not_called()
else: else:
mock_broadcast.assert_called_once() mock_send_to_channel.assert_called_once()
mock_send_dm.assert_not_called() mock_send_dm.assert_not_called()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topic,model,response,expected_in_message", "topic,model,response",
[ [
( (
"Weather Check", "Weather Check",
"anthropic/claude-3-5-sonnet-20241022", "anthropic/claude-3-5-sonnet-20241022",
"It's sunny!", "It's sunny!",
[
"**Topic:** Weather Check",
"**Model:** anthropic/claude-3-5-sonnet-20241022",
"**Response:** It's sunny!",
],
), ),
( (
"Test Topic", "Test Topic",
"gpt-4", "gpt-4",
"Hello world", "Hello world",
["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
), ),
( (
"Long Topic Name Here", "Long Topic Name Here",
"claude-2", "claude-2",
"Short", "Short",
[
"**Topic:** Long Topic Name Here",
"**Model:** claude-2",
"**Response:** Short",
],
), ),
], ],
) )
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") @patch("memory.discord.messages.discord.send_dm")
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message): def test_message_formatting(mock_send_dm, topic, model, response):
"""Test the Discord message formatting with different inputs.""" """Test that _send_to_discord sends the response as-is."""
# Create a mock scheduled call with a mock Discord user # Create a mock scheduled call with a mock Discord user
mock_discord_user = Mock() mock_discord_user = Mock()
mock_discord_user.username = "testuser" mock_discord_user.username = "testuser"
@ -590,16 +579,15 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
mock_call.discord_user = mock_discord_user mock_call.discord_user = mock_discord_user
mock_call.discord_channel = None mock_call.discord_channel = None
scheduled_calls.send_to_discord(mock_call, response) scheduled_calls.send_to_discord(999999999, mock_call, response)
# Get the actual message that was sent # Get the actual message that was sent
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
assert args[0] == mock_call.user_id assert args[0] == 999999999 # bot_id
actual_message = args[2] actual_message = args[2]
# Verify all expected parts are in the message # The new implementation sends the response as-is, without formatting
for expected_part in expected_in_message: assert actual_message == response
assert expected_part in actual_message
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -612,7 +600,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
("cancelled", False), ("cancelled", False),
], ],
) )
@patch("memory.workers.tasks.scheduled_calls.llms.summarize") @patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_status_check( def test_execute_scheduled_call_status_check(
mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
): ):

125
tools/deploy.sh Executable file
View File

@ -0,0 +1,125 @@
#!/bin/bash
set -e
REMOTE_HOST="memory"
REMOTE_DIR="/home/ec2-user/memory"
DEFAULT_BRANCH="master"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
usage() {
echo "Usage: $0 <command> [options]"
echo ""
echo "Commands:"
echo " sync Rsync local code to server"
echo " pull [branch] Git checkout and pull (default: master)"
echo " restart Restart docker services"
echo " deploy [branch] Pull + restart"
echo " run <command> Run command on server (with venv activated)"
exit 1
}
sync_code() {
echo -e "${GREEN}Syncing code to $REMOTE_HOST...${NC}"
rsync -avz --delete \
--exclude='__pycache__' \
--exclude='*.pyc' \
--exclude='*.pyo' \
--exclude='.git' \
--exclude='memory_files' \
--exclude='secrets' \
--exclude='Books' \
--exclude='clean_books' \
--exclude='.env' \
--exclude='venv' \
--exclude='.venv' \
--exclude='*.egg-info' \
--exclude='node_modules' \
--exclude='.DS_Store' \
--exclude='docker-compose.override.yml' \
--exclude='.pytest_cache' \
--exclude='.mypy_cache' \
--exclude='.ruff_cache' \
--exclude='htmlcov' \
--exclude='.coverage' \
--exclude='*.log' \
"$PROJECT_DIR/src" \
"$PROJECT_DIR/tests" \
"$PROJECT_DIR/tools" \
"$PROJECT_DIR/db" \
"$PROJECT_DIR/docker" \
"$PROJECT_DIR/frontend" \
"$PROJECT_DIR/requirements" \
"$PROJECT_DIR/setup.py" \
"$PROJECT_DIR/pyproject.toml" \
"$PROJECT_DIR/docker-compose.yaml" \
"$PROJECT_DIR/pytest.ini" \
"$REMOTE_HOST:$REMOTE_DIR/"
echo -e "${GREEN}Sync complete!${NC}"
}
git_pull() {
local branch="${1:-$DEFAULT_BRANCH}"
echo -e "${GREEN}Pulling branch '$branch' on $REMOTE_HOST...${NC}"
ssh "$REMOTE_HOST" "cd $REMOTE_DIR && \
git stash --quiet 2>/dev/null || true && \
git fetch origin && \
git checkout $branch && \
git pull origin $branch"
echo -e "${GREEN}Pull complete!${NC}"
}
restart_services() {
echo -e "${GREEN}Restarting services on $REMOTE_HOST...${NC}"
ssh "$REMOTE_HOST" "cd $REMOTE_DIR && docker compose up --build -d"
echo -e "${GREEN}Services restarted!${NC}"
}
deploy() {
local branch="${1:-$DEFAULT_BRANCH}"
git_pull "$branch"
restart_services
}
run_remote() {
if [ $# -eq 0 ]; then
echo -e "${RED}Error: No command specified${NC}"
exit 1
fi
ssh "$REMOTE_HOST" "cd $REMOTE_DIR && source venv/bin/activate && $*"
}
# Main
case "${1:-}" in
sync)
sync_code
;;
pull)
git_pull "${2:-}"
;;
restart)
restart_services
;;
deploy)
deploy "${2:-}"
;;
run)
shift
run_remote "$@"
;;
*)
usage
;;
esac