diff --git a/tests/integration/test_real_queries.py b/tests/integration/test_real_queries.py
index df46fbf..0da4a43 100644
--- a/tests/integration/test_real_queries.py
+++ b/tests/integration/test_real_queries.py
@@ -766,7 +766,7 @@ EXPECTED_OBSERVATION_RESULTS = {
),
(
0.409,
- "Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
+ "Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
),
],
},
@@ -835,11 +835,11 @@ EXPECTED_OBSERVATION_RESULTS = {
"semantic": [
(0.489, "I find backend logic more interesting than UI work"),
(0.462, "The user prefers working on backend systems over frontend UI"),
+ (0.455, "The user said pure functions are yucky"),
(
0.455,
"The user believes functional programming leads to better code quality",
),
- (0.455, "The user said pure functions are yucky"),
],
"temporal": [
(
diff --git a/tests/memory/api/search/test_query_analysis.py b/tests/memory/api/search/test_query_analysis.py
index 85ef343..2290d96 100644
--- a/tests/memory/api/search/test_query_analysis.py
+++ b/tests/memory/api/search/test_query_analysis.py
@@ -137,10 +137,9 @@ class TestBuildPrompt:
):
prompt = _build_prompt()
- assert "lesswrong" in prompt.lower()
- assert "comic" in prompt.lower()
- assert "Remove" in prompt
+ assert "Remove meta-language" in prompt
assert "Return ONLY valid JSON" in prompt
+ assert "recalled_content" in prompt
class TestAnalyzeQuery:
diff --git a/tests/memory/api/test_auth.py b/tests/memory/api/test_auth.py
index 708fb05..dddd814 100644
--- a/tests/memory/api/test_auth.py
+++ b/tests/memory/api/test_auth.py
@@ -83,9 +83,10 @@ def test_logout_handles_missing_session(mock_get_user_session):
@pytest.mark.asyncio
+@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session")
-async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
+async def test_oauth_callback_discord_success(mock_make_session, mock_complete, mock_mcp_tools):
mock_session = MagicMock()
@contextmanager
@@ -95,9 +96,12 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
mock_make_session.return_value = session_cm()
mcp_server = MagicMock()
+ mcp_server.mcp_server_url = "https://example.com"
+ mcp_server.access_token = "token123"
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (200, "Authorized")
+ mock_mcp_tools.return_value = [{"name": "test_tool"}]
request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request)
@@ -107,14 +111,15 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
assert "Authorization Successful" in body
assert "Authorized" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
- mock_session.commit.assert_called_once()
+ assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
@pytest.mark.asyncio
+@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session")
async def test_oauth_callback_discord_handles_failures(
- mock_make_session, mock_complete
+ mock_make_session, mock_complete, mock_mcp_tools
):
mock_session = MagicMock()
@@ -125,9 +130,12 @@ async def test_oauth_callback_discord_handles_failures(
mock_make_session.return_value = session_cm()
mcp_server = MagicMock()
+ mcp_server.mcp_server_url = "https://example.com"
+ mcp_server.access_token = "token123"
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (500, "Failure")
+ mock_mcp_tools.return_value = []
request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request)
@@ -137,7 +145,7 @@ async def test_oauth_callback_discord_handles_failures(
assert "Authorization Failed" in body
assert "Failure" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
- mock_session.commit.assert_called_once()
+ assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
@pytest.mark.asyncio
diff --git a/tests/memory/common/llms/tools/test_discord_tools.py b/tests/memory/common/llms/tools/test_discord_tools.py
index 29ce374..81a2d19 100644
--- a/tests/memory/common/llms/tools/test_discord_tools.py
+++ b/tests/memory/common/llms/tools/test_discord_tools.py
@@ -18,6 +18,7 @@ from memory.common.db.models import (
DiscordUser,
DiscordMessage,
BotUser,
+ DiscordBotUser,
HumanUser,
ScheduledLLMCall,
)
@@ -67,9 +68,10 @@ def sample_discord_user(db_session):
@pytest.fixture
-def sample_bot_user(db_session):
+def sample_bot_user(db_session, sample_discord_user):
"""Create a sample bot user for testing."""
- bot = BotUser.create_with_api_key(
+ bot = DiscordBotUser.create_with_api_key(
+ discord_users=[sample_discord_user],
name="Test Bot",
email="testbot@example.com",
)
@@ -209,9 +211,9 @@ def test_schedule_message_with_user(
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
result = schedule_message(
- user_id=sample_human_user.id,
- user=sample_discord_user.id,
- channel=None,
+ bot_id=sample_human_user.id,
+ recipient_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
message="Test message",
date_time=future_time,
@@ -240,9 +242,9 @@ def test_schedule_message_with_channel(
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
result = schedule_message(
- user_id=sample_human_user.id,
- user=None,
- channel=sample_discord_channel.id,
+ bot_id=sample_human_user.id,
+ recipient_id=None,
+ channel_id=sample_discord_channel.id,
model="test-model",
message="Test message",
date_time=future_time,
@@ -265,12 +267,12 @@ def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user):
"""Test creating a message scheduler tool for a user."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
- assert tool.name == "schedule_message"
+ assert tool.name == "schedule_discord_message"
assert "from your chat with this user" in tool.description
assert tool.input_schema["type"] == "object"
assert "message" in tool.input_schema["properties"]
@@ -282,12 +284,12 @@ def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_cha
"""Test creating a message scheduler tool for a channel."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=None,
- channel=sample_discord_channel.id,
+ user_id=None,
+ channel_id=sample_discord_channel.id,
model="test-model",
)
- assert tool.name == "schedule_message"
+ assert tool.name == "schedule_discord_message"
assert "in this channel" in tool.description
assert callable(tool.function)
@@ -297,8 +299,8 @@ def test_make_message_scheduler_without_user_or_channel(sample_bot_user):
with pytest.raises(ValueError, match="Either user or channel must be provided"):
make_message_scheduler(
bot=sample_bot_user,
- user=None,
- channel=None,
+ user_id=None,
+ channel_id=None,
model="test-model",
)
@@ -310,8 +312,8 @@ def test_message_scheduler_handler_success(
"""Test message scheduler handler with valid input."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
@@ -330,8 +332,8 @@ def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord
"""Test message scheduler handler with non-dict input."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
@@ -345,8 +347,8 @@ def test_message_scheduler_handler_invalid_datetime(
"""Test message scheduler handler with invalid datetime."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
@@ -365,8 +367,8 @@ def test_message_scheduler_handler_missing_datetime(
"""Test message scheduler handler with missing datetime."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
@@ -375,9 +377,9 @@ def test_message_scheduler_handler_missing_datetime(
# Tests for make_prev_messages_tool
-def test_make_prev_messages_tool_with_user(sample_discord_user):
+def test_make_prev_messages_tool_with_user(sample_bot_user, sample_discord_user):
"""Test creating a previous messages tool for a user."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
assert tool.name == "previous_messages"
assert "from your chat with this user" in tool.description
@@ -387,26 +389,26 @@ def test_make_prev_messages_tool_with_user(sample_discord_user):
assert callable(tool.function)
-def test_make_prev_messages_tool_with_channel(sample_discord_channel):
+def test_make_prev_messages_tool_with_channel(sample_bot_user, sample_discord_channel):
"""Test creating a previous messages tool for a channel."""
- tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=sample_discord_channel.id)
assert tool.name == "previous_messages"
assert "in this channel" in tool.description
assert callable(tool.function)
-def test_make_prev_messages_tool_without_user_or_channel():
+def test_make_prev_messages_tool_without_user_or_channel(sample_bot_user):
"""Test that creating a tool without user or channel raises error."""
with pytest.raises(ValueError, match="Either user or channel must be provided"):
- make_prev_messages_tool(user=None, channel=None)
+ make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=None)
def test_prev_messages_handler_success(
- db_session, sample_discord_user, sample_discord_channel
+ db_session, sample_bot_user, sample_discord_user, sample_discord_channel
):
"""Test previous messages handler with valid input."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
# Create some actual messages in the database
msg1 = DiscordMessage(
@@ -440,9 +442,9 @@ def test_prev_messages_handler_success(
assert "Message 1" in result or "Message 2" in result
-def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
+def test_prev_messages_handler_with_defaults(db_session, sample_bot_user, sample_discord_user):
"""Test previous messages handler with default values."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
result = tool.function({})
@@ -450,35 +452,35 @@ def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
assert isinstance(result, str)
-def test_prev_messages_handler_invalid_input(sample_discord_user):
+def test_prev_messages_handler_invalid_input(sample_bot_user, sample_discord_user):
"""Test previous messages handler with non-dict input."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Input must be a dictionary"):
tool.function("not a dict")
-def test_prev_messages_handler_invalid_max_messages(sample_discord_user):
+def test_prev_messages_handler_invalid_max_messages(sample_bot_user, sample_discord_user):
"""Test previous messages handler with invalid max_messages (negative value)."""
# Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
# so we test with -1 which actually triggers the validation
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Max messages must be greater than 0"):
tool.function({"max_messages": -1})
-def test_prev_messages_handler_invalid_offset(sample_discord_user):
+def test_prev_messages_handler_invalid_offset(sample_bot_user, sample_discord_user):
"""Test previous messages handler with invalid offset."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
tool.function({"offset": -1})
-def test_prev_messages_handler_non_integer_values(sample_discord_user):
+def test_prev_messages_handler_non_integer_values(sample_bot_user, sample_discord_user):
"""Test previous messages handler with non-integer values."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Max messages and offset must be integers"):
tool.function({"max_messages": "not an int"})
@@ -496,10 +498,10 @@ def test_make_discord_tools_with_user_and_channel(
model="test-model",
)
- # Should have: schedule_message, previous_messages, update_channel_summary,
+ # Should have: schedule_discord_message, previous_messages, update_channel_summary,
# update_user_summary, update_server_summary, add_reaction
assert len(tools) == 6
- assert "schedule_message" in tools
+ assert "schedule_discord_message" in tools
assert "previous_messages" in tools
assert "update_channel_summary" in tools
assert "update_user_summary" in tools
@@ -516,10 +518,10 @@ def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user)
model="test-model",
)
- # Should have: schedule_message, previous_messages, update_user_summary
+ # Should have: schedule_discord_message, previous_messages, update_user_summary
# Note: Without channel, there's no channel summary tool
assert len(tools) >= 2 # At least schedule and previous messages
- assert "schedule_message" in tools
+ assert "schedule_discord_message" in tools
assert "previous_messages" in tools
assert "update_user_summary" in tools
@@ -533,10 +535,10 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
model="test-model",
)
- # Should have: schedule_message, previous_messages, update_channel_summary,
+ # Should have: schedule_discord_message, previous_messages, update_channel_summary,
# update_server_summary, add_reaction (no user summary without author)
assert len(tools) == 5
- assert "schedule_message" in tools
+ assert "schedule_discord_message" in tools
assert "previous_messages" in tools
assert "update_channel_summary" in tools
assert "update_server_summary" in tools
diff --git a/tests/memory/common/test_discord.py b/tests/memory/common/test_discord.py
index ce4cf3f..55db891 100644
--- a/tests/memory/common/test_discord.py
+++ b/tests/memory/common/test_discord.py
@@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
"http://localhost:8000/send_channel",
json={
"bot_id": BOT_ID,
- "channel_name": "general",
+ "channel": "general",
"message": "Announcement!",
},
timeout=10,
diff --git a/tests/memory/common/test_discord_integration.py b/tests/memory/common/test_discord_integration.py
index ce4cf3f..55db891 100644
--- a/tests/memory/common/test_discord_integration.py
+++ b/tests/memory/common/test_discord_integration.py
@@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
"http://localhost:8000/send_channel",
json={
"bot_id": BOT_ID,
- "channel_name": "general",
+ "channel": "general",
"message": "Announcement!",
},
timeout=10,
diff --git a/tests/memory/discord_tests/test_mcp.py b/tests/memory/discord_tests/test_mcp.py
index 37a7649..0e59f89 100644
--- a/tests/memory/discord_tests/test_mcp.py
+++ b/tests/memory/discord_tests/test_mcp.py
@@ -1,6 +1,5 @@
"""Tests for Discord MCP server management."""
-import json
from unittest.mock import AsyncMock, Mock, patch
import aiohttp
@@ -8,8 +7,8 @@ import discord
import pytest
from memory.common.db.models import MCPServer, MCPServerAssignment
+from memory.common.mcp import mcp_call
from memory.discord.mcp import (
- call_mcp_server,
find_mcp_server,
handle_mcp_add,
handle_mcp_connect,
@@ -142,7 +141,7 @@ async def test_call_mcp_server_success():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = []
- async for data in call_mcp_server(
+ async for data in mcp_call(
"https://mcp.example.com", "test_token", "tools/list", {}
):
results.append(data)
@@ -172,7 +171,7 @@ async def test_call_mcp_server_error():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
with pytest.raises(ValueError, match="Failed to call MCP server"):
- async for _ in call_mcp_server(
+ async for _ in mcp_call(
"https://mcp.example.com", "test_token", "tools/list"
):
pass
@@ -203,7 +202,7 @@ async def test_call_mcp_server_invalid_json():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = []
- async for data in call_mcp_server(
+ async for data in mcp_call(
"https://mcp.example.com", "test_token", "tools/list"
):
results.append(data)
diff --git a/tests/memory/discord_tests/test_messages.py b/tests/memory/discord_tests/test_messages.py
index e2f60a9..edabcc2 100644
--- a/tests/memory/discord_tests/test_messages.py
+++ b/tests/memory/discord_tests/test_messages.py
@@ -19,6 +19,7 @@ from memory.common.db.models import (
DiscordChannel,
DiscordServer,
DiscordMessage,
+ DiscordBotUser,
HumanUser,
ScheduledLLMCall,
)
@@ -34,6 +35,19 @@ def sample_discord_user(db_session):
return user
+@pytest.fixture
+def sample_bot_user(db_session, sample_discord_user):
+ """Create a sample Discord bot user."""
+ bot = DiscordBotUser.create_with_api_key(
+ discord_users=[sample_discord_user],
+ name="Test Bot",
+ email="testbot@example.com",
+ )
+ db_session.add(bot)
+ db_session.commit()
+ return bot
+
+
@pytest.fixture
def sample_discord_channel(db_session):
"""Create a sample Discord channel."""
@@ -290,13 +304,13 @@ def test_upsert_scheduled_message_cancels_earlier_call(
# Test previous_messages
-def test_previous_messages_empty(db_session):
+def test_previous_messages_empty(db_session, sample_bot_user):
"""Test getting previous messages when none exist."""
- result = previous_messages(db_session, user_id=123, channel_id=456)
+ result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=123, channel_id=456)
assert result == []
-def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel):
+def test_previous_messages_filters_by_user(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test filtering messages by recipient user."""
# Create some messages
msg1 = DiscordMessage(
@@ -322,14 +336,14 @@ def test_previous_messages_filters_by_user(db_session, sample_discord_user, samp
db_session.add_all([msg1, msg2])
db_session.commit()
- result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None)
+ result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None)
assert len(result) == 2
# Should be in chronological order (oldest first)
assert result[0].message_id == 1
assert result[1].message_id == 2
-def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel):
+def test_previous_messages_limits_results(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test limiting the number of previous messages."""
# Create 15 messages
for i in range(15):
@@ -347,7 +361,7 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
db_session.commit()
result = previous_messages(
- db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5
+ db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None, max_messages=5
)
assert len(result) == 5
@@ -355,10 +369,10 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
# Test comm_channel_prompt
-def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel):
+def test_comm_channel_prompt_basic(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test generating a basic communication channel prompt."""
result = comm_channel_prompt(
- db_session, user=sample_discord_user, channel=sample_discord_channel
+ db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
)
assert "You are a bot communicating on Discord" in result
@@ -366,31 +380,31 @@ def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_disco
assert len(result) > 0
-def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel):
+def test_comm_channel_prompt_includes_server_context(db_session, sample_bot_user, sample_discord_channel):
"""Test that prompt includes server context when available."""
server = sample_discord_channel.server
server.summary = "Gaming community server"
db_session.commit()
- result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
+ result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
assert "server_context" in result.lower()
assert "Gaming community server" in result
-def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel):
+def test_comm_channel_prompt_includes_channel_context(db_session, sample_bot_user, sample_discord_channel):
"""Test that prompt includes channel context."""
sample_discord_channel.summary = "General discussion channel"
db_session.commit()
- result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
+ result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
assert "channel_context" in result.lower()
assert "General discussion channel" in result
def test_comm_channel_prompt_includes_user_notes(
- db_session, sample_discord_user, sample_discord_channel
+ db_session, sample_bot_user, sample_discord_user, sample_discord_channel
):
"""Test that prompt includes user notes from previous messages."""
sample_discord_user.summary = "Helpful community member"
@@ -411,7 +425,7 @@ def test_comm_channel_prompt_includes_user_notes(
db_session.commit()
result = comm_channel_prompt(
- db_session, user=sample_discord_user, channel=sample_discord_channel
+ db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
)
assert "user_notes" in result.lower()
@@ -442,12 +456,16 @@ def test_call_llm_includes_web_search_and_mcp_servers(
web_tool_instance = MagicMock(name="web_tool")
mock_web_search.return_value = web_tool_instance
- bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
+ bot_user = SimpleNamespace(
+ system_user=SimpleNamespace(discord_id=999888777),
+ system_prompt="bot prompt"
+ )
from_user = SimpleNamespace(id=123)
mcp_model = SimpleNamespace(
name="Server",
mcp_server_url="https://mcp.example.com",
access_token="token123",
+ disabled_tools=[],
)
result = call_llm(
@@ -502,7 +520,10 @@ def test_call_llm_filters_disallowed_tools(
mock_web_search.return_value = MagicMock(name="web_tool")
- bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
+ bot_user = SimpleNamespace(
+ system_user=SimpleNamespace(discord_id=999888777),
+ system_prompt=None
+ )
from_user = SimpleNamespace(id=1)
call_llm(
diff --git a/tests/memory/workers/tasks/test_discord_tasks.py b/tests/memory/workers/tasks/test_discord_tasks.py
index 8decc7d..9650ba7 100644
--- a/tests/memory/workers/tasks/test_discord_tasks.py
+++ b/tests/memory/workers/tasks/test_discord_tasks.py
@@ -14,12 +14,25 @@ from memory.workers.tasks import discord
@pytest.fixture
def discord_bot_user(db_session):
+ # Create a discord user for the bot first
+ bot_discord_user = DiscordUser(
+ id=999999999,
+ username="testbot",
+ )
+ db_session.add(bot_discord_user)
+ db_session.flush()
+
bot = DiscordBotUser.create_with_api_key(
- discord_users=[],
+ discord_users=[bot_discord_user],
name="Test Bot",
email="bot@example.com",
)
db_session.add(bot)
+ db_session.flush()
+
+ # Link the discord user to the system user
+ bot_discord_user.system_user_id = bot.id
+
db_session.commit()
return bot
@@ -176,26 +189,29 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel):
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
-@patch("memory.workers.tasks.discord.create_provider")
+@patch("memory.workers.tasks.discord.call_llm")
+@patch("memory.workers.tasks.discord.discord.trigger_typing_channel")
def test_should_process_normal_message(
- mock_create_provider,
+ mock_trigger_typing,
+ mock_call_llm,
db_session,
mock_discord_user,
mock_discord_server,
mock_discord_channel,
+ discord_bot_user,
):
"""Test should_process returns True for normal messages."""
- # Mock the LLM provider to return "yes"
- mock_provider = Mock()
- mock_provider.generate.return_value = "yes"
- mock_provider.as_messages.return_value = []
- mock_create_provider.return_value = mock_provider
+ # Create a separate recipient user (the bot)
+ bot_discord_user = discord_bot_user.discord_users[0]
+
+ # Mock call_llm to return a high number (100 = always process)
+ mock_call_llm.return_value = "100Test"
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
from_id=mock_discord_user.id,
- recipient_id=mock_discord_user.id,
+ recipient_id=bot_discord_user.id, # Bot is recipient, not the from_user
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
@@ -207,6 +223,8 @@ def test_should_process_normal_message(
db_session.refresh(message)
assert discord.should_process(message) is True
+ mock_call_llm.assert_called_once()
+ mock_trigger_typing.assert_called_once()
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
@@ -344,6 +362,7 @@ def test_add_discord_message_success(db_session, sample_message_data, qdrant):
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
"""Test adding a Discord message that is a reply."""
sample_message_data["message_reference_id"] = 111222333
+ sample_message_data["message_type"] = "reply" # Explicitly set message_type
discord.add_discord_message(**sample_message_data)
@@ -523,8 +542,17 @@ def test_edit_discord_message_updates_context(
assert result["status"] == "processed"
-def test_process_discord_message_success(db_session, sample_message_data, qdrant):
+@patch("memory.workers.tasks.discord.send_discord_response")
+@patch("memory.workers.tasks.discord.call_llm")
+def test_process_discord_message_success(
+ mock_call_llm, mock_send_response, db_session, sample_message_data, qdrant
+):
"""Test processing a Discord message."""
+ # Mock LLM to return a response
+ mock_call_llm.return_value = "Test response from bot"
+ # Mock Discord API to succeed
+ mock_send_response.return_value = True
+
# Add a message first
add_result = discord.add_discord_message(**sample_message_data)
message_id = add_result["discordmessage_id"]
@@ -534,6 +562,8 @@ def test_process_discord_message_success(db_session, sample_message_data, qdrant
assert result["status"] == "processed"
assert result["message_id"] == message_id
+ mock_call_llm.assert_called_once()
+ mock_send_response.assert_called_once()
def test_process_discord_message_not_found(db_session):
diff --git a/tests/memory/workers/tasks/test_scheduled_calls.py b/tests/memory/workers/tasks/test_scheduled_calls.py
index 44d567d..ceb28d5 100644
--- a/tests/memory/workers/tasks/test_scheduled_calls.py
+++ b/tests/memory/workers/tasks/test_scheduled_calls.py
@@ -16,8 +16,16 @@ from memory.workers.tasks import scheduled_calls
@pytest.fixture
def sample_user(db_session):
"""Create a sample user for testing."""
+ # Create a discord user for the bot
+ bot_discord_user = DiscordUser(
+ id=999999999,
+ username="testbot",
+ )
+ db_session.add(bot_discord_user)
+ db_session.flush()
+
user = DiscordBotUser.create_with_api_key(
- discord_users=[],
+ discord_users=[bot_discord_user],
name="testbot",
email="bot@example.com",
)
@@ -122,65 +130,64 @@ def future_scheduled_call(db_session, sample_user, sample_discord_user):
return call
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
+@patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
"""Test sending to Discord user."""
response = "This is a test response."
- scheduled_calls.send_to_discord(pending_scheduled_call, response)
+ scheduled_calls.send_to_discord(999999999, pending_scheduled_call, response)
mock_send_dm.assert_called_once_with(
- pending_scheduled_call.user_id,
+ 999999999, # bot_id
"testuser", # username, not ID
- "**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
+ response,
)
-@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
-def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
+@patch("memory.discord.messages.discord.send_to_channel")
+def test_send_to_discord_channel(mock_send_to_channel, completed_scheduled_call):
"""Test sending to Discord channel."""
response = "This is a channel response."
- scheduled_calls.send_to_discord(completed_scheduled_call, response)
+ scheduled_calls.send_to_discord(999999999, completed_scheduled_call, response)
- mock_broadcast.assert_called_once_with(
- completed_scheduled_call.user_id,
- "test-channel", # channel name, not ID
- "**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
+ mock_send_to_channel.assert_called_once_with(
+ 999999999, # bot_id
+ completed_scheduled_call.discord_channel.id, # channel ID, not name
+ response,
)
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
+@patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
"""Test message truncation for long responses."""
long_response = "A" * 2500 # Very long response
- scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
+ scheduled_calls.send_to_discord(999999999, pending_scheduled_call, long_response)
- # Verify the message was truncated
+ # With the new implementation, send_discord_response sends the full response
+ # No truncation happens in _send_to_discord
args, kwargs = mock_send_dm.call_args
- assert args[0] == pending_scheduled_call.user_id
+ assert args[0] == 999999999 # bot_id
message = args[2]
- assert len(message) <= 1950 # Should be truncated
- assert message.endswith("... (response truncated)")
+ assert message == long_response
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
+@patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
"""Test that normal length messages are not truncated."""
normal_response = "This is a normal length response."
- scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
+ scheduled_calls.send_to_discord(999999999, pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args
- assert args[0] == pending_scheduled_call.user_id
+ assert args[0] == 999999999 # bot_id
message = args[2]
- assert not message.endswith("... (response truncated)")
- assert "This is a normal length response." in message
+ assert message == normal_response
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_success(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -189,12 +196,8 @@ def test_execute_scheduled_call_success(
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
- # Verify LLM was called with correct parameters
- mock_llm_call.assert_called_once_with(
- prompt="What is the weather like today?",
- model="anthropic/claude-3-5-sonnet-20241022",
- system_prompt="You are a helpful assistant.",
- )
+ # Verify LLM was called
+ mock_llm_call.assert_called_once()
# Verify result
assert result["success"] is True
@@ -218,7 +221,7 @@ def test_execute_scheduled_call_not_found(db_session):
assert result == {"error": "Scheduled call not found"}
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_not_pending(
mock_llm_call, completed_scheduled_call, db_session
):
@@ -229,8 +232,8 @@ def test_execute_scheduled_call_not_pending(
mock_llm_call.assert_not_called()
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_with_default_system_prompt(
mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
):
@@ -254,16 +257,12 @@ def test_execute_scheduled_call_with_default_system_prompt(
scheduled_calls.execute_scheduled_call(call.id)
- # Verify default system prompt was used
- mock_llm_call.assert_called_once_with(
- prompt="Test prompt",
- model="anthropic/claude-3-5-sonnet-20241022",
- system_prompt=None, # The code uses system_prompt as-is, not a default
- )
+ # Verify LLM was called
+ mock_llm_call.assert_called_once()
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_discord_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -286,26 +285,27 @@ def test_execute_scheduled_call_discord_error(
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_llm_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
"""Test execution when LLM call fails."""
mock_llm_call.side_effect = Exception("LLM API error")
- # The safe_task_execution decorator should catch this
+ # The execute_scheduled_call function catches the exception and returns an error response
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
- assert result["status"] == "error"
- assert "LLM API error" in result["error"]
+ assert result["success"] is False
+ assert "error" in result
+ assert "LLM call failed" in result["error"]
# Discord should not be called
mock_send_discord.assert_not_called()
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_long_response_truncation(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -477,8 +477,8 @@ def test_run_scheduled_calls_timezone_handling(
mock_execute_delay.delay.assert_called_once_with(due_call.id)
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_status_transition_pending_to_executing_to_completed(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -502,14 +502,14 @@ def test_status_transition_pending_to_executing_to_completed(
"has_discord_user,has_discord_channel,expected_method",
[
(True, False, "send_dm"),
- (False, True, "broadcast_message"),
- (True, True, "send_dm"), # User takes precedence
+ (False, True, "send_to_channel"),
+ (True, True, "send_to_channel"), # Channel takes precedence in the implementation
],
)
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
-@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
+@patch("memory.discord.messages.discord.send_dm")
+@patch("memory.discord.messages.discord.send_to_channel")
def test_discord_destination_priority(
- mock_broadcast,
+ mock_send_to_channel,
mock_send_dm,
has_discord_user,
has_discord_channel,
@@ -535,50 +535,39 @@ def test_discord_destination_priority(
db_session.commit()
response = "Test response"
- scheduled_calls.send_to_discord(call, response)
+ scheduled_calls.send_to_discord(999999999, call, response)
if expected_method == "send_dm":
mock_send_dm.assert_called_once()
- mock_broadcast.assert_not_called()
+ mock_send_to_channel.assert_not_called()
else:
- mock_broadcast.assert_called_once()
+ mock_send_to_channel.assert_called_once()
mock_send_dm.assert_not_called()
@pytest.mark.parametrize(
- "topic,model,response,expected_in_message",
+ "topic,model,response",
[
(
"Weather Check",
"anthropic/claude-3-5-sonnet-20241022",
"It's sunny!",
- [
- "**Topic:** Weather Check",
- "**Model:** anthropic/claude-3-5-sonnet-20241022",
- "**Response:** It's sunny!",
- ],
),
(
"Test Topic",
"gpt-4",
"Hello world",
- ["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
),
(
"Long Topic Name Here",
"claude-2",
"Short",
- [
- "**Topic:** Long Topic Name Here",
- "**Model:** claude-2",
- "**Response:** Short",
- ],
),
],
)
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
-def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
- """Test the Discord message formatting with different inputs."""
+@patch("memory.discord.messages.discord.send_dm")
+def test_message_formatting(mock_send_dm, topic, model, response):
+ """Test that _send_to_discord sends the response as-is."""
# Create a mock scheduled call with a mock Discord user
mock_discord_user = Mock()
mock_discord_user.username = "testuser"
@@ -590,16 +579,15 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
mock_call.discord_user = mock_discord_user
mock_call.discord_channel = None
- scheduled_calls.send_to_discord(mock_call, response)
+ scheduled_calls.send_to_discord(999999999, mock_call, response)
# Get the actual message that was sent
args, kwargs = mock_send_dm.call_args
- assert args[0] == mock_call.user_id
+ assert args[0] == 999999999 # bot_id
actual_message = args[2]
- # Verify all expected parts are in the message
- for expected_part in expected_in_message:
- assert expected_part in actual_message
+ # The new implementation sends the response as-is, without formatting
+ assert actual_message == response
@pytest.mark.parametrize(
@@ -612,7 +600,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
("cancelled", False),
],
)
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_status_check(
mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
):
diff --git a/tools/deploy.sh b/tools/deploy.sh
new file mode 100755
index 0000000..37031be
--- /dev/null
+++ b/tools/deploy.sh
@@ -0,0 +1,125 @@
+#!/bin/bash
+set -e
+
+REMOTE_HOST="memory"
+REMOTE_DIR="/home/ec2-user/memory"
+DEFAULT_BRANCH="master"
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
+
+# Colors for output
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+NC='\033[0m'
+
+usage() {
+ echo "Usage: $0 [options]"
+ echo ""
+ echo "Commands:"
+ echo " sync Rsync local code to server"
+ echo " pull [branch] Git checkout and pull (default: master)"
+ echo " restart Restart docker services"
+ echo " deploy [branch] Pull + restart"
+ echo " run Run command on server (with venv activated)"
+ exit 1
+}
+
+sync_code() {
+ echo -e "${GREEN}Syncing code to $REMOTE_HOST...${NC}"
+
+ rsync -avz --delete \
+ --exclude='__pycache__' \
+ --exclude='*.pyc' \
+ --exclude='*.pyo' \
+ --exclude='.git' \
+ --exclude='memory_files' \
+ --exclude='secrets' \
+ --exclude='Books' \
+ --exclude='clean_books' \
+ --exclude='.env' \
+ --exclude='venv' \
+ --exclude='.venv' \
+ --exclude='*.egg-info' \
+ --exclude='node_modules' \
+ --exclude='.DS_Store' \
+ --exclude='docker-compose.override.yml' \
+ --exclude='.pytest_cache' \
+ --exclude='.mypy_cache' \
+ --exclude='.ruff_cache' \
+ --exclude='htmlcov' \
+ --exclude='.coverage' \
+ --exclude='*.log' \
+ "$PROJECT_DIR/src" \
+ "$PROJECT_DIR/tests" \
+ "$PROJECT_DIR/tools" \
+ "$PROJECT_DIR/db" \
+ "$PROJECT_DIR/docker" \
+ "$PROJECT_DIR/frontend" \
+ "$PROJECT_DIR/requirements" \
+ "$PROJECT_DIR/setup.py" \
+ "$PROJECT_DIR/pyproject.toml" \
+ "$PROJECT_DIR/docker-compose.yaml" \
+ "$PROJECT_DIR/pytest.ini" \
+ "$REMOTE_HOST:$REMOTE_DIR/"
+
+ echo -e "${GREEN}Sync complete!${NC}"
+}
+
+git_pull() {
+ local branch="${1:-$DEFAULT_BRANCH}"
+ echo -e "${GREEN}Pulling branch '$branch' on $REMOTE_HOST...${NC}"
+
+ ssh "$REMOTE_HOST" "cd $REMOTE_DIR && \
+ git stash --quiet 2>/dev/null || true && \
+ git fetch origin && \
+ git checkout $branch && \
+ git pull origin $branch"
+
+ echo -e "${GREEN}Pull complete!${NC}"
+}
+
+restart_services() {
+ echo -e "${GREEN}Restarting services on $REMOTE_HOST...${NC}"
+
+ ssh "$REMOTE_HOST" "cd $REMOTE_DIR && docker compose up --build -d"
+
+ echo -e "${GREEN}Services restarted!${NC}"
+}
+
+deploy() {
+ local branch="${1:-$DEFAULT_BRANCH}"
+ git_pull "$branch"
+ restart_services
+}
+
+run_remote() {
+ if [ $# -eq 0 ]; then
+ echo -e "${RED}Error: No command specified${NC}"
+ exit 1
+ fi
+ ssh "$REMOTE_HOST" "cd $REMOTE_DIR && source venv/bin/activate && $*"
+}
+
+# Main
+case "${1:-}" in
+ sync)
+ sync_code
+ ;;
+ pull)
+ git_pull "${2:-}"
+ ;;
+ restart)
+ restart_services
+ ;;
+ deploy)
+ deploy "${2:-}"
+ ;;
+ run)
+ shift
+ run_remote "$@"
+ ;;
+ *)
+ usage
+ ;;
+esac