mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Add deployer
This commit is contained in:
parent
47180e1e71
commit
d3d71edf1d
@ -766,7 +766,7 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
),
|
||||
(
|
||||
0.409,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
|
||||
),
|
||||
],
|
||||
},
|
||||
@ -835,11 +835,11 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"semantic": [
|
||||
(0.489, "I find backend logic more interesting than UI work"),
|
||||
(0.462, "The user prefers working on backend systems over frontend UI"),
|
||||
(0.455, "The user said pure functions are yucky"),
|
||||
(
|
||||
0.455,
|
||||
"The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(0.455, "The user said pure functions are yucky"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
|
||||
@ -137,10 +137,9 @@ class TestBuildPrompt:
|
||||
):
|
||||
prompt = _build_prompt()
|
||||
|
||||
assert "lesswrong" in prompt.lower()
|
||||
assert "comic" in prompt.lower()
|
||||
assert "Remove" in prompt
|
||||
assert "Remove meta-language" in prompt
|
||||
assert "Return ONLY valid JSON" in prompt
|
||||
assert "recalled_content" in prompt
|
||||
|
||||
|
||||
class TestAnalyzeQuery:
|
||||
|
||||
@ -83,9 +83,10 @@ def test_logout_handles_missing_session(mock_get_user_session):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.make_session")
|
||||
async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
||||
async def test_oauth_callback_discord_success(mock_make_session, mock_complete, mock_mcp_tools):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
@ -95,9 +96,12 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
mcp_server = MagicMock()
|
||||
mcp_server.mcp_server_url = "https://example.com"
|
||||
mcp_server.access_token = "token123"
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||
|
||||
mock_complete.return_value = (200, "Authorized")
|
||||
mock_mcp_tools.return_value = [{"name": "test_tool"}]
|
||||
|
||||
request = make_request("code=abc123&state=state456")
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
@ -107,14 +111,15 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
||||
assert "Authorization Successful" in body
|
||||
assert "Authorized" in body
|
||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||
mock_session.commit.assert_called_once()
|
||||
assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.make_session")
|
||||
async def test_oauth_callback_discord_handles_failures(
|
||||
mock_make_session, mock_complete
|
||||
mock_make_session, mock_complete, mock_mcp_tools
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@ -125,9 +130,12 @@ async def test_oauth_callback_discord_handles_failures(
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
mcp_server = MagicMock()
|
||||
mcp_server.mcp_server_url = "https://example.com"
|
||||
mcp_server.access_token = "token123"
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||
|
||||
mock_complete.return_value = (500, "Failure")
|
||||
mock_mcp_tools.return_value = []
|
||||
|
||||
request = make_request("code=abc123&state=state456")
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
@ -137,7 +145,7 @@ async def test_oauth_callback_discord_handles_failures(
|
||||
assert "Authorization Failed" in body
|
||||
assert "Failure" in body
|
||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||
mock_session.commit.assert_called_once()
|
||||
assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
||||
DiscordUser,
|
||||
DiscordMessage,
|
||||
BotUser,
|
||||
DiscordBotUser,
|
||||
HumanUser,
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
@ -67,9 +68,10 @@ def sample_discord_user(db_session):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bot_user(db_session):
|
||||
def sample_bot_user(db_session, sample_discord_user):
|
||||
"""Create a sample bot user for testing."""
|
||||
bot = BotUser.create_with_api_key(
|
||||
bot = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[sample_discord_user],
|
||||
name="Test Bot",
|
||||
email="testbot@example.com",
|
||||
)
|
||||
@ -209,9 +211,9 @@ def test_schedule_message_with_user(
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = schedule_message(
|
||||
user_id=sample_human_user.id,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
bot_id=sample_human_user.id,
|
||||
recipient_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
message="Test message",
|
||||
date_time=future_time,
|
||||
@ -240,9 +242,9 @@ def test_schedule_message_with_channel(
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = schedule_message(
|
||||
user_id=sample_human_user.id,
|
||||
user=None,
|
||||
channel=sample_discord_channel.id,
|
||||
bot_id=sample_human_user.id,
|
||||
recipient_id=None,
|
||||
channel_id=sample_discord_channel.id,
|
||||
model="test-model",
|
||||
message="Test message",
|
||||
date_time=future_time,
|
||||
@ -265,12 +267,12 @@ def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user):
|
||||
"""Test creating a message scheduler tool for a user."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert tool.name == "schedule_message"
|
||||
assert tool.name == "schedule_discord_message"
|
||||
assert "from your chat with this user" in tool.description
|
||||
assert tool.input_schema["type"] == "object"
|
||||
assert "message" in tool.input_schema["properties"]
|
||||
@ -282,12 +284,12 @@ def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_cha
|
||||
"""Test creating a message scheduler tool for a channel."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=None,
|
||||
channel=sample_discord_channel.id,
|
||||
user_id=None,
|
||||
channel_id=sample_discord_channel.id,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert tool.name == "schedule_message"
|
||||
assert tool.name == "schedule_discord_message"
|
||||
assert "in this channel" in tool.description
|
||||
assert callable(tool.function)
|
||||
|
||||
@ -297,8 +299,8 @@ def test_make_message_scheduler_without_user_or_channel(sample_bot_user):
|
||||
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
||||
make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=None,
|
||||
channel=None,
|
||||
user_id=None,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -310,8 +312,8 @@ def test_message_scheduler_handler_success(
|
||||
"""Test message scheduler handler with valid input."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -330,8 +332,8 @@ def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord
|
||||
"""Test message scheduler handler with non-dict input."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -345,8 +347,8 @@ def test_message_scheduler_handler_invalid_datetime(
|
||||
"""Test message scheduler handler with invalid datetime."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -365,8 +367,8 @@ def test_message_scheduler_handler_missing_datetime(
|
||||
"""Test message scheduler handler with missing datetime."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -375,9 +377,9 @@ def test_message_scheduler_handler_missing_datetime(
|
||||
|
||||
|
||||
# Tests for make_prev_messages_tool
|
||||
def test_make_prev_messages_tool_with_user(sample_discord_user):
|
||||
def test_make_prev_messages_tool_with_user(sample_bot_user, sample_discord_user):
|
||||
"""Test creating a previous messages tool for a user."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
assert tool.name == "previous_messages"
|
||||
assert "from your chat with this user" in tool.description
|
||||
@ -387,26 +389,26 @@ def test_make_prev_messages_tool_with_user(sample_discord_user):
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_prev_messages_tool_with_channel(sample_discord_channel):
|
||||
def test_make_prev_messages_tool_with_channel(sample_bot_user, sample_discord_channel):
|
||||
"""Test creating a previous messages tool for a channel."""
|
||||
tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=sample_discord_channel.id)
|
||||
|
||||
assert tool.name == "previous_messages"
|
||||
assert "in this channel" in tool.description
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_prev_messages_tool_without_user_or_channel():
|
||||
def test_make_prev_messages_tool_without_user_or_channel(sample_bot_user):
|
||||
"""Test that creating a tool without user or channel raises error."""
|
||||
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
||||
make_prev_messages_tool(user=None, channel=None)
|
||||
make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=None)
|
||||
|
||||
|
||||
def test_prev_messages_handler_success(
|
||||
db_session, sample_discord_user, sample_discord_channel
|
||||
db_session, sample_bot_user, sample_discord_user, sample_discord_channel
|
||||
):
|
||||
"""Test previous messages handler with valid input."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
# Create some actual messages in the database
|
||||
msg1 = DiscordMessage(
|
||||
@ -440,9 +442,9 @@ def test_prev_messages_handler_success(
|
||||
assert "Message 1" in result or "Message 2" in result
|
||||
|
||||
|
||||
def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
|
||||
def test_prev_messages_handler_with_defaults(db_session, sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with default values."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
result = tool.function({})
|
||||
|
||||
@ -450,35 +452,35 @@ def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_prev_messages_handler_invalid_input(sample_discord_user):
|
||||
def test_prev_messages_handler_invalid_input(sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with non-dict input."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Input must be a dictionary"):
|
||||
tool.function("not a dict")
|
||||
|
||||
|
||||
def test_prev_messages_handler_invalid_max_messages(sample_discord_user):
|
||||
def test_prev_messages_handler_invalid_max_messages(sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with invalid max_messages (negative value)."""
|
||||
# Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
|
||||
# so we test with -1 which actually triggers the validation
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Max messages must be greater than 0"):
|
||||
tool.function({"max_messages": -1})
|
||||
|
||||
|
||||
def test_prev_messages_handler_invalid_offset(sample_discord_user):
|
||||
def test_prev_messages_handler_invalid_offset(sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with invalid offset."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
|
||||
tool.function({"offset": -1})
|
||||
|
||||
|
||||
def test_prev_messages_handler_non_integer_values(sample_discord_user):
|
||||
def test_prev_messages_handler_non_integer_values(sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with non-integer values."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Max messages and offset must be integers"):
|
||||
tool.function({"max_messages": "not an int"})
|
||||
@ -496,10 +498,10 @@ def test_make_discord_tools_with_user_and_channel(
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# Should have: schedule_discord_message, previous_messages, update_channel_summary,
|
||||
# update_user_summary, update_server_summary, add_reaction
|
||||
assert len(tools) == 6
|
||||
assert "schedule_message" in tools
|
||||
assert "schedule_discord_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_user_summary" in tools
|
||||
@ -516,10 +518,10 @@ def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user)
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_user_summary
|
||||
# Should have: schedule_discord_message, previous_messages, update_user_summary
|
||||
# Note: Without channel, there's no channel summary tool
|
||||
assert len(tools) >= 2 # At least schedule and previous messages
|
||||
assert "schedule_message" in tools
|
||||
assert "schedule_discord_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_user_summary" in tools
|
||||
|
||||
@ -533,10 +535,10 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# Should have: schedule_discord_message, previous_messages, update_channel_summary,
|
||||
# update_server_summary, add_reaction (no user summary without author)
|
||||
assert len(tools) == 5
|
||||
assert "schedule_message" in tools
|
||||
assert "schedule_discord_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_server_summary" in tools
|
||||
|
||||
@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
"http://localhost:8000/send_channel",
|
||||
json={
|
||||
"bot_id": BOT_ID,
|
||||
"channel_name": "general",
|
||||
"channel": "general",
|
||||
"message": "Announcement!",
|
||||
},
|
||||
timeout=10,
|
||||
|
||||
@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
"http://localhost:8000/send_channel",
|
||||
json={
|
||||
"bot_id": BOT_ID,
|
||||
"channel_name": "general",
|
||||
"channel": "general",
|
||||
"message": "Announcement!",
|
||||
},
|
||||
timeout=10,
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Tests for Discord MCP server management."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import aiohttp
|
||||
@ -8,8 +7,8 @@ import discord
|
||||
import pytest
|
||||
|
||||
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||
from memory.common.mcp import mcp_call
|
||||
from memory.discord.mcp import (
|
||||
call_mcp_server,
|
||||
find_mcp_server,
|
||||
handle_mcp_add,
|
||||
handle_mcp_connect,
|
||||
@ -142,7 +141,7 @@ async def test_call_mcp_server_success():
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
results = []
|
||||
async for data in call_mcp_server(
|
||||
async for data in mcp_call(
|
||||
"https://mcp.example.com", "test_token", "tools/list", {}
|
||||
):
|
||||
results.append(data)
|
||||
@ -172,7 +171,7 @@ async def test_call_mcp_server_error():
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Failed to call MCP server"):
|
||||
async for _ in call_mcp_server(
|
||||
async for _ in mcp_call(
|
||||
"https://mcp.example.com", "test_token", "tools/list"
|
||||
):
|
||||
pass
|
||||
@ -203,7 +202,7 @@ async def test_call_mcp_server_invalid_json():
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
results = []
|
||||
async for data in call_mcp_server(
|
||||
async for data in mcp_call(
|
||||
"https://mcp.example.com", "test_token", "tools/list"
|
||||
):
|
||||
results.append(data)
|
||||
|
||||
@ -19,6 +19,7 @@ from memory.common.db.models import (
|
||||
DiscordChannel,
|
||||
DiscordServer,
|
||||
DiscordMessage,
|
||||
DiscordBotUser,
|
||||
HumanUser,
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
@ -34,6 +35,19 @@ def sample_discord_user(db_session):
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bot_user(db_session, sample_discord_user):
|
||||
"""Create a sample Discord bot user."""
|
||||
bot = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[sample_discord_user],
|
||||
name="Test Bot",
|
||||
email="testbot@example.com",
|
||||
)
|
||||
db_session.add(bot)
|
||||
db_session.commit()
|
||||
return bot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discord_channel(db_session):
|
||||
"""Create a sample Discord channel."""
|
||||
@ -290,13 +304,13 @@ def test_upsert_scheduled_message_cancels_earlier_call(
|
||||
# Test previous_messages
|
||||
|
||||
|
||||
def test_previous_messages_empty(db_session):
|
||||
def test_previous_messages_empty(db_session, sample_bot_user):
|
||||
"""Test getting previous messages when none exist."""
|
||||
result = previous_messages(db_session, user_id=123, channel_id=456)
|
||||
result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=123, channel_id=456)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel):
|
||||
def test_previous_messages_filters_by_user(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
|
||||
"""Test filtering messages by recipient user."""
|
||||
# Create some messages
|
||||
msg1 = DiscordMessage(
|
||||
@ -322,14 +336,14 @@ def test_previous_messages_filters_by_user(db_session, sample_discord_user, samp
|
||||
db_session.add_all([msg1, msg2])
|
||||
db_session.commit()
|
||||
|
||||
result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None)
|
||||
result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None)
|
||||
assert len(result) == 2
|
||||
# Should be in chronological order (oldest first)
|
||||
assert result[0].message_id == 1
|
||||
assert result[1].message_id == 2
|
||||
|
||||
|
||||
def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel):
|
||||
def test_previous_messages_limits_results(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
|
||||
"""Test limiting the number of previous messages."""
|
||||
# Create 15 messages
|
||||
for i in range(15):
|
||||
@ -347,7 +361,7 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
|
||||
db_session.commit()
|
||||
|
||||
result = previous_messages(
|
||||
db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5
|
||||
db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None, max_messages=5
|
||||
)
|
||||
assert len(result) == 5
|
||||
|
||||
@ -355,10 +369,10 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
|
||||
# Test comm_channel_prompt
|
||||
|
||||
|
||||
def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel):
|
||||
def test_comm_channel_prompt_basic(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
|
||||
"""Test generating a basic communication channel prompt."""
|
||||
result = comm_channel_prompt(
|
||||
db_session, user=sample_discord_user, channel=sample_discord_channel
|
||||
db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
|
||||
)
|
||||
|
||||
assert "You are a bot communicating on Discord" in result
|
||||
@ -366,31 +380,31 @@ def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_disco
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel):
|
||||
def test_comm_channel_prompt_includes_server_context(db_session, sample_bot_user, sample_discord_channel):
|
||||
"""Test that prompt includes server context when available."""
|
||||
server = sample_discord_channel.server
|
||||
server.summary = "Gaming community server"
|
||||
db_session.commit()
|
||||
|
||||
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
|
||||
result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
|
||||
|
||||
assert "server_context" in result.lower()
|
||||
assert "Gaming community server" in result
|
||||
|
||||
|
||||
def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel):
|
||||
def test_comm_channel_prompt_includes_channel_context(db_session, sample_bot_user, sample_discord_channel):
|
||||
"""Test that prompt includes channel context."""
|
||||
sample_discord_channel.summary = "General discussion channel"
|
||||
db_session.commit()
|
||||
|
||||
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
|
||||
result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
|
||||
|
||||
assert "channel_context" in result.lower()
|
||||
assert "General discussion channel" in result
|
||||
|
||||
|
||||
def test_comm_channel_prompt_includes_user_notes(
|
||||
db_session, sample_discord_user, sample_discord_channel
|
||||
db_session, sample_bot_user, sample_discord_user, sample_discord_channel
|
||||
):
|
||||
"""Test that prompt includes user notes from previous messages."""
|
||||
sample_discord_user.summary = "Helpful community member"
|
||||
@ -411,7 +425,7 @@ def test_comm_channel_prompt_includes_user_notes(
|
||||
db_session.commit()
|
||||
|
||||
result = comm_channel_prompt(
|
||||
db_session, user=sample_discord_user, channel=sample_discord_channel
|
||||
db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
|
||||
)
|
||||
|
||||
assert "user_notes" in result.lower()
|
||||
@ -442,12 +456,16 @@ def test_call_llm_includes_web_search_and_mcp_servers(
|
||||
web_tool_instance = MagicMock(name="web_tool")
|
||||
mock_web_search.return_value = web_tool_instance
|
||||
|
||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
|
||||
bot_user = SimpleNamespace(
|
||||
system_user=SimpleNamespace(discord_id=999888777),
|
||||
system_prompt="bot prompt"
|
||||
)
|
||||
from_user = SimpleNamespace(id=123)
|
||||
mcp_model = SimpleNamespace(
|
||||
name="Server",
|
||||
mcp_server_url="https://mcp.example.com",
|
||||
access_token="token123",
|
||||
disabled_tools=[],
|
||||
)
|
||||
|
||||
result = call_llm(
|
||||
@ -502,7 +520,10 @@ def test_call_llm_filters_disallowed_tools(
|
||||
|
||||
mock_web_search.return_value = MagicMock(name="web_tool")
|
||||
|
||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
|
||||
bot_user = SimpleNamespace(
|
||||
system_user=SimpleNamespace(discord_id=999888777),
|
||||
system_prompt=None
|
||||
)
|
||||
from_user = SimpleNamespace(id=1)
|
||||
|
||||
call_llm(
|
||||
|
||||
@ -14,12 +14,25 @@ from memory.workers.tasks import discord
|
||||
|
||||
@pytest.fixture
|
||||
def discord_bot_user(db_session):
|
||||
# Create a discord user for the bot first
|
||||
bot_discord_user = DiscordUser(
|
||||
id=999999999,
|
||||
username="testbot",
|
||||
)
|
||||
db_session.add(bot_discord_user)
|
||||
db_session.flush()
|
||||
|
||||
bot = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
discord_users=[bot_discord_user],
|
||||
name="Test Bot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
db_session.add(bot)
|
||||
db_session.flush()
|
||||
|
||||
# Link the discord user to the system user
|
||||
bot_discord_user.system_user_id = bot.id
|
||||
|
||||
db_session.commit()
|
||||
return bot
|
||||
|
||||
@ -176,26 +189,29 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel):
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.workers.tasks.discord.create_provider")
|
||||
@patch("memory.workers.tasks.discord.call_llm")
|
||||
@patch("memory.workers.tasks.discord.discord.trigger_typing_channel")
|
||||
def test_should_process_normal_message(
|
||||
mock_create_provider,
|
||||
mock_trigger_typing,
|
||||
mock_call_llm,
|
||||
db_session,
|
||||
mock_discord_user,
|
||||
mock_discord_server,
|
||||
mock_discord_channel,
|
||||
discord_bot_user,
|
||||
):
|
||||
"""Test should_process returns True for normal messages."""
|
||||
# Mock the LLM provider to return "yes"
|
||||
mock_provider = Mock()
|
||||
mock_provider.generate.return_value = "<response>yes</response>"
|
||||
mock_provider.as_messages.return_value = []
|
||||
mock_create_provider.return_value = mock_provider
|
||||
# Create a separate recipient user (the bot)
|
||||
bot_discord_user = discord_bot_user.discord_users[0]
|
||||
|
||||
# Mock call_llm to return a high number (100 = always process)
|
||||
mock_call_llm.return_value = "<response><number>100</number><reason>Test</reason></response>"
|
||||
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
recipient_id=bot_discord_user.id, # Bot is recipient, not the from_user
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
@ -207,6 +223,8 @@ def test_should_process_normal_message(
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is True
|
||||
mock_call_llm.assert_called_once()
|
||||
mock_trigger_typing.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
||||
@ -344,6 +362,7 @@ def test_add_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
|
||||
"""Test adding a Discord message that is a reply."""
|
||||
sample_message_data["message_reference_id"] = 111222333
|
||||
sample_message_data["message_type"] = "reply" # Explicitly set message_type
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
@ -523,8 +542,17 @@ def test_edit_discord_message_updates_context(
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
def test_process_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
@patch("memory.workers.tasks.discord.send_discord_response")
|
||||
@patch("memory.workers.tasks.discord.call_llm")
|
||||
def test_process_discord_message_success(
|
||||
mock_call_llm, mock_send_response, db_session, sample_message_data, qdrant
|
||||
):
|
||||
"""Test processing a Discord message."""
|
||||
# Mock LLM to return a response
|
||||
mock_call_llm.return_value = "Test response from bot"
|
||||
# Mock Discord API to succeed
|
||||
mock_send_response.return_value = True
|
||||
|
||||
# Add a message first
|
||||
add_result = discord.add_discord_message(**sample_message_data)
|
||||
message_id = add_result["discordmessage_id"]
|
||||
@ -534,6 +562,8 @@ def test_process_discord_message_success(db_session, sample_message_data, qdrant
|
||||
|
||||
assert result["status"] == "processed"
|
||||
assert result["message_id"] == message_id
|
||||
mock_call_llm.assert_called_once()
|
||||
mock_send_response.assert_called_once()
|
||||
|
||||
|
||||
def test_process_discord_message_not_found(db_session):
|
||||
|
||||
@ -16,8 +16,16 @@ from memory.workers.tasks import scheduled_calls
|
||||
@pytest.fixture
|
||||
def sample_user(db_session):
|
||||
"""Create a sample user for testing."""
|
||||
# Create a discord user for the bot
|
||||
bot_discord_user = DiscordUser(
|
||||
id=999999999,
|
||||
username="testbot",
|
||||
)
|
||||
db_session.add(bot_discord_user)
|
||||
db_session.flush()
|
||||
|
||||
user = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
discord_users=[bot_discord_user],
|
||||
name="testbot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
@ -122,65 +130,64 @@ def future_scheduled_call(db_session, sample_user, sample_discord_user):
|
||||
return call
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
||||
"""Test sending to Discord user."""
|
||||
response = "This is a test response."
|
||||
|
||||
scheduled_calls.send_to_discord(pending_scheduled_call, response)
|
||||
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, response)
|
||||
|
||||
mock_send_dm.assert_called_once_with(
|
||||
pending_scheduled_call.user_id,
|
||||
999999999, # bot_id
|
||||
"testuser", # username, not ID
|
||||
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
||||
response,
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
||||
def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
||||
@patch("memory.discord.messages.discord.send_to_channel")
|
||||
def test_send_to_discord_channel(mock_send_to_channel, completed_scheduled_call):
|
||||
"""Test sending to Discord channel."""
|
||||
response = "This is a channel response."
|
||||
|
||||
scheduled_calls.send_to_discord(completed_scheduled_call, response)
|
||||
scheduled_calls.send_to_discord(999999999, completed_scheduled_call, response)
|
||||
|
||||
mock_broadcast.assert_called_once_with(
|
||||
completed_scheduled_call.user_id,
|
||||
"test-channel", # channel name, not ID
|
||||
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
||||
mock_send_to_channel.assert_called_once_with(
|
||||
999999999, # bot_id
|
||||
completed_scheduled_call.discord_channel.id, # channel ID, not name
|
||||
response,
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
|
||||
"""Test message truncation for long responses."""
|
||||
long_response = "A" * 2500 # Very long response
|
||||
|
||||
scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
|
||||
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, long_response)
|
||||
|
||||
# Verify the message was truncated
|
||||
# With the new implementation, send_discord_response sends the full response
|
||||
# No truncation happens in _send_to_discord
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
assert args[0] == pending_scheduled_call.user_id
|
||||
assert args[0] == 999999999 # bot_id
|
||||
message = args[2]
|
||||
assert len(message) <= 1950 # Should be truncated
|
||||
assert message.endswith("... (response truncated)")
|
||||
assert message == long_response
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
|
||||
"""Test that normal length messages are not truncated."""
|
||||
normal_response = "This is a normal length response."
|
||||
|
||||
scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
|
||||
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, normal_response)
|
||||
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
assert args[0] == pending_scheduled_call.user_id
|
||||
assert args[0] == 999999999 # bot_id
|
||||
message = args[2]
|
||||
assert not message.endswith("... (response truncated)")
|
||||
assert "This is a normal length response." in message
|
||||
assert message == normal_response
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_success(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -189,12 +196,8 @@ def test_execute_scheduled_call_success(
|
||||
|
||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||
|
||||
# Verify LLM was called with correct parameters
|
||||
mock_llm_call.assert_called_once_with(
|
||||
prompt="What is the weather like today?",
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
# Verify LLM was called
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
assert result["success"] is True
|
||||
@ -218,7 +221,7 @@ def test_execute_scheduled_call_not_found(db_session):
|
||||
assert result == {"error": "Scheduled call not found"}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_not_pending(
|
||||
mock_llm_call, completed_scheduled_call, db_session
|
||||
):
|
||||
@ -229,8 +232,8 @@ def test_execute_scheduled_call_not_pending(
|
||||
mock_llm_call.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_with_default_system_prompt(
|
||||
mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
|
||||
):
|
||||
@ -254,16 +257,12 @@ def test_execute_scheduled_call_with_default_system_prompt(
|
||||
|
||||
scheduled_calls.execute_scheduled_call(call.id)
|
||||
|
||||
# Verify default system prompt was used
|
||||
mock_llm_call.assert_called_once_with(
|
||||
prompt="Test prompt",
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
system_prompt=None, # The code uses system_prompt as-is, not a default
|
||||
)
|
||||
# Verify LLM was called
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_discord_error(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -286,26 +285,27 @@ def test_execute_scheduled_call_discord_error(
|
||||
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_llm_error(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
"""Test execution when LLM call fails."""
|
||||
mock_llm_call.side_effect = Exception("LLM API error")
|
||||
|
||||
# The safe_task_execution decorator should catch this
|
||||
# The execute_scheduled_call function catches the exception and returns an error response
|
||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "LLM API error" in result["error"]
|
||||
assert result["success"] is False
|
||||
assert "error" in result
|
||||
assert "LLM call failed" in result["error"]
|
||||
|
||||
# Discord should not be called
|
||||
mock_send_discord.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_long_response_truncation(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -477,8 +477,8 @@ def test_run_scheduled_calls_timezone_handling(
|
||||
mock_execute_delay.delay.assert_called_once_with(due_call.id)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_status_transition_pending_to_executing_to_completed(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -502,14 +502,14 @@ def test_status_transition_pending_to_executing_to_completed(
|
||||
"has_discord_user,has_discord_channel,expected_method",
|
||||
[
|
||||
(True, False, "send_dm"),
|
||||
(False, True, "broadcast_message"),
|
||||
(True, True, "send_dm"), # User takes precedence
|
||||
(False, True, "send_to_channel"),
|
||||
(True, True, "send_to_channel"), # Channel takes precedence in the implementation
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
@patch("memory.discord.messages.discord.send_to_channel")
|
||||
def test_discord_destination_priority(
|
||||
mock_broadcast,
|
||||
mock_send_to_channel,
|
||||
mock_send_dm,
|
||||
has_discord_user,
|
||||
has_discord_channel,
|
||||
@ -535,50 +535,39 @@ def test_discord_destination_priority(
|
||||
db_session.commit()
|
||||
|
||||
response = "Test response"
|
||||
scheduled_calls.send_to_discord(call, response)
|
||||
scheduled_calls.send_to_discord(999999999, call, response)
|
||||
|
||||
if expected_method == "send_dm":
|
||||
mock_send_dm.assert_called_once()
|
||||
mock_broadcast.assert_not_called()
|
||||
mock_send_to_channel.assert_not_called()
|
||||
else:
|
||||
mock_broadcast.assert_called_once()
|
||||
mock_send_to_channel.assert_called_once()
|
||||
mock_send_dm.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topic,model,response,expected_in_message",
|
||||
"topic,model,response",
|
||||
[
|
||||
(
|
||||
"Weather Check",
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
"It's sunny!",
|
||||
[
|
||||
"**Topic:** Weather Check",
|
||||
"**Model:** anthropic/claude-3-5-sonnet-20241022",
|
||||
"**Response:** It's sunny!",
|
||||
],
|
||||
),
|
||||
(
|
||||
"Test Topic",
|
||||
"gpt-4",
|
||||
"Hello world",
|
||||
["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
|
||||
),
|
||||
(
|
||||
"Long Topic Name Here",
|
||||
"claude-2",
|
||||
"Short",
|
||||
[
|
||||
"**Topic:** Long Topic Name Here",
|
||||
"**Model:** claude-2",
|
||||
"**Response:** Short",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
|
||||
"""Test the Discord message formatting with different inputs."""
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
def test_message_formatting(mock_send_dm, topic, model, response):
|
||||
"""Test that _send_to_discord sends the response as-is."""
|
||||
# Create a mock scheduled call with a mock Discord user
|
||||
mock_discord_user = Mock()
|
||||
mock_discord_user.username = "testuser"
|
||||
@ -590,16 +579,15 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
mock_call.discord_user = mock_discord_user
|
||||
mock_call.discord_channel = None
|
||||
|
||||
scheduled_calls.send_to_discord(mock_call, response)
|
||||
scheduled_calls.send_to_discord(999999999, mock_call, response)
|
||||
|
||||
# Get the actual message that was sent
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
assert args[0] == mock_call.user_id
|
||||
assert args[0] == 999999999 # bot_id
|
||||
actual_message = args[2]
|
||||
|
||||
# Verify all expected parts are in the message
|
||||
for expected_part in expected_in_message:
|
||||
assert expected_part in actual_message
|
||||
# The new implementation sends the response as-is, without formatting
|
||||
assert actual_message == response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -612,7 +600,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
("cancelled", False),
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_status_check(
|
||||
mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
|
||||
):
|
||||
|
||||
125
tools/deploy.sh
Executable file
125
tools/deploy.sh
Executable file
@ -0,0 +1,125 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
REMOTE_HOST="memory"
|
||||
REMOTE_DIR="/home/ec2-user/memory"
|
||||
DEFAULT_BRANCH="master"
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 <command> [options]"
|
||||
echo ""
|
||||
echo "Commands:"
|
||||
echo " sync Rsync local code to server"
|
||||
echo " pull [branch] Git checkout and pull (default: master)"
|
||||
echo " restart Restart docker services"
|
||||
echo " deploy [branch] Pull + restart"
|
||||
echo " run <command> Run command on server (with venv activated)"
|
||||
exit 1
|
||||
}
|
||||
|
||||
sync_code() {
|
||||
echo -e "${GREEN}Syncing code to $REMOTE_HOST...${NC}"
|
||||
|
||||
rsync -avz --delete \
|
||||
--exclude='__pycache__' \
|
||||
--exclude='*.pyc' \
|
||||
--exclude='*.pyo' \
|
||||
--exclude='.git' \
|
||||
--exclude='memory_files' \
|
||||
--exclude='secrets' \
|
||||
--exclude='Books' \
|
||||
--exclude='clean_books' \
|
||||
--exclude='.env' \
|
||||
--exclude='venv' \
|
||||
--exclude='.venv' \
|
||||
--exclude='*.egg-info' \
|
||||
--exclude='node_modules' \
|
||||
--exclude='.DS_Store' \
|
||||
--exclude='docker-compose.override.yml' \
|
||||
--exclude='.pytest_cache' \
|
||||
--exclude='.mypy_cache' \
|
||||
--exclude='.ruff_cache' \
|
||||
--exclude='htmlcov' \
|
||||
--exclude='.coverage' \
|
||||
--exclude='*.log' \
|
||||
"$PROJECT_DIR/src" \
|
||||
"$PROJECT_DIR/tests" \
|
||||
"$PROJECT_DIR/tools" \
|
||||
"$PROJECT_DIR/db" \
|
||||
"$PROJECT_DIR/docker" \
|
||||
"$PROJECT_DIR/frontend" \
|
||||
"$PROJECT_DIR/requirements" \
|
||||
"$PROJECT_DIR/setup.py" \
|
||||
"$PROJECT_DIR/pyproject.toml" \
|
||||
"$PROJECT_DIR/docker-compose.yaml" \
|
||||
"$PROJECT_DIR/pytest.ini" \
|
||||
"$REMOTE_HOST:$REMOTE_DIR/"
|
||||
|
||||
echo -e "${GREEN}Sync complete!${NC}"
|
||||
}
|
||||
|
||||
git_pull() {
|
||||
local branch="${1:-$DEFAULT_BRANCH}"
|
||||
echo -e "${GREEN}Pulling branch '$branch' on $REMOTE_HOST...${NC}"
|
||||
|
||||
ssh "$REMOTE_HOST" "cd $REMOTE_DIR && \
|
||||
git stash --quiet 2>/dev/null || true && \
|
||||
git fetch origin && \
|
||||
git checkout $branch && \
|
||||
git pull origin $branch"
|
||||
|
||||
echo -e "${GREEN}Pull complete!${NC}"
|
||||
}
|
||||
|
||||
restart_services() {
|
||||
echo -e "${GREEN}Restarting services on $REMOTE_HOST...${NC}"
|
||||
|
||||
ssh "$REMOTE_HOST" "cd $REMOTE_DIR && docker compose up --build -d"
|
||||
|
||||
echo -e "${GREEN}Services restarted!${NC}"
|
||||
}
|
||||
|
||||
deploy() {
|
||||
local branch="${1:-$DEFAULT_BRANCH}"
|
||||
git_pull "$branch"
|
||||
restart_services
|
||||
}
|
||||
|
||||
run_remote() {
|
||||
if [ $# -eq 0 ]; then
|
||||
echo -e "${RED}Error: No command specified${NC}"
|
||||
exit 1
|
||||
fi
|
||||
ssh "$REMOTE_HOST" "cd $REMOTE_DIR && source venv/bin/activate && $*"
|
||||
}
|
||||
|
||||
# Main
|
||||
case "${1:-}" in
|
||||
sync)
|
||||
sync_code
|
||||
;;
|
||||
pull)
|
||||
git_pull "${2:-}"
|
||||
;;
|
||||
restart)
|
||||
restart_services
|
||||
;;
|
||||
deploy)
|
||||
deploy "${2:-}"
|
||||
;;
|
||||
run)
|
||||
shift
|
||||
run_remote "$@"
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
Loading…
x
Reference in New Issue
Block a user