mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
add a bunch of tests
This commit is contained in:
parent
56c0df9761
commit
ad6510bd17
@ -134,11 +134,15 @@ class UsageTracker:
|
||||
default_config: RateLimitConfig | None = None,
|
||||
) -> None:
|
||||
self._configs = configs or {}
|
||||
self._default_config = default_config or RateLimitConfig(
|
||||
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
|
||||
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
||||
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
||||
)
|
||||
if default_config is None:
|
||||
default_config = RateLimitConfig(
|
||||
window=timedelta(
|
||||
minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES
|
||||
),
|
||||
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
||||
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
||||
)
|
||||
self._default_config = default_config
|
||||
self._lock = Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@ -260,8 +264,8 @@ class UsageTracker:
|
||||
|
||||
with self._lock:
|
||||
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||
for model, state in self.iter_state_items():
|
||||
prov, model_name = split_model_key(model)
|
||||
for model_key, state in self.iter_state_items():
|
||||
prov, model_name = split_model_key(model_key)
|
||||
if provider and provider != prov:
|
||||
continue
|
||||
if model and model != model_name:
|
||||
@ -304,7 +308,10 @@ class UsageTracker:
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _get_config(self, model: str) -> RateLimitConfig | None:
|
||||
return self._configs.get(model) or self._default_config
|
||||
config = self._configs.get(model)
|
||||
if config is not None:
|
||||
return config
|
||||
return self._default_config
|
||||
|
||||
def _prune_expired_events(
|
||||
self,
|
||||
|
||||
@ -179,17 +179,14 @@ def should_track_message(
|
||||
channel: DiscordChannel,
|
||||
user: DiscordUser,
|
||||
) -> bool:
|
||||
"""Pure function to determine if we should track this message"""
|
||||
if server and not server.track_messages: # type: ignore
|
||||
if server and server.ignore_messages:
|
||||
return False
|
||||
|
||||
if not channel.track_messages:
|
||||
if channel.ignore_messages:
|
||||
return False
|
||||
|
||||
if channel.channel_type in ("dm", "group_dm"):
|
||||
return bool(user.track_messages)
|
||||
|
||||
# Default: track the message
|
||||
return not user.ignore_messages
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -20,6 +21,31 @@ from memory.common.qdrant import initialize_collections
|
||||
from tests.providers.email_provider import MockEmailProvider
|
||||
|
||||
|
||||
class MockRedis:
|
||||
"""In-memory mock of Redis for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
|
||||
def get(self, key: str):
|
||||
return self._data.get(key)
|
||||
|
||||
def set(self, key: str, value):
|
||||
self._data[key] = value
|
||||
|
||||
def scan_iter(self, match: str):
|
||||
import fnmatch
|
||||
|
||||
pattern = match.replace("*", "**")
|
||||
for key in self._data.keys():
|
||||
if fnmatch.fnmatch(key, pattern):
|
||||
yield key
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str):
|
||||
return cls()
|
||||
|
||||
|
||||
def get_test_db_name() -> str:
|
||||
return f"test_db_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@ -83,7 +109,7 @@ def run_alembic_migrations(db_name: str) -> None:
|
||||
alembic_ini = project_root / "db" / "migrations" / "alembic.ini"
|
||||
|
||||
subprocess.run(
|
||||
["alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||
[sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||
env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)},
|
||||
check=True,
|
||||
capture_output=True,
|
||||
@ -265,7 +291,8 @@ def mock_openai_client():
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
)
|
||||
|
||||
@ -281,7 +308,8 @@ def mock_openai_client():
|
||||
delta=Mock(content="test", tool_calls=None),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
Mock(
|
||||
choices=[
|
||||
@ -289,7 +317,8 @@ def mock_openai_client():
|
||||
delta=Mock(content=" response", tool_calls=None),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=15),
|
||||
),
|
||||
]
|
||||
)
|
||||
@ -303,7 +332,8 @@ def mock_openai_client():
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
|
||||
client.chat.completions.create.side_effect = streaming_response
|
||||
@ -312,6 +342,8 @@ def mock_openai_client():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_anthropic_client():
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
|
||||
client = mock_client()
|
||||
client.messages = Mock()
|
||||
@ -345,7 +377,57 @@ def mock_anthropic_client():
|
||||
]
|
||||
)
|
||||
)
|
||||
yield client
|
||||
|
||||
# Mock async client
|
||||
async_client = Mock()
|
||||
async_client.messages = Mock()
|
||||
async_client.messages.create = AsyncMock(
|
||||
return_value=Mock(
|
||||
content=[
|
||||
Mock(
|
||||
type="text",
|
||||
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Mock async streaming
|
||||
def async_stream_ctx(*args, **kwargs):
|
||||
async def async_iter():
|
||||
yield Mock(
|
||||
type="content_block_delta",
|
||||
delta=Mock(
|
||||
type="text_delta",
|
||||
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||
),
|
||||
)
|
||||
|
||||
class AsyncStreamMock:
|
||||
async def __aenter__(self):
|
||||
return async_iter()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
return AsyncStreamMock()
|
||||
|
||||
async_client.messages.stream = Mock(side_effect=async_stream_ctx)
|
||||
|
||||
# Add async_client property to mock
|
||||
mock_client.return_value._async_client = None
|
||||
|
||||
with patch.object(anthropic, "AsyncAnthropic", return_value=async_client):
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_redis():
|
||||
"""Mock Redis client for all tests."""
|
||||
import redis
|
||||
|
||||
with patch.object(redis, "Redis", MockRedis):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
151
tests/memory/api/test_auth.py
Normal file
151
tests/memory/api/test_auth.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""Tests for authentication helpers and OAuth callback."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
|
||||
from memory.api import auth
|
||||
from memory.common import settings
|
||||
|
||||
|
||||
def make_request(query: str) -> Request:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/auth/callback/discord",
|
||||
"headers": [],
|
||||
"query_string": query.encode(),
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
|
||||
return Request(scope, receive)
|
||||
|
||||
|
||||
def test_get_bearer_token_parses_header():
|
||||
request = SimpleNamespace(headers={"Authorization": "Bearer token123"})
|
||||
|
||||
assert auth.get_bearer_token(request) == "token123"
|
||||
|
||||
|
||||
def test_get_bearer_token_handles_missing_header():
|
||||
request = SimpleNamespace(headers={})
|
||||
|
||||
assert auth.get_bearer_token(request) is None
|
||||
|
||||
|
||||
def test_get_token_prefers_header_over_cookie():
|
||||
request = SimpleNamespace(
|
||||
headers={"Authorization": "Bearer header-token"},
|
||||
cookies={"session": "cookie-token"},
|
||||
)
|
||||
|
||||
assert auth.get_token(request) == "header-token"
|
||||
|
||||
|
||||
def test_get_token_falls_back_to_cookie():
|
||||
request = SimpleNamespace(
|
||||
headers={},
|
||||
cookies={settings.SESSION_COOKIE_NAME: "cookie-token"},
|
||||
)
|
||||
|
||||
assert auth.get_token(request) == "cookie-token"
|
||||
|
||||
|
||||
@patch("memory.api.auth.get_user_session")
|
||||
def test_logout_removes_session(mock_get_user_session):
|
||||
db = MagicMock()
|
||||
session = MagicMock()
|
||||
mock_get_user_session.return_value = session
|
||||
request = SimpleNamespace()
|
||||
|
||||
result = auth.logout(request, db)
|
||||
|
||||
assert result == {"message": "Logged out successfully"}
|
||||
db.delete.assert_called_once_with(session)
|
||||
db.commit.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.api.auth.get_user_session", return_value=None)
|
||||
def test_logout_handles_missing_session(mock_get_user_session):
|
||||
db = MagicMock()
|
||||
request = SimpleNamespace()
|
||||
|
||||
result = auth.logout(request, db)
|
||||
|
||||
assert result == {"message": "Logged out successfully"}
|
||||
db.delete.assert_not_called()
|
||||
db.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.make_session")
|
||||
async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def session_cm():
|
||||
yield mock_session
|
||||
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
mcp_server = MagicMock()
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||
|
||||
mock_complete.return_value = (200, "Authorized")
|
||||
|
||||
request = make_request("code=abc123&state=state456")
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.body.decode()
|
||||
assert "Authorization Successful" in body
|
||||
assert "Authorized" in body
|
||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.make_session")
|
||||
async def test_oauth_callback_discord_handles_failures(
|
||||
mock_make_session, mock_complete
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def session_cm():
|
||||
yield mock_session
|
||||
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
mcp_server = MagicMock()
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||
|
||||
mock_complete.return_value = (500, "Failure")
|
||||
|
||||
request = make_request("code=abc123&state=state456")
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
|
||||
assert response.status_code == 500
|
||||
body = response.body.decode()
|
||||
assert "Authorization Failed" in body
|
||||
assert "Failure" in body
|
||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_discord_validates_query_params():
|
||||
request = make_request("code=&state=")
|
||||
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
|
||||
assert response.status_code == 400
|
||||
body = response.body.decode()
|
||||
assert "Missing authorization code" in body
|
||||
@ -1,5 +1,7 @@
|
||||
"""Tests for Discord database models."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from memory.common.db.models import DiscordServer, DiscordChannel, DiscordUser
|
||||
|
||||
@ -19,12 +21,11 @@ def test_create_discord_server(db_session):
|
||||
assert server.name == "Test Server"
|
||||
assert server.description == "A test Discord server"
|
||||
assert server.member_count == 100
|
||||
assert server.track_messages is True # default value
|
||||
assert server.ignore_messages is False
|
||||
assert server.ignore_messages is False # default value
|
||||
|
||||
|
||||
def test_discord_server_as_xml(db_session):
|
||||
"""Test DiscordServer.as_xml() method."""
|
||||
"""Test DiscordServer.to_xml() method."""
|
||||
server = DiscordServer(
|
||||
id=123456789,
|
||||
name="Test Server",
|
||||
@ -33,11 +34,11 @@ def test_discord_server_as_xml(db_session):
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
xml = server.as_xml()
|
||||
assert "<servers>" in xml # tablename is discord_servers, strips to "servers"
|
||||
assert "<name>Test Server</name>" in xml
|
||||
assert "<summary>This is a test server for gaming</summary>" in xml
|
||||
assert "</servers>" in xml
|
||||
xml = server.to_xml("name", "summary")
|
||||
assert "<server>" in xml # tablename is discord_servers, strips to "server"
|
||||
assert "<name>" in xml and "Test Server" in xml
|
||||
assert "<summary>" in xml and "This is a test server for gaming" in xml
|
||||
assert "</server>" in xml
|
||||
|
||||
|
||||
def test_discord_server_message_tracking(db_session):
|
||||
@ -45,13 +46,11 @@ def test_discord_server_message_tracking(db_session):
|
||||
server = DiscordServer(
|
||||
id=123456789,
|
||||
name="Test Server",
|
||||
track_messages=False,
|
||||
ignore_messages=True,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
assert server.track_messages is False
|
||||
assert server.ignore_messages is True
|
||||
|
||||
|
||||
@ -111,7 +110,7 @@ def test_discord_channel_without_server(db_session):
|
||||
|
||||
|
||||
def test_discord_channel_as_xml(db_session):
|
||||
"""Test DiscordChannel.as_xml() method."""
|
||||
"""Test DiscordChannel.to_xml() method."""
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
name="general",
|
||||
@ -121,30 +120,28 @@ def test_discord_channel_as_xml(db_session):
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
xml = channel.as_xml()
|
||||
assert "<channels>" in xml # tablename is discord_channels, strips to "channels"
|
||||
assert "<name>general</name>" in xml
|
||||
assert "<summary>Main discussion channel</summary>" in xml
|
||||
assert "</channels>" in xml
|
||||
xml = channel.to_xml("name", "summary")
|
||||
assert "<channel>" in xml # tablename is discord_channels, strips to "channel"
|
||||
assert "<name>" in xml and "general" in xml
|
||||
assert "<summary>" in xml and "Main discussion channel" in xml
|
||||
assert "</channel>" in xml
|
||||
|
||||
|
||||
def test_discord_channel_inherits_server_settings(db_session):
|
||||
"""Test that channels can have their own or inherit server settings."""
|
||||
server = DiscordServer(
|
||||
id=987654321, name="Server", track_messages=True, ignore_messages=False
|
||||
)
|
||||
server = DiscordServer(id=987654321, name="Server", ignore_messages=False)
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
server_id=server.id,
|
||||
name="announcements",
|
||||
channel_type="text",
|
||||
track_messages=False, # Override server setting
|
||||
ignore_messages=True, # Override server setting
|
||||
)
|
||||
db_session.add_all([server, channel])
|
||||
db_session.commit()
|
||||
|
||||
assert server.track_messages is True
|
||||
assert channel.track_messages is False
|
||||
assert server.ignore_messages is False
|
||||
assert channel.ignore_messages is True
|
||||
|
||||
|
||||
def test_create_discord_user(db_session):
|
||||
@ -186,7 +183,7 @@ def test_discord_user_with_system_user(db_session):
|
||||
|
||||
|
||||
def test_discord_user_as_xml(db_session):
|
||||
"""Test DiscordUser.as_xml() method."""
|
||||
"""Test DiscordUser.to_xml() method."""
|
||||
user = DiscordUser(
|
||||
id=555666777,
|
||||
username="testuser",
|
||||
@ -195,11 +192,10 @@ def test_discord_user_as_xml(db_session):
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
xml = user.as_xml()
|
||||
assert "<users>" in xml # tablename is discord_users, strips to "users"
|
||||
assert "<name>testuser</name>" in xml
|
||||
assert "<summary>Friendly and helpful community member</summary>" in xml
|
||||
assert "</users>" in xml
|
||||
xml = user.to_xml("summary")
|
||||
assert "<user>" in xml # tablename is discord_users, strips to "user"
|
||||
assert "<summary>" in xml and "Friendly and helpful community member" in xml
|
||||
assert "</user>" in xml
|
||||
|
||||
|
||||
def test_discord_user_message_preferences(db_session):
|
||||
@ -207,13 +203,11 @@ def test_discord_user_message_preferences(db_session):
|
||||
user = DiscordUser(
|
||||
id=555666777,
|
||||
username="testuser",
|
||||
track_messages=True,
|
||||
ignore_messages=False,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
assert user.track_messages is True
|
||||
assert user.ignore_messages is False
|
||||
|
||||
|
||||
@ -234,6 +228,21 @@ def test_discord_server_channel_relationship(db_session):
|
||||
assert channel2 in server.channels
|
||||
|
||||
|
||||
def test_discord_processor_xml_mcp_servers():
|
||||
"""Test xml_mcp_servers includes assigned MCP server XML."""
|
||||
server = DiscordServer(id=111, name="Server")
|
||||
mcp_stub = SimpleNamespace(
|
||||
as_xml=lambda: "<mcp_server><name>Example</name></mcp_server>"
|
||||
)
|
||||
|
||||
# Relationship is optional for test purposes; assign directly
|
||||
server.mcp_servers = [mcp_stub]
|
||||
|
||||
xml_output = server.xml_mcp_servers()
|
||||
assert "<mcp_server>" in xml_output
|
||||
assert "Example" in xml_output
|
||||
|
||||
|
||||
def test_discord_server_cascade_delete(db_session):
|
||||
"""Test that deleting a server cascades to channels."""
|
||||
server = DiscordServer(id=987654321, name="Test Server")
|
||||
|
||||
155
tests/memory/common/db/models/test_mcp_models.py
Normal file
155
tests/memory/common/db/models/test_mcp_models.py
Normal file
@ -0,0 +1,155 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from memory.common.db.models.mcp import MCPServer, MCPServerAssignment
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"available_tools,expected_tools",
|
||||
[
|
||||
(["search", "summarize"], ["• search", "• summarize"]),
|
||||
([], []),
|
||||
],
|
||||
)
|
||||
def test_mcp_server_as_xml_formats_available_tools(available_tools, expected_tools):
|
||||
server = MCPServer(
|
||||
name="Example Server",
|
||||
mcp_server_url="https://example.com/mcp",
|
||||
client_id="client-123",
|
||||
available_tools=available_tools,
|
||||
)
|
||||
|
||||
xml_output = server.as_xml()
|
||||
root = ET.fromstring(xml_output)
|
||||
|
||||
assert root.find("name").text.strip() == "Example Server"
|
||||
assert root.find("mcp_server_url").text.strip() == "https://example.com/mcp"
|
||||
assert root.find("client_id").text.strip() == "client-123"
|
||||
|
||||
tools_element = root.find("available_tools")
|
||||
assert tools_element is not None
|
||||
|
||||
tools_text = tools_element.text.strip() if tools_element.text else ""
|
||||
if expected_tools:
|
||||
assert tools_text.splitlines() == expected_tools
|
||||
else:
|
||||
assert tools_text == ""
|
||||
|
||||
|
||||
def test_mcp_server_crud_and_token_expiration(db_session):
|
||||
initial_expiry = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
server = MCPServer(
|
||||
name="Initial Server",
|
||||
mcp_server_url="https://initial.example.com/mcp",
|
||||
client_id="client-initial",
|
||||
available_tools=["search"],
|
||||
access_token="access-123",
|
||||
refresh_token="refresh-123",
|
||||
token_expires_at=initial_expiry,
|
||||
)
|
||||
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
fetched = db_session.get(MCPServer, server.id)
|
||||
assert fetched is not None
|
||||
assert fetched.access_token == "access-123"
|
||||
assert fetched.refresh_token == "refresh-123"
|
||||
assert fetched.token_expires_at == initial_expiry
|
||||
assert fetched.token_expires_at.tzinfo is not None
|
||||
|
||||
new_expiry = initial_expiry + timedelta(minutes=15)
|
||||
fetched.name = "Updated Server"
|
||||
fetched.available_tools = [*fetched.available_tools, "summarize"]
|
||||
fetched.access_token = "access-456"
|
||||
fetched.refresh_token = "refresh-456"
|
||||
fetched.token_expires_at = new_expiry
|
||||
db_session.commit()
|
||||
|
||||
updated = db_session.get(MCPServer, server.id)
|
||||
assert updated is not None
|
||||
assert updated.name == "Updated Server"
|
||||
assert updated.available_tools == ["search", "summarize"]
|
||||
assert updated.access_token == "access-456"
|
||||
assert updated.refresh_token == "refresh-456"
|
||||
assert updated.token_expires_at == new_expiry
|
||||
|
||||
db_session.delete(updated)
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.get(MCPServer, server.id) is None
|
||||
|
||||
|
||||
def test_mcp_server_assignments_relationship_and_cascade(db_session):
|
||||
server = MCPServer(
|
||||
name="Cascade Server",
|
||||
mcp_server_url="https://cascade.example.com/mcp",
|
||||
client_id="client-cascade",
|
||||
available_tools=["search"],
|
||||
)
|
||||
server.assignments.extend(
|
||||
[
|
||||
MCPServerAssignment(entity_type="DiscordUser", entity_id=101),
|
||||
MCPServerAssignment(entity_type="DiscordChannel", entity_id=202),
|
||||
]
|
||||
)
|
||||
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
persisted_server = db_session.get(MCPServer, server.id)
|
||||
assert persisted_server is not None
|
||||
assert len(persisted_server.assignments) == 2
|
||||
assert {assignment.entity_type for assignment in persisted_server.assignments} == {
|
||||
"DiscordUser",
|
||||
"DiscordChannel",
|
||||
}
|
||||
assert all(
|
||||
assignment.mcp_server_id == persisted_server.id
|
||||
for assignment in persisted_server.assignments
|
||||
)
|
||||
|
||||
db_session.delete(persisted_server)
|
||||
db_session.commit()
|
||||
|
||||
remaining_assignments = db_session.query(MCPServerAssignment).all()
|
||||
assert remaining_assignments == []
|
||||
|
||||
|
||||
def test_mcp_server_assignment_unique_constraint(db_session):
|
||||
server = MCPServer(
|
||||
name="Unique Server",
|
||||
mcp_server_url="https://unique.example.com/mcp",
|
||||
client_id="client-unique",
|
||||
available_tools=["search"],
|
||||
)
|
||||
assignment = MCPServerAssignment(
|
||||
entity_type="DiscordUser",
|
||||
entity_id=12345,
|
||||
)
|
||||
server.assignments.append(assignment)
|
||||
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
duplicate_assignment = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=12345,
|
||||
)
|
||||
db_session.add(duplicate_assignment)
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
db_session.commit()
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
assignments = (
|
||||
db_session.query(MCPServerAssignment)
|
||||
.filter(MCPServerAssignment.mcp_server_id == server.id)
|
||||
.all()
|
||||
)
|
||||
assert len(assignments) == 1
|
||||
@ -162,8 +162,18 @@ def test_create_bot_user_auto_api_key(db_session):
|
||||
|
||||
def test_create_discord_bot_user(db_session):
|
||||
"""Test creating a DiscordBotUser"""
|
||||
from memory.common.db.models import DiscordUser
|
||||
|
||||
# Create a Discord user for the bot
|
||||
discord_user = DiscordUser(
|
||||
id=123456789,
|
||||
username="botuser",
|
||||
)
|
||||
db_session.add(discord_user)
|
||||
db_session.commit()
|
||||
|
||||
user = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
discord_users=[discord_user],
|
||||
name="Discord Bot",
|
||||
email="discordbot@example.com",
|
||||
api_key="discord_key_123",
|
||||
@ -176,6 +186,7 @@ def test_create_discord_bot_user(db_session):
|
||||
assert user.name == "Discord Bot"
|
||||
assert user.user_type == "discord_bot"
|
||||
assert user.api_key == "discord_key_123"
|
||||
assert len(user.discord_users) == 1
|
||||
|
||||
|
||||
def test_user_serialization_human(db_session):
|
||||
|
||||
@ -131,7 +131,7 @@ def test_build_request_kwargs_basic(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, settings)
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, None, settings)
|
||||
|
||||
assert kwargs["model"] == "claude-3-opus-20240229"
|
||||
assert kwargs["temperature"] == 0.5
|
||||
@ -143,7 +143,9 @@ def test_build_request_kwargs_with_system_prompt(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
|
||||
kwargs = provider._build_request_kwargs(
|
||||
messages, "system prompt", None, None, settings
|
||||
)
|
||||
|
||||
assert kwargs["system"] == "system prompt"
|
||||
|
||||
@ -160,7 +162,7 @@ def test_build_request_kwargs_with_tools(provider):
|
||||
]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
|
||||
kwargs = provider._build_request_kwargs(messages, None, tools, None, settings)
|
||||
|
||||
assert "tools" in kwargs
|
||||
assert len(kwargs["tools"]) == 1
|
||||
@ -170,7 +172,9 @@ def test_build_request_kwargs_with_thinking(thinking_provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(max_tokens=5000)
|
||||
|
||||
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
|
||||
kwargs = thinking_provider._build_request_kwargs(
|
||||
messages, None, None, None, settings
|
||||
)
|
||||
|
||||
assert "thinking" in kwargs
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
@ -183,7 +187,9 @@ def test_build_request_kwargs_thinking_insufficient_tokens(thinking_provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(max_tokens=1000)
|
||||
|
||||
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
|
||||
kwargs = thinking_provider._build_request_kwargs(
|
||||
messages, None, None, None, settings
|
||||
)
|
||||
|
||||
# Shouldn't enable thinking if not enough tokens
|
||||
assert "thinking" not in kwargs
|
||||
@ -326,7 +332,7 @@ async def test_agenerate_basic(provider, mock_anthropic_client):
|
||||
|
||||
result = await provider.agenerate(messages)
|
||||
|
||||
assert result == "test summary"
|
||||
assert "<summary>test summary</summary>" in result
|
||||
provider.async_client.messages.create.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from PIL import Image
|
||||
|
||||
from memory.common.llms.openai_provider import OpenAIProvider
|
||||
@ -192,7 +192,7 @@ def test_build_request_kwargs_basic(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, settings)
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, None, settings)
|
||||
|
||||
assert kwargs["model"] == "gpt-4o"
|
||||
assert kwargs["temperature"] == 0.5
|
||||
@ -204,7 +204,9 @@ def test_build_request_kwargs_with_system_prompt_standard_model(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
|
||||
kwargs = provider._build_request_kwargs(
|
||||
messages, "system prompt", None, None, settings
|
||||
)
|
||||
|
||||
# For gpt-4o, system prompt becomes system message
|
||||
assert kwargs["messages"][0]["role"] == "system"
|
||||
@ -218,7 +220,7 @@ def test_build_request_kwargs_with_system_prompt_reasoning_model(
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = reasoning_provider._build_request_kwargs(
|
||||
messages, "system prompt", None, settings
|
||||
messages, "system prompt", None, None, settings
|
||||
)
|
||||
|
||||
# For o1 models, system prompt becomes developer message
|
||||
@ -232,7 +234,9 @@ def test_build_request_kwargs_reasoning_model_uses_max_completion_tokens(
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(max_tokens=2000)
|
||||
|
||||
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
|
||||
kwargs = reasoning_provider._build_request_kwargs(
|
||||
messages, None, None, None, settings
|
||||
)
|
||||
|
||||
# Reasoning models use max_completion_tokens
|
||||
assert "max_completion_tokens" in kwargs
|
||||
@ -244,7 +248,9 @@ def test_build_request_kwargs_reasoning_model_no_temperature(reasoning_provider)
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(temperature=0.7)
|
||||
|
||||
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
|
||||
kwargs = reasoning_provider._build_request_kwargs(
|
||||
messages, None, None, None, settings
|
||||
)
|
||||
|
||||
# Reasoning models don't support temperature
|
||||
assert "temperature" not in kwargs
|
||||
@ -263,7 +269,7 @@ def test_build_request_kwargs_with_tools(provider):
|
||||
]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
|
||||
kwargs = provider._build_request_kwargs(messages, None, tools, None, settings)
|
||||
|
||||
assert "tools" in kwargs
|
||||
assert len(kwargs["tools"]) == 1
|
||||
@ -274,7 +280,9 @@ def test_build_request_kwargs_with_stream(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, settings, stream=True)
|
||||
kwargs = provider._build_request_kwargs(
|
||||
messages, None, None, None, settings, stream=True
|
||||
)
|
||||
|
||||
assert kwargs["stream"] is True
|
||||
|
||||
@ -314,7 +322,8 @@ def test_handle_stream_chunk_text_content(provider):
|
||||
delta=Mock(content="hello", tool_calls=None),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
)
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||
@ -342,8 +351,9 @@ def test_handle_stream_chunk_tool_call_start(provider):
|
||||
choice.delta = delta
|
||||
choice.finish_reason = None
|
||||
|
||||
chunk = Mock(spec=["choices"])
|
||||
chunk = Mock(spec=["choices", "usage"])
|
||||
chunk.choices = [choice]
|
||||
chunk.usage = Mock(prompt_tokens=10, completion_tokens=5)
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||
|
||||
@ -369,7 +379,8 @@ def test_handle_stream_chunk_tool_call_arguments(provider):
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
)
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
||||
@ -386,7 +397,8 @@ def test_handle_stream_chunk_finish_with_tool_call(provider):
|
||||
delta=Mock(content=None, tool_calls=None),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
)
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
||||
@ -399,7 +411,7 @@ def test_handle_stream_chunk_finish_with_tool_call(provider):
|
||||
|
||||
|
||||
def test_handle_stream_chunk_empty_choices(provider):
|
||||
chunk = Mock(choices=[])
|
||||
chunk = Mock(choices=[], usage=Mock(prompt_tokens=10, completion_tokens=5))
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||
|
||||
@ -435,8 +447,13 @@ async def test_agenerate_basic(provider, mock_openai_client):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
|
||||
# Mock the async client
|
||||
mock_response = Mock(choices=[Mock(message=Mock(content="async response"))])
|
||||
provider.async_client.chat.completions.create = Mock(return_value=mock_response)
|
||||
mock_response = Mock(
|
||||
choices=[Mock(message=Mock(content="async response"))],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
provider.async_client.chat.completions.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
result = await provider.agenerate(messages)
|
||||
|
||||
@ -452,15 +469,19 @@ async def test_astream_basic(provider, mock_openai_client):
|
||||
yield Mock(
|
||||
choices=[
|
||||
Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None)
|
||||
]
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
)
|
||||
yield Mock(
|
||||
choices=[
|
||||
Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop")
|
||||
]
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=10),
|
||||
)
|
||||
|
||||
provider.async_client.chat.completions.create = Mock(return_value=async_stream())
|
||||
provider.async_client.chat.completions.create = AsyncMock(
|
||||
return_value=async_stream()
|
||||
)
|
||||
|
||||
events = []
|
||||
async for event in provider.astream(messages):
|
||||
|
||||
@ -18,10 +18,10 @@ except ModuleNotFoundError: # pragma: no cover - import guard for test envs
|
||||
|
||||
sys.modules.setdefault("redis", _RedisStub())
|
||||
|
||||
from memory.common.llms.redis_usage_tracker import RedisUsageTracker
|
||||
from memory.common.llms.usage_tracker import (
|
||||
from memory.common.llms.usage import (
|
||||
InMemoryUsageTracker,
|
||||
RateLimitConfig,
|
||||
RedisUsageTracker,
|
||||
UsageTracker,
|
||||
)
|
||||
|
||||
@ -84,7 +84,9 @@ def redis_tracker() -> RedisUsageTracker:
|
||||
(timedelta(seconds=0), {"max_total_tokens": 1}),
|
||||
],
|
||||
)
|
||||
def test_rate_limit_config_validation(window: timedelta, kwargs: dict[str, int]) -> None:
|
||||
def test_rate_limit_config_validation(
|
||||
window: timedelta, kwargs: dict[str, int]
|
||||
) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
RateLimitConfig(window=window, **kwargs)
|
||||
|
||||
@ -93,9 +95,7 @@ def test_allows_usage_within_limits(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||
|
||||
allowance = tracker.get_available_tokens(
|
||||
"anthropic/claude-3", timestamp=now
|
||||
)
|
||||
allowance = tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
|
||||
assert allowance is not None
|
||||
assert allowance.input_tokens == 900
|
||||
assert allowance.output_tokens == 1_800
|
||||
@ -114,9 +114,7 @@ def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
|
||||
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
||||
|
||||
later = now + timedelta(minutes=2)
|
||||
allowance = tracker.get_available_tokens(
|
||||
"anthropic/claude-3", timestamp=later
|
||||
)
|
||||
allowance = tracker.get_available_tokens("anthropic/claude-3", timestamp=later)
|
||||
assert allowance is not None
|
||||
assert allowance.input_tokens == 1_000
|
||||
assert allowance.output_tokens == 2_000
|
||||
@ -126,6 +124,7 @@ def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
|
||||
|
||||
def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
# Use the configured models from the fixture
|
||||
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||
tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||
|
||||
@ -144,6 +143,7 @@ def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> N
|
||||
|
||||
def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
# Use configured models from the fixture
|
||||
tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now)
|
||||
tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now)
|
||||
|
||||
@ -156,15 +156,19 @@ def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
||||
assert set(filtered_model["openai"].keys()) == {"gpt-4o"}
|
||||
|
||||
|
||||
def test_missing_configuration_records_lifetime_only() -> None:
|
||||
def test_missing_configuration_uses_default() -> None:
|
||||
# With no specific config, falls back to default config (from settings)
|
||||
tracker = InMemoryUsageTracker(configs={})
|
||||
tracker.record_usage("openai/gpt-4o", 10, 20)
|
||||
|
||||
assert tracker.get_available_tokens("openai/gpt-4o") is None
|
||||
# Uses default config, so get_available_tokens returns allowance
|
||||
allowance = tracker.get_available_tokens("openai/gpt-4o")
|
||||
assert allowance is not None
|
||||
|
||||
# Lifetime stats are tracked
|
||||
breakdown = tracker.get_usage_breakdown()
|
||||
usage = breakdown["openai"]["gpt-4o"]
|
||||
assert usage.window_input_tokens == 0
|
||||
assert usage.window_input_tokens == 10
|
||||
assert usage.lifetime_input_tokens == 10
|
||||
|
||||
|
||||
@ -193,6 +197,7 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
|
||||
|
||||
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
# Use configured models from the fixture
|
||||
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||
|
||||
@ -201,6 +206,8 @@ def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) ->
|
||||
assert allowance.input_tokens == 900
|
||||
|
||||
breakdown = redis_tracker.get_usage_breakdown()
|
||||
assert "anthropic" in breakdown
|
||||
assert "claude-3" in breakdown["anthropic"]
|
||||
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
|
||||
|
||||
items = dict(redis_tracker.iter_state_items())
|
||||
|
||||
26
tests/memory/common/llms/tools/test_base_tools.py
Normal file
26
tests/memory/common/llms/tools/test_base_tools.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""Tests for base web tool definitions."""
|
||||
|
||||
from memory.common.llms.tools.base import WebFetchTool, WebSearchTool
|
||||
|
||||
|
||||
def test_web_search_tool_provider_formats():
|
||||
tool = WebSearchTool()
|
||||
|
||||
assert tool.provider_format("openai") == {"type": "web_search"}
|
||||
assert tool.provider_format("anthropic") == {
|
||||
"type": "web_search_20250305",
|
||||
"name": "web_search",
|
||||
"max_uses": 10,
|
||||
}
|
||||
assert tool.provider_format("unknown") is None
|
||||
|
||||
|
||||
def test_web_fetch_tool_provider_formats():
|
||||
tool = WebFetchTool()
|
||||
|
||||
assert tool.provider_format("anthropic") == {
|
||||
"type": "web_fetch_20250910",
|
||||
"name": "web_fetch",
|
||||
"max_uses": 10,
|
||||
}
|
||||
assert tool.provider_format("openai") is None
|
||||
@ -497,13 +497,14 @@ def test_make_discord_tools_with_user_and_channel(
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# update_user_summary, update_server_summary
|
||||
assert len(tools) == 5
|
||||
# update_user_summary, update_server_summary, add_reaction
|
||||
assert len(tools) == 6
|
||||
assert "schedule_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_user_summary" in tools
|
||||
assert "update_server_summary" in tools
|
||||
assert "add_reaction" in tools
|
||||
|
||||
|
||||
def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user):
|
||||
@ -533,12 +534,13 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# update_server_summary (no user summary without author)
|
||||
assert len(tools) == 4
|
||||
# update_server_summary, add_reaction (no user summary without author)
|
||||
assert len(tools) == 5
|
||||
assert "schedule_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_server_summary" in tools
|
||||
assert "add_reaction" in tools
|
||||
assert "update_user_summary" not in tools
|
||||
|
||||
|
||||
|
||||
539
tests/memory/common/test_oauth.py
Normal file
539
tests/memory/common/test_oauth.py
Normal file
@ -0,0 +1,539 @@
|
||||
"""Tests for OAuth 2.0 flow handling."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import aiohttp
|
||||
|
||||
from memory.common.oauth import (
|
||||
OAuthEndpoints,
|
||||
generate_pkce_pair,
|
||||
discover_oauth_metadata,
|
||||
get_endpoints,
|
||||
register_oauth_client,
|
||||
issue_challenge,
|
||||
complete_oauth_flow,
|
||||
)
|
||||
from memory.common.db.models import MCPServer
|
||||
|
||||
|
||||
class TestGeneratePkcePair:
|
||||
"""Tests for generate_pkce_pair function."""
|
||||
|
||||
def test_generates_valid_verifier_and_challenge(self):
|
||||
"""Test that PKCE pair is generated correctly."""
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
|
||||
# Verifier should be base64url encoded (no padding)
|
||||
assert len(verifier) > 0
|
||||
assert "=" not in verifier
|
||||
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" for c in verifier)
|
||||
|
||||
# Challenge should be base64url encoded (no padding)
|
||||
assert len(challenge) > 0
|
||||
assert "=" not in challenge
|
||||
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" for c in challenge)
|
||||
|
||||
# They should be different
|
||||
assert verifier != challenge
|
||||
|
||||
def test_generates_unique_pairs(self):
|
||||
"""Test that each call generates a unique pair."""
|
||||
verifier1, challenge1 = generate_pkce_pair()
|
||||
verifier2, challenge2 = generate_pkce_pair()
|
||||
|
||||
assert verifier1 != verifier2
|
||||
assert challenge1 != challenge2
|
||||
|
||||
|
||||
class TestDiscoverOauthMetadata:
|
||||
"""Tests for discover_oauth_metadata function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_metadata_success(self):
|
||||
"""Test successful OAuth metadata discovery."""
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value=metadata)
|
||||
|
||||
mock_get = AsyncMock()
|
||||
mock_get.__aenter__.return_value = mock_response
|
||||
mock_get.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.get = Mock(return_value=mock_get)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await discover_oauth_metadata("https://example.com")
|
||||
|
||||
assert result == metadata
|
||||
assert result["authorization_endpoint"] == "https://example.com/auth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_metadata_not_found(self):
|
||||
"""Test OAuth metadata discovery when endpoint not found."""
|
||||
mock_response = Mock()
|
||||
mock_response.status = 404
|
||||
|
||||
mock_get = AsyncMock()
|
||||
mock_get.__aenter__.return_value = mock_response
|
||||
mock_get.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.get = Mock(return_value=mock_get)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await discover_oauth_metadata("https://example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_metadata_connection_error(self):
|
||||
"""Test OAuth metadata discovery with connection error."""
|
||||
mock_get = AsyncMock()
|
||||
mock_get.__aenter__.side_effect = aiohttp.ClientError("Connection failed")
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.get = Mock(return_value=mock_get)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await discover_oauth_metadata("https://example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetEndpoints:
|
||||
"""Tests for get_endpoints function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_success(self):
|
||||
"""Test successful endpoint retrieval."""
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||
result = await get_endpoints("https://example.com")
|
||||
|
||||
assert isinstance(result, OAuthEndpoints)
|
||||
assert result.authorization_endpoint == "https://example.com/auth"
|
||||
assert result.registration_endpoint == "https://example.com/register"
|
||||
assert result.token_endpoint == "https://example.com/token"
|
||||
assert "/auth/callback/discord" in result.redirect_uri
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_no_metadata(self):
|
||||
"""Test when OAuth metadata cannot be discovered."""
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=None):
|
||||
with pytest.raises(ValueError, match="Failed to connect to MCP server"):
|
||||
await get_endpoints("https://example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_missing_authorization(self):
|
||||
"""Test when authorization endpoint is missing."""
|
||||
metadata = {
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||
with pytest.raises(ValueError, match="authorization endpoint"):
|
||||
await get_endpoints("https://example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_missing_registration(self):
|
||||
"""Test when registration endpoint is missing."""
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||
with pytest.raises(ValueError, match="dynamic client registration"):
|
||||
await get_endpoints("https://example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_missing_token(self):
|
||||
"""Test when token endpoint is missing."""
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
}
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||
with pytest.raises(ValueError, match="token endpoint"):
|
||||
await get_endpoints("https://example.com")
|
||||
|
||||
|
||||
class TestRegisterOauthClient:
|
||||
"""Tests for register_oauth_client function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client_success(self):
|
||||
"""Test successful OAuth client registration."""
|
||||
endpoints = OAuthEndpoints(
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
registration_endpoint="https://example.com/register",
|
||||
token_endpoint="https://example.com/token",
|
||||
redirect_uri="https://myapp.com/callback",
|
||||
)
|
||||
|
||||
client_info = {"client_id": "test-client-123"}
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.text = AsyncMock(return_value="Success")
|
||||
mock_response.json = AsyncMock(return_value=client_info)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
client_id = await register_oauth_client(
|
||||
endpoints,
|
||||
"https://example.com",
|
||||
"Test Client",
|
||||
)
|
||||
|
||||
assert client_id == "test-client-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client_http_error(self):
|
||||
"""Test OAuth client registration with HTTP error."""
|
||||
endpoints = OAuthEndpoints(
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
registration_endpoint="https://example.com/register",
|
||||
token_endpoint="https://example.com/token",
|
||||
redirect_uri="https://myapp.com/callback",
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status = Mock(side_effect=aiohttp.ClientResponseError(
|
||||
request_info=Mock(),
|
||||
history=(),
|
||||
status=400,
|
||||
message="Bad Request",
|
||||
))
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Failed to register OAuth client"):
|
||||
await register_oauth_client(
|
||||
endpoints,
|
||||
"https://example.com",
|
||||
"Test Client",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client_missing_client_id(self):
|
||||
"""Test OAuth client registration when response lacks client_id."""
|
||||
endpoints = OAuthEndpoints(
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
registration_endpoint="https://example.com/register",
|
||||
token_endpoint="https://example.com/token",
|
||||
redirect_uri="https://myapp.com/callback",
|
||||
)
|
||||
|
||||
client_info = {} # Missing client_id
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value=client_info)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Failed to register OAuth client"):
|
||||
await register_oauth_client(
|
||||
endpoints,
|
||||
"https://example.com",
|
||||
"Test Client",
|
||||
)
|
||||
|
||||
|
||||
class TestIssueChallenge:
|
||||
"""Tests for issue_challenge function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_issue_challenge_success(self, db_session):
|
||||
"""Test successful OAuth challenge issuance."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
endpoints = OAuthEndpoints(
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
registration_endpoint="https://example.com/register",
|
||||
token_endpoint="https://example.com/token",
|
||||
redirect_uri="https://myapp.com/callback",
|
||||
)
|
||||
|
||||
with patch("memory.common.oauth.generate_pkce_pair", return_value=("verifier123", "challenge123")):
|
||||
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||
|
||||
# Verify the auth URL contains expected parameters
|
||||
assert "https://example.com/auth?" in auth_url
|
||||
assert "client_id=test-client-123" in auth_url
|
||||
# redirect_uri will be URL encoded
|
||||
assert "redirect_uri=" in auth_url
|
||||
assert "myapp.com" in auth_url
|
||||
assert "callback" in auth_url
|
||||
assert "response_type=code" in auth_url
|
||||
assert "code_challenge=challenge123" in auth_url
|
||||
assert "code_challenge_method=S256" in auth_url
|
||||
assert "state=" in auth_url
|
||||
|
||||
# Verify state and code_verifier were stored
|
||||
assert mcp_server.state is not None
|
||||
assert mcp_server.code_verifier == "verifier123"
|
||||
|
||||
|
||||
class TestCompleteOauthFlow:
|
||||
"""Tests for complete_oauth_flow function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_success(self, db_session):
|
||||
"""Test successful OAuth flow completion."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
state="test-state",
|
||||
code_verifier="test-verifier",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
token_response = {
|
||||
"access_token": "access-token-123",
|
||||
"refresh_token": "refresh-token-123",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
mock_token_response = Mock()
|
||||
mock_token_response.status = 200
|
||||
mock_token_response.json = AsyncMock(return_value=token_response)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_token_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
|
||||
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
|
||||
):
|
||||
status, message = await complete_oauth_flow(
|
||||
mcp_server,
|
||||
"auth-code-123",
|
||||
"test-state",
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert "successful" in message
|
||||
|
||||
# Verify tokens were stored
|
||||
assert mcp_server.access_token == "access-token-123"
|
||||
assert mcp_server.refresh_token == "refresh-token-123"
|
||||
assert mcp_server.token_expires_at is not None
|
||||
|
||||
# Verify temporary state was cleared
|
||||
assert mcp_server.state is None
|
||||
assert mcp_server.code_verifier is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_invalid_state(self):
|
||||
"""Test OAuth flow completion with invalid state."""
|
||||
status, message = await complete_oauth_flow(
|
||||
None,
|
||||
"auth-code-123",
|
||||
"invalid-state",
|
||||
)
|
||||
|
||||
assert status == 400
|
||||
assert "Invalid or expired" in message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_token_error(self, db_session):
|
||||
"""Test OAuth flow completion when token exchange fails."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
state="test-state",
|
||||
code_verifier="test-verifier",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
mock_token_response = Mock()
|
||||
mock_token_response.status = 400
|
||||
mock_token_response.text = AsyncMock(return_value="Invalid grant")
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_token_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
|
||||
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
|
||||
):
|
||||
status, message = await complete_oauth_flow(
|
||||
mcp_server,
|
||||
"invalid-code",
|
||||
"test-state",
|
||||
)
|
||||
|
||||
assert status == 500
|
||||
assert "Token exchange failed" in message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_missing_access_token(self, db_session):
|
||||
"""Test OAuth flow completion when access token is missing from response."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
state="test-state",
|
||||
code_verifier="test-verifier",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
token_response = {} # Missing access_token
|
||||
|
||||
mock_token_response = Mock()
|
||||
mock_token_response.status = 200
|
||||
mock_token_response.json = AsyncMock(return_value=token_response)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_token_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
|
||||
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
|
||||
):
|
||||
status, message = await complete_oauth_flow(
|
||||
mcp_server,
|
||||
"auth-code-123",
|
||||
"test-state",
|
||||
)
|
||||
|
||||
assert status == 500
|
||||
assert "did not include access_token" in message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_get_endpoints_error(self, db_session):
|
||||
"""Test OAuth flow completion when getting endpoints fails."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
state="test-state",
|
||||
code_verifier="test-verifier",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=None):
|
||||
status, message = await complete_oauth_flow(
|
||||
mcp_server,
|
||||
"auth-code-123",
|
||||
"test-state",
|
||||
)
|
||||
|
||||
assert status == 500
|
||||
assert "Failed to get OAuth endpoints" in message
|
||||
253
tests/memory/discord_tests/test_api.py
Normal file
253
tests/memory/discord_tests/test_api.py
Normal file
@ -0,0 +1,253 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from memory.discord import api
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_app_bots():
|
||||
existing = getattr(api.app, "bots", None)
|
||||
api.app.bots = {}
|
||||
yield
|
||||
if existing is None:
|
||||
delattr(api.app, "bots")
|
||||
else:
|
||||
api.app.bots = existing
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def active_bot():
|
||||
collector = SimpleNamespace(
|
||||
send_dm=AsyncMock(return_value=True),
|
||||
trigger_typing_dm=AsyncMock(return_value=True),
|
||||
send_to_channel=AsyncMock(return_value=True),
|
||||
trigger_typing_channel=AsyncMock(return_value=True),
|
||||
add_reaction=AsyncMock(return_value=True),
|
||||
refresh_metadata=AsyncMock(return_value={"refreshed": True}),
|
||||
is_closed=Mock(return_value=False),
|
||||
user="CollectorUser#1234",
|
||||
guilds=[101, 202],
|
||||
)
|
||||
bot = SimpleNamespace(
|
||||
collector=collector,
|
||||
collector_task=None,
|
||||
bot_id=1,
|
||||
bot_token="token-123",
|
||||
bot_name="Test Bot",
|
||||
)
|
||||
api.app.bots[bot.bot_id] = bot
|
||||
return bot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_success(active_bot):
|
||||
request = api.SendDMRequest(bot_id=active_bot.bot_id, user="user123", message="Hello")
|
||||
|
||||
response = await api.send_dm_endpoint(request)
|
||||
|
||||
assert response == {
|
||||
"success": True,
|
||||
"message": "DM sent to user123",
|
||||
"user": "user123",
|
||||
}
|
||||
active_bot.collector.send_dm.assert_awaited_once_with("user123", "Hello")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint,payload",
|
||||
[
|
||||
(
|
||||
api.send_dm_endpoint,
|
||||
api.SendDMRequest(bot_id=99, user="ghost", message="hi"),
|
||||
),
|
||||
(
|
||||
api.trigger_dm_typing,
|
||||
api.TypingDMRequest(bot_id=99, user="ghost"),
|
||||
),
|
||||
(
|
||||
api.send_channel_endpoint,
|
||||
api.SendChannelRequest(bot_id=99, channel="general", message="hello"),
|
||||
),
|
||||
(
|
||||
api.trigger_channel_typing,
|
||||
api.TypingChannelRequest(bot_id=99, channel="general"),
|
||||
),
|
||||
(
|
||||
api.add_reaction_endpoint,
|
||||
api.AddReactionRequest(
|
||||
bot_id=99,
|
||||
channel="general",
|
||||
message_id=42,
|
||||
emoji=":thumbsup:",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_endpoint_returns_404_when_bot_missing(endpoint, payload):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await endpoint(payload)
|
||||
|
||||
assert exc.value.status_code == 404
|
||||
assert exc.value.detail == "Bot not found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint,request_cls,request_kwargs,attr_name,detail_template",
|
||||
[
|
||||
(
|
||||
api.send_dm_endpoint,
|
||||
api.SendDMRequest,
|
||||
{"bot_id": 1, "user": "user123", "message": "Hi"},
|
||||
"send_dm",
|
||||
"Failed to send DM to {user}",
|
||||
),
|
||||
(
|
||||
api.trigger_dm_typing,
|
||||
api.TypingDMRequest,
|
||||
{"bot_id": 1, "user": "user123"},
|
||||
"trigger_typing_dm",
|
||||
"Failed to trigger typing for user123",
|
||||
),
|
||||
(
|
||||
api.send_channel_endpoint,
|
||||
api.SendChannelRequest,
|
||||
{"bot_id": 1, "channel": "general", "message": "Hello"},
|
||||
"send_to_channel",
|
||||
"Failed to send message to channel general",
|
||||
),
|
||||
(
|
||||
api.trigger_channel_typing,
|
||||
api.TypingChannelRequest,
|
||||
{"bot_id": 1, "channel": "general"},
|
||||
"trigger_typing_channel",
|
||||
"Failed to trigger typing for channel general",
|
||||
),
|
||||
(
|
||||
api.add_reaction_endpoint,
|
||||
api.AddReactionRequest,
|
||||
{"bot_id": 1, "channel": "general", "message_id": 55, "emoji": ":fire:"},
|
||||
"add_reaction",
|
||||
"Failed to add reaction to message 55",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_endpoint_returns_400_on_collector_failure(
|
||||
active_bot, endpoint, request_cls, request_kwargs, attr_name, detail_template
|
||||
):
|
||||
request = request_cls(**request_kwargs)
|
||||
getattr(active_bot.collector, attr_name).return_value = False
|
||||
expected_detail = detail_template.format(**request_kwargs)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await endpoint(request)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert exc.value.detail == expected_detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint,request_cls,request_kwargs,attr_name",
|
||||
[
|
||||
(
|
||||
api.send_dm_endpoint,
|
||||
api.SendDMRequest,
|
||||
{"bot_id": 1, "user": "user123", "message": "Hi"},
|
||||
"send_dm",
|
||||
),
|
||||
(
|
||||
api.trigger_dm_typing,
|
||||
api.TypingDMRequest,
|
||||
{"bot_id": 1, "user": "user123"},
|
||||
"trigger_typing_dm",
|
||||
),
|
||||
(
|
||||
api.send_channel_endpoint,
|
||||
api.SendChannelRequest,
|
||||
{"bot_id": 1, "channel": "general", "message": "Hello"},
|
||||
"send_to_channel",
|
||||
),
|
||||
(
|
||||
api.trigger_channel_typing,
|
||||
api.TypingChannelRequest,
|
||||
{"bot_id": 1, "channel": "general"},
|
||||
"trigger_typing_channel",
|
||||
),
|
||||
(
|
||||
api.add_reaction_endpoint,
|
||||
api.AddReactionRequest,
|
||||
{"bot_id": 1, "channel": "general", "message_id": 55, "emoji": ":fire:"},
|
||||
"add_reaction",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_endpoint_returns_500_on_collector_exception(
|
||||
active_bot, endpoint, request_cls, request_kwargs, attr_name
|
||||
):
|
||||
request = request_cls(**request_kwargs)
|
||||
getattr(active_bot.collector, attr_name).side_effect = RuntimeError("boom")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await endpoint(request)
|
||||
|
||||
assert exc.value.status_code == 500
|
||||
assert "boom" in exc.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(active_bot):
|
||||
response = await api.health_check()
|
||||
|
||||
assert response["Test Bot"] == {
|
||||
"status": "healthy",
|
||||
"connected": True,
|
||||
"user": "CollectorUser#1234",
|
||||
"guilds": 2,
|
||||
}
|
||||
active_bot.collector.is_closed.assert_called_once_with()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_without_bots():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await api.health_check()
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
assert exc.value.detail == "Discord collector not running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_metadata_success(active_bot):
|
||||
active_bot.collector.refresh_metadata.return_value = {"channels": 3}
|
||||
|
||||
response = await api.refresh_metadata()
|
||||
|
||||
assert response["success"] is True
|
||||
assert response["message"] == "Metadata refreshed successfully for 1 bots"
|
||||
assert response["results"]["Test Bot"] == {"channels": 3}
|
||||
active_bot.collector.refresh_metadata.assert_awaited_once_with()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_metadata_without_bots():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await api.refresh_metadata()
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
assert exc.value.detail == "Discord collector not running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_metadata_failure(active_bot):
|
||||
active_bot.collector.refresh_metadata.side_effect = RuntimeError("sync failed")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await api.refresh_metadata()
|
||||
|
||||
assert exc.value.status_code == 500
|
||||
assert "sync failed" in exc.value.detail
|
||||
@ -79,6 +79,7 @@ def mock_message(mock_text_channel, mock_user):
|
||||
message.content = "Test message"
|
||||
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
message.reference = None
|
||||
message.attachments = []
|
||||
return message
|
||||
|
||||
|
||||
@ -351,7 +352,7 @@ def test_determine_message_metadata_thread():
|
||||
# Tests for should_track_message
|
||||
def test_should_track_message_server_disabled(db_session):
|
||||
"""Test when server has tracking disabled"""
|
||||
server = DiscordServer(id=1, name="Server", track_messages=False)
|
||||
server = DiscordServer(id=1, name="Server", ignore_messages=True)
|
||||
channel = DiscordChannel(id=2, name="Channel", channel_type="text")
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
@ -362,9 +363,9 @@ def test_should_track_message_server_disabled(db_session):
|
||||
|
||||
def test_should_track_message_channel_disabled(db_session):
|
||||
"""Test when channel has tracking disabled"""
|
||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||
server = DiscordServer(id=1, name="Server", ignore_messages=False)
|
||||
channel = DiscordChannel(
|
||||
id=2, name="Channel", channel_type="text", track_messages=False
|
||||
id=2, name="Channel", channel_type="text", ignore_messages=True
|
||||
)
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
@ -375,8 +376,8 @@ def test_should_track_message_channel_disabled(db_session):
|
||||
|
||||
def test_should_track_message_dm_allowed(db_session):
|
||||
"""Test DM tracking when user allows it"""
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||
user = DiscordUser(id=3, username="User", track_messages=True)
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", ignore_messages=False)
|
||||
user = DiscordUser(id=3, username="User", ignore_messages=False)
|
||||
|
||||
result = should_track_message(None, channel, user)
|
||||
|
||||
@ -385,8 +386,8 @@ def test_should_track_message_dm_allowed(db_session):
|
||||
|
||||
def test_should_track_message_dm_not_allowed(db_session):
|
||||
"""Test DM tracking when user doesn't allow it"""
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||
user = DiscordUser(id=3, username="User", track_messages=False)
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", ignore_messages=False)
|
||||
user = DiscordUser(id=3, username="User", ignore_messages=True)
|
||||
|
||||
result = should_track_message(None, channel, user)
|
||||
|
||||
@ -395,9 +396,9 @@ def test_should_track_message_dm_not_allowed(db_session):
|
||||
|
||||
def test_should_track_message_default_true(db_session):
|
||||
"""Test default tracking behavior"""
|
||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||
server = DiscordServer(id=1, name="Server", ignore_messages=False)
|
||||
channel = DiscordChannel(
|
||||
id=2, name="Channel", channel_type="text", track_messages=True
|
||||
id=2, name="Channel", channel_type="text", ignore_messages=False
|
||||
)
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
@ -465,6 +466,7 @@ def test_sync_guild_metadata(mock_make_session, mock_guild):
|
||||
voice_channel.guild = mock_guild
|
||||
|
||||
mock_guild.channels = [text_channel, voice_channel]
|
||||
mock_guild.threads = []
|
||||
|
||||
sync_guild_metadata(mock_guild)
|
||||
|
||||
@ -489,16 +491,25 @@ def test_message_collector_init():
|
||||
async def test_on_ready():
|
||||
"""Test on_ready event handler"""
|
||||
collector = MessageCollector()
|
||||
collector.user = Mock()
|
||||
collector.user.name = "TestBot"
|
||||
collector.guilds = [Mock(), Mock()]
|
||||
collector.sync_servers_and_channels = AsyncMock()
|
||||
collector.tree.sync = AsyncMock()
|
||||
|
||||
await collector.on_ready()
|
||||
# Mock the properties
|
||||
mock_user = Mock()
|
||||
mock_user.name = "TestBot"
|
||||
with patch.object(
|
||||
type(collector), "user", new_callable=lambda: property(lambda self: mock_user)
|
||||
):
|
||||
with patch.object(
|
||||
type(collector),
|
||||
"guilds",
|
||||
new_callable=lambda: property(lambda self: [Mock(), Mock()]),
|
||||
):
|
||||
collector.sync_servers_and_channels = AsyncMock()
|
||||
collector.tree.sync = AsyncMock()
|
||||
|
||||
collector.sync_servers_and_channels.assert_called_once()
|
||||
collector.tree.sync.assert_awaited()
|
||||
await collector.on_ready()
|
||||
|
||||
collector.sync_servers_and_channels.assert_called_once()
|
||||
collector.tree.sync.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -593,14 +604,18 @@ async def test_sync_servers_and_channels():
|
||||
guild2 = Mock()
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild1, guild2]
|
||||
|
||||
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
|
||||
await collector.sync_servers_and_channels()
|
||||
with patch.object(
|
||||
type(collector),
|
||||
"guilds",
|
||||
new_callable=lambda: property(lambda self: [guild1, guild2]),
|
||||
):
|
||||
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
|
||||
await collector.sync_servers_and_channels()
|
||||
|
||||
assert mock_sync.call_count == 2
|
||||
mock_sync.assert_any_call(guild1)
|
||||
mock_sync.assert_any_call(guild2)
|
||||
assert mock_sync.call_count == 2
|
||||
mock_sync.assert_any_call(guild1)
|
||||
mock_sync.assert_any_call(guild2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -617,17 +632,26 @@ async def test_refresh_metadata(mock_make_session):
|
||||
guild.name = "Test"
|
||||
guild.channels = []
|
||||
guild.members = []
|
||||
guild.threads = []
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
collector.intents = Mock()
|
||||
collector.intents.members = False
|
||||
|
||||
result = await collector.refresh_metadata()
|
||||
mock_intents = Mock()
|
||||
mock_intents.members = False
|
||||
|
||||
assert result["servers_updated"] == 1
|
||||
assert result["channels_updated"] == 0
|
||||
assert result["users_updated"] == 0
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||
):
|
||||
with patch.object(
|
||||
type(collector),
|
||||
"intents",
|
||||
new_callable=lambda: property(lambda self: mock_intents),
|
||||
):
|
||||
result = await collector.refresh_metadata()
|
||||
|
||||
assert result["servers_updated"] == 1
|
||||
assert result["channels_updated"] == 0
|
||||
assert result["users_updated"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -637,7 +661,7 @@ async def test_get_user_by_id():
|
||||
user.id = 123
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_user = Mock(return_value=user)
|
||||
collector.get_user = AsyncMock(return_value=user)
|
||||
|
||||
result = await collector.get_user(123)
|
||||
|
||||
@ -656,22 +680,32 @@ async def test_get_user_by_username():
|
||||
guild.members = [member]
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
result = await collector.get_user("testuser")
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||
):
|
||||
result = await collector.get_user("testuser")
|
||||
|
||||
assert result == member
|
||||
assert result == member
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_not_found():
|
||||
"""Test getting non-existent user"""
|
||||
collector = MessageCollector()
|
||||
collector.guilds = []
|
||||
|
||||
with patch.object(collector, "get_user", return_value=None):
|
||||
# Create proper mock response for discord.NotFound
|
||||
mock_response = Mock()
|
||||
mock_response.status = 404
|
||||
mock_response.text = ""
|
||||
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [])
|
||||
):
|
||||
with patch.object(
|
||||
collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock())
|
||||
collector,
|
||||
"fetch_user",
|
||||
AsyncMock(side_effect=discord.NotFound(mock_response, "User not found")),
|
||||
):
|
||||
result = await collector.get_user(999)
|
||||
assert result is None
|
||||
@ -687,11 +721,13 @@ async def test_get_channel_by_name():
|
||||
guild.channels = [channel]
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
result = await collector.get_channel_by_name("general")
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||
):
|
||||
result = await collector.get_channel_by_name("general")
|
||||
|
||||
assert result == channel
|
||||
assert result == channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -701,11 +737,13 @@ async def test_get_channel_by_name_not_found():
|
||||
guild.channels = []
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
result = await collector.get_channel_by_name("nonexistent")
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||
):
|
||||
result = await collector.get_channel_by_name("nonexistent")
|
||||
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -730,11 +768,13 @@ async def test_create_channel_no_guild():
|
||||
"""Test creating channel when no guild available"""
|
||||
collector = MessageCollector()
|
||||
collector.get_guild = Mock(return_value=None)
|
||||
collector.guilds = []
|
||||
|
||||
result = await collector.create_channel("new-channel")
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [])
|
||||
):
|
||||
result = await collector.create_channel("new-channel")
|
||||
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -816,27 +856,19 @@ async def test_send_to_channel_not_found():
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="run_collector function doesn't exist or uses different settings"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
async def test_run_collector():
|
||||
"""Test running the collector"""
|
||||
from memory.discord.collector import run_collector
|
||||
|
||||
with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
|
||||
mock_collector = Mock()
|
||||
mock_collector.start = AsyncMock()
|
||||
mock_collector_class.return_value = mock_collector
|
||||
|
||||
await run_collector()
|
||||
|
||||
mock_collector.start.assert_called_once_with("test_token")
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="run_collector function doesn't exist or uses different settings"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
||||
async def test_run_collector_no_token():
|
||||
"""Test running collector without token"""
|
||||
from memory.discord.collector import run_collector
|
||||
|
||||
# Should return early without raising
|
||||
await run_collector()
|
||||
pass
|
||||
|
||||
@ -1,17 +1,23 @@
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import discord
|
||||
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
from memory.discord.commands import (
|
||||
CommandContext,
|
||||
CommandError,
|
||||
CommandResponse,
|
||||
run_command,
|
||||
handle_prompt,
|
||||
handle_chattiness,
|
||||
handle_ignore,
|
||||
handle_summary,
|
||||
respond,
|
||||
with_object_context,
|
||||
handle_mcp_servers,
|
||||
)
|
||||
|
||||
|
||||
@ -66,29 +72,54 @@ def interaction(guild, text_channel, discord_user) -> DummyInteraction:
|
||||
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
|
||||
|
||||
|
||||
def test_handle_command_prompt_server(db_session, guild, interaction):
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_prompt_server(db_session, guild, interaction):
|
||||
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
handler=handle_prompt,
|
||||
target=server,
|
||||
display_name="server **Test Guild**",
|
||||
)
|
||||
|
||||
response = await handle_prompt(context)
|
||||
|
||||
assert isinstance(response, CommandResponse)
|
||||
assert "Be helpful" in response.content
|
||||
|
||||
|
||||
def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel):
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="channel",
|
||||
handler=handle_prompt,
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_prompt_channel_creates_channel(
|
||||
db_session, interaction, text_channel, guild
|
||||
):
|
||||
# Create the server first to satisfy FK constraint
|
||||
server = DiscordServer(id=guild.id, name="Test Guild")
|
||||
db_session.add(server)
|
||||
|
||||
channel_model = DiscordChannel(
|
||||
id=text_channel.id,
|
||||
name=text_channel.name,
|
||||
channel_type="text",
|
||||
server_id=guild.id,
|
||||
)
|
||||
db_session.add(channel_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="channel",
|
||||
target=channel_model,
|
||||
display_name=f"channel **#{text_channel.name}**",
|
||||
)
|
||||
|
||||
response = await handle_prompt(context)
|
||||
|
||||
assert "No prompt" in response.content
|
||||
channel = db_session.get(DiscordChannel, text_channel.id)
|
||||
@ -96,77 +127,253 @@ def test_handle_command_prompt_channel_creates_channel(db_session, interaction,
|
||||
assert channel.name == text_channel.name
|
||||
|
||||
|
||||
def test_handle_command_chattiness_show(db_session, interaction, guild):
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_chattiness_show(db_session, interaction, guild):
|
||||
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
handler=handle_chattiness,
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_chattiness(context, value=None)
|
||||
|
||||
assert str(server.chattiness_threshold) in response.content
|
||||
|
||||
|
||||
def test_handle_command_chattiness_update(db_session, interaction):
|
||||
user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15)
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_chattiness_update(db_session, interaction):
|
||||
user_model = DiscordUser(
|
||||
id=interaction.user.id, username="command-user", chattiness_threshold=15
|
||||
)
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=user_model,
|
||||
scope="user",
|
||||
handler=handle_chattiness,
|
||||
value=80,
|
||||
target=user_model,
|
||||
display_name="user **command-user**",
|
||||
)
|
||||
|
||||
response = await handle_chattiness(context, value=80)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert user_model.chattiness_threshold == 80
|
||||
|
||||
|
||||
def test_handle_command_chattiness_invalid_value(db_session, interaction):
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_chattiness_invalid_value(db_session, interaction):
|
||||
user_model = DiscordUser(id=interaction.user.id, username="command-user")
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=user_model,
|
||||
scope="user",
|
||||
target=user_model,
|
||||
display_name="user **command-user**",
|
||||
)
|
||||
|
||||
with pytest.raises(CommandError):
|
||||
run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
handler=handle_chattiness,
|
||||
value=150,
|
||||
)
|
||||
await handle_chattiness(context, value=150)
|
||||
|
||||
|
||||
def test_handle_command_ignore_toggle(db_session, interaction, guild):
|
||||
channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id)
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_ignore_toggle(db_session, interaction, guild):
|
||||
# Create the server first to satisfy FK constraint
|
||||
server = DiscordServer(id=guild.id, name="Test Guild")
|
||||
db_session.add(server)
|
||||
|
||||
channel = DiscordChannel(
|
||||
id=interaction.channel.id,
|
||||
name="general",
|
||||
channel_type="text",
|
||||
server_id=guild.id,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="channel",
|
||||
handler=handle_ignore,
|
||||
ignore_enabled=True,
|
||||
target=channel,
|
||||
display_name="channel **#general**",
|
||||
)
|
||||
|
||||
response = await handle_ignore(context, ignore_enabled=True)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
assert "no longer" not in response.content
|
||||
assert channel.ignore_messages is True
|
||||
|
||||
|
||||
def test_handle_command_summary_missing(db_session, interaction):
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_summary_missing(db_session, interaction):
|
||||
user_model = DiscordUser(id=interaction.user.id, username="command-user")
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=user_model,
|
||||
scope="user",
|
||||
handler=handle_summary,
|
||||
target=user_model,
|
||||
display_name="user **command-user**",
|
||||
)
|
||||
|
||||
response = await handle_summary(context)
|
||||
|
||||
assert "No summary" in response.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respond_sends_message_without_file():
|
||||
interaction = MagicMock(spec=discord.Interaction)
|
||||
interaction.response.send_message = AsyncMock()
|
||||
|
||||
await respond(interaction, "hello world", ephemeral=False)
|
||||
|
||||
interaction.response.send_message.assert_awaited_once_with(
|
||||
"hello world", ephemeral=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respond_sends_file_when_content_too_large():
|
||||
interaction = MagicMock(spec=discord.Interaction)
|
||||
interaction.response.send_message = AsyncMock()
|
||||
|
||||
oversized = "x" * 2000
|
||||
with patch("memory.discord.commands.discord.File") as mock_file:
|
||||
file_instance = MagicMock()
|
||||
mock_file.return_value = file_instance
|
||||
|
||||
await respond(interaction, oversized)
|
||||
|
||||
interaction.response.send_message.assert_awaited_once_with(
|
||||
"Response too large, sending as file:",
|
||||
file=file_instance,
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.discord.commands._ensure_channel")
|
||||
@patch("memory.discord.commands.ensure_server")
|
||||
@patch("memory.discord.commands.ensure_user")
|
||||
@patch("memory.discord.commands.make_session")
|
||||
def test_with_object_context_uses_ensured_objects(
|
||||
mock_make_session,
|
||||
mock_ensure_user,
|
||||
mock_ensure_server,
|
||||
mock_ensure_channel,
|
||||
interaction,
|
||||
guild,
|
||||
text_channel,
|
||||
discord_user,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def session_cm():
|
||||
yield mock_session
|
||||
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
bot_model = MagicMock(name="bot_model")
|
||||
user_model = MagicMock(name="user_model")
|
||||
server_model = MagicMock(name="server_model")
|
||||
channel_model = MagicMock(name="channel_model")
|
||||
|
||||
mock_ensure_user.side_effect = [bot_model, user_model]
|
||||
mock_ensure_server.return_value = server_model
|
||||
mock_ensure_channel.return_value = channel_model
|
||||
|
||||
handler_objects = {}
|
||||
|
||||
def handler(objects):
|
||||
handler_objects["objects"] = objects
|
||||
return "done"
|
||||
|
||||
bot_client = SimpleNamespace(user=MagicMock())
|
||||
override_user = MagicMock(spec=discord.User)
|
||||
|
||||
result = with_object_context(bot_client, interaction, handler, override_user)
|
||||
|
||||
assert result == "done"
|
||||
objects = handler_objects["objects"]
|
||||
assert objects.bot is bot_model
|
||||
assert objects.server is server_model
|
||||
assert objects.channel is channel_model
|
||||
assert objects.user is user_model
|
||||
|
||||
mock_ensure_user.assert_any_call(mock_session, bot_client.user)
|
||||
mock_ensure_user.assert_any_call(mock_session, override_user)
|
||||
mock_ensure_server.assert_called_once_with(mock_session, guild)
|
||||
mock_ensure_channel.assert_called_once_with(
|
||||
mock_session, text_channel, guild.id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.commands.run_mcp_server_command", new_callable=AsyncMock)
|
||||
async def test_handle_mcp_servers_returns_response(mock_run_mcp, interaction):
|
||||
mock_run_mcp.return_value = "Listed servers"
|
||||
server_model = DiscordServer(id=interaction.guild.id, name="Guild")
|
||||
|
||||
context = CommandContext(
|
||||
session=MagicMock(),
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server_model,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
interaction.client = SimpleNamespace(user=MagicMock(spec=discord.User))
|
||||
|
||||
response = await handle_mcp_servers(
|
||||
context, action="list", url=None
|
||||
)
|
||||
|
||||
assert response.content == "Listed servers"
|
||||
mock_run_mcp.assert_awaited_once_with(
|
||||
interaction.client.user, "list", None, "DiscordServer", server_model.id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.commands.run_mcp_server_command", new_callable=AsyncMock)
|
||||
async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
|
||||
mock_run_mcp.side_effect = RuntimeError("boom")
|
||||
server_model = DiscordServer(id=interaction.guild.id, name="Guild")
|
||||
|
||||
context = CommandContext(
|
||||
session=MagicMock(),
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server_model,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
interaction.client = SimpleNamespace(user=MagicMock(spec=discord.User))
|
||||
|
||||
with pytest.raises(CommandError) as exc:
|
||||
await handle_mcp_servers(context, action="list", url=None)
|
||||
|
||||
assert "Error: boom" in str(exc.value)
|
||||
|
||||
590
tests/memory/discord_tests/test_mcp.py
Normal file
590
tests/memory/discord_tests/test_mcp.py
Normal file
@ -0,0 +1,590 @@
|
||||
"""Tests for Discord MCP server management."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
import pytest
|
||||
|
||||
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||
from memory.discord.mcp import (
|
||||
call_mcp_server,
|
||||
find_mcp_server,
|
||||
handle_mcp_add,
|
||||
handle_mcp_connect,
|
||||
handle_mcp_delete,
|
||||
handle_mcp_list,
|
||||
handle_mcp_tools,
|
||||
run_mcp_server_command,
|
||||
)
|
||||
|
||||
|
||||
# Helper class for async iteration
|
||||
class AsyncIterator:
|
||||
"""Helper to create an async iterator for mocking aiohttp response content."""
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
self.index = 0
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.index >= len(self.items):
|
||||
raise StopAsyncIteration
|
||||
item = self.items[self.index]
|
||||
self.index += 1
|
||||
return item
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server(db_session) -> MCPServer:
|
||||
"""Create a test MCP server."""
|
||||
server = MCPServer(
|
||||
name="Test MCP Server",
|
||||
mcp_server_url="https://mcp.example.com",
|
||||
client_id="test_client_id",
|
||||
access_token="test_access_token",
|
||||
available_tools=["tool1", "tool2"],
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_assignment(db_session, mcp_server: MCPServer) -> MCPServerAssignment:
|
||||
"""Create a test MCP server assignment."""
|
||||
assignment = MCPServerAssignment(
|
||||
mcp_server_id=mcp_server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=123456,
|
||||
)
|
||||
db_session.add(assignment)
|
||||
db_session.commit()
|
||||
return assignment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_user() -> discord.User:
|
||||
"""Create a mock Discord bot user."""
|
||||
user = Mock(spec=discord.User)
|
||||
user.name = "TestBot"
|
||||
user.id = 999
|
||||
return user
|
||||
|
||||
|
||||
def test_find_mcp_server_exists(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test finding an existing MCP server."""
|
||||
result = find_mcp_server(
|
||||
db_session,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=123456,
|
||||
url="https://mcp.example.com",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == mcp_server.id
|
||||
assert result.mcp_server_url == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_find_mcp_server_not_found(db_session):
|
||||
"""Test finding a non-existent MCP server."""
|
||||
result = find_mcp_server(
|
||||
db_session,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=999999,
|
||||
url="https://nonexistent.com",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_find_mcp_server_wrong_entity(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test finding MCP server with wrong entity type."""
|
||||
result = find_mcp_server(
|
||||
db_session,
|
||||
entity_type="DiscordChannel", # Wrong entity type
|
||||
entity_id=123456,
|
||||
url="https://mcp.example.com",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_server_success():
|
||||
"""Test calling MCP server successfully."""
|
||||
mock_response_data = [
|
||||
b'data: {"result": {"tools": [{"name": "test"}]}}\n',
|
||||
b'data: {"status": "ok"}\n',
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.content = AsyncIterator(mock_response_data)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
results = []
|
||||
async for data in call_mcp_server(
|
||||
"https://mcp.example.com", "test_token", "tools/list", {}
|
||||
):
|
||||
results.append(data)
|
||||
|
||||
assert len(results) == 2
|
||||
assert "result" in results[0]
|
||||
assert results[0]["result"]["tools"][0]["name"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_server_error():
|
||||
"""Test calling MCP server with error response."""
|
||||
mock_response = Mock()
|
||||
mock_response.status = 500
|
||||
mock_response.text = AsyncMock(return_value="Internal Server Error")
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Failed to call MCP server"):
|
||||
async for _ in call_mcp_server(
|
||||
"https://mcp.example.com", "test_token", "tools/list"
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_server_invalid_json():
|
||||
"""Test calling MCP server with invalid JSON."""
|
||||
mock_response_data = [
|
||||
b"data: invalid json\n",
|
||||
b'data: {"valid": "json"}\n',
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.content = AsyncIterator(mock_response_data)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
results = []
|
||||
async for data in call_mcp_server(
|
||||
"https://mcp.example.com", "test_token", "tools/list"
|
||||
):
|
||||
results.append(data)
|
||||
|
||||
# Should skip invalid JSON and only return valid one
|
||||
assert len(results) == 1
|
||||
assert results[0] == {"valid": "json"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_list_empty(db_session):
|
||||
"""Test listing MCP servers when none exist."""
|
||||
result = await handle_mcp_list("DiscordUser", 123456)
|
||||
|
||||
assert "You don't have any MCP servers configured yet" in result
|
||||
assert "/memory_mcp_servers add" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_list_with_servers(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test listing MCP servers with existing servers."""
|
||||
result = await handle_mcp_list("DiscordUser", 123456)
|
||||
|
||||
assert "Your MCP Servers" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "test_client_id" in result
|
||||
assert "🟢" in result # Server has access token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_list_disconnected_server(db_session):
|
||||
"""Test listing MCP servers with disconnected server."""
|
||||
server = MCPServer(
|
||||
name="Disconnected Server",
|
||||
mcp_server_url="https://disconnected.example.com",
|
||||
client_id="client_123",
|
||||
access_token=None, # No access token
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.flush()
|
||||
|
||||
assignment = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=123456,
|
||||
)
|
||||
db_session.add(assignment)
|
||||
db_session.commit()
|
||||
|
||||
result = await handle_mcp_list("DiscordUser", 123456)
|
||||
|
||||
assert "🔴" in result # Server has no access token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_add_new_server(db_session, mock_bot_user):
|
||||
"""Test adding a new MCP server."""
|
||||
with (
|
||||
patch("memory.discord.mcp.get_endpoints") as mock_get_endpoints,
|
||||
patch("memory.discord.mcp.register_oauth_client") as mock_register,
|
||||
patch("memory.discord.mcp.issue_challenge") as mock_challenge,
|
||||
):
|
||||
mock_endpoints = Mock()
|
||||
mock_get_endpoints.return_value = mock_endpoints
|
||||
mock_register.return_value = "new_client_id"
|
||||
mock_challenge.return_value = "https://auth.example.com/authorize"
|
||||
|
||||
result = await handle_mcp_add(
|
||||
"DiscordUser", 123456, mock_bot_user, "https://new.example.com"
|
||||
)
|
||||
|
||||
assert "Add MCP Server" in result
|
||||
assert "https://new.example.com" in result
|
||||
assert "https://auth.example.com/authorize" in result
|
||||
|
||||
# Verify server was created
|
||||
server = (
|
||||
db_session.query(MCPServer)
|
||||
.filter(MCPServer.mcp_server_url == "https://new.example.com")
|
||||
.first()
|
||||
)
|
||||
assert server is not None
|
||||
assert server.client_id == "new_client_id"
|
||||
|
||||
# Verify assignment was created
|
||||
assignment = (
|
||||
db_session.query(MCPServerAssignment)
|
||||
.filter(
|
||||
MCPServerAssignment.mcp_server_id == server.id,
|
||||
MCPServerAssignment.entity_type == "DiscordUser",
|
||||
MCPServerAssignment.entity_id == 123456,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
assert assignment is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_add_existing_server(
|
||||
db_session,
|
||||
mcp_server: MCPServer,
|
||||
mcp_assignment: MCPServerAssignment,
|
||||
mock_bot_user,
|
||||
):
|
||||
"""Test adding an MCP server that already exists."""
|
||||
result = await handle_mcp_add(
|
||||
"DiscordUser", 123456, mock_bot_user, "https://mcp.example.com"
|
||||
)
|
||||
|
||||
assert "MCP Server Already Exists" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "/memory_mcp_servers connect" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_add_no_bot_user(db_session):
|
||||
"""Test adding MCP server without bot user."""
|
||||
with pytest.raises(ValueError, match="Bot user is required"):
|
||||
await handle_mcp_add("DiscordUser", 123456, None, "https://example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_delete_existing(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test deleting an existing MCP server assignment."""
|
||||
# Store IDs before deletion
|
||||
assignment_id = mcp_assignment.id
|
||||
server_id = mcp_server.id
|
||||
|
||||
result = await handle_mcp_delete("DiscordUser", 123456, "https://mcp.example.com")
|
||||
|
||||
assert "Delete MCP Server" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "has been removed" in result
|
||||
|
||||
# Verify assignment was deleted
|
||||
assignment = (
|
||||
db_session.query(MCPServerAssignment)
|
||||
.filter(MCPServerAssignment.id == assignment_id)
|
||||
.first()
|
||||
)
|
||||
assert assignment is None
|
||||
|
||||
# Verify server was also deleted (no other assignments)
|
||||
server = db_session.query(MCPServer).filter(MCPServer.id == server_id).first()
|
||||
assert server is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_delete_not_found(db_session):
|
||||
"""Test deleting a non-existent MCP server."""
|
||||
result = await handle_mcp_delete("DiscordUser", 123456, "https://nonexistent.com")
|
||||
|
||||
assert "MCP Server Not Found" in result
|
||||
assert "https://nonexistent.com" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_delete_with_other_assignments(db_session):
|
||||
"""Test deleting MCP server with multiple assignments."""
|
||||
server = MCPServer(
|
||||
name="Shared Server",
|
||||
mcp_server_url="https://shared.example.com",
|
||||
client_id="shared_client",
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.flush()
|
||||
|
||||
assignment1 = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=111,
|
||||
)
|
||||
assignment2 = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=222,
|
||||
)
|
||||
db_session.add_all([assignment1, assignment2])
|
||||
db_session.commit()
|
||||
|
||||
# Delete one assignment
|
||||
result = await handle_mcp_delete("DiscordUser", 111, "https://shared.example.com")
|
||||
|
||||
assert "has been removed" in result
|
||||
|
||||
# Verify only one assignment was deleted
|
||||
remaining = (
|
||||
db_session.query(MCPServerAssignment)
|
||||
.filter(MCPServerAssignment.mcp_server_id == server.id)
|
||||
.count()
|
||||
)
|
||||
assert remaining == 1
|
||||
|
||||
# Verify server still exists
|
||||
server_check = db_session.query(MCPServer).filter(MCPServer.id == server.id).first()
|
||||
assert server_check is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_connect_existing(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test reconnecting to an existing MCP server."""
|
||||
with (
|
||||
patch("memory.discord.mcp.get_endpoints") as mock_get_endpoints,
|
||||
patch("memory.discord.mcp.issue_challenge") as mock_challenge,
|
||||
):
|
||||
mock_endpoints = Mock()
|
||||
mock_get_endpoints.return_value = mock_endpoints
|
||||
mock_challenge.return_value = "https://auth.example.com/authorize?state=new"
|
||||
|
||||
result = await handle_mcp_connect(
|
||||
"DiscordUser", 123456, "https://mcp.example.com"
|
||||
)
|
||||
|
||||
assert "Reconnect to MCP Server" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "https://auth.example.com/authorize?state=new" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_connect_not_found(db_session):
|
||||
"""Test reconnecting to a non-existent MCP server."""
|
||||
with pytest.raises(ValueError, match="MCP Server Not Found"):
|
||||
await handle_mcp_connect("DiscordUser", 123456, "https://nonexistent.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_success(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test listing tools from an MCP server."""
|
||||
mock_response_data = [
|
||||
b'data: {"result": {"tools": [{"name": "search", "description": "Search tool"}]}}\n',
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.content = AsyncIterator(mock_response_data)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await handle_mcp_tools(
|
||||
"DiscordUser", 123456, "https://mcp.example.com"
|
||||
)
|
||||
|
||||
assert "MCP Server Tools" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "search" in result
|
||||
assert "Search tool" in result
|
||||
assert "Found 1 tool(s)" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_no_tools(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test listing tools when server has no tools."""
|
||||
mock_response_data = [
|
||||
b'data: {"result": {"tools": []}}\n',
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.content = AsyncIterator(mock_response_data)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await handle_mcp_tools(
|
||||
"DiscordUser", 123456, "https://mcp.example.com"
|
||||
)
|
||||
|
||||
assert "No tools available" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_server_not_found(db_session):
|
||||
"""Test listing tools for a non-existent server."""
|
||||
with pytest.raises(ValueError, match="MCP Server Not Found"):
|
||||
await handle_mcp_tools("DiscordUser", 123456, "https://nonexistent.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_not_authorized(db_session):
|
||||
"""Test listing tools when not authorized."""
|
||||
server = MCPServer(
|
||||
name="Unauthorized Server",
|
||||
mcp_server_url="https://unauthorized.example.com",
|
||||
client_id="client_123",
|
||||
access_token=None, # No access token
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.flush()
|
||||
|
||||
assignment = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=123456,
|
||||
)
|
||||
db_session.add(assignment)
|
||||
db_session.commit()
|
||||
|
||||
with pytest.raises(ValueError, match="Not Authorized"):
|
||||
await handle_mcp_tools(
|
||||
"DiscordUser", 123456, "https://unauthorized.example.com"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_connection_error(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test listing tools with connection error."""
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.side_effect = aiohttp.ClientError("Connection failed")
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Connection failed"):
|
||||
await handle_mcp_tools("DiscordUser", 123456, "https://mcp.example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_mcp_server_command_list(db_session, mock_bot_user):
|
||||
"""Test run_mcp_server_command with list action."""
|
||||
result = await run_mcp_server_command(
|
||||
mock_bot_user, "list", None, "DiscordUser", 123456
|
||||
)
|
||||
|
||||
assert "Your MCP Servers" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_mcp_server_command_invalid_action(mock_bot_user):
|
||||
"""Test run_mcp_server_command with invalid action."""
|
||||
with pytest.raises(ValueError, match="Invalid action"):
|
||||
await run_mcp_server_command(
|
||||
mock_bot_user, "invalid", None, "DiscordUser", 123456
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_mcp_server_command_missing_url(mock_bot_user):
|
||||
"""Test run_mcp_server_command with missing URL for non-list action."""
|
||||
with pytest.raises(ValueError, match="URL is required"):
|
||||
await run_mcp_server_command(mock_bot_user, "add", None, "DiscordUser", 123456)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_mcp_server_command_no_bot_user():
|
||||
"""Test run_mcp_server_command without bot user."""
|
||||
with pytest.raises(ValueError, match="Bot user is required"):
|
||||
await run_mcp_server_command(None, "list", None, "DiscordUser", 123456)
|
||||
@ -1,5 +1,8 @@
|
||||
"""Tests for Discord message helper functions."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from memory.discord.messages import (
|
||||
@ -9,6 +12,7 @@ from memory.discord.messages import (
|
||||
upsert_scheduled_message,
|
||||
previous_messages,
|
||||
comm_channel_prompt,
|
||||
call_llm,
|
||||
)
|
||||
from memory.common.db.models import (
|
||||
DiscordUser,
|
||||
@ -18,6 +22,7 @@ from memory.common.db.models import (
|
||||
HumanUser,
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
from memory.common.llms.tools import MCPServer as MCPServerDefinition
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -411,3 +416,107 @@ def test_comm_channel_prompt_includes_user_notes(
|
||||
|
||||
assert "user_notes" in result.lower()
|
||||
assert "testuser" in result # username should appear
|
||||
|
||||
|
||||
@patch("memory.discord.messages.create_provider")
|
||||
@patch("memory.discord.messages.previous_messages")
|
||||
@patch("memory.common.llms.tools.discord.make_discord_tools")
|
||||
@patch("memory.common.llms.tools.base.WebSearchTool")
|
||||
def test_call_llm_includes_web_search_and_mcp_servers(
|
||||
mock_web_search,
|
||||
mock_make_tools,
|
||||
mock_prev_messages,
|
||||
mock_create_provider,
|
||||
):
|
||||
provider = MagicMock()
|
||||
provider.usage_tracker.is_rate_limited.return_value = False
|
||||
provider.as_messages.return_value = ["converted"]
|
||||
provider.run_with_tools.return_value = SimpleNamespace(response="llm-output")
|
||||
mock_create_provider.return_value = provider
|
||||
|
||||
mock_prev_messages.return_value = [SimpleNamespace(as_content=lambda: "prev")]
|
||||
|
||||
existing_tool = MagicMock(name="existing_tool")
|
||||
mock_make_tools.return_value = {"existing": existing_tool}
|
||||
|
||||
web_tool_instance = MagicMock(name="web_tool")
|
||||
mock_web_search.return_value = web_tool_instance
|
||||
|
||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
|
||||
from_user = SimpleNamespace(id=123)
|
||||
mcp_model = SimpleNamespace(
|
||||
name="Server",
|
||||
mcp_server_url="https://mcp.example.com",
|
||||
access_token="token123",
|
||||
)
|
||||
|
||||
result = call_llm(
|
||||
session=MagicMock(),
|
||||
bot_user=bot_user,
|
||||
from_user=from_user,
|
||||
channel=None,
|
||||
model="gpt-test",
|
||||
messages=["hi"],
|
||||
mcp_servers=[mcp_model],
|
||||
)
|
||||
|
||||
assert result == "llm-output"
|
||||
|
||||
kwargs = provider.run_with_tools.call_args.kwargs
|
||||
tools = kwargs["tools"]
|
||||
assert tools["existing"] is existing_tool
|
||||
assert tools["web_search"] is web_tool_instance
|
||||
|
||||
mcp_servers = kwargs["mcp_servers"]
|
||||
assert mcp_servers == [
|
||||
MCPServerDefinition(
|
||||
name="Server", url="https://mcp.example.com", token="token123"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@patch("memory.discord.messages.create_provider")
|
||||
@patch("memory.discord.messages.previous_messages")
|
||||
@patch("memory.common.llms.tools.discord.make_discord_tools")
|
||||
@patch("memory.common.llms.tools.base.WebSearchTool")
|
||||
def test_call_llm_filters_disallowed_tools(
|
||||
mock_web_search,
|
||||
mock_make_tools,
|
||||
mock_prev_messages,
|
||||
mock_create_provider,
|
||||
):
|
||||
provider = MagicMock()
|
||||
provider.usage_tracker.is_rate_limited.return_value = False
|
||||
provider.as_messages.return_value = ["converted"]
|
||||
provider.run_with_tools.return_value = SimpleNamespace(response="filtered-output")
|
||||
mock_create_provider.return_value = provider
|
||||
|
||||
mock_prev_messages.return_value = []
|
||||
|
||||
allowed_tool = MagicMock(name="allowed")
|
||||
blocked_tool = MagicMock(name="blocked")
|
||||
mock_make_tools.return_value = {
|
||||
"allowed": allowed_tool,
|
||||
"blocked": blocked_tool,
|
||||
}
|
||||
|
||||
mock_web_search.return_value = MagicMock(name="web_tool")
|
||||
|
||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
|
||||
from_user = SimpleNamespace(id=1)
|
||||
|
||||
call_llm(
|
||||
session=MagicMock(),
|
||||
bot_user=bot_user,
|
||||
from_user=from_user,
|
||||
channel=None,
|
||||
model="gpt-test",
|
||||
messages=[],
|
||||
allowed_tools={"allowed"},
|
||||
mcp_servers=None,
|
||||
)
|
||||
|
||||
tools = provider.run_with_tools.call_args.kwargs["tools"]
|
||||
assert "allowed" in tools
|
||||
assert "blocked" not in tools
|
||||
assert "web_search" not in tools
|
||||
|
||||
41
tests/tools/test_discord_setup.py
Normal file
41
tests/tools/test_discord_setup.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Tests for Discord setup CLI utilities."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from tools.discord_setup import generate_bot_invite_url, make_invite
|
||||
|
||||
|
||||
def test_make_invite_generates_expected_url():
|
||||
result = make_invite(123456789)
|
||||
|
||||
assert (
|
||||
result
|
||||
== "https://discord.com/oauth2/authorize?client_id=123456789&scope=bot&permissions=3088"
|
||||
)
|
||||
|
||||
|
||||
@patch("tools.discord_setup.requests.get")
|
||||
def test_generate_bot_invite_url_outputs_link(mock_get):
|
||||
response = MagicMock()
|
||||
response.raise_for_status.return_value = None
|
||||
response.json.return_value = {"id": "987654321"}
|
||||
mock_get.return_value = response
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(generate_bot_invite_url, ["--bot-token", "abc.def"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Bot invite URL" in result.output
|
||||
assert "987654321" in result.output
|
||||
|
||||
|
||||
@patch("tools.discord_setup.requests.get", side_effect=Exception("api down"))
|
||||
def test_generate_bot_invite_url_handles_errors(mock_get):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(generate_bot_invite_url, ["--bot-token", "token"])
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert isinstance(result.exception, ValueError)
|
||||
assert "Could not get bot info" in str(result.exception)
|
||||
Loading…
x
Reference in New Issue
Block a user