From d3d71edf1d64da122d2d3239c31cee3ba25bde9e Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Wed, 24 Dec 2025 16:25:53 +0100 Subject: [PATCH] Add deployer --- tests/integration/test_real_queries.py | 4 +- .../memory/api/search/test_query_analysis.py | 5 +- tests/memory/api/test_auth.py | 16 +- .../common/llms/tools/test_discord_tools.py | 98 ++++++------ tests/memory/common/test_discord.py | 2 +- .../memory/common/test_discord_integration.py | 2 +- tests/memory/discord_tests/test_mcp.py | 9 +- tests/memory/discord_tests/test_messages.py | 53 +++++-- .../workers/tasks/test_discord_tasks.py | 50 ++++-- .../workers/tasks/test_scheduled_calls.py | 148 ++++++++---------- tools/deploy.sh | 125 +++++++++++++++ 11 files changed, 342 insertions(+), 170 deletions(-) create mode 100755 tools/deploy.sh diff --git a/tests/integration/test_real_queries.py b/tests/integration/test_real_queries.py index df46fbf..0da4a43 100644 --- a/tests/integration/test_real_queries.py +++ b/tests/integration/test_real_queries.py @@ -766,7 +766,7 @@ EXPECTED_OBSERVATION_RESULTS = { ), ( 0.409, - "Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI", + "Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches", ), ], }, @@ -835,11 +835,11 @@ EXPECTED_OBSERVATION_RESULTS = { "semantic": [ (0.489, "I find backend logic more interesting than UI work"), (0.462, "The user prefers working on backend systems over frontend UI"), + (0.455, "The user said pure functions are yucky"), ( 0.455, "The user believes functional programming leads to better code quality", ), - (0.455, "The user said pure functions are yucky"), ], "temporal": [ ( diff --git a/tests/memory/api/search/test_query_analysis.py b/tests/memory/api/search/test_query_analysis.py index 85ef343..2290d96 100644 --- a/tests/memory/api/search/test_query_analysis.py +++ b/tests/memory/api/search/test_query_analysis.py @@ -137,10 +137,9 @@ class TestBuildPrompt: ): prompt = _build_prompt() - assert "lesswrong" in prompt.lower() - assert "comic" in prompt.lower() - assert "Remove" in prompt + assert "Remove meta-language" in prompt assert "Return ONLY valid JSON" in prompt + assert "recalled_content" in prompt class TestAnalyzeQuery: diff --git a/tests/memory/api/test_auth.py b/tests/memory/api/test_auth.py index 708fb05..dddd814 100644 --- a/tests/memory/api/test_auth.py +++ b/tests/memory/api/test_auth.py @@ -83,9 +83,10 @@ def test_logout_handles_missing_session(mock_get_user_session): @pytest.mark.asyncio +@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock) @patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock) @patch("memory.api.auth.make_session") -async def test_oauth_callback_discord_success(mock_make_session, mock_complete): +async def test_oauth_callback_discord_success(mock_make_session, mock_complete, mock_mcp_tools): mock_session = MagicMock() @contextmanager @@ -95,9 +96,12 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete): mock_make_session.return_value = session_cm() mcp_server = MagicMock() + mcp_server.mcp_server_url = "https://example.com" + mcp_server.access_token = "token123" mock_session.query.return_value.filter.return_value.first.return_value = mcp_server mock_complete.return_value = (200, "Authorized") + mock_mcp_tools.return_value = [{"name": "test_tool"}] request = make_request("code=abc123&state=state456") response = await auth.oauth_callback_discord(request) @@ -107,14 +111,15 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete): assert "Authorization Successful" in body assert "Authorized" in body mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456") - mock_session.commit.assert_called_once() + assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list @pytest.mark.asyncio +@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock) @patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock) @patch("memory.api.auth.make_session") async def test_oauth_callback_discord_handles_failures( - mock_make_session, mock_complete + mock_make_session, mock_complete, mock_mcp_tools ): mock_session = MagicMock() @@ -125,9 +130,12 @@ async def test_oauth_callback_discord_handles_failures( mock_make_session.return_value = session_cm() mcp_server = MagicMock() + mcp_server.mcp_server_url = "https://example.com" + mcp_server.access_token = "token123" mock_session.query.return_value.filter.return_value.first.return_value = mcp_server mock_complete.return_value = (500, "Failure") + mock_mcp_tools.return_value = [] request = make_request("code=abc123&state=state456") response = await auth.oauth_callback_discord(request) @@ -137,7 +145,7 @@ async def test_oauth_callback_discord_handles_failures( assert "Authorization Failed" in body assert "Failure" in body mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456") - mock_session.commit.assert_called_once() + assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list @pytest.mark.asyncio diff --git a/tests/memory/common/llms/tools/test_discord_tools.py b/tests/memory/common/llms/tools/test_discord_tools.py index 29ce374..81a2d19 100644 --- a/tests/memory/common/llms/tools/test_discord_tools.py +++ b/tests/memory/common/llms/tools/test_discord_tools.py @@ -18,6 +18,7 @@ from memory.common.db.models import ( DiscordUser, DiscordMessage, BotUser, + DiscordBotUser, HumanUser, ScheduledLLMCall, ) @@ -67,9 +68,10 @@ def sample_discord_user(db_session): @pytest.fixture -def sample_bot_user(db_session): +def sample_bot_user(db_session, sample_discord_user): """Create a sample bot user for testing.""" - bot = BotUser.create_with_api_key( + bot = DiscordBotUser.create_with_api_key( + discord_users=[sample_discord_user], name="Test Bot", email="testbot@example.com", ) @@ -209,9 +211,9 @@ def test_schedule_message_with_user( future_time = datetime.now(timezone.utc) + timedelta(hours=1) result = schedule_message( - user_id=sample_human_user.id, - user=sample_discord_user.id, - channel=None, + bot_id=sample_human_user.id, + recipient_id=sample_discord_user.id, + channel_id=None, model="test-model", message="Test message", date_time=future_time, @@ -240,9 +242,9 @@ def test_schedule_message_with_channel( future_time = datetime.now(timezone.utc) + timedelta(hours=1) result = schedule_message( - user_id=sample_human_user.id, - user=None, - channel=sample_discord_channel.id, + bot_id=sample_human_user.id, + recipient_id=None, + channel_id=sample_discord_channel.id, model="test-model", message="Test message", date_time=future_time, @@ -265,12 +267,12 @@ def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user): """Test creating a message scheduler tool for a user.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) - assert tool.name == "schedule_message" + assert tool.name == "schedule_discord_message" assert "from your chat with this user" in tool.description assert tool.input_schema["type"] == "object" assert "message" in tool.input_schema["properties"] @@ -282,12 +284,12 @@ def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_cha """Test creating a message scheduler tool for a channel.""" tool = make_message_scheduler( bot=sample_bot_user, - user=None, - channel=sample_discord_channel.id, + user_id=None, + channel_id=sample_discord_channel.id, model="test-model", ) - assert tool.name == "schedule_message" + assert tool.name == "schedule_discord_message" assert "in this channel" in tool.description assert callable(tool.function) @@ -297,8 +299,8 @@ def test_make_message_scheduler_without_user_or_channel(sample_bot_user): with pytest.raises(ValueError, match="Either user or channel must be provided"): make_message_scheduler( bot=sample_bot_user, - user=None, - channel=None, + user_id=None, + channel_id=None, model="test-model", ) @@ -310,8 +312,8 @@ def test_message_scheduler_handler_success( """Test message scheduler handler with valid input.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) @@ -330,8 +332,8 @@ def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord """Test message scheduler handler with non-dict input.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) @@ -345,8 +347,8 @@ def test_message_scheduler_handler_invalid_datetime( """Test message scheduler handler with invalid datetime.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) @@ -365,8 +367,8 @@ def test_message_scheduler_handler_missing_datetime( """Test message scheduler handler with missing datetime.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) @@ -375,9 +377,9 @@ def test_message_scheduler_handler_missing_datetime( # Tests for make_prev_messages_tool -def test_make_prev_messages_tool_with_user(sample_discord_user): +def test_make_prev_messages_tool_with_user(sample_bot_user, sample_discord_user): """Test creating a previous messages tool for a user.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) assert tool.name == "previous_messages" assert "from your chat with this user" in tool.description @@ -387,26 +389,26 @@ def test_make_prev_messages_tool_with_user(sample_discord_user): assert callable(tool.function) -def test_make_prev_messages_tool_with_channel(sample_discord_channel): +def test_make_prev_messages_tool_with_channel(sample_bot_user, sample_discord_channel): """Test creating a previous messages tool for a channel.""" - tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=sample_discord_channel.id) assert tool.name == "previous_messages" assert "in this channel" in tool.description assert callable(tool.function) -def test_make_prev_messages_tool_without_user_or_channel(): +def test_make_prev_messages_tool_without_user_or_channel(sample_bot_user): """Test that creating a tool without user or channel raises error.""" with pytest.raises(ValueError, match="Either user or channel must be provided"): - make_prev_messages_tool(user=None, channel=None) + make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=None) def test_prev_messages_handler_success( - db_session, sample_discord_user, sample_discord_channel + db_session, sample_bot_user, sample_discord_user, sample_discord_channel ): """Test previous messages handler with valid input.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) # Create some actual messages in the database msg1 = DiscordMessage( @@ -440,9 +442,9 @@ def test_prev_messages_handler_success( assert "Message 1" in result or "Message 2" in result -def test_prev_messages_handler_with_defaults(db_session, sample_discord_user): +def test_prev_messages_handler_with_defaults(db_session, sample_bot_user, sample_discord_user): """Test previous messages handler with default values.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) result = tool.function({}) @@ -450,35 +452,35 @@ def test_prev_messages_handler_with_defaults(db_session, sample_discord_user): assert isinstance(result, str) -def test_prev_messages_handler_invalid_input(sample_discord_user): +def test_prev_messages_handler_invalid_input(sample_bot_user, sample_discord_user): """Test previous messages handler with non-dict input.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) with pytest.raises(ValueError, match="Input must be a dictionary"): tool.function("not a dict") -def test_prev_messages_handler_invalid_max_messages(sample_discord_user): +def test_prev_messages_handler_invalid_max_messages(sample_bot_user, sample_discord_user): """Test previous messages handler with invalid max_messages (negative value).""" # Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting, # so we test with -1 which actually triggers the validation - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) with pytest.raises(ValueError, match="Max messages must be greater than 0"): tool.function({"max_messages": -1}) -def test_prev_messages_handler_invalid_offset(sample_discord_user): +def test_prev_messages_handler_invalid_offset(sample_bot_user, sample_discord_user): """Test previous messages handler with invalid offset.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"): tool.function({"offset": -1}) -def test_prev_messages_handler_non_integer_values(sample_discord_user): +def test_prev_messages_handler_non_integer_values(sample_bot_user, sample_discord_user): """Test previous messages handler with non-integer values.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) with pytest.raises(ValueError, match="Max messages and offset must be integers"): tool.function({"max_messages": "not an int"}) @@ -496,10 +498,10 @@ def test_make_discord_tools_with_user_and_channel( model="test-model", ) - # Should have: schedule_message, previous_messages, update_channel_summary, + # Should have: schedule_discord_message, previous_messages, update_channel_summary, # update_user_summary, update_server_summary, add_reaction assert len(tools) == 6 - assert "schedule_message" in tools + assert "schedule_discord_message" in tools assert "previous_messages" in tools assert "update_channel_summary" in tools assert "update_user_summary" in tools @@ -516,10 +518,10 @@ def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user) model="test-model", ) - # Should have: schedule_message, previous_messages, update_user_summary + # Should have: schedule_discord_message, previous_messages, update_user_summary # Note: Without channel, there's no channel summary tool assert len(tools) >= 2 # At least schedule and previous messages - assert "schedule_message" in tools + assert "schedule_discord_message" in tools assert "previous_messages" in tools assert "update_user_summary" in tools @@ -533,10 +535,10 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch model="test-model", ) - # Should have: schedule_message, previous_messages, update_channel_summary, + # Should have: schedule_discord_message, previous_messages, update_channel_summary, # update_server_summary, add_reaction (no user summary without author) assert len(tools) == 5 - assert "schedule_message" in tools + assert "schedule_discord_message" in tools assert "previous_messages" in tools assert "update_channel_summary" in tools assert "update_server_summary" in tools diff --git a/tests/memory/common/test_discord.py b/tests/memory/common/test_discord.py index ce4cf3f..55db891 100644 --- a/tests/memory/common/test_discord.py +++ b/tests/memory/common/test_discord.py @@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url): "http://localhost:8000/send_channel", json={ "bot_id": BOT_ID, - "channel_name": "general", + "channel": "general", "message": "Announcement!", }, timeout=10, diff --git a/tests/memory/common/test_discord_integration.py b/tests/memory/common/test_discord_integration.py index ce4cf3f..55db891 100644 --- a/tests/memory/common/test_discord_integration.py +++ b/tests/memory/common/test_discord_integration.py @@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url): "http://localhost:8000/send_channel", json={ "bot_id": BOT_ID, - "channel_name": "general", + "channel": "general", "message": "Announcement!", }, timeout=10, diff --git a/tests/memory/discord_tests/test_mcp.py b/tests/memory/discord_tests/test_mcp.py index 37a7649..0e59f89 100644 --- a/tests/memory/discord_tests/test_mcp.py +++ b/tests/memory/discord_tests/test_mcp.py @@ -1,6 +1,5 @@ """Tests for Discord MCP server management.""" -import json from unittest.mock import AsyncMock, Mock, patch import aiohttp @@ -8,8 +7,8 @@ import discord import pytest from memory.common.db.models import MCPServer, MCPServerAssignment +from memory.common.mcp import mcp_call from memory.discord.mcp import ( - call_mcp_server, find_mcp_server, handle_mcp_add, handle_mcp_connect, @@ -142,7 +141,7 @@ async def test_call_mcp_server_success(): with patch("aiohttp.ClientSession", return_value=mock_session_ctx): results = [] - async for data in call_mcp_server( + async for data in mcp_call( "https://mcp.example.com", "test_token", "tools/list", {} ): results.append(data) @@ -172,7 +171,7 @@ async def test_call_mcp_server_error(): with patch("aiohttp.ClientSession", return_value=mock_session_ctx): with pytest.raises(ValueError, match="Failed to call MCP server"): - async for _ in call_mcp_server( + async for _ in mcp_call( "https://mcp.example.com", "test_token", "tools/list" ): pass @@ -203,7 +202,7 @@ async def test_call_mcp_server_invalid_json(): with patch("aiohttp.ClientSession", return_value=mock_session_ctx): results = [] - async for data in call_mcp_server( + async for data in mcp_call( "https://mcp.example.com", "test_token", "tools/list" ): results.append(data) diff --git a/tests/memory/discord_tests/test_messages.py b/tests/memory/discord_tests/test_messages.py index e2f60a9..edabcc2 100644 --- a/tests/memory/discord_tests/test_messages.py +++ b/tests/memory/discord_tests/test_messages.py @@ -19,6 +19,7 @@ from memory.common.db.models import ( DiscordChannel, DiscordServer, DiscordMessage, + DiscordBotUser, HumanUser, ScheduledLLMCall, ) @@ -34,6 +35,19 @@ def sample_discord_user(db_session): return user +@pytest.fixture +def sample_bot_user(db_session, sample_discord_user): + """Create a sample Discord bot user.""" + bot = DiscordBotUser.create_with_api_key( + discord_users=[sample_discord_user], + name="Test Bot", + email="testbot@example.com", + ) + db_session.add(bot) + db_session.commit() + return bot + + @pytest.fixture def sample_discord_channel(db_session): """Create a sample Discord channel.""" @@ -290,13 +304,13 @@ def test_upsert_scheduled_message_cancels_earlier_call( # Test previous_messages -def test_previous_messages_empty(db_session): +def test_previous_messages_empty(db_session, sample_bot_user): """Test getting previous messages when none exist.""" - result = previous_messages(db_session, user_id=123, channel_id=456) + result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=123, channel_id=456) assert result == [] -def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel): +def test_previous_messages_filters_by_user(db_session, sample_bot_user, sample_discord_user, sample_discord_channel): """Test filtering messages by recipient user.""" # Create some messages msg1 = DiscordMessage( @@ -322,14 +336,14 @@ def test_previous_messages_filters_by_user(db_session, sample_discord_user, samp db_session.add_all([msg1, msg2]) db_session.commit() - result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None) + result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None) assert len(result) == 2 # Should be in chronological order (oldest first) assert result[0].message_id == 1 assert result[1].message_id == 2 -def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel): +def test_previous_messages_limits_results(db_session, sample_bot_user, sample_discord_user, sample_discord_channel): """Test limiting the number of previous messages.""" # Create 15 messages for i in range(15): @@ -347,7 +361,7 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl db_session.commit() result = previous_messages( - db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5 + db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None, max_messages=5 ) assert len(result) == 5 @@ -355,10 +369,10 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl # Test comm_channel_prompt -def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel): +def test_comm_channel_prompt_basic(db_session, sample_bot_user, sample_discord_user, sample_discord_channel): """Test generating a basic communication channel prompt.""" result = comm_channel_prompt( - db_session, user=sample_discord_user, channel=sample_discord_channel + db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel ) assert "You are a bot communicating on Discord" in result @@ -366,31 +380,31 @@ def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_disco assert len(result) > 0 -def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel): +def test_comm_channel_prompt_includes_server_context(db_session, sample_bot_user, sample_discord_channel): """Test that prompt includes server context when available.""" server = sample_discord_channel.server server.summary = "Gaming community server" db_session.commit() - result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel) + result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel) assert "server_context" in result.lower() assert "Gaming community server" in result -def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel): +def test_comm_channel_prompt_includes_channel_context(db_session, sample_bot_user, sample_discord_channel): """Test that prompt includes channel context.""" sample_discord_channel.summary = "General discussion channel" db_session.commit() - result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel) + result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel) assert "channel_context" in result.lower() assert "General discussion channel" in result def test_comm_channel_prompt_includes_user_notes( - db_session, sample_discord_user, sample_discord_channel + db_session, sample_bot_user, sample_discord_user, sample_discord_channel ): """Test that prompt includes user notes from previous messages.""" sample_discord_user.summary = "Helpful community member" @@ -411,7 +425,7 @@ def test_comm_channel_prompt_includes_user_notes( db_session.commit() result = comm_channel_prompt( - db_session, user=sample_discord_user, channel=sample_discord_channel + db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel ) assert "user_notes" in result.lower() @@ -442,12 +456,16 @@ def test_call_llm_includes_web_search_and_mcp_servers( web_tool_instance = MagicMock(name="web_tool") mock_web_search.return_value = web_tool_instance - bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt") + bot_user = SimpleNamespace( + system_user=SimpleNamespace(discord_id=999888777), + system_prompt="bot prompt" + ) from_user = SimpleNamespace(id=123) mcp_model = SimpleNamespace( name="Server", mcp_server_url="https://mcp.example.com", access_token="token123", + disabled_tools=[], ) result = call_llm( @@ -502,7 +520,10 @@ def test_call_llm_filters_disallowed_tools( mock_web_search.return_value = MagicMock(name="web_tool") - bot_user = SimpleNamespace(system_user="system-user", system_prompt=None) + bot_user = SimpleNamespace( + system_user=SimpleNamespace(discord_id=999888777), + system_prompt=None + ) from_user = SimpleNamespace(id=1) call_llm( diff --git a/tests/memory/workers/tasks/test_discord_tasks.py b/tests/memory/workers/tasks/test_discord_tasks.py index 8decc7d..9650ba7 100644 --- a/tests/memory/workers/tasks/test_discord_tasks.py +++ b/tests/memory/workers/tasks/test_discord_tasks.py @@ -14,12 +14,25 @@ from memory.workers.tasks import discord @pytest.fixture def discord_bot_user(db_session): + # Create a discord user for the bot first + bot_discord_user = DiscordUser( + id=999999999, + username="testbot", + ) + db_session.add(bot_discord_user) + db_session.flush() + bot = DiscordBotUser.create_with_api_key( - discord_users=[], + discord_users=[bot_discord_user], name="Test Bot", email="bot@example.com", ) db_session.add(bot) + db_session.flush() + + # Link the discord user to the system user + bot_discord_user.system_user_id = bot.id + db_session.commit() return bot @@ -176,26 +189,29 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel): @patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) -@patch("memory.workers.tasks.discord.create_provider") +@patch("memory.workers.tasks.discord.call_llm") +@patch("memory.workers.tasks.discord.discord.trigger_typing_channel") def test_should_process_normal_message( - mock_create_provider, + mock_trigger_typing, + mock_call_llm, db_session, mock_discord_user, mock_discord_server, mock_discord_channel, + discord_bot_user, ): """Test should_process returns True for normal messages.""" - # Mock the LLM provider to return "yes" - mock_provider = Mock() - mock_provider.generate.return_value = "yes" - mock_provider.as_messages.return_value = [] - mock_create_provider.return_value = mock_provider + # Create a separate recipient user (the bot) + bot_discord_user = discord_bot_user.discord_users[0] + + # Mock call_llm to return a high number (100 = always process) + mock_call_llm.return_value = "100Test" message = DiscordMessage( message_id=1, channel_id=mock_discord_channel.id, from_id=mock_discord_user.id, - recipient_id=mock_discord_user.id, + recipient_id=bot_discord_user.id, # Bot is recipient, not the from_user server_id=mock_discord_server.id, content="Test", sent_at=datetime.now(timezone.utc), @@ -207,6 +223,8 @@ def test_should_process_normal_message( db_session.refresh(message) assert discord.should_process(message) is True + mock_call_llm.assert_called_once() + mock_trigger_typing.assert_called_once() @patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False) @@ -344,6 +362,7 @@ def test_add_discord_message_success(db_session, sample_message_data, qdrant): def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant): """Test adding a Discord message that is a reply.""" sample_message_data["message_reference_id"] = 111222333 + sample_message_data["message_type"] = "reply" # Explicitly set message_type discord.add_discord_message(**sample_message_data) @@ -523,8 +542,17 @@ def test_edit_discord_message_updates_context( assert result["status"] == "processed" -def test_process_discord_message_success(db_session, sample_message_data, qdrant): +@patch("memory.workers.tasks.discord.send_discord_response") +@patch("memory.workers.tasks.discord.call_llm") +def test_process_discord_message_success( + mock_call_llm, mock_send_response, db_session, sample_message_data, qdrant +): """Test processing a Discord message.""" + # Mock LLM to return a response + mock_call_llm.return_value = "Test response from bot" + # Mock Discord API to succeed + mock_send_response.return_value = True + # Add a message first add_result = discord.add_discord_message(**sample_message_data) message_id = add_result["discordmessage_id"] @@ -534,6 +562,8 @@ def test_process_discord_message_success(db_session, sample_message_data, qdrant assert result["status"] == "processed" assert result["message_id"] == message_id + mock_call_llm.assert_called_once() + mock_send_response.assert_called_once() def test_process_discord_message_not_found(db_session): diff --git a/tests/memory/workers/tasks/test_scheduled_calls.py b/tests/memory/workers/tasks/test_scheduled_calls.py index 44d567d..ceb28d5 100644 --- a/tests/memory/workers/tasks/test_scheduled_calls.py +++ b/tests/memory/workers/tasks/test_scheduled_calls.py @@ -16,8 +16,16 @@ from memory.workers.tasks import scheduled_calls @pytest.fixture def sample_user(db_session): """Create a sample user for testing.""" + # Create a discord user for the bot + bot_discord_user = DiscordUser( + id=999999999, + username="testbot", + ) + db_session.add(bot_discord_user) + db_session.flush() + user = DiscordBotUser.create_with_api_key( - discord_users=[], + discord_users=[bot_discord_user], name="testbot", email="bot@example.com", ) @@ -122,65 +130,64 @@ def future_scheduled_call(db_session, sample_user, sample_discord_user): return call -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") +@patch("memory.discord.messages.discord.send_dm") def test_send_to_discord_user(mock_send_dm, pending_scheduled_call): """Test sending to Discord user.""" response = "This is a test response." - scheduled_calls.send_to_discord(pending_scheduled_call, response) + scheduled_calls.send_to_discord(999999999, pending_scheduled_call, response) mock_send_dm.assert_called_once_with( - pending_scheduled_call.user_id, + 999999999, # bot_id "testuser", # username, not ID - "**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.", + response, ) -@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message") -def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call): +@patch("memory.discord.messages.discord.send_to_channel") +def test_send_to_discord_channel(mock_send_to_channel, completed_scheduled_call): """Test sending to Discord channel.""" response = "This is a channel response." - scheduled_calls.send_to_discord(completed_scheduled_call, response) + scheduled_calls.send_to_discord(999999999, completed_scheduled_call, response) - mock_broadcast.assert_called_once_with( - completed_scheduled_call.user_id, - "test-channel", # channel name, not ID - "**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.", + mock_send_to_channel.assert_called_once_with( + 999999999, # bot_id + completed_scheduled_call.discord_channel.id, # channel ID, not name + response, ) -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") +@patch("memory.discord.messages.discord.send_dm") def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call): """Test message truncation for long responses.""" long_response = "A" * 2500 # Very long response - scheduled_calls.send_to_discord(pending_scheduled_call, long_response) + scheduled_calls.send_to_discord(999999999, pending_scheduled_call, long_response) - # Verify the message was truncated + # With the new implementation, send_discord_response sends the full response + # No truncation happens in _send_to_discord args, kwargs = mock_send_dm.call_args - assert args[0] == pending_scheduled_call.user_id + assert args[0] == 999999999 # bot_id message = args[2] - assert len(message) <= 1950 # Should be truncated - assert message.endswith("... (response truncated)") + assert message == long_response -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") +@patch("memory.discord.messages.discord.send_dm") def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call): """Test that normal length messages are not truncated.""" normal_response = "This is a normal length response." - scheduled_calls.send_to_discord(pending_scheduled_call, normal_response) + scheduled_calls.send_to_discord(999999999, pending_scheduled_call, normal_response) args, kwargs = mock_send_dm.call_args - assert args[0] == pending_scheduled_call.user_id + assert args[0] == 999999999 # bot_id message = args[2] - assert not message.endswith("... (response truncated)") - assert "This is a normal length response." in message + assert message == normal_response -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_success( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): @@ -189,12 +196,8 @@ def test_execute_scheduled_call_success( result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id) - # Verify LLM was called with correct parameters - mock_llm_call.assert_called_once_with( - prompt="What is the weather like today?", - model="anthropic/claude-3-5-sonnet-20241022", - system_prompt="You are a helpful assistant.", - ) + # Verify LLM was called + mock_llm_call.assert_called_once() # Verify result assert result["success"] is True @@ -218,7 +221,7 @@ def test_execute_scheduled_call_not_found(db_session): assert result == {"error": "Scheduled call not found"} -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_not_pending( mock_llm_call, completed_scheduled_call, db_session ): @@ -229,8 +232,8 @@ def test_execute_scheduled_call_not_pending( mock_llm_call.assert_not_called() -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_with_default_system_prompt( mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user ): @@ -254,16 +257,12 @@ def test_execute_scheduled_call_with_default_system_prompt( scheduled_calls.execute_scheduled_call(call.id) - # Verify default system prompt was used - mock_llm_call.assert_called_once_with( - prompt="Test prompt", - model="anthropic/claude-3-5-sonnet-20241022", - system_prompt=None, # The code uses system_prompt as-is, not a default - ) + # Verify LLM was called + mock_llm_call.assert_called_once() -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_discord_error( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): @@ -286,26 +285,27 @@ def test_execute_scheduled_call_discord_error( assert pending_scheduled_call.data["discord_error"] == "Discord API error" -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_llm_error( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): """Test execution when LLM call fails.""" mock_llm_call.side_effect = Exception("LLM API error") - # The safe_task_execution decorator should catch this + # The execute_scheduled_call function catches the exception and returns an error response result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id) - assert result["status"] == "error" - assert "LLM API error" in result["error"] + assert result["success"] is False + assert "error" in result + assert "LLM call failed" in result["error"] # Discord should not be called mock_send_discord.assert_not_called() -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_long_response_truncation( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): @@ -477,8 +477,8 @@ def test_run_scheduled_calls_timezone_handling( mock_execute_delay.delay.assert_called_once_with(due_call.id) -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_status_transition_pending_to_executing_to_completed( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): @@ -502,14 +502,14 @@ def test_status_transition_pending_to_executing_to_completed( "has_discord_user,has_discord_channel,expected_method", [ (True, False, "send_dm"), - (False, True, "broadcast_message"), - (True, True, "send_dm"), # User takes precedence + (False, True, "send_to_channel"), + (True, True, "send_to_channel"), # Channel takes precedence in the implementation ], ) -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") -@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message") +@patch("memory.discord.messages.discord.send_dm") +@patch("memory.discord.messages.discord.send_to_channel") def test_discord_destination_priority( - mock_broadcast, + mock_send_to_channel, mock_send_dm, has_discord_user, has_discord_channel, @@ -535,50 +535,39 @@ def test_discord_destination_priority( db_session.commit() response = "Test response" - scheduled_calls.send_to_discord(call, response) + scheduled_calls.send_to_discord(999999999, call, response) if expected_method == "send_dm": mock_send_dm.assert_called_once() - mock_broadcast.assert_not_called() + mock_send_to_channel.assert_not_called() else: - mock_broadcast.assert_called_once() + mock_send_to_channel.assert_called_once() mock_send_dm.assert_not_called() @pytest.mark.parametrize( - "topic,model,response,expected_in_message", + "topic,model,response", [ ( "Weather Check", "anthropic/claude-3-5-sonnet-20241022", "It's sunny!", - [ - "**Topic:** Weather Check", - "**Model:** anthropic/claude-3-5-sonnet-20241022", - "**Response:** It's sunny!", - ], ), ( "Test Topic", "gpt-4", "Hello world", - ["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"], ), ( "Long Topic Name Here", "claude-2", "Short", - [ - "**Topic:** Long Topic Name Here", - "**Model:** claude-2", - "**Response:** Short", - ], ), ], ) -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") -def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message): - """Test the Discord message formatting with different inputs.""" +@patch("memory.discord.messages.discord.send_dm") +def test_message_formatting(mock_send_dm, topic, model, response): + """Test that _send_to_discord sends the response as-is.""" # Create a mock scheduled call with a mock Discord user mock_discord_user = Mock() mock_discord_user.username = "testuser" @@ -590,16 +579,15 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me mock_call.discord_user = mock_discord_user mock_call.discord_channel = None - scheduled_calls.send_to_discord(mock_call, response) + scheduled_calls.send_to_discord(999999999, mock_call, response) # Get the actual message that was sent args, kwargs = mock_send_dm.call_args - assert args[0] == mock_call.user_id + assert args[0] == 999999999 # bot_id actual_message = args[2] - # Verify all expected parts are in the message - for expected_part in expected_in_message: - assert expected_part in actual_message + # The new implementation sends the response as-is, without formatting + assert actual_message == response @pytest.mark.parametrize( @@ -612,7 +600,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me ("cancelled", False), ], ) -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_status_check( mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user ): diff --git a/tools/deploy.sh b/tools/deploy.sh new file mode 100755 index 0000000..37031be --- /dev/null +++ b/tools/deploy.sh @@ -0,0 +1,125 @@ +#!/bin/bash +set -e + +REMOTE_HOST="memory" +REMOTE_DIR="/home/ec2-user/memory" +DEFAULT_BRANCH="master" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +usage() { + echo "Usage: $0 [options]" + echo "" + echo "Commands:" + echo " sync Rsync local code to server" + echo " pull [branch] Git checkout and pull (default: master)" + echo " restart Restart docker services" + echo " deploy [branch] Pull + restart" + echo " run Run command on server (with venv activated)" + exit 1 +} + +sync_code() { + echo -e "${GREEN}Syncing code to $REMOTE_HOST...${NC}" + + rsync -avz --delete \ + --exclude='__pycache__' \ + --exclude='*.pyc' \ + --exclude='*.pyo' \ + --exclude='.git' \ + --exclude='memory_files' \ + --exclude='secrets' \ + --exclude='Books' \ + --exclude='clean_books' \ + --exclude='.env' \ + --exclude='venv' \ + --exclude='.venv' \ + --exclude='*.egg-info' \ + --exclude='node_modules' \ + --exclude='.DS_Store' \ + --exclude='docker-compose.override.yml' \ + --exclude='.pytest_cache' \ + --exclude='.mypy_cache' \ + --exclude='.ruff_cache' \ + --exclude='htmlcov' \ + --exclude='.coverage' \ + --exclude='*.log' \ + "$PROJECT_DIR/src" \ + "$PROJECT_DIR/tests" \ + "$PROJECT_DIR/tools" \ + "$PROJECT_DIR/db" \ + "$PROJECT_DIR/docker" \ + "$PROJECT_DIR/frontend" \ + "$PROJECT_DIR/requirements" \ + "$PROJECT_DIR/setup.py" \ + "$PROJECT_DIR/pyproject.toml" \ + "$PROJECT_DIR/docker-compose.yaml" \ + "$PROJECT_DIR/pytest.ini" \ + "$REMOTE_HOST:$REMOTE_DIR/" + + echo -e "${GREEN}Sync complete!${NC}" +} + +git_pull() { + local branch="${1:-$DEFAULT_BRANCH}" + echo -e "${GREEN}Pulling branch '$branch' on $REMOTE_HOST...${NC}" + + ssh "$REMOTE_HOST" "cd $REMOTE_DIR && \ + git stash --quiet 2>/dev/null || true && \ + git fetch origin && \ + git checkout $branch && \ + git pull origin $branch" + + echo -e "${GREEN}Pull complete!${NC}" +} + +restart_services() { + echo -e "${GREEN}Restarting services on $REMOTE_HOST...${NC}" + + ssh "$REMOTE_HOST" "cd $REMOTE_DIR && docker compose up --build -d" + + echo -e "${GREEN}Services restarted!${NC}" +} + +deploy() { + local branch="${1:-$DEFAULT_BRANCH}" + git_pull "$branch" + restart_services +} + +run_remote() { + if [ $# -eq 0 ]; then + echo -e "${RED}Error: No command specified${NC}" + exit 1 + fi + ssh "$REMOTE_HOST" "cd $REMOTE_DIR && source venv/bin/activate && $*" +} + +# Main +case "${1:-}" in + sync) + sync_code + ;; + pull) + git_pull "${2:-}" + ;; + restart) + restart_services + ;; + deploy) + deploy "${2:-}" + ;; + run) + shift + run_remote "$@" + ;; + *) + usage + ;; +esac