mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Add deployer
This commit is contained in:
parent
47180e1e71
commit
d3d71edf1d
@ -766,7 +766,7 @@ EXPECTED_OBSERVATION_RESULTS = {
|
|||||||
),
|
),
|
||||||
(
|
(
|
||||||
0.409,
|
0.409,
|
||||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
"Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
@ -835,11 +835,11 @@ EXPECTED_OBSERVATION_RESULTS = {
|
|||||||
"semantic": [
|
"semantic": [
|
||||||
(0.489, "I find backend logic more interesting than UI work"),
|
(0.489, "I find backend logic more interesting than UI work"),
|
||||||
(0.462, "The user prefers working on backend systems over frontend UI"),
|
(0.462, "The user prefers working on backend systems over frontend UI"),
|
||||||
|
(0.455, "The user said pure functions are yucky"),
|
||||||
(
|
(
|
||||||
0.455,
|
0.455,
|
||||||
"The user believes functional programming leads to better code quality",
|
"The user believes functional programming leads to better code quality",
|
||||||
),
|
),
|
||||||
(0.455, "The user said pure functions are yucky"),
|
|
||||||
],
|
],
|
||||||
"temporal": [
|
"temporal": [
|
||||||
(
|
(
|
||||||
|
|||||||
@ -137,10 +137,9 @@ class TestBuildPrompt:
|
|||||||
):
|
):
|
||||||
prompt = _build_prompt()
|
prompt = _build_prompt()
|
||||||
|
|
||||||
assert "lesswrong" in prompt.lower()
|
assert "Remove meta-language" in prompt
|
||||||
assert "comic" in prompt.lower()
|
|
||||||
assert "Remove" in prompt
|
|
||||||
assert "Return ONLY valid JSON" in prompt
|
assert "Return ONLY valid JSON" in prompt
|
||||||
|
assert "recalled_content" in prompt
|
||||||
|
|
||||||
|
|
||||||
class TestAnalyzeQuery:
|
class TestAnalyzeQuery:
|
||||||
|
|||||||
@ -83,9 +83,10 @@ def test_logout_handles_missing_session(mock_get_user_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
|
||||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||||
@patch("memory.api.auth.make_session")
|
@patch("memory.api.auth.make_session")
|
||||||
async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
async def test_oauth_callback_discord_success(mock_make_session, mock_complete, mock_mcp_tools):
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -95,9 +96,12 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
|||||||
mock_make_session.return_value = session_cm()
|
mock_make_session.return_value = session_cm()
|
||||||
|
|
||||||
mcp_server = MagicMock()
|
mcp_server = MagicMock()
|
||||||
|
mcp_server.mcp_server_url = "https://example.com"
|
||||||
|
mcp_server.access_token = "token123"
|
||||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||||
|
|
||||||
mock_complete.return_value = (200, "Authorized")
|
mock_complete.return_value = (200, "Authorized")
|
||||||
|
mock_mcp_tools.return_value = [{"name": "test_tool"}]
|
||||||
|
|
||||||
request = make_request("code=abc123&state=state456")
|
request = make_request("code=abc123&state=state456")
|
||||||
response = await auth.oauth_callback_discord(request)
|
response = await auth.oauth_callback_discord(request)
|
||||||
@ -107,14 +111,15 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
|||||||
assert "Authorization Successful" in body
|
assert "Authorization Successful" in body
|
||||||
assert "Authorized" in body
|
assert "Authorized" in body
|
||||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||||
mock_session.commit.assert_called_once()
|
assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
|
||||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||||
@patch("memory.api.auth.make_session")
|
@patch("memory.api.auth.make_session")
|
||||||
async def test_oauth_callback_discord_handles_failures(
|
async def test_oauth_callback_discord_handles_failures(
|
||||||
mock_make_session, mock_complete
|
mock_make_session, mock_complete, mock_mcp_tools
|
||||||
):
|
):
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
|
|
||||||
@ -125,9 +130,12 @@ async def test_oauth_callback_discord_handles_failures(
|
|||||||
mock_make_session.return_value = session_cm()
|
mock_make_session.return_value = session_cm()
|
||||||
|
|
||||||
mcp_server = MagicMock()
|
mcp_server = MagicMock()
|
||||||
|
mcp_server.mcp_server_url = "https://example.com"
|
||||||
|
mcp_server.access_token = "token123"
|
||||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||||
|
|
||||||
mock_complete.return_value = (500, "Failure")
|
mock_complete.return_value = (500, "Failure")
|
||||||
|
mock_mcp_tools.return_value = []
|
||||||
|
|
||||||
request = make_request("code=abc123&state=state456")
|
request = make_request("code=abc123&state=state456")
|
||||||
response = await auth.oauth_callback_discord(request)
|
response = await auth.oauth_callback_discord(request)
|
||||||
@ -137,7 +145,7 @@ async def test_oauth_callback_discord_handles_failures(
|
|||||||
assert "Authorization Failed" in body
|
assert "Authorization Failed" in body
|
||||||
assert "Failure" in body
|
assert "Failure" in body
|
||||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||||
mock_session.commit.assert_called_once()
|
assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
|||||||
DiscordUser,
|
DiscordUser,
|
||||||
DiscordMessage,
|
DiscordMessage,
|
||||||
BotUser,
|
BotUser,
|
||||||
|
DiscordBotUser,
|
||||||
HumanUser,
|
HumanUser,
|
||||||
ScheduledLLMCall,
|
ScheduledLLMCall,
|
||||||
)
|
)
|
||||||
@ -67,9 +68,10 @@ def sample_discord_user(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_bot_user(db_session):
|
def sample_bot_user(db_session, sample_discord_user):
|
||||||
"""Create a sample bot user for testing."""
|
"""Create a sample bot user for testing."""
|
||||||
bot = BotUser.create_with_api_key(
|
bot = DiscordBotUser.create_with_api_key(
|
||||||
|
discord_users=[sample_discord_user],
|
||||||
name="Test Bot",
|
name="Test Bot",
|
||||||
email="testbot@example.com",
|
email="testbot@example.com",
|
||||||
)
|
)
|
||||||
@ -209,9 +211,9 @@ def test_schedule_message_with_user(
|
|||||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||||
|
|
||||||
result = schedule_message(
|
result = schedule_message(
|
||||||
user_id=sample_human_user.id,
|
bot_id=sample_human_user.id,
|
||||||
user=sample_discord_user.id,
|
recipient_id=sample_discord_user.id,
|
||||||
channel=None,
|
channel_id=None,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
message="Test message",
|
message="Test message",
|
||||||
date_time=future_time,
|
date_time=future_time,
|
||||||
@ -240,9 +242,9 @@ def test_schedule_message_with_channel(
|
|||||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||||
|
|
||||||
result = schedule_message(
|
result = schedule_message(
|
||||||
user_id=sample_human_user.id,
|
bot_id=sample_human_user.id,
|
||||||
user=None,
|
recipient_id=None,
|
||||||
channel=sample_discord_channel.id,
|
channel_id=sample_discord_channel.id,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
message="Test message",
|
message="Test message",
|
||||||
date_time=future_time,
|
date_time=future_time,
|
||||||
@ -265,12 +267,12 @@ def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user):
|
|||||||
"""Test creating a message scheduler tool for a user."""
|
"""Test creating a message scheduler tool for a user."""
|
||||||
tool = make_message_scheduler(
|
tool = make_message_scheduler(
|
||||||
bot=sample_bot_user,
|
bot=sample_bot_user,
|
||||||
user=sample_discord_user.id,
|
user_id=sample_discord_user.id,
|
||||||
channel=None,
|
channel_id=None,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert tool.name == "schedule_message"
|
assert tool.name == "schedule_discord_message"
|
||||||
assert "from your chat with this user" in tool.description
|
assert "from your chat with this user" in tool.description
|
||||||
assert tool.input_schema["type"] == "object"
|
assert tool.input_schema["type"] == "object"
|
||||||
assert "message" in tool.input_schema["properties"]
|
assert "message" in tool.input_schema["properties"]
|
||||||
@ -282,12 +284,12 @@ def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_cha
|
|||||||
"""Test creating a message scheduler tool for a channel."""
|
"""Test creating a message scheduler tool for a channel."""
|
||||||
tool = make_message_scheduler(
|
tool = make_message_scheduler(
|
||||||
bot=sample_bot_user,
|
bot=sample_bot_user,
|
||||||
user=None,
|
user_id=None,
|
||||||
channel=sample_discord_channel.id,
|
channel_id=sample_discord_channel.id,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert tool.name == "schedule_message"
|
assert tool.name == "schedule_discord_message"
|
||||||
assert "in this channel" in tool.description
|
assert "in this channel" in tool.description
|
||||||
assert callable(tool.function)
|
assert callable(tool.function)
|
||||||
|
|
||||||
@ -297,8 +299,8 @@ def test_make_message_scheduler_without_user_or_channel(sample_bot_user):
|
|||||||
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
||||||
make_message_scheduler(
|
make_message_scheduler(
|
||||||
bot=sample_bot_user,
|
bot=sample_bot_user,
|
||||||
user=None,
|
user_id=None,
|
||||||
channel=None,
|
channel_id=None,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,8 +312,8 @@ def test_message_scheduler_handler_success(
|
|||||||
"""Test message scheduler handler with valid input."""
|
"""Test message scheduler handler with valid input."""
|
||||||
tool = make_message_scheduler(
|
tool = make_message_scheduler(
|
||||||
bot=sample_bot_user,
|
bot=sample_bot_user,
|
||||||
user=sample_discord_user.id,
|
user_id=sample_discord_user.id,
|
||||||
channel=None,
|
channel_id=None,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -330,8 +332,8 @@ def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord
|
|||||||
"""Test message scheduler handler with non-dict input."""
|
"""Test message scheduler handler with non-dict input."""
|
||||||
tool = make_message_scheduler(
|
tool = make_message_scheduler(
|
||||||
bot=sample_bot_user,
|
bot=sample_bot_user,
|
||||||
user=sample_discord_user.id,
|
user_id=sample_discord_user.id,
|
||||||
channel=None,
|
channel_id=None,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -345,8 +347,8 @@ def test_message_scheduler_handler_invalid_datetime(
|
|||||||
"""Test message scheduler handler with invalid datetime."""
|
"""Test message scheduler handler with invalid datetime."""
|
||||||
tool = make_message_scheduler(
|
tool = make_message_scheduler(
|
||||||
bot=sample_bot_user,
|
bot=sample_bot_user,
|
||||||
user=sample_discord_user.id,
|
user_id=sample_discord_user.id,
|
||||||
channel=None,
|
channel_id=None,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -365,8 +367,8 @@ def test_message_scheduler_handler_missing_datetime(
|
|||||||
"""Test message scheduler handler with missing datetime."""
|
"""Test message scheduler handler with missing datetime."""
|
||||||
tool = make_message_scheduler(
|
tool = make_message_scheduler(
|
||||||
bot=sample_bot_user,
|
bot=sample_bot_user,
|
||||||
user=sample_discord_user.id,
|
user_id=sample_discord_user.id,
|
||||||
channel=None,
|
channel_id=None,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -375,9 +377,9 @@ def test_message_scheduler_handler_missing_datetime(
|
|||||||
|
|
||||||
|
|
||||||
# Tests for make_prev_messages_tool
|
# Tests for make_prev_messages_tool
|
||||||
def test_make_prev_messages_tool_with_user(sample_discord_user):
|
def test_make_prev_messages_tool_with_user(sample_bot_user, sample_discord_user):
|
||||||
"""Test creating a previous messages tool for a user."""
|
"""Test creating a previous messages tool for a user."""
|
||||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||||
|
|
||||||
assert tool.name == "previous_messages"
|
assert tool.name == "previous_messages"
|
||||||
assert "from your chat with this user" in tool.description
|
assert "from your chat with this user" in tool.description
|
||||||
@ -387,26 +389,26 @@ def test_make_prev_messages_tool_with_user(sample_discord_user):
|
|||||||
assert callable(tool.function)
|
assert callable(tool.function)
|
||||||
|
|
||||||
|
|
||||||
def test_make_prev_messages_tool_with_channel(sample_discord_channel):
|
def test_make_prev_messages_tool_with_channel(sample_bot_user, sample_discord_channel):
|
||||||
"""Test creating a previous messages tool for a channel."""
|
"""Test creating a previous messages tool for a channel."""
|
||||||
tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id)
|
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=sample_discord_channel.id)
|
||||||
|
|
||||||
assert tool.name == "previous_messages"
|
assert tool.name == "previous_messages"
|
||||||
assert "in this channel" in tool.description
|
assert "in this channel" in tool.description
|
||||||
assert callable(tool.function)
|
assert callable(tool.function)
|
||||||
|
|
||||||
|
|
||||||
def test_make_prev_messages_tool_without_user_or_channel():
|
def test_make_prev_messages_tool_without_user_or_channel(sample_bot_user):
|
||||||
"""Test that creating a tool without user or channel raises error."""
|
"""Test that creating a tool without user or channel raises error."""
|
||||||
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
||||||
make_prev_messages_tool(user=None, channel=None)
|
make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=None)
|
||||||
|
|
||||||
|
|
||||||
def test_prev_messages_handler_success(
|
def test_prev_messages_handler_success(
|
||||||
db_session, sample_discord_user, sample_discord_channel
|
db_session, sample_bot_user, sample_discord_user, sample_discord_channel
|
||||||
):
|
):
|
||||||
"""Test previous messages handler with valid input."""
|
"""Test previous messages handler with valid input."""
|
||||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||||
|
|
||||||
# Create some actual messages in the database
|
# Create some actual messages in the database
|
||||||
msg1 = DiscordMessage(
|
msg1 = DiscordMessage(
|
||||||
@ -440,9 +442,9 @@ def test_prev_messages_handler_success(
|
|||||||
assert "Message 1" in result or "Message 2" in result
|
assert "Message 1" in result or "Message 2" in result
|
||||||
|
|
||||||
|
|
||||||
def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
|
def test_prev_messages_handler_with_defaults(db_session, sample_bot_user, sample_discord_user):
|
||||||
"""Test previous messages handler with default values."""
|
"""Test previous messages handler with default values."""
|
||||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||||
|
|
||||||
result = tool.function({})
|
result = tool.function({})
|
||||||
|
|
||||||
@ -450,35 +452,35 @@ def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
|
|||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
def test_prev_messages_handler_invalid_input(sample_discord_user):
|
def test_prev_messages_handler_invalid_input(sample_bot_user, sample_discord_user):
|
||||||
"""Test previous messages handler with non-dict input."""
|
"""Test previous messages handler with non-dict input."""
|
||||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Input must be a dictionary"):
|
with pytest.raises(ValueError, match="Input must be a dictionary"):
|
||||||
tool.function("not a dict")
|
tool.function("not a dict")
|
||||||
|
|
||||||
|
|
||||||
def test_prev_messages_handler_invalid_max_messages(sample_discord_user):
|
def test_prev_messages_handler_invalid_max_messages(sample_bot_user, sample_discord_user):
|
||||||
"""Test previous messages handler with invalid max_messages (negative value)."""
|
"""Test previous messages handler with invalid max_messages (negative value)."""
|
||||||
# Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
|
# Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
|
||||||
# so we test with -1 which actually triggers the validation
|
# so we test with -1 which actually triggers the validation
|
||||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Max messages must be greater than 0"):
|
with pytest.raises(ValueError, match="Max messages must be greater than 0"):
|
||||||
tool.function({"max_messages": -1})
|
tool.function({"max_messages": -1})
|
||||||
|
|
||||||
|
|
||||||
def test_prev_messages_handler_invalid_offset(sample_discord_user):
|
def test_prev_messages_handler_invalid_offset(sample_bot_user, sample_discord_user):
|
||||||
"""Test previous messages handler with invalid offset."""
|
"""Test previous messages handler with invalid offset."""
|
||||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
|
with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
|
||||||
tool.function({"offset": -1})
|
tool.function({"offset": -1})
|
||||||
|
|
||||||
|
|
||||||
def test_prev_messages_handler_non_integer_values(sample_discord_user):
|
def test_prev_messages_handler_non_integer_values(sample_bot_user, sample_discord_user):
|
||||||
"""Test previous messages handler with non-integer values."""
|
"""Test previous messages handler with non-integer values."""
|
||||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Max messages and offset must be integers"):
|
with pytest.raises(ValueError, match="Max messages and offset must be integers"):
|
||||||
tool.function({"max_messages": "not an int"})
|
tool.function({"max_messages": "not an int"})
|
||||||
@ -496,10 +498,10 @@ def test_make_discord_tools_with_user_and_channel(
|
|||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
# Should have: schedule_discord_message, previous_messages, update_channel_summary,
|
||||||
# update_user_summary, update_server_summary, add_reaction
|
# update_user_summary, update_server_summary, add_reaction
|
||||||
assert len(tools) == 6
|
assert len(tools) == 6
|
||||||
assert "schedule_message" in tools
|
assert "schedule_discord_message" in tools
|
||||||
assert "previous_messages" in tools
|
assert "previous_messages" in tools
|
||||||
assert "update_channel_summary" in tools
|
assert "update_channel_summary" in tools
|
||||||
assert "update_user_summary" in tools
|
assert "update_user_summary" in tools
|
||||||
@ -516,10 +518,10 @@ def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user)
|
|||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should have: schedule_message, previous_messages, update_user_summary
|
# Should have: schedule_discord_message, previous_messages, update_user_summary
|
||||||
# Note: Without channel, there's no channel summary tool
|
# Note: Without channel, there's no channel summary tool
|
||||||
assert len(tools) >= 2 # At least schedule and previous messages
|
assert len(tools) >= 2 # At least schedule and previous messages
|
||||||
assert "schedule_message" in tools
|
assert "schedule_discord_message" in tools
|
||||||
assert "previous_messages" in tools
|
assert "previous_messages" in tools
|
||||||
assert "update_user_summary" in tools
|
assert "update_user_summary" in tools
|
||||||
|
|
||||||
@ -533,10 +535,10 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
|
|||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
# Should have: schedule_discord_message, previous_messages, update_channel_summary,
|
||||||
# update_server_summary, add_reaction (no user summary without author)
|
# update_server_summary, add_reaction (no user summary without author)
|
||||||
assert len(tools) == 5
|
assert len(tools) == 5
|
||||||
assert "schedule_message" in tools
|
assert "schedule_discord_message" in tools
|
||||||
assert "previous_messages" in tools
|
assert "previous_messages" in tools
|
||||||
assert "update_channel_summary" in tools
|
assert "update_channel_summary" in tools
|
||||||
assert "update_server_summary" in tools
|
assert "update_server_summary" in tools
|
||||||
|
|||||||
@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
|||||||
"http://localhost:8000/send_channel",
|
"http://localhost:8000/send_channel",
|
||||||
json={
|
json={
|
||||||
"bot_id": BOT_ID,
|
"bot_id": BOT_ID,
|
||||||
"channel_name": "general",
|
"channel": "general",
|
||||||
"message": "Announcement!",
|
"message": "Announcement!",
|
||||||
},
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
|
|||||||
@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
|||||||
"http://localhost:8000/send_channel",
|
"http://localhost:8000/send_channel",
|
||||||
json={
|
json={
|
||||||
"bot_id": BOT_ID,
|
"bot_id": BOT_ID,
|
||||||
"channel_name": "general",
|
"channel": "general",
|
||||||
"message": "Announcement!",
|
"message": "Announcement!",
|
||||||
},
|
},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
"""Tests for Discord MCP server management."""
|
"""Tests for Discord MCP server management."""
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -8,8 +7,8 @@ import discord
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from memory.common.db.models import MCPServer, MCPServerAssignment
|
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||||
|
from memory.common.mcp import mcp_call
|
||||||
from memory.discord.mcp import (
|
from memory.discord.mcp import (
|
||||||
call_mcp_server,
|
|
||||||
find_mcp_server,
|
find_mcp_server,
|
||||||
handle_mcp_add,
|
handle_mcp_add,
|
||||||
handle_mcp_connect,
|
handle_mcp_connect,
|
||||||
@ -142,7 +141,7 @@ async def test_call_mcp_server_success():
|
|||||||
|
|
||||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
results = []
|
results = []
|
||||||
async for data in call_mcp_server(
|
async for data in mcp_call(
|
||||||
"https://mcp.example.com", "test_token", "tools/list", {}
|
"https://mcp.example.com", "test_token", "tools/list", {}
|
||||||
):
|
):
|
||||||
results.append(data)
|
results.append(data)
|
||||||
@ -172,7 +171,7 @@ async def test_call_mcp_server_error():
|
|||||||
|
|
||||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
with pytest.raises(ValueError, match="Failed to call MCP server"):
|
with pytest.raises(ValueError, match="Failed to call MCP server"):
|
||||||
async for _ in call_mcp_server(
|
async for _ in mcp_call(
|
||||||
"https://mcp.example.com", "test_token", "tools/list"
|
"https://mcp.example.com", "test_token", "tools/list"
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@ -203,7 +202,7 @@ async def test_call_mcp_server_invalid_json():
|
|||||||
|
|
||||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
results = []
|
results = []
|
||||||
async for data in call_mcp_server(
|
async for data in mcp_call(
|
||||||
"https://mcp.example.com", "test_token", "tools/list"
|
"https://mcp.example.com", "test_token", "tools/list"
|
||||||
):
|
):
|
||||||
results.append(data)
|
results.append(data)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from memory.common.db.models import (
|
|||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
DiscordServer,
|
DiscordServer,
|
||||||
DiscordMessage,
|
DiscordMessage,
|
||||||
|
DiscordBotUser,
|
||||||
HumanUser,
|
HumanUser,
|
||||||
ScheduledLLMCall,
|
ScheduledLLMCall,
|
||||||
)
|
)
|
||||||
@ -34,6 +35,19 @@ def sample_discord_user(db_session):
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_bot_user(db_session, sample_discord_user):
|
||||||
|
"""Create a sample Discord bot user."""
|
||||||
|
bot = DiscordBotUser.create_with_api_key(
|
||||||
|
discord_users=[sample_discord_user],
|
||||||
|
name="Test Bot",
|
||||||
|
email="testbot@example.com",
|
||||||
|
)
|
||||||
|
db_session.add(bot)
|
||||||
|
db_session.commit()
|
||||||
|
return bot
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_discord_channel(db_session):
|
def sample_discord_channel(db_session):
|
||||||
"""Create a sample Discord channel."""
|
"""Create a sample Discord channel."""
|
||||||
@ -290,13 +304,13 @@ def test_upsert_scheduled_message_cancels_earlier_call(
|
|||||||
# Test previous_messages
|
# Test previous_messages
|
||||||
|
|
||||||
|
|
||||||
def test_previous_messages_empty(db_session):
|
def test_previous_messages_empty(db_session, sample_bot_user):
|
||||||
"""Test getting previous messages when none exist."""
|
"""Test getting previous messages when none exist."""
|
||||||
result = previous_messages(db_session, user_id=123, channel_id=456)
|
result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=123, channel_id=456)
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel):
|
def test_previous_messages_filters_by_user(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
|
||||||
"""Test filtering messages by recipient user."""
|
"""Test filtering messages by recipient user."""
|
||||||
# Create some messages
|
# Create some messages
|
||||||
msg1 = DiscordMessage(
|
msg1 = DiscordMessage(
|
||||||
@ -322,14 +336,14 @@ def test_previous_messages_filters_by_user(db_session, sample_discord_user, samp
|
|||||||
db_session.add_all([msg1, msg2])
|
db_session.add_all([msg1, msg2])
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None)
|
result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None)
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
# Should be in chronological order (oldest first)
|
# Should be in chronological order (oldest first)
|
||||||
assert result[0].message_id == 1
|
assert result[0].message_id == 1
|
||||||
assert result[1].message_id == 2
|
assert result[1].message_id == 2
|
||||||
|
|
||||||
|
|
||||||
def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel):
|
def test_previous_messages_limits_results(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
|
||||||
"""Test limiting the number of previous messages."""
|
"""Test limiting the number of previous messages."""
|
||||||
# Create 15 messages
|
# Create 15 messages
|
||||||
for i in range(15):
|
for i in range(15):
|
||||||
@ -347,7 +361,7 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
result = previous_messages(
|
result = previous_messages(
|
||||||
db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5
|
db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None, max_messages=5
|
||||||
)
|
)
|
||||||
assert len(result) == 5
|
assert len(result) == 5
|
||||||
|
|
||||||
@ -355,10 +369,10 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
|
|||||||
# Test comm_channel_prompt
|
# Test comm_channel_prompt
|
||||||
|
|
||||||
|
|
||||||
def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel):
|
def test_comm_channel_prompt_basic(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
|
||||||
"""Test generating a basic communication channel prompt."""
|
"""Test generating a basic communication channel prompt."""
|
||||||
result = comm_channel_prompt(
|
result = comm_channel_prompt(
|
||||||
db_session, user=sample_discord_user, channel=sample_discord_channel
|
db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "You are a bot communicating on Discord" in result
|
assert "You are a bot communicating on Discord" in result
|
||||||
@ -366,31 +380,31 @@ def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_disco
|
|||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel):
|
def test_comm_channel_prompt_includes_server_context(db_session, sample_bot_user, sample_discord_channel):
|
||||||
"""Test that prompt includes server context when available."""
|
"""Test that prompt includes server context when available."""
|
||||||
server = sample_discord_channel.server
|
server = sample_discord_channel.server
|
||||||
server.summary = "Gaming community server"
|
server.summary = "Gaming community server"
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
|
result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
|
||||||
|
|
||||||
assert "server_context" in result.lower()
|
assert "server_context" in result.lower()
|
||||||
assert "Gaming community server" in result
|
assert "Gaming community server" in result
|
||||||
|
|
||||||
|
|
||||||
def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel):
|
def test_comm_channel_prompt_includes_channel_context(db_session, sample_bot_user, sample_discord_channel):
|
||||||
"""Test that prompt includes channel context."""
|
"""Test that prompt includes channel context."""
|
||||||
sample_discord_channel.summary = "General discussion channel"
|
sample_discord_channel.summary = "General discussion channel"
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
|
result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
|
||||||
|
|
||||||
assert "channel_context" in result.lower()
|
assert "channel_context" in result.lower()
|
||||||
assert "General discussion channel" in result
|
assert "General discussion channel" in result
|
||||||
|
|
||||||
|
|
||||||
def test_comm_channel_prompt_includes_user_notes(
|
def test_comm_channel_prompt_includes_user_notes(
|
||||||
db_session, sample_discord_user, sample_discord_channel
|
db_session, sample_bot_user, sample_discord_user, sample_discord_channel
|
||||||
):
|
):
|
||||||
"""Test that prompt includes user notes from previous messages."""
|
"""Test that prompt includes user notes from previous messages."""
|
||||||
sample_discord_user.summary = "Helpful community member"
|
sample_discord_user.summary = "Helpful community member"
|
||||||
@ -411,7 +425,7 @@ def test_comm_channel_prompt_includes_user_notes(
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
result = comm_channel_prompt(
|
result = comm_channel_prompt(
|
||||||
db_session, user=sample_discord_user, channel=sample_discord_channel
|
db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "user_notes" in result.lower()
|
assert "user_notes" in result.lower()
|
||||||
@ -442,12 +456,16 @@ def test_call_llm_includes_web_search_and_mcp_servers(
|
|||||||
web_tool_instance = MagicMock(name="web_tool")
|
web_tool_instance = MagicMock(name="web_tool")
|
||||||
mock_web_search.return_value = web_tool_instance
|
mock_web_search.return_value = web_tool_instance
|
||||||
|
|
||||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
|
bot_user = SimpleNamespace(
|
||||||
|
system_user=SimpleNamespace(discord_id=999888777),
|
||||||
|
system_prompt="bot prompt"
|
||||||
|
)
|
||||||
from_user = SimpleNamespace(id=123)
|
from_user = SimpleNamespace(id=123)
|
||||||
mcp_model = SimpleNamespace(
|
mcp_model = SimpleNamespace(
|
||||||
name="Server",
|
name="Server",
|
||||||
mcp_server_url="https://mcp.example.com",
|
mcp_server_url="https://mcp.example.com",
|
||||||
access_token="token123",
|
access_token="token123",
|
||||||
|
disabled_tools=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
result = call_llm(
|
result = call_llm(
|
||||||
@ -502,7 +520,10 @@ def test_call_llm_filters_disallowed_tools(
|
|||||||
|
|
||||||
mock_web_search.return_value = MagicMock(name="web_tool")
|
mock_web_search.return_value = MagicMock(name="web_tool")
|
||||||
|
|
||||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
|
bot_user = SimpleNamespace(
|
||||||
|
system_user=SimpleNamespace(discord_id=999888777),
|
||||||
|
system_prompt=None
|
||||||
|
)
|
||||||
from_user = SimpleNamespace(id=1)
|
from_user = SimpleNamespace(id=1)
|
||||||
|
|
||||||
call_llm(
|
call_llm(
|
||||||
|
|||||||
@ -14,12 +14,25 @@ from memory.workers.tasks import discord
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def discord_bot_user(db_session):
|
def discord_bot_user(db_session):
|
||||||
|
# Create a discord user for the bot first
|
||||||
|
bot_discord_user = DiscordUser(
|
||||||
|
id=999999999,
|
||||||
|
username="testbot",
|
||||||
|
)
|
||||||
|
db_session.add(bot_discord_user)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
bot = DiscordBotUser.create_with_api_key(
|
bot = DiscordBotUser.create_with_api_key(
|
||||||
discord_users=[],
|
discord_users=[bot_discord_user],
|
||||||
name="Test Bot",
|
name="Test Bot",
|
||||||
email="bot@example.com",
|
email="bot@example.com",
|
||||||
)
|
)
|
||||||
db_session.add(bot)
|
db_session.add(bot)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Link the discord user to the system user
|
||||||
|
bot_discord_user.system_user_id = bot.id
|
||||||
|
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
return bot
|
return bot
|
||||||
|
|
||||||
@ -176,26 +189,29 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel):
|
|||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
@patch("memory.workers.tasks.discord.create_provider")
|
@patch("memory.workers.tasks.discord.call_llm")
|
||||||
|
@patch("memory.workers.tasks.discord.discord.trigger_typing_channel")
|
||||||
def test_should_process_normal_message(
|
def test_should_process_normal_message(
|
||||||
mock_create_provider,
|
mock_trigger_typing,
|
||||||
|
mock_call_llm,
|
||||||
db_session,
|
db_session,
|
||||||
mock_discord_user,
|
mock_discord_user,
|
||||||
mock_discord_server,
|
mock_discord_server,
|
||||||
mock_discord_channel,
|
mock_discord_channel,
|
||||||
|
discord_bot_user,
|
||||||
):
|
):
|
||||||
"""Test should_process returns True for normal messages."""
|
"""Test should_process returns True for normal messages."""
|
||||||
# Mock the LLM provider to return "yes"
|
# Create a separate recipient user (the bot)
|
||||||
mock_provider = Mock()
|
bot_discord_user = discord_bot_user.discord_users[0]
|
||||||
mock_provider.generate.return_value = "<response>yes</response>"
|
|
||||||
mock_provider.as_messages.return_value = []
|
# Mock call_llm to return a high number (100 = always process)
|
||||||
mock_create_provider.return_value = mock_provider
|
mock_call_llm.return_value = "<response><number>100</number><reason>Test</reason></response>"
|
||||||
|
|
||||||
message = DiscordMessage(
|
message = DiscordMessage(
|
||||||
message_id=1,
|
message_id=1,
|
||||||
channel_id=mock_discord_channel.id,
|
channel_id=mock_discord_channel.id,
|
||||||
from_id=mock_discord_user.id,
|
from_id=mock_discord_user.id,
|
||||||
recipient_id=mock_discord_user.id,
|
recipient_id=bot_discord_user.id, # Bot is recipient, not the from_user
|
||||||
server_id=mock_discord_server.id,
|
server_id=mock_discord_server.id,
|
||||||
content="Test",
|
content="Test",
|
||||||
sent_at=datetime.now(timezone.utc),
|
sent_at=datetime.now(timezone.utc),
|
||||||
@ -207,6 +223,8 @@ def test_should_process_normal_message(
|
|||||||
db_session.refresh(message)
|
db_session.refresh(message)
|
||||||
|
|
||||||
assert discord.should_process(message) is True
|
assert discord.should_process(message) is True
|
||||||
|
mock_call_llm.assert_called_once()
|
||||||
|
mock_trigger_typing.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
||||||
@ -344,6 +362,7 @@ def test_add_discord_message_success(db_session, sample_message_data, qdrant):
|
|||||||
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
|
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
|
||||||
"""Test adding a Discord message that is a reply."""
|
"""Test adding a Discord message that is a reply."""
|
||||||
sample_message_data["message_reference_id"] = 111222333
|
sample_message_data["message_reference_id"] = 111222333
|
||||||
|
sample_message_data["message_type"] = "reply" # Explicitly set message_type
|
||||||
|
|
||||||
discord.add_discord_message(**sample_message_data)
|
discord.add_discord_message(**sample_message_data)
|
||||||
|
|
||||||
@ -523,8 +542,17 @@ def test_edit_discord_message_updates_context(
|
|||||||
assert result["status"] == "processed"
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
|
||||||
def test_process_discord_message_success(db_session, sample_message_data, qdrant):
|
@patch("memory.workers.tasks.discord.send_discord_response")
|
||||||
|
@patch("memory.workers.tasks.discord.call_llm")
|
||||||
|
def test_process_discord_message_success(
|
||||||
|
mock_call_llm, mock_send_response, db_session, sample_message_data, qdrant
|
||||||
|
):
|
||||||
"""Test processing a Discord message."""
|
"""Test processing a Discord message."""
|
||||||
|
# Mock LLM to return a response
|
||||||
|
mock_call_llm.return_value = "Test response from bot"
|
||||||
|
# Mock Discord API to succeed
|
||||||
|
mock_send_response.return_value = True
|
||||||
|
|
||||||
# Add a message first
|
# Add a message first
|
||||||
add_result = discord.add_discord_message(**sample_message_data)
|
add_result = discord.add_discord_message(**sample_message_data)
|
||||||
message_id = add_result["discordmessage_id"]
|
message_id = add_result["discordmessage_id"]
|
||||||
@ -534,6 +562,8 @@ def test_process_discord_message_success(db_session, sample_message_data, qdrant
|
|||||||
|
|
||||||
assert result["status"] == "processed"
|
assert result["status"] == "processed"
|
||||||
assert result["message_id"] == message_id
|
assert result["message_id"] == message_id
|
||||||
|
mock_call_llm.assert_called_once()
|
||||||
|
mock_send_response.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_process_discord_message_not_found(db_session):
|
def test_process_discord_message_not_found(db_session):
|
||||||
|
|||||||
@ -16,8 +16,16 @@ from memory.workers.tasks import scheduled_calls
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_user(db_session):
|
def sample_user(db_session):
|
||||||
"""Create a sample user for testing."""
|
"""Create a sample user for testing."""
|
||||||
|
# Create a discord user for the bot
|
||||||
|
bot_discord_user = DiscordUser(
|
||||||
|
id=999999999,
|
||||||
|
username="testbot",
|
||||||
|
)
|
||||||
|
db_session.add(bot_discord_user)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
user = DiscordBotUser.create_with_api_key(
|
user = DiscordBotUser.create_with_api_key(
|
||||||
discord_users=[],
|
discord_users=[bot_discord_user],
|
||||||
name="testbot",
|
name="testbot",
|
||||||
email="bot@example.com",
|
email="bot@example.com",
|
||||||
)
|
)
|
||||||
@ -122,65 +130,64 @@ def future_scheduled_call(db_session, sample_user, sample_discord_user):
|
|||||||
return call
|
return call
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
@patch("memory.discord.messages.discord.send_dm")
|
||||||
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
||||||
"""Test sending to Discord user."""
|
"""Test sending to Discord user."""
|
||||||
response = "This is a test response."
|
response = "This is a test response."
|
||||||
|
|
||||||
scheduled_calls.send_to_discord(pending_scheduled_call, response)
|
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, response)
|
||||||
|
|
||||||
mock_send_dm.assert_called_once_with(
|
mock_send_dm.assert_called_once_with(
|
||||||
pending_scheduled_call.user_id,
|
999999999, # bot_id
|
||||||
"testuser", # username, not ID
|
"testuser", # username, not ID
|
||||||
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
@patch("memory.discord.messages.discord.send_to_channel")
|
||||||
def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
def test_send_to_discord_channel(mock_send_to_channel, completed_scheduled_call):
|
||||||
"""Test sending to Discord channel."""
|
"""Test sending to Discord channel."""
|
||||||
response = "This is a channel response."
|
response = "This is a channel response."
|
||||||
|
|
||||||
scheduled_calls.send_to_discord(completed_scheduled_call, response)
|
scheduled_calls.send_to_discord(999999999, completed_scheduled_call, response)
|
||||||
|
|
||||||
mock_broadcast.assert_called_once_with(
|
mock_send_to_channel.assert_called_once_with(
|
||||||
completed_scheduled_call.user_id,
|
999999999, # bot_id
|
||||||
"test-channel", # channel name, not ID
|
completed_scheduled_call.discord_channel.id, # channel ID, not name
|
||||||
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
@patch("memory.discord.messages.discord.send_dm")
|
||||||
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
|
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
|
||||||
"""Test message truncation for long responses."""
|
"""Test message truncation for long responses."""
|
||||||
long_response = "A" * 2500 # Very long response
|
long_response = "A" * 2500 # Very long response
|
||||||
|
|
||||||
scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
|
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, long_response)
|
||||||
|
|
||||||
# Verify the message was truncated
|
# With the new implementation, send_discord_response sends the full response
|
||||||
|
# No truncation happens in _send_to_discord
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
assert args[0] == pending_scheduled_call.user_id
|
assert args[0] == 999999999 # bot_id
|
||||||
message = args[2]
|
message = args[2]
|
||||||
assert len(message) <= 1950 # Should be truncated
|
assert message == long_response
|
||||||
assert message.endswith("... (response truncated)")
|
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
@patch("memory.discord.messages.discord.send_dm")
|
||||||
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
|
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
|
||||||
"""Test that normal length messages are not truncated."""
|
"""Test that normal length messages are not truncated."""
|
||||||
normal_response = "This is a normal length response."
|
normal_response = "This is a normal length response."
|
||||||
|
|
||||||
scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
|
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, normal_response)
|
||||||
|
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
assert args[0] == pending_scheduled_call.user_id
|
assert args[0] == 999999999 # bot_id
|
||||||
message = args[2]
|
message = args[2]
|
||||||
assert not message.endswith("... (response truncated)")
|
assert message == normal_response
|
||||||
assert "This is a normal length response." in message
|
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||||
def test_execute_scheduled_call_success(
|
def test_execute_scheduled_call_success(
|
||||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
):
|
):
|
||||||
@ -189,12 +196,8 @@ def test_execute_scheduled_call_success(
|
|||||||
|
|
||||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||||
|
|
||||||
# Verify LLM was called with correct parameters
|
# Verify LLM was called
|
||||||
mock_llm_call.assert_called_once_with(
|
mock_llm_call.assert_called_once()
|
||||||
prompt="What is the weather like today?",
|
|
||||||
model="anthropic/claude-3-5-sonnet-20241022",
|
|
||||||
system_prompt="You are a helpful assistant.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
assert result["success"] is True
|
assert result["success"] is True
|
||||||
@ -218,7 +221,7 @@ def test_execute_scheduled_call_not_found(db_session):
|
|||||||
assert result == {"error": "Scheduled call not found"}
|
assert result == {"error": "Scheduled call not found"}
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||||
def test_execute_scheduled_call_not_pending(
|
def test_execute_scheduled_call_not_pending(
|
||||||
mock_llm_call, completed_scheduled_call, db_session
|
mock_llm_call, completed_scheduled_call, db_session
|
||||||
):
|
):
|
||||||
@ -229,8 +232,8 @@ def test_execute_scheduled_call_not_pending(
|
|||||||
mock_llm_call.assert_not_called()
|
mock_llm_call.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||||
def test_execute_scheduled_call_with_default_system_prompt(
|
def test_execute_scheduled_call_with_default_system_prompt(
|
||||||
mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
|
mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
|
||||||
):
|
):
|
||||||
@ -254,16 +257,12 @@ def test_execute_scheduled_call_with_default_system_prompt(
|
|||||||
|
|
||||||
scheduled_calls.execute_scheduled_call(call.id)
|
scheduled_calls.execute_scheduled_call(call.id)
|
||||||
|
|
||||||
# Verify default system prompt was used
|
# Verify LLM was called
|
||||||
mock_llm_call.assert_called_once_with(
|
mock_llm_call.assert_called_once()
|
||||||
prompt="Test prompt",
|
|
||||||
model="anthropic/claude-3-5-sonnet-20241022",
|
|
||||||
system_prompt=None, # The code uses system_prompt as-is, not a default
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||||
def test_execute_scheduled_call_discord_error(
|
def test_execute_scheduled_call_discord_error(
|
||||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
):
|
):
|
||||||
@ -286,26 +285,27 @@ def test_execute_scheduled_call_discord_error(
|
|||||||
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
|
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||||
def test_execute_scheduled_call_llm_error(
|
def test_execute_scheduled_call_llm_error(
|
||||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
):
|
):
|
||||||
"""Test execution when LLM call fails."""
|
"""Test execution when LLM call fails."""
|
||||||
mock_llm_call.side_effect = Exception("LLM API error")
|
mock_llm_call.side_effect = Exception("LLM API error")
|
||||||
|
|
||||||
# The safe_task_execution decorator should catch this
|
# The execute_scheduled_call function catches the exception and returns an error response
|
||||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||||
|
|
||||||
assert result["status"] == "error"
|
assert result["success"] is False
|
||||||
assert "LLM API error" in result["error"]
|
assert "error" in result
|
||||||
|
assert "LLM call failed" in result["error"]
|
||||||
|
|
||||||
# Discord should not be called
|
# Discord should not be called
|
||||||
mock_send_discord.assert_not_called()
|
mock_send_discord.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||||
def test_execute_scheduled_call_long_response_truncation(
|
def test_execute_scheduled_call_long_response_truncation(
|
||||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
):
|
):
|
||||||
@ -477,8 +477,8 @@ def test_run_scheduled_calls_timezone_handling(
|
|||||||
mock_execute_delay.delay.assert_called_once_with(due_call.id)
|
mock_execute_delay.delay.assert_called_once_with(due_call.id)
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||||
def test_status_transition_pending_to_executing_to_completed(
|
def test_status_transition_pending_to_executing_to_completed(
|
||||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
):
|
):
|
||||||
@ -502,14 +502,14 @@ def test_status_transition_pending_to_executing_to_completed(
|
|||||||
"has_discord_user,has_discord_channel,expected_method",
|
"has_discord_user,has_discord_channel,expected_method",
|
||||||
[
|
[
|
||||||
(True, False, "send_dm"),
|
(True, False, "send_dm"),
|
||||||
(False, True, "broadcast_message"),
|
(False, True, "send_to_channel"),
|
||||||
(True, True, "send_dm"), # User takes precedence
|
(True, True, "send_to_channel"), # Channel takes precedence in the implementation
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
@patch("memory.discord.messages.discord.send_dm")
|
||||||
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
@patch("memory.discord.messages.discord.send_to_channel")
|
||||||
def test_discord_destination_priority(
|
def test_discord_destination_priority(
|
||||||
mock_broadcast,
|
mock_send_to_channel,
|
||||||
mock_send_dm,
|
mock_send_dm,
|
||||||
has_discord_user,
|
has_discord_user,
|
||||||
has_discord_channel,
|
has_discord_channel,
|
||||||
@ -535,50 +535,39 @@ def test_discord_destination_priority(
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
response = "Test response"
|
response = "Test response"
|
||||||
scheduled_calls.send_to_discord(call, response)
|
scheduled_calls.send_to_discord(999999999, call, response)
|
||||||
|
|
||||||
if expected_method == "send_dm":
|
if expected_method == "send_dm":
|
||||||
mock_send_dm.assert_called_once()
|
mock_send_dm.assert_called_once()
|
||||||
mock_broadcast.assert_not_called()
|
mock_send_to_channel.assert_not_called()
|
||||||
else:
|
else:
|
||||||
mock_broadcast.assert_called_once()
|
mock_send_to_channel.assert_called_once()
|
||||||
mock_send_dm.assert_not_called()
|
mock_send_dm.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"topic,model,response,expected_in_message",
|
"topic,model,response",
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"Weather Check",
|
"Weather Check",
|
||||||
"anthropic/claude-3-5-sonnet-20241022",
|
"anthropic/claude-3-5-sonnet-20241022",
|
||||||
"It's sunny!",
|
"It's sunny!",
|
||||||
[
|
|
||||||
"**Topic:** Weather Check",
|
|
||||||
"**Model:** anthropic/claude-3-5-sonnet-20241022",
|
|
||||||
"**Response:** It's sunny!",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"Test Topic",
|
"Test Topic",
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
"Hello world",
|
"Hello world",
|
||||||
["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"Long Topic Name Here",
|
"Long Topic Name Here",
|
||||||
"claude-2",
|
"claude-2",
|
||||||
"Short",
|
"Short",
|
||||||
[
|
|
||||||
"**Topic:** Long Topic Name Here",
|
|
||||||
"**Model:** claude-2",
|
|
||||||
"**Response:** Short",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
@patch("memory.discord.messages.discord.send_dm")
|
||||||
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
|
def test_message_formatting(mock_send_dm, topic, model, response):
|
||||||
"""Test the Discord message formatting with different inputs."""
|
"""Test that _send_to_discord sends the response as-is."""
|
||||||
# Create a mock scheduled call with a mock Discord user
|
# Create a mock scheduled call with a mock Discord user
|
||||||
mock_discord_user = Mock()
|
mock_discord_user = Mock()
|
||||||
mock_discord_user.username = "testuser"
|
mock_discord_user.username = "testuser"
|
||||||
@ -590,16 +579,15 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
|||||||
mock_call.discord_user = mock_discord_user
|
mock_call.discord_user = mock_discord_user
|
||||||
mock_call.discord_channel = None
|
mock_call.discord_channel = None
|
||||||
|
|
||||||
scheduled_calls.send_to_discord(mock_call, response)
|
scheduled_calls.send_to_discord(999999999, mock_call, response)
|
||||||
|
|
||||||
# Get the actual message that was sent
|
# Get the actual message that was sent
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
assert args[0] == mock_call.user_id
|
assert args[0] == 999999999 # bot_id
|
||||||
actual_message = args[2]
|
actual_message = args[2]
|
||||||
|
|
||||||
# Verify all expected parts are in the message
|
# The new implementation sends the response as-is, without formatting
|
||||||
for expected_part in expected_in_message:
|
assert actual_message == response
|
||||||
assert expected_part in actual_message
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -612,7 +600,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
|||||||
("cancelled", False),
|
("cancelled", False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||||
def test_execute_scheduled_call_status_check(
|
def test_execute_scheduled_call_status_check(
|
||||||
mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
|
mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
|
||||||
):
|
):
|
||||||
|
|||||||
125
tools/deploy.sh
Executable file
125
tools/deploy.sh
Executable file
@ -0,0 +1,125 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
REMOTE_HOST="memory"
|
||||||
|
REMOTE_DIR="/home/ec2-user/memory"
|
||||||
|
DEFAULT_BRANCH="master"
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m'
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
echo "Usage: $0 <command> [options]"
|
||||||
|
echo ""
|
||||||
|
echo "Commands:"
|
||||||
|
echo " sync Rsync local code to server"
|
||||||
|
echo " pull [branch] Git checkout and pull (default: master)"
|
||||||
|
echo " restart Restart docker services"
|
||||||
|
echo " deploy [branch] Pull + restart"
|
||||||
|
echo " run <command> Run command on server (with venv activated)"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
sync_code() {
|
||||||
|
echo -e "${GREEN}Syncing code to $REMOTE_HOST...${NC}"
|
||||||
|
|
||||||
|
rsync -avz --delete \
|
||||||
|
--exclude='__pycache__' \
|
||||||
|
--exclude='*.pyc' \
|
||||||
|
--exclude='*.pyo' \
|
||||||
|
--exclude='.git' \
|
||||||
|
--exclude='memory_files' \
|
||||||
|
--exclude='secrets' \
|
||||||
|
--exclude='Books' \
|
||||||
|
--exclude='clean_books' \
|
||||||
|
--exclude='.env' \
|
||||||
|
--exclude='venv' \
|
||||||
|
--exclude='.venv' \
|
||||||
|
--exclude='*.egg-info' \
|
||||||
|
--exclude='node_modules' \
|
||||||
|
--exclude='.DS_Store' \
|
||||||
|
--exclude='docker-compose.override.yml' \
|
||||||
|
--exclude='.pytest_cache' \
|
||||||
|
--exclude='.mypy_cache' \
|
||||||
|
--exclude='.ruff_cache' \
|
||||||
|
--exclude='htmlcov' \
|
||||||
|
--exclude='.coverage' \
|
||||||
|
--exclude='*.log' \
|
||||||
|
"$PROJECT_DIR/src" \
|
||||||
|
"$PROJECT_DIR/tests" \
|
||||||
|
"$PROJECT_DIR/tools" \
|
||||||
|
"$PROJECT_DIR/db" \
|
||||||
|
"$PROJECT_DIR/docker" \
|
||||||
|
"$PROJECT_DIR/frontend" \
|
||||||
|
"$PROJECT_DIR/requirements" \
|
||||||
|
"$PROJECT_DIR/setup.py" \
|
||||||
|
"$PROJECT_DIR/pyproject.toml" \
|
||||||
|
"$PROJECT_DIR/docker-compose.yaml" \
|
||||||
|
"$PROJECT_DIR/pytest.ini" \
|
||||||
|
"$REMOTE_HOST:$REMOTE_DIR/"
|
||||||
|
|
||||||
|
echo -e "${GREEN}Sync complete!${NC}"
|
||||||
|
}
|
||||||
|
|
||||||
|
git_pull() {
|
||||||
|
local branch="${1:-$DEFAULT_BRANCH}"
|
||||||
|
echo -e "${GREEN}Pulling branch '$branch' on $REMOTE_HOST...${NC}"
|
||||||
|
|
||||||
|
ssh "$REMOTE_HOST" "cd $REMOTE_DIR && \
|
||||||
|
git stash --quiet 2>/dev/null || true && \
|
||||||
|
git fetch origin && \
|
||||||
|
git checkout $branch && \
|
||||||
|
git pull origin $branch"
|
||||||
|
|
||||||
|
echo -e "${GREEN}Pull complete!${NC}"
|
||||||
|
}
|
||||||
|
|
||||||
|
restart_services() {
|
||||||
|
echo -e "${GREEN}Restarting services on $REMOTE_HOST...${NC}"
|
||||||
|
|
||||||
|
ssh "$REMOTE_HOST" "cd $REMOTE_DIR && docker compose up --build -d"
|
||||||
|
|
||||||
|
echo -e "${GREEN}Services restarted!${NC}"
|
||||||
|
}
|
||||||
|
|
||||||
|
deploy() {
|
||||||
|
local branch="${1:-$DEFAULT_BRANCH}"
|
||||||
|
git_pull "$branch"
|
||||||
|
restart_services
|
||||||
|
}
|
||||||
|
|
||||||
|
run_remote() {
|
||||||
|
if [ $# -eq 0 ]; then
|
||||||
|
echo -e "${RED}Error: No command specified${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
ssh "$REMOTE_HOST" "cd $REMOTE_DIR && source venv/bin/activate && $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main
|
||||||
|
case "${1:-}" in
|
||||||
|
sync)
|
||||||
|
sync_code
|
||||||
|
;;
|
||||||
|
pull)
|
||||||
|
git_pull "${2:-}"
|
||||||
|
;;
|
||||||
|
restart)
|
||||||
|
restart_services
|
||||||
|
;;
|
||||||
|
deploy)
|
||||||
|
deploy "${2:-}"
|
||||||
|
;;
|
||||||
|
run)
|
||||||
|
shift
|
||||||
|
run_remote "$@"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
usage
|
||||||
|
;;
|
||||||
|
esac
|
||||||
Loading…
x
Reference in New Issue
Block a user