add a bunch of tests

This commit is contained in:
Daniel O'Connell 2025-11-03 23:23:41 +01:00
parent 56c0df9761
commit ad6510bd17
19 changed files with 2443 additions and 198 deletions

View File

@ -134,11 +134,15 @@ class UsageTracker:
default_config: RateLimitConfig | None = None, default_config: RateLimitConfig | None = None,
) -> None: ) -> None:
self._configs = configs or {} self._configs = configs or {}
self._default_config = default_config or RateLimitConfig( if default_config is None:
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES), default_config = RateLimitConfig(
window=timedelta(
minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES
),
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS, max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS, max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
) )
self._default_config = default_config
self._lock = Lock() self._lock = Lock()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -260,8 +264,8 @@ class UsageTracker:
with self._lock: with self._lock:
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict) providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
for model, state in self.iter_state_items(): for model_key, state in self.iter_state_items():
prov, model_name = split_model_key(model) prov, model_name = split_model_key(model_key)
if provider and provider != prov: if provider and provider != prov:
continue continue
if model and model != model_name: if model and model != model_name:
@ -304,7 +308,10 @@ class UsageTracker:
# Internal helpers # Internal helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _get_config(self, model: str) -> RateLimitConfig | None: def _get_config(self, model: str) -> RateLimitConfig | None:
return self._configs.get(model) or self._default_config config = self._configs.get(model)
if config is not None:
return config
return self._default_config
def _prune_expired_events( def _prune_expired_events(
self, self,

View File

@ -179,17 +179,14 @@ def should_track_message(
channel: DiscordChannel, channel: DiscordChannel,
user: DiscordUser, user: DiscordUser,
) -> bool: ) -> bool:
"""Pure function to determine if we should track this message""" if server and server.ignore_messages:
if server and not server.track_messages: # type: ignore
return False return False
if not channel.track_messages: if channel.ignore_messages:
return False return False
if channel.channel_type in ("dm", "group_dm"): if channel.channel_type in ("dm", "group_dm"):
return bool(user.track_messages) return not user.ignore_messages
# Default: track the message
return True return True

View File

@ -1,5 +1,6 @@
import os import os
import subprocess import subprocess
import sys
import uuid import uuid
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -20,6 +21,31 @@ from memory.common.qdrant import initialize_collections
from tests.providers.email_provider import MockEmailProvider from tests.providers.email_provider import MockEmailProvider
class MockRedis:
"""In-memory mock of Redis for testing."""
def __init__(self):
self._data = {}
def get(self, key: str):
return self._data.get(key)
def set(self, key: str, value):
self._data[key] = value
def scan_iter(self, match: str):
import fnmatch
pattern = match.replace("*", "**")
for key in self._data.keys():
if fnmatch.fnmatch(key, pattern):
yield key
@classmethod
def from_url(cls, url: str):
return cls()
def get_test_db_name() -> str: def get_test_db_name() -> str:
return f"test_db_{uuid.uuid4().hex[:8]}" return f"test_db_{uuid.uuid4().hex[:8]}"
@ -83,7 +109,7 @@ def run_alembic_migrations(db_name: str) -> None:
alembic_ini = project_root / "db" / "migrations" / "alembic.ini" alembic_ini = project_root / "db" / "migrations" / "alembic.ini"
subprocess.run( subprocess.run(
["alembic", "-c", str(alembic_ini), "upgrade", "head"], [sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)}, env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)},
check=True, check=True,
capture_output=True, capture_output=True,
@ -265,7 +291,8 @@ def mock_openai_client():
), ),
finish_reason=None, finish_reason=None,
) )
] ],
usage=Mock(prompt_tokens=10, completion_tokens=20),
) )
) )
@ -281,7 +308,8 @@ def mock_openai_client():
delta=Mock(content="test", tool_calls=None), delta=Mock(content="test", tool_calls=None),
finish_reason=None, finish_reason=None,
) )
] ],
usage=Mock(prompt_tokens=10, completion_tokens=5),
), ),
Mock( Mock(
choices=[ choices=[
@ -289,7 +317,8 @@ def mock_openai_client():
delta=Mock(content=" response", tool_calls=None), delta=Mock(content=" response", tool_calls=None),
finish_reason="stop", finish_reason="stop",
) )
] ],
usage=Mock(prompt_tokens=10, completion_tokens=15),
), ),
] ]
) )
@ -303,7 +332,8 @@ def mock_openai_client():
), ),
finish_reason=None, finish_reason=None,
) )
] ],
usage=Mock(prompt_tokens=10, completion_tokens=20),
) )
client.chat.completions.create.side_effect = streaming_response client.chat.completions.create.side_effect = streaming_response
@ -312,6 +342,8 @@ def mock_openai_client():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_anthropic_client(): def mock_anthropic_client():
from unittest.mock import AsyncMock
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client: with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
client = mock_client() client = mock_client()
client.messages = Mock() client.messages = Mock()
@ -345,9 +377,59 @@ def mock_anthropic_client():
] ]
) )
) )
# Mock async client
async_client = Mock()
async_client.messages = Mock()
async_client.messages.create = AsyncMock(
return_value=Mock(
content=[
Mock(
type="text",
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
)
]
)
)
# Mock async streaming
def async_stream_ctx(*args, **kwargs):
async def async_iter():
yield Mock(
type="content_block_delta",
delta=Mock(
type="text_delta",
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
),
)
class AsyncStreamMock:
async def __aenter__(self):
return async_iter()
async def __aexit__(self, *args):
pass
return AsyncStreamMock()
async_client.messages.stream = Mock(side_effect=async_stream_ctx)
# Add async_client property to mock
mock_client.return_value._async_client = None
with patch.object(anthropic, "AsyncAnthropic", return_value=async_client):
yield client yield client
@pytest.fixture(autouse=True)
def mock_redis():
"""Mock Redis client for all tests."""
import redis
with patch.object(redis, "Redis", MockRedis):
yield
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_discord_client(): def mock_discord_client():
with patch.object(settings, "DISCORD_NOTIFICATIONS_ENABLED", False): with patch.object(settings, "DISCORD_NOTIFICATIONS_ENABLED", False):

View File

@ -0,0 +1,151 @@
"""Tests for authentication helpers and OAuth callback."""
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from starlette.requests import Request
from memory.api import auth
from memory.common import settings
def make_request(query: str) -> Request:
scope = {
"type": "http",
"method": "GET",
"path": "/auth/callback/discord",
"headers": [],
"query_string": query.encode(),
}
async def receive():
return {"type": "http.request", "body": b"", "more_body": False}
return Request(scope, receive)
def test_get_bearer_token_parses_header():
request = SimpleNamespace(headers={"Authorization": "Bearer token123"})
assert auth.get_bearer_token(request) == "token123"
def test_get_bearer_token_handles_missing_header():
request = SimpleNamespace(headers={})
assert auth.get_bearer_token(request) is None
def test_get_token_prefers_header_over_cookie():
request = SimpleNamespace(
headers={"Authorization": "Bearer header-token"},
cookies={"session": "cookie-token"},
)
assert auth.get_token(request) == "header-token"
def test_get_token_falls_back_to_cookie():
request = SimpleNamespace(
headers={},
cookies={settings.SESSION_COOKIE_NAME: "cookie-token"},
)
assert auth.get_token(request) == "cookie-token"
@patch("memory.api.auth.get_user_session")
def test_logout_removes_session(mock_get_user_session):
db = MagicMock()
session = MagicMock()
mock_get_user_session.return_value = session
request = SimpleNamespace()
result = auth.logout(request, db)
assert result == {"message": "Logged out successfully"}
db.delete.assert_called_once_with(session)
db.commit.assert_called_once()
@patch("memory.api.auth.get_user_session", return_value=None)
def test_logout_handles_missing_session(mock_get_user_session):
db = MagicMock()
request = SimpleNamespace()
result = auth.logout(request, db)
assert result == {"message": "Logged out successfully"}
db.delete.assert_not_called()
db.commit.assert_not_called()
@pytest.mark.asyncio
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session")
async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
mock_session = MagicMock()
@contextmanager
def session_cm():
yield mock_session
mock_make_session.return_value = session_cm()
mcp_server = MagicMock()
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (200, "Authorized")
request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request)
assert response.status_code == 200
body = response.body.decode()
assert "Authorization Successful" in body
assert "Authorized" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session")
async def test_oauth_callback_discord_handles_failures(
mock_make_session, mock_complete
):
mock_session = MagicMock()
@contextmanager
def session_cm():
yield mock_session
mock_make_session.return_value = session_cm()
mcp_server = MagicMock()
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (500, "Failure")
request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request)
assert response.status_code == 500
body = response.body.decode()
assert "Authorization Failed" in body
assert "Failure" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_oauth_callback_discord_validates_query_params():
request = make_request("code=&state=")
response = await auth.oauth_callback_discord(request)
assert response.status_code == 400
body = response.body.decode()
assert "Missing authorization code" in body

View File

@ -1,5 +1,7 @@
"""Tests for Discord database models.""" """Tests for Discord database models."""
from types import SimpleNamespace
import pytest import pytest
from memory.common.db.models import DiscordServer, DiscordChannel, DiscordUser from memory.common.db.models import DiscordServer, DiscordChannel, DiscordUser
@ -19,12 +21,11 @@ def test_create_discord_server(db_session):
assert server.name == "Test Server" assert server.name == "Test Server"
assert server.description == "A test Discord server" assert server.description == "A test Discord server"
assert server.member_count == 100 assert server.member_count == 100
assert server.track_messages is True # default value assert server.ignore_messages is False # default value
assert server.ignore_messages is False
def test_discord_server_as_xml(db_session): def test_discord_server_as_xml(db_session):
"""Test DiscordServer.as_xml() method.""" """Test DiscordServer.to_xml() method."""
server = DiscordServer( server = DiscordServer(
id=123456789, id=123456789,
name="Test Server", name="Test Server",
@ -33,11 +34,11 @@ def test_discord_server_as_xml(db_session):
db_session.add(server) db_session.add(server)
db_session.commit() db_session.commit()
xml = server.as_xml() xml = server.to_xml("name", "summary")
assert "<servers>" in xml # tablename is discord_servers, strips to "servers" assert "<server>" in xml # tablename is discord_servers, strips to "server"
assert "<name>Test Server</name>" in xml assert "<name>" in xml and "Test Server" in xml
assert "<summary>This is a test server for gaming</summary>" in xml assert "<summary>" in xml and "This is a test server for gaming" in xml
assert "</servers>" in xml assert "</server>" in xml
def test_discord_server_message_tracking(db_session): def test_discord_server_message_tracking(db_session):
@ -45,13 +46,11 @@ def test_discord_server_message_tracking(db_session):
server = DiscordServer( server = DiscordServer(
id=123456789, id=123456789,
name="Test Server", name="Test Server",
track_messages=False,
ignore_messages=True, ignore_messages=True,
) )
db_session.add(server) db_session.add(server)
db_session.commit() db_session.commit()
assert server.track_messages is False
assert server.ignore_messages is True assert server.ignore_messages is True
@ -111,7 +110,7 @@ def test_discord_channel_without_server(db_session):
def test_discord_channel_as_xml(db_session): def test_discord_channel_as_xml(db_session):
"""Test DiscordChannel.as_xml() method.""" """Test DiscordChannel.to_xml() method."""
channel = DiscordChannel( channel = DiscordChannel(
id=111222333, id=111222333,
name="general", name="general",
@ -121,30 +120,28 @@ def test_discord_channel_as_xml(db_session):
db_session.add(channel) db_session.add(channel)
db_session.commit() db_session.commit()
xml = channel.as_xml() xml = channel.to_xml("name", "summary")
assert "<channels>" in xml # tablename is discord_channels, strips to "channels" assert "<channel>" in xml # tablename is discord_channels, strips to "channel"
assert "<name>general</name>" in xml assert "<name>" in xml and "general" in xml
assert "<summary>Main discussion channel</summary>" in xml assert "<summary>" in xml and "Main discussion channel" in xml
assert "</channels>" in xml assert "</channel>" in xml
def test_discord_channel_inherits_server_settings(db_session): def test_discord_channel_inherits_server_settings(db_session):
"""Test that channels can have their own or inherit server settings.""" """Test that channels can have their own or inherit server settings."""
server = DiscordServer( server = DiscordServer(id=987654321, name="Server", ignore_messages=False)
id=987654321, name="Server", track_messages=True, ignore_messages=False
)
channel = DiscordChannel( channel = DiscordChannel(
id=111222333, id=111222333,
server_id=server.id, server_id=server.id,
name="announcements", name="announcements",
channel_type="text", channel_type="text",
track_messages=False, # Override server setting ignore_messages=True, # Override server setting
) )
db_session.add_all([server, channel]) db_session.add_all([server, channel])
db_session.commit() db_session.commit()
assert server.track_messages is True assert server.ignore_messages is False
assert channel.track_messages is False assert channel.ignore_messages is True
def test_create_discord_user(db_session): def test_create_discord_user(db_session):
@ -186,7 +183,7 @@ def test_discord_user_with_system_user(db_session):
def test_discord_user_as_xml(db_session): def test_discord_user_as_xml(db_session):
"""Test DiscordUser.as_xml() method.""" """Test DiscordUser.to_xml() method."""
user = DiscordUser( user = DiscordUser(
id=555666777, id=555666777,
username="testuser", username="testuser",
@ -195,11 +192,10 @@ def test_discord_user_as_xml(db_session):
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
xml = user.as_xml() xml = user.to_xml("summary")
assert "<users>" in xml # tablename is discord_users, strips to "users" assert "<user>" in xml # tablename is discord_users, strips to "user"
assert "<name>testuser</name>" in xml assert "<summary>" in xml and "Friendly and helpful community member" in xml
assert "<summary>Friendly and helpful community member</summary>" in xml assert "</user>" in xml
assert "</users>" in xml
def test_discord_user_message_preferences(db_session): def test_discord_user_message_preferences(db_session):
@ -207,13 +203,11 @@ def test_discord_user_message_preferences(db_session):
user = DiscordUser( user = DiscordUser(
id=555666777, id=555666777,
username="testuser", username="testuser",
track_messages=True,
ignore_messages=False, ignore_messages=False,
) )
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
assert user.track_messages is True
assert user.ignore_messages is False assert user.ignore_messages is False
@ -234,6 +228,21 @@ def test_discord_server_channel_relationship(db_session):
assert channel2 in server.channels assert channel2 in server.channels
def test_discord_processor_xml_mcp_servers():
"""Test xml_mcp_servers includes assigned MCP server XML."""
server = DiscordServer(id=111, name="Server")
mcp_stub = SimpleNamespace(
as_xml=lambda: "<mcp_server><name>Example</name></mcp_server>"
)
# Relationship is optional for test purposes; assign directly
server.mcp_servers = [mcp_stub]
xml_output = server.xml_mcp_servers()
assert "<mcp_server>" in xml_output
assert "Example" in xml_output
def test_discord_server_cascade_delete(db_session): def test_discord_server_cascade_delete(db_session):
"""Test that deleting a server cascades to channels.""" """Test that deleting a server cascades to channels."""
server = DiscordServer(id=987654321, name="Test Server") server = DiscordServer(id=987654321, name="Test Server")

View File

@ -0,0 +1,155 @@
import xml.etree.ElementTree as ET
from datetime import datetime, timedelta, timezone
import pytest
from sqlalchemy.exc import IntegrityError
from memory.common.db.models.mcp import MCPServer, MCPServerAssignment
@pytest.mark.parametrize(
"available_tools,expected_tools",
[
(["search", "summarize"], ["• search", "• summarize"]),
([], []),
],
)
def test_mcp_server_as_xml_formats_available_tools(available_tools, expected_tools):
server = MCPServer(
name="Example Server",
mcp_server_url="https://example.com/mcp",
client_id="client-123",
available_tools=available_tools,
)
xml_output = server.as_xml()
root = ET.fromstring(xml_output)
assert root.find("name").text.strip() == "Example Server"
assert root.find("mcp_server_url").text.strip() == "https://example.com/mcp"
assert root.find("client_id").text.strip() == "client-123"
tools_element = root.find("available_tools")
assert tools_element is not None
tools_text = tools_element.text.strip() if tools_element.text else ""
if expected_tools:
assert tools_text.splitlines() == expected_tools
else:
assert tools_text == ""
def test_mcp_server_crud_and_token_expiration(db_session):
initial_expiry = datetime.now(timezone.utc) + timedelta(minutes=30)
server = MCPServer(
name="Initial Server",
mcp_server_url="https://initial.example.com/mcp",
client_id="client-initial",
available_tools=["search"],
access_token="access-123",
refresh_token="refresh-123",
token_expires_at=initial_expiry,
)
db_session.add(server)
db_session.commit()
fetched = db_session.get(MCPServer, server.id)
assert fetched is not None
assert fetched.access_token == "access-123"
assert fetched.refresh_token == "refresh-123"
assert fetched.token_expires_at == initial_expiry
assert fetched.token_expires_at.tzinfo is not None
new_expiry = initial_expiry + timedelta(minutes=15)
fetched.name = "Updated Server"
fetched.available_tools = [*fetched.available_tools, "summarize"]
fetched.access_token = "access-456"
fetched.refresh_token = "refresh-456"
fetched.token_expires_at = new_expiry
db_session.commit()
updated = db_session.get(MCPServer, server.id)
assert updated is not None
assert updated.name == "Updated Server"
assert updated.available_tools == ["search", "summarize"]
assert updated.access_token == "access-456"
assert updated.refresh_token == "refresh-456"
assert updated.token_expires_at == new_expiry
db_session.delete(updated)
db_session.commit()
assert db_session.get(MCPServer, server.id) is None
def test_mcp_server_assignments_relationship_and_cascade(db_session):
server = MCPServer(
name="Cascade Server",
mcp_server_url="https://cascade.example.com/mcp",
client_id="client-cascade",
available_tools=["search"],
)
server.assignments.extend(
[
MCPServerAssignment(entity_type="DiscordUser", entity_id=101),
MCPServerAssignment(entity_type="DiscordChannel", entity_id=202),
]
)
db_session.add(server)
db_session.commit()
persisted_server = db_session.get(MCPServer, server.id)
assert persisted_server is not None
assert len(persisted_server.assignments) == 2
assert {assignment.entity_type for assignment in persisted_server.assignments} == {
"DiscordUser",
"DiscordChannel",
}
assert all(
assignment.mcp_server_id == persisted_server.id
for assignment in persisted_server.assignments
)
db_session.delete(persisted_server)
db_session.commit()
remaining_assignments = db_session.query(MCPServerAssignment).all()
assert remaining_assignments == []
def test_mcp_server_assignment_unique_constraint(db_session):
server = MCPServer(
name="Unique Server",
mcp_server_url="https://unique.example.com/mcp",
client_id="client-unique",
available_tools=["search"],
)
assignment = MCPServerAssignment(
entity_type="DiscordUser",
entity_id=12345,
)
server.assignments.append(assignment)
db_session.add(server)
db_session.commit()
duplicate_assignment = MCPServerAssignment(
mcp_server_id=server.id,
entity_type="DiscordUser",
entity_id=12345,
)
db_session.add(duplicate_assignment)
with pytest.raises(IntegrityError):
db_session.commit()
db_session.rollback()
assignments = (
db_session.query(MCPServerAssignment)
.filter(MCPServerAssignment.mcp_server_id == server.id)
.all()
)
assert len(assignments) == 1

View File

@ -162,8 +162,18 @@ def test_create_bot_user_auto_api_key(db_session):
def test_create_discord_bot_user(db_session): def test_create_discord_bot_user(db_session):
"""Test creating a DiscordBotUser""" """Test creating a DiscordBotUser"""
from memory.common.db.models import DiscordUser
# Create a Discord user for the bot
discord_user = DiscordUser(
id=123456789,
username="botuser",
)
db_session.add(discord_user)
db_session.commit()
user = DiscordBotUser.create_with_api_key( user = DiscordBotUser.create_with_api_key(
discord_users=[], discord_users=[discord_user],
name="Discord Bot", name="Discord Bot",
email="discordbot@example.com", email="discordbot@example.com",
api_key="discord_key_123", api_key="discord_key_123",
@ -176,6 +186,7 @@ def test_create_discord_bot_user(db_session):
assert user.name == "Discord Bot" assert user.name == "Discord Bot"
assert user.user_type == "discord_bot" assert user.user_type == "discord_bot"
assert user.api_key == "discord_key_123" assert user.api_key == "discord_key_123"
assert len(user.discord_users) == 1
def test_user_serialization_human(db_session): def test_user_serialization_human(db_session):

View File

@ -131,7 +131,7 @@ def test_build_request_kwargs_basic(provider):
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(temperature=0.5, max_tokens=1000) settings = LLMSettings(temperature=0.5, max_tokens=1000)
kwargs = provider._build_request_kwargs(messages, None, None, settings) kwargs = provider._build_request_kwargs(messages, None, None, None, settings)
assert kwargs["model"] == "claude-3-opus-20240229" assert kwargs["model"] == "claude-3-opus-20240229"
assert kwargs["temperature"] == 0.5 assert kwargs["temperature"] == 0.5
@ -143,7 +143,9 @@ def test_build_request_kwargs_with_system_prompt(provider):
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings() settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings) kwargs = provider._build_request_kwargs(
messages, "system prompt", None, None, settings
)
assert kwargs["system"] == "system prompt" assert kwargs["system"] == "system prompt"
@ -160,7 +162,7 @@ def test_build_request_kwargs_with_tools(provider):
] ]
settings = LLMSettings() settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, None, tools, settings) kwargs = provider._build_request_kwargs(messages, None, tools, None, settings)
assert "tools" in kwargs assert "tools" in kwargs
assert len(kwargs["tools"]) == 1 assert len(kwargs["tools"]) == 1
@ -170,7 +172,9 @@ def test_build_request_kwargs_with_thinking(thinking_provider):
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(max_tokens=5000) settings = LLMSettings(max_tokens=5000)
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings) kwargs = thinking_provider._build_request_kwargs(
messages, None, None, None, settings
)
assert "thinking" in kwargs assert "thinking" in kwargs
assert kwargs["thinking"]["type"] == "enabled" assert kwargs["thinking"]["type"] == "enabled"
@ -183,7 +187,9 @@ def test_build_request_kwargs_thinking_insufficient_tokens(thinking_provider):
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(max_tokens=1000) settings = LLMSettings(max_tokens=1000)
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings) kwargs = thinking_provider._build_request_kwargs(
messages, None, None, None, settings
)
# Shouldn't enable thinking if not enough tokens # Shouldn't enable thinking if not enough tokens
assert "thinking" not in kwargs assert "thinking" not in kwargs
@ -326,7 +332,7 @@ async def test_agenerate_basic(provider, mock_anthropic_client):
result = await provider.agenerate(messages) result = await provider.agenerate(messages)
assert result == "test summary" assert "<summary>test summary</summary>" in result
provider.async_client.messages.create.assert_called_once() provider.async_client.messages.create.assert_called_once()

View File

@ -1,5 +1,5 @@
import pytest import pytest
from unittest.mock import Mock from unittest.mock import Mock, AsyncMock
from PIL import Image from PIL import Image
from memory.common.llms.openai_provider import OpenAIProvider from memory.common.llms.openai_provider import OpenAIProvider
@ -192,7 +192,7 @@ def test_build_request_kwargs_basic(provider):
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(temperature=0.5, max_tokens=1000) settings = LLMSettings(temperature=0.5, max_tokens=1000)
kwargs = provider._build_request_kwargs(messages, None, None, settings) kwargs = provider._build_request_kwargs(messages, None, None, None, settings)
assert kwargs["model"] == "gpt-4o" assert kwargs["model"] == "gpt-4o"
assert kwargs["temperature"] == 0.5 assert kwargs["temperature"] == 0.5
@ -204,7 +204,9 @@ def test_build_request_kwargs_with_system_prompt_standard_model(provider):
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings() settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings) kwargs = provider._build_request_kwargs(
messages, "system prompt", None, None, settings
)
# For gpt-4o, system prompt becomes system message # For gpt-4o, system prompt becomes system message
assert kwargs["messages"][0]["role"] == "system" assert kwargs["messages"][0]["role"] == "system"
@ -218,7 +220,7 @@ def test_build_request_kwargs_with_system_prompt_reasoning_model(
settings = LLMSettings() settings = LLMSettings()
kwargs = reasoning_provider._build_request_kwargs( kwargs = reasoning_provider._build_request_kwargs(
messages, "system prompt", None, settings messages, "system prompt", None, None, settings
) )
# For o1 models, system prompt becomes developer message # For o1 models, system prompt becomes developer message
@ -232,7 +234,9 @@ def test_build_request_kwargs_reasoning_model_uses_max_completion_tokens(
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(max_tokens=2000) settings = LLMSettings(max_tokens=2000)
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings) kwargs = reasoning_provider._build_request_kwargs(
messages, None, None, None, settings
)
# Reasoning models use max_completion_tokens # Reasoning models use max_completion_tokens
assert "max_completion_tokens" in kwargs assert "max_completion_tokens" in kwargs
@ -244,7 +248,9 @@ def test_build_request_kwargs_reasoning_model_no_temperature(reasoning_provider)
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings(temperature=0.7) settings = LLMSettings(temperature=0.7)
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings) kwargs = reasoning_provider._build_request_kwargs(
messages, None, None, None, settings
)
# Reasoning models don't support temperature # Reasoning models don't support temperature
assert "temperature" not in kwargs assert "temperature" not in kwargs
@ -263,7 +269,7 @@ def test_build_request_kwargs_with_tools(provider):
] ]
settings = LLMSettings() settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, None, tools, settings) kwargs = provider._build_request_kwargs(messages, None, tools, None, settings)
assert "tools" in kwargs assert "tools" in kwargs
assert len(kwargs["tools"]) == 1 assert len(kwargs["tools"]) == 1
@ -274,7 +280,9 @@ def test_build_request_kwargs_with_stream(provider):
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
settings = LLMSettings() settings = LLMSettings()
kwargs = provider._build_request_kwargs(messages, None, None, settings, stream=True) kwargs = provider._build_request_kwargs(
messages, None, None, None, settings, stream=True
)
assert kwargs["stream"] is True assert kwargs["stream"] is True
@ -314,7 +322,8 @@ def test_handle_stream_chunk_text_content(provider):
delta=Mock(content="hello", tool_calls=None), delta=Mock(content="hello", tool_calls=None),
finish_reason=None, finish_reason=None,
) )
] ],
usage=Mock(prompt_tokens=10, completion_tokens=5),
) )
events, tool_call = provider._handle_stream_chunk(chunk, None) events, tool_call = provider._handle_stream_chunk(chunk, None)
@ -342,8 +351,9 @@ def test_handle_stream_chunk_tool_call_start(provider):
choice.delta = delta choice.delta = delta
choice.finish_reason = None choice.finish_reason = None
chunk = Mock(spec=["choices"]) chunk = Mock(spec=["choices", "usage"])
chunk.choices = [choice] chunk.choices = [choice]
chunk.usage = Mock(prompt_tokens=10, completion_tokens=5)
events, tool_call = provider._handle_stream_chunk(chunk, None) events, tool_call = provider._handle_stream_chunk(chunk, None)
@ -369,7 +379,8 @@ def test_handle_stream_chunk_tool_call_arguments(provider):
), ),
finish_reason=None, finish_reason=None,
) )
] ],
usage=Mock(prompt_tokens=10, completion_tokens=5),
) )
events, tool_call = provider._handle_stream_chunk(chunk, current_tool) events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
@ -386,7 +397,8 @@ def test_handle_stream_chunk_finish_with_tool_call(provider):
delta=Mock(content=None, tool_calls=None), delta=Mock(content=None, tool_calls=None),
finish_reason="tool_calls", finish_reason="tool_calls",
) )
] ],
usage=Mock(prompt_tokens=10, completion_tokens=5),
) )
events, tool_call = provider._handle_stream_chunk(chunk, current_tool) events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
@ -399,7 +411,7 @@ def test_handle_stream_chunk_finish_with_tool_call(provider):
def test_handle_stream_chunk_empty_choices(provider): def test_handle_stream_chunk_empty_choices(provider):
chunk = Mock(choices=[]) chunk = Mock(choices=[], usage=Mock(prompt_tokens=10, completion_tokens=5))
events, tool_call = provider._handle_stream_chunk(chunk, None) events, tool_call = provider._handle_stream_chunk(chunk, None)
@ -435,8 +447,13 @@ async def test_agenerate_basic(provider, mock_openai_client):
messages = [Message(role=MessageRole.USER, content="test")] messages = [Message(role=MessageRole.USER, content="test")]
# Mock the async client # Mock the async client
mock_response = Mock(choices=[Mock(message=Mock(content="async response"))]) mock_response = Mock(
provider.async_client.chat.completions.create = Mock(return_value=mock_response) choices=[Mock(message=Mock(content="async response"))],
usage=Mock(prompt_tokens=10, completion_tokens=20),
)
provider.async_client.chat.completions.create = AsyncMock(
return_value=mock_response
)
result = await provider.agenerate(messages) result = await provider.agenerate(messages)
@ -452,15 +469,19 @@ async def test_astream_basic(provider, mock_openai_client):
yield Mock( yield Mock(
choices=[ choices=[
Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None) Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None)
] ],
usage=Mock(prompt_tokens=10, completion_tokens=5),
) )
yield Mock( yield Mock(
choices=[ choices=[
Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop") Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop")
] ],
usage=Mock(prompt_tokens=10, completion_tokens=10),
) )
provider.async_client.chat.completions.create = Mock(return_value=async_stream()) provider.async_client.chat.completions.create = AsyncMock(
return_value=async_stream()
)
events = [] events = []
async for event in provider.astream(messages): async for event in provider.astream(messages):

View File

@ -18,10 +18,10 @@ except ModuleNotFoundError: # pragma: no cover - import guard for test envs
sys.modules.setdefault("redis", _RedisStub()) sys.modules.setdefault("redis", _RedisStub())
from memory.common.llms.redis_usage_tracker import RedisUsageTracker from memory.common.llms.usage import (
from memory.common.llms.usage_tracker import (
InMemoryUsageTracker, InMemoryUsageTracker,
RateLimitConfig, RateLimitConfig,
RedisUsageTracker,
UsageTracker, UsageTracker,
) )
@ -84,7 +84,9 @@ def redis_tracker() -> RedisUsageTracker:
(timedelta(seconds=0), {"max_total_tokens": 1}), (timedelta(seconds=0), {"max_total_tokens": 1}),
], ],
) )
def test_rate_limit_config_validation(window: timedelta, kwargs: dict[str, int]) -> None: def test_rate_limit_config_validation(
window: timedelta, kwargs: dict[str, int]
) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
RateLimitConfig(window=window, **kwargs) RateLimitConfig(window=window, **kwargs)
@ -93,9 +95,7 @@ def test_allows_usage_within_limits(tracker: InMemoryUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc) now = datetime(2024, 1, 1, tzinfo=timezone.utc)
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now) tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
allowance = tracker.get_available_tokens( allowance = tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
"anthropic/claude-3", timestamp=now
)
assert allowance is not None assert allowance is not None
assert allowance.input_tokens == 900 assert allowance.input_tokens == 900
assert allowance.output_tokens == 1_800 assert allowance.output_tokens == 1_800
@ -114,9 +114,7 @@ def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now) tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
later = now + timedelta(minutes=2) later = now + timedelta(minutes=2)
allowance = tracker.get_available_tokens( allowance = tracker.get_available_tokens("anthropic/claude-3", timestamp=later)
"anthropic/claude-3", timestamp=later
)
assert allowance is not None assert allowance is not None
assert allowance.input_tokens == 1_000 assert allowance.input_tokens == 1_000
assert allowance.output_tokens == 2_000 assert allowance.output_tokens == 2_000
@ -126,6 +124,7 @@ def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None: def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc) now = datetime(2024, 1, 1, tzinfo=timezone.utc)
# Use the configured models from the fixture
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now) tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now) tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
@ -144,6 +143,7 @@ def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> N
def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None: def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc) now = datetime(2024, 1, 1, tzinfo=timezone.utc)
# Use configured models from the fixture
tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now) tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now)
tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now) tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now)
@ -156,15 +156,19 @@ def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
assert set(filtered_model["openai"].keys()) == {"gpt-4o"} assert set(filtered_model["openai"].keys()) == {"gpt-4o"}
def test_missing_configuration_records_lifetime_only() -> None: def test_missing_configuration_uses_default() -> None:
# With no specific config, falls back to default config (from settings)
tracker = InMemoryUsageTracker(configs={}) tracker = InMemoryUsageTracker(configs={})
tracker.record_usage("openai/gpt-4o", 10, 20) tracker.record_usage("openai/gpt-4o", 10, 20)
assert tracker.get_available_tokens("openai/gpt-4o") is None # Uses default config, so get_available_tokens returns allowance
allowance = tracker.get_available_tokens("openai/gpt-4o")
assert allowance is not None
# Lifetime stats are tracked
breakdown = tracker.get_usage_breakdown() breakdown = tracker.get_usage_breakdown()
usage = breakdown["openai"]["gpt-4o"] usage = breakdown["openai"]["gpt-4o"]
assert usage.window_input_tokens == 0 assert usage.window_input_tokens == 10
assert usage.lifetime_input_tokens == 10 assert usage.lifetime_input_tokens == 10
@ -193,6 +197,7 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None: def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc) now = datetime(2024, 1, 1, tzinfo=timezone.utc)
# Use configured models from the fixture
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now) redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now) redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
@ -201,6 +206,8 @@ def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) ->
assert allowance.input_tokens == 900 assert allowance.input_tokens == 900
breakdown = redis_tracker.get_usage_breakdown() breakdown = redis_tracker.get_usage_breakdown()
assert "anthropic" in breakdown
assert "claude-3" in breakdown["anthropic"]
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200 assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
items = dict(redis_tracker.iter_state_items()) items = dict(redis_tracker.iter_state_items())

View File

@ -0,0 +1,26 @@
"""Tests for base web tool definitions."""
from memory.common.llms.tools.base import WebFetchTool, WebSearchTool
def test_web_search_tool_provider_formats():
tool = WebSearchTool()
assert tool.provider_format("openai") == {"type": "web_search"}
assert tool.provider_format("anthropic") == {
"type": "web_search_20250305",
"name": "web_search",
"max_uses": 10,
}
assert tool.provider_format("unknown") is None
def test_web_fetch_tool_provider_formats():
tool = WebFetchTool()
assert tool.provider_format("anthropic") == {
"type": "web_fetch_20250910",
"name": "web_fetch",
"max_uses": 10,
}
assert tool.provider_format("openai") is None

View File

@ -497,13 +497,14 @@ def test_make_discord_tools_with_user_and_channel(
) )
# Should have: schedule_message, previous_messages, update_channel_summary, # Should have: schedule_message, previous_messages, update_channel_summary,
# update_user_summary, update_server_summary # update_user_summary, update_server_summary, add_reaction
assert len(tools) == 5 assert len(tools) == 6
assert "schedule_message" in tools assert "schedule_message" in tools
assert "previous_messages" in tools assert "previous_messages" in tools
assert "update_channel_summary" in tools assert "update_channel_summary" in tools
assert "update_user_summary" in tools assert "update_user_summary" in tools
assert "update_server_summary" in tools assert "update_server_summary" in tools
assert "add_reaction" in tools
def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user): def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user):
@ -533,12 +534,13 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
) )
# Should have: schedule_message, previous_messages, update_channel_summary, # Should have: schedule_message, previous_messages, update_channel_summary,
# update_server_summary (no user summary without author) # update_server_summary, add_reaction (no user summary without author)
assert len(tools) == 4 assert len(tools) == 5
assert "schedule_message" in tools assert "schedule_message" in tools
assert "previous_messages" in tools assert "previous_messages" in tools
assert "update_channel_summary" in tools assert "update_channel_summary" in tools
assert "update_server_summary" in tools assert "update_server_summary" in tools
assert "add_reaction" in tools
assert "update_user_summary" not in tools assert "update_user_summary" not in tools

View File

@ -0,0 +1,539 @@
"""Tests for OAuth 2.0 flow handling."""
import pytest
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, Mock, patch
import aiohttp
from memory.common.oauth import (
OAuthEndpoints,
generate_pkce_pair,
discover_oauth_metadata,
get_endpoints,
register_oauth_client,
issue_challenge,
complete_oauth_flow,
)
from memory.common.db.models import MCPServer
class TestGeneratePkcePair:
"""Tests for generate_pkce_pair function."""
def test_generates_valid_verifier_and_challenge(self):
"""Test that PKCE pair is generated correctly."""
verifier, challenge = generate_pkce_pair()
# Verifier should be base64url encoded (no padding)
assert len(verifier) > 0
assert "=" not in verifier
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" for c in verifier)
# Challenge should be base64url encoded (no padding)
assert len(challenge) > 0
assert "=" not in challenge
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" for c in challenge)
# They should be different
assert verifier != challenge
def test_generates_unique_pairs(self):
"""Test that each call generates a unique pair."""
verifier1, challenge1 = generate_pkce_pair()
verifier2, challenge2 = generate_pkce_pair()
assert verifier1 != verifier2
assert challenge1 != challenge2
class TestDiscoverOauthMetadata:
"""Tests for discover_oauth_metadata function."""
@pytest.mark.asyncio
async def test_discover_metadata_success(self):
"""Test successful OAuth metadata discovery."""
metadata = {
"authorization_endpoint": "https://example.com/auth",
"registration_endpoint": "https://example.com/register",
"token_endpoint": "https://example.com/token",
}
mock_response = Mock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=metadata)
mock_get = AsyncMock()
mock_get.__aenter__.return_value = mock_response
mock_get.__aexit__.return_value = None
mock_session = Mock()
mock_session.get = Mock(return_value=mock_get)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
result = await discover_oauth_metadata("https://example.com")
assert result == metadata
assert result["authorization_endpoint"] == "https://example.com/auth"
@pytest.mark.asyncio
async def test_discover_metadata_not_found(self):
"""Test OAuth metadata discovery when endpoint not found."""
mock_response = Mock()
mock_response.status = 404
mock_get = AsyncMock()
mock_get.__aenter__.return_value = mock_response
mock_get.__aexit__.return_value = None
mock_session = Mock()
mock_session.get = Mock(return_value=mock_get)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
result = await discover_oauth_metadata("https://example.com")
assert result is None
@pytest.mark.asyncio
async def test_discover_metadata_connection_error(self):
"""Test OAuth metadata discovery with connection error."""
mock_get = AsyncMock()
mock_get.__aenter__.side_effect = aiohttp.ClientError("Connection failed")
mock_session = Mock()
mock_session.get = Mock(return_value=mock_get)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
result = await discover_oauth_metadata("https://example.com")
assert result is None
class TestGetEndpoints:
"""Tests for get_endpoints function."""
@pytest.mark.asyncio
async def test_get_endpoints_success(self):
"""Test successful endpoint retrieval."""
metadata = {
"authorization_endpoint": "https://example.com/auth",
"registration_endpoint": "https://example.com/register",
"token_endpoint": "https://example.com/token",
}
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
result = await get_endpoints("https://example.com")
assert isinstance(result, OAuthEndpoints)
assert result.authorization_endpoint == "https://example.com/auth"
assert result.registration_endpoint == "https://example.com/register"
assert result.token_endpoint == "https://example.com/token"
assert "/auth/callback/discord" in result.redirect_uri
@pytest.mark.asyncio
async def test_get_endpoints_no_metadata(self):
"""Test when OAuth metadata cannot be discovered."""
with patch("memory.common.oauth.discover_oauth_metadata", return_value=None):
with pytest.raises(ValueError, match="Failed to connect to MCP server"):
await get_endpoints("https://example.com")
@pytest.mark.asyncio
async def test_get_endpoints_missing_authorization(self):
"""Test when authorization endpoint is missing."""
metadata = {
"registration_endpoint": "https://example.com/register",
"token_endpoint": "https://example.com/token",
}
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
with pytest.raises(ValueError, match="authorization endpoint"):
await get_endpoints("https://example.com")
@pytest.mark.asyncio
async def test_get_endpoints_missing_registration(self):
"""Test when registration endpoint is missing."""
metadata = {
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
}
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
with pytest.raises(ValueError, match="dynamic client registration"):
await get_endpoints("https://example.com")
@pytest.mark.asyncio
async def test_get_endpoints_missing_token(self):
"""Test when token endpoint is missing."""
metadata = {
"authorization_endpoint": "https://example.com/auth",
"registration_endpoint": "https://example.com/register",
}
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
with pytest.raises(ValueError, match="token endpoint"):
await get_endpoints("https://example.com")
class TestRegisterOauthClient:
"""Tests for register_oauth_client function."""
@pytest.mark.asyncio
async def test_register_client_success(self):
"""Test successful OAuth client registration."""
endpoints = OAuthEndpoints(
authorization_endpoint="https://example.com/auth",
registration_endpoint="https://example.com/register",
token_endpoint="https://example.com/token",
redirect_uri="https://myapp.com/callback",
)
client_info = {"client_id": "test-client-123"}
mock_response = Mock()
mock_response.status = 200
mock_response.text = AsyncMock(return_value="Success")
mock_response.json = AsyncMock(return_value=client_info)
mock_response.raise_for_status = Mock()
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
client_id = await register_oauth_client(
endpoints,
"https://example.com",
"Test Client",
)
assert client_id == "test-client-123"
@pytest.mark.asyncio
async def test_register_client_http_error(self):
"""Test OAuth client registration with HTTP error."""
endpoints = OAuthEndpoints(
authorization_endpoint="https://example.com/auth",
registration_endpoint="https://example.com/register",
token_endpoint="https://example.com/token",
redirect_uri="https://myapp.com/callback",
)
mock_response = Mock()
mock_response.raise_for_status = Mock(side_effect=aiohttp.ClientResponseError(
request_info=Mock(),
history=(),
status=400,
message="Bad Request",
))
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
with pytest.raises(ValueError, match="Failed to register OAuth client"):
await register_oauth_client(
endpoints,
"https://example.com",
"Test Client",
)
@pytest.mark.asyncio
async def test_register_client_missing_client_id(self):
"""Test OAuth client registration when response lacks client_id."""
endpoints = OAuthEndpoints(
authorization_endpoint="https://example.com/auth",
registration_endpoint="https://example.com/register",
token_endpoint="https://example.com/token",
redirect_uri="https://myapp.com/callback",
)
client_info = {} # Missing client_id
mock_response = Mock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=client_info)
mock_response.raise_for_status = Mock()
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
with pytest.raises(ValueError, match="Failed to register OAuth client"):
await register_oauth_client(
endpoints,
"https://example.com",
"Test Client",
)
class TestIssueChallenge:
"""Tests for issue_challenge function."""
@pytest.mark.asyncio
async def test_issue_challenge_success(self, db_session):
"""Test successful OAuth challenge issuance."""
mcp_server = MCPServer(
name="Test Server",
mcp_server_url="https://example.com",
client_id="test-client-123",
)
db_session.add(mcp_server)
db_session.commit()
endpoints = OAuthEndpoints(
authorization_endpoint="https://example.com/auth",
registration_endpoint="https://example.com/register",
token_endpoint="https://example.com/token",
redirect_uri="https://myapp.com/callback",
)
with patch("memory.common.oauth.generate_pkce_pair", return_value=("verifier123", "challenge123")):
auth_url = await issue_challenge(mcp_server, endpoints)
# Verify the auth URL contains expected parameters
assert "https://example.com/auth?" in auth_url
assert "client_id=test-client-123" in auth_url
# redirect_uri will be URL encoded
assert "redirect_uri=" in auth_url
assert "myapp.com" in auth_url
assert "callback" in auth_url
assert "response_type=code" in auth_url
assert "code_challenge=challenge123" in auth_url
assert "code_challenge_method=S256" in auth_url
assert "state=" in auth_url
# Verify state and code_verifier were stored
assert mcp_server.state is not None
assert mcp_server.code_verifier == "verifier123"
class TestCompleteOauthFlow:
"""Tests for complete_oauth_flow function."""
@pytest.mark.asyncio
async def test_complete_oauth_flow_success(self, db_session):
"""Test successful OAuth flow completion."""
mcp_server = MCPServer(
name="Test Server",
mcp_server_url="https://example.com",
client_id="test-client-123",
state="test-state",
code_verifier="test-verifier",
)
db_session.add(mcp_server)
db_session.commit()
metadata = {
"authorization_endpoint": "https://example.com/auth",
"registration_endpoint": "https://example.com/register",
"token_endpoint": "https://example.com/token",
}
token_response = {
"access_token": "access-token-123",
"refresh_token": "refresh-token-123",
"expires_in": 3600,
}
mock_token_response = Mock()
mock_token_response.status = 200
mock_token_response.json = AsyncMock(return_value=token_response)
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_token_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with (
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
):
status, message = await complete_oauth_flow(
mcp_server,
"auth-code-123",
"test-state",
)
assert status == 200
assert "successful" in message
# Verify tokens were stored
assert mcp_server.access_token == "access-token-123"
assert mcp_server.refresh_token == "refresh-token-123"
assert mcp_server.token_expires_at is not None
# Verify temporary state was cleared
assert mcp_server.state is None
assert mcp_server.code_verifier is None
@pytest.mark.asyncio
async def test_complete_oauth_flow_invalid_state(self):
"""Test OAuth flow completion with invalid state."""
status, message = await complete_oauth_flow(
None,
"auth-code-123",
"invalid-state",
)
assert status == 400
assert "Invalid or expired" in message
@pytest.mark.asyncio
async def test_complete_oauth_flow_token_error(self, db_session):
"""Test OAuth flow completion when token exchange fails."""
mcp_server = MCPServer(
name="Test Server",
mcp_server_url="https://example.com",
client_id="test-client-123",
state="test-state",
code_verifier="test-verifier",
)
db_session.add(mcp_server)
db_session.commit()
metadata = {
"authorization_endpoint": "https://example.com/auth",
"registration_endpoint": "https://example.com/register",
"token_endpoint": "https://example.com/token",
}
mock_token_response = Mock()
mock_token_response.status = 400
mock_token_response.text = AsyncMock(return_value="Invalid grant")
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_token_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with (
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
):
status, message = await complete_oauth_flow(
mcp_server,
"invalid-code",
"test-state",
)
assert status == 500
assert "Token exchange failed" in message
@pytest.mark.asyncio
async def test_complete_oauth_flow_missing_access_token(self, db_session):
"""Test OAuth flow completion when access token is missing from response."""
mcp_server = MCPServer(
name="Test Server",
mcp_server_url="https://example.com",
client_id="test-client-123",
state="test-state",
code_verifier="test-verifier",
)
db_session.add(mcp_server)
db_session.commit()
metadata = {
"authorization_endpoint": "https://example.com/auth",
"registration_endpoint": "https://example.com/register",
"token_endpoint": "https://example.com/token",
}
token_response = {} # Missing access_token
mock_token_response = Mock()
mock_token_response.status = 200
mock_token_response.json = AsyncMock(return_value=token_response)
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_token_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with (
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
):
status, message = await complete_oauth_flow(
mcp_server,
"auth-code-123",
"test-state",
)
assert status == 500
assert "did not include access_token" in message
@pytest.mark.asyncio
async def test_complete_oauth_flow_get_endpoints_error(self, db_session):
"""Test OAuth flow completion when getting endpoints fails."""
mcp_server = MCPServer(
name="Test Server",
mcp_server_url="https://example.com",
client_id="test-client-123",
state="test-state",
code_verifier="test-verifier",
)
db_session.add(mcp_server)
db_session.commit()
with patch("memory.common.oauth.discover_oauth_metadata", return_value=None):
status, message = await complete_oauth_flow(
mcp_server,
"auth-code-123",
"test-state",
)
assert status == 500
assert "Failed to get OAuth endpoints" in message

View File

@ -0,0 +1,253 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import pytest
from fastapi import HTTPException
from memory.discord import api
@pytest.fixture(autouse=True)
def reset_app_bots():
existing = getattr(api.app, "bots", None)
api.app.bots = {}
yield
if existing is None:
delattr(api.app, "bots")
else:
api.app.bots = existing
@pytest.fixture
def active_bot():
collector = SimpleNamespace(
send_dm=AsyncMock(return_value=True),
trigger_typing_dm=AsyncMock(return_value=True),
send_to_channel=AsyncMock(return_value=True),
trigger_typing_channel=AsyncMock(return_value=True),
add_reaction=AsyncMock(return_value=True),
refresh_metadata=AsyncMock(return_value={"refreshed": True}),
is_closed=Mock(return_value=False),
user="CollectorUser#1234",
guilds=[101, 202],
)
bot = SimpleNamespace(
collector=collector,
collector_task=None,
bot_id=1,
bot_token="token-123",
bot_name="Test Bot",
)
api.app.bots[bot.bot_id] = bot
return bot
@pytest.mark.asyncio
async def test_send_dm_success(active_bot):
request = api.SendDMRequest(bot_id=active_bot.bot_id, user="user123", message="Hello")
response = await api.send_dm_endpoint(request)
assert response == {
"success": True,
"message": "DM sent to user123",
"user": "user123",
}
active_bot.collector.send_dm.assert_awaited_once_with("user123", "Hello")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint,payload",
[
(
api.send_dm_endpoint,
api.SendDMRequest(bot_id=99, user="ghost", message="hi"),
),
(
api.trigger_dm_typing,
api.TypingDMRequest(bot_id=99, user="ghost"),
),
(
api.send_channel_endpoint,
api.SendChannelRequest(bot_id=99, channel="general", message="hello"),
),
(
api.trigger_channel_typing,
api.TypingChannelRequest(bot_id=99, channel="general"),
),
(
api.add_reaction_endpoint,
api.AddReactionRequest(
bot_id=99,
channel="general",
message_id=42,
emoji=":thumbsup:",
),
),
],
)
async def test_endpoint_returns_404_when_bot_missing(endpoint, payload):
with pytest.raises(HTTPException) as exc:
await endpoint(payload)
assert exc.value.status_code == 404
assert exc.value.detail == "Bot not found"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint,request_cls,request_kwargs,attr_name,detail_template",
[
(
api.send_dm_endpoint,
api.SendDMRequest,
{"bot_id": 1, "user": "user123", "message": "Hi"},
"send_dm",
"Failed to send DM to {user}",
),
(
api.trigger_dm_typing,
api.TypingDMRequest,
{"bot_id": 1, "user": "user123"},
"trigger_typing_dm",
"Failed to trigger typing for user123",
),
(
api.send_channel_endpoint,
api.SendChannelRequest,
{"bot_id": 1, "channel": "general", "message": "Hello"},
"send_to_channel",
"Failed to send message to channel general",
),
(
api.trigger_channel_typing,
api.TypingChannelRequest,
{"bot_id": 1, "channel": "general"},
"trigger_typing_channel",
"Failed to trigger typing for channel general",
),
(
api.add_reaction_endpoint,
api.AddReactionRequest,
{"bot_id": 1, "channel": "general", "message_id": 55, "emoji": ":fire:"},
"add_reaction",
"Failed to add reaction to message 55",
),
],
)
async def test_endpoint_returns_400_on_collector_failure(
active_bot, endpoint, request_cls, request_kwargs, attr_name, detail_template
):
request = request_cls(**request_kwargs)
getattr(active_bot.collector, attr_name).return_value = False
expected_detail = detail_template.format(**request_kwargs)
with pytest.raises(HTTPException) as exc:
await endpoint(request)
assert exc.value.status_code == 400
assert exc.value.detail == expected_detail
@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint,request_cls,request_kwargs,attr_name",
[
(
api.send_dm_endpoint,
api.SendDMRequest,
{"bot_id": 1, "user": "user123", "message": "Hi"},
"send_dm",
),
(
api.trigger_dm_typing,
api.TypingDMRequest,
{"bot_id": 1, "user": "user123"},
"trigger_typing_dm",
),
(
api.send_channel_endpoint,
api.SendChannelRequest,
{"bot_id": 1, "channel": "general", "message": "Hello"},
"send_to_channel",
),
(
api.trigger_channel_typing,
api.TypingChannelRequest,
{"bot_id": 1, "channel": "general"},
"trigger_typing_channel",
),
(
api.add_reaction_endpoint,
api.AddReactionRequest,
{"bot_id": 1, "channel": "general", "message_id": 55, "emoji": ":fire:"},
"add_reaction",
),
],
)
async def test_endpoint_returns_500_on_collector_exception(
active_bot, endpoint, request_cls, request_kwargs, attr_name
):
request = request_cls(**request_kwargs)
getattr(active_bot.collector, attr_name).side_effect = RuntimeError("boom")
with pytest.raises(HTTPException) as exc:
await endpoint(request)
assert exc.value.status_code == 500
assert "boom" in exc.value.detail
@pytest.mark.asyncio
async def test_health_check_success(active_bot):
response = await api.health_check()
assert response["Test Bot"] == {
"status": "healthy",
"connected": True,
"user": "CollectorUser#1234",
"guilds": 2,
}
active_bot.collector.is_closed.assert_called_once_with()
@pytest.mark.asyncio
async def test_health_check_without_bots():
with pytest.raises(HTTPException) as exc:
await api.health_check()
assert exc.value.status_code == 503
assert exc.value.detail == "Discord collector not running"
@pytest.mark.asyncio
async def test_refresh_metadata_success(active_bot):
active_bot.collector.refresh_metadata.return_value = {"channels": 3}
response = await api.refresh_metadata()
assert response["success"] is True
assert response["message"] == "Metadata refreshed successfully for 1 bots"
assert response["results"]["Test Bot"] == {"channels": 3}
active_bot.collector.refresh_metadata.assert_awaited_once_with()
@pytest.mark.asyncio
async def test_refresh_metadata_without_bots():
with pytest.raises(HTTPException) as exc:
await api.refresh_metadata()
assert exc.value.status_code == 503
assert exc.value.detail == "Discord collector not running"
@pytest.mark.asyncio
async def test_refresh_metadata_failure(active_bot):
active_bot.collector.refresh_metadata.side_effect = RuntimeError("sync failed")
with pytest.raises(HTTPException) as exc:
await api.refresh_metadata()
assert exc.value.status_code == 500
assert "sync failed" in exc.value.detail

View File

@ -79,6 +79,7 @@ def mock_message(mock_text_channel, mock_user):
message.content = "Test message" message.content = "Test message"
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
message.reference = None message.reference = None
message.attachments = []
return message return message
@ -351,7 +352,7 @@ def test_determine_message_metadata_thread():
# Tests for should_track_message # Tests for should_track_message
def test_should_track_message_server_disabled(db_session): def test_should_track_message_server_disabled(db_session):
"""Test when server has tracking disabled""" """Test when server has tracking disabled"""
server = DiscordServer(id=1, name="Server", track_messages=False) server = DiscordServer(id=1, name="Server", ignore_messages=True)
channel = DiscordChannel(id=2, name="Channel", channel_type="text") channel = DiscordChannel(id=2, name="Channel", channel_type="text")
user = DiscordUser(id=3, username="User") user = DiscordUser(id=3, username="User")
@ -362,9 +363,9 @@ def test_should_track_message_server_disabled(db_session):
def test_should_track_message_channel_disabled(db_session): def test_should_track_message_channel_disabled(db_session):
"""Test when channel has tracking disabled""" """Test when channel has tracking disabled"""
server = DiscordServer(id=1, name="Server", track_messages=True) server = DiscordServer(id=1, name="Server", ignore_messages=False)
channel = DiscordChannel( channel = DiscordChannel(
id=2, name="Channel", channel_type="text", track_messages=False id=2, name="Channel", channel_type="text", ignore_messages=True
) )
user = DiscordUser(id=3, username="User") user = DiscordUser(id=3, username="User")
@ -375,8 +376,8 @@ def test_should_track_message_channel_disabled(db_session):
def test_should_track_message_dm_allowed(db_session): def test_should_track_message_dm_allowed(db_session):
"""Test DM tracking when user allows it""" """Test DM tracking when user allows it"""
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True) channel = DiscordChannel(id=2, name="DM", channel_type="dm", ignore_messages=False)
user = DiscordUser(id=3, username="User", track_messages=True) user = DiscordUser(id=3, username="User", ignore_messages=False)
result = should_track_message(None, channel, user) result = should_track_message(None, channel, user)
@ -385,8 +386,8 @@ def test_should_track_message_dm_allowed(db_session):
def test_should_track_message_dm_not_allowed(db_session): def test_should_track_message_dm_not_allowed(db_session):
"""Test DM tracking when user doesn't allow it""" """Test DM tracking when user doesn't allow it"""
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True) channel = DiscordChannel(id=2, name="DM", channel_type="dm", ignore_messages=False)
user = DiscordUser(id=3, username="User", track_messages=False) user = DiscordUser(id=3, username="User", ignore_messages=True)
result = should_track_message(None, channel, user) result = should_track_message(None, channel, user)
@ -395,9 +396,9 @@ def test_should_track_message_dm_not_allowed(db_session):
def test_should_track_message_default_true(db_session): def test_should_track_message_default_true(db_session):
"""Test default tracking behavior""" """Test default tracking behavior"""
server = DiscordServer(id=1, name="Server", track_messages=True) server = DiscordServer(id=1, name="Server", ignore_messages=False)
channel = DiscordChannel( channel = DiscordChannel(
id=2, name="Channel", channel_type="text", track_messages=True id=2, name="Channel", channel_type="text", ignore_messages=False
) )
user = DiscordUser(id=3, username="User") user = DiscordUser(id=3, username="User")
@ -465,6 +466,7 @@ def test_sync_guild_metadata(mock_make_session, mock_guild):
voice_channel.guild = mock_guild voice_channel.guild = mock_guild
mock_guild.channels = [text_channel, voice_channel] mock_guild.channels = [text_channel, voice_channel]
mock_guild.threads = []
sync_guild_metadata(mock_guild) sync_guild_metadata(mock_guild)
@ -489,9 +491,18 @@ def test_message_collector_init():
async def test_on_ready(): async def test_on_ready():
"""Test on_ready event handler""" """Test on_ready event handler"""
collector = MessageCollector() collector = MessageCollector()
collector.user = Mock()
collector.user.name = "TestBot" # Mock the properties
collector.guilds = [Mock(), Mock()] mock_user = Mock()
mock_user.name = "TestBot"
with patch.object(
type(collector), "user", new_callable=lambda: property(lambda self: mock_user)
):
with patch.object(
type(collector),
"guilds",
new_callable=lambda: property(lambda self: [Mock(), Mock()]),
):
collector.sync_servers_and_channels = AsyncMock() collector.sync_servers_and_channels = AsyncMock()
collector.tree.sync = AsyncMock() collector.tree.sync = AsyncMock()
@ -593,8 +604,12 @@ async def test_sync_servers_and_channels():
guild2 = Mock() guild2 = Mock()
collector = MessageCollector() collector = MessageCollector()
collector.guilds = [guild1, guild2]
with patch.object(
type(collector),
"guilds",
new_callable=lambda: property(lambda self: [guild1, guild2]),
):
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync: with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
await collector.sync_servers_and_channels() await collector.sync_servers_and_channels()
@ -617,12 +632,21 @@ async def test_refresh_metadata(mock_make_session):
guild.name = "Test" guild.name = "Test"
guild.channels = [] guild.channels = []
guild.members = [] guild.members = []
guild.threads = []
collector = MessageCollector() collector = MessageCollector()
collector.guilds = [guild]
collector.intents = Mock()
collector.intents.members = False
mock_intents = Mock()
mock_intents.members = False
with patch.object(
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
):
with patch.object(
type(collector),
"intents",
new_callable=lambda: property(lambda self: mock_intents),
):
result = await collector.refresh_metadata() result = await collector.refresh_metadata()
assert result["servers_updated"] == 1 assert result["servers_updated"] == 1
@ -637,7 +661,7 @@ async def test_get_user_by_id():
user.id = 123 user.id = 123
collector = MessageCollector() collector = MessageCollector()
collector.get_user = Mock(return_value=user) collector.get_user = AsyncMock(return_value=user)
result = await collector.get_user(123) result = await collector.get_user(123)
@ -656,8 +680,10 @@ async def test_get_user_by_username():
guild.members = [member] guild.members = [member]
collector = MessageCollector() collector = MessageCollector()
collector.guilds = [guild]
with patch.object(
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
):
result = await collector.get_user("testuser") result = await collector.get_user("testuser")
assert result == member assert result == member
@ -667,11 +693,19 @@ async def test_get_user_by_username():
async def test_get_user_not_found(): async def test_get_user_not_found():
"""Test getting non-existent user""" """Test getting non-existent user"""
collector = MessageCollector() collector = MessageCollector()
collector.guilds = []
with patch.object(collector, "get_user", return_value=None): # Create proper mock response for discord.NotFound
mock_response = Mock()
mock_response.status = 404
mock_response.text = ""
with patch.object( with patch.object(
collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock()) type(collector), "guilds", new_callable=lambda: property(lambda self: [])
):
with patch.object(
collector,
"fetch_user",
AsyncMock(side_effect=discord.NotFound(mock_response, "User not found")),
): ):
result = await collector.get_user(999) result = await collector.get_user(999)
assert result is None assert result is None
@ -687,8 +721,10 @@ async def test_get_channel_by_name():
guild.channels = [channel] guild.channels = [channel]
collector = MessageCollector() collector = MessageCollector()
collector.guilds = [guild]
with patch.object(
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
):
result = await collector.get_channel_by_name("general") result = await collector.get_channel_by_name("general")
assert result == channel assert result == channel
@ -701,8 +737,10 @@ async def test_get_channel_by_name_not_found():
guild.channels = [] guild.channels = []
collector = MessageCollector() collector = MessageCollector()
collector.guilds = [guild]
with patch.object(
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
):
result = await collector.get_channel_by_name("nonexistent") result = await collector.get_channel_by_name("nonexistent")
assert result is None assert result is None
@ -730,8 +768,10 @@ async def test_create_channel_no_guild():
"""Test creating channel when no guild available""" """Test creating channel when no guild available"""
collector = MessageCollector() collector = MessageCollector()
collector.get_guild = Mock(return_value=None) collector.get_guild = Mock(return_value=None)
collector.guilds = []
with patch.object(
type(collector), "guilds", new_callable=lambda: property(lambda self: [])
):
result = await collector.create_channel("new-channel") result = await collector.create_channel("new-channel")
assert result is None assert result is None
@ -816,27 +856,19 @@ async def test_send_to_channel_not_found():
assert result is False assert result is False
@pytest.mark.skip(
reason="run_collector function doesn't exist or uses different settings"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
async def test_run_collector(): async def test_run_collector():
"""Test running the collector""" """Test running the collector"""
from memory.discord.collector import run_collector pass
with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
mock_collector = Mock()
mock_collector.start = AsyncMock()
mock_collector_class.return_value = mock_collector
await run_collector()
mock_collector.start.assert_called_once_with("test_token")
@pytest.mark.skip(
reason="run_collector function doesn't exist or uses different settings"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
async def test_run_collector_no_token(): async def test_run_collector_no_token():
"""Test running collector without token""" """Test running collector without token"""
from memory.discord.collector import run_collector pass
# Should return early without raising
await run_collector()

View File

@ -1,17 +1,23 @@
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from unittest.mock import MagicMock
import discord import discord
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
from memory.discord.commands import ( from memory.discord.commands import (
CommandContext,
CommandError, CommandError,
CommandResponse, CommandResponse,
run_command,
handle_prompt, handle_prompt,
handle_chattiness, handle_chattiness,
handle_ignore, handle_ignore,
handle_summary, handle_summary,
respond,
with_object_context,
handle_mcp_servers,
) )
@ -66,29 +72,54 @@ def interaction(guild, text_channel, discord_user) -> DummyInteraction:
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user) return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
def test_handle_command_prompt_server(db_session, guild, interaction): @pytest.mark.asyncio
async def test_handle_command_prompt_server(db_session, guild, interaction):
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful") server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
db_session.add(server) db_session.add(server)
db_session.commit() db_session.commit()
response = run_command( context = CommandContext(
db_session, session=db_session,
interaction, interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server", scope="server",
handler=handle_prompt, target=server,
display_name="server **Test Guild**",
) )
response = await handle_prompt(context)
assert isinstance(response, CommandResponse) assert isinstance(response, CommandResponse)
assert "Be helpful" in response.content assert "Be helpful" in response.content
def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel): @pytest.mark.asyncio
response = run_command( async def test_handle_command_prompt_channel_creates_channel(
db_session, db_session, interaction, text_channel, guild
interaction, ):
scope="channel", # Create the server first to satisfy FK constraint
handler=handle_prompt, server = DiscordServer(id=guild.id, name="Test Guild")
db_session.add(server)
channel_model = DiscordChannel(
id=text_channel.id,
name=text_channel.name,
channel_type="text",
server_id=guild.id,
) )
db_session.add(channel_model)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="channel",
target=channel_model,
display_name=f"channel **#{text_channel.name}**",
)
response = await handle_prompt(context)
assert "No prompt" in response.content assert "No prompt" in response.content
channel = db_session.get(DiscordChannel, text_channel.id) channel = db_session.get(DiscordChannel, text_channel.id)
@ -96,77 +127,253 @@ def test_handle_command_prompt_channel_creates_channel(db_session, interaction,
assert channel.name == text_channel.name assert channel.name == text_channel.name
def test_handle_command_chattiness_show(db_session, interaction, guild): @pytest.mark.asyncio
async def test_handle_command_chattiness_show(db_session, interaction, guild):
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73) server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
db_session.add(server) db_session.add(server)
db_session.commit() db_session.commit()
response = run_command( context = CommandContext(
db_session, session=db_session,
interaction, interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server", scope="server",
handler=handle_chattiness, target=server,
display_name="server **Guild**",
) )
response = await handle_chattiness(context, value=None)
assert str(server.chattiness_threshold) in response.content assert str(server.chattiness_threshold) in response.content
def test_handle_command_chattiness_update(db_session, interaction): @pytest.mark.asyncio
user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15) async def test_handle_command_chattiness_update(db_session, interaction):
user_model = DiscordUser(
id=interaction.user.id, username="command-user", chattiness_threshold=15
)
db_session.add(user_model) db_session.add(user_model)
db_session.commit() db_session.commit()
response = run_command( context = CommandContext(
db_session, session=db_session,
interaction, interaction=interaction,
actor=user_model,
scope="user", scope="user",
handler=handle_chattiness, target=user_model,
value=80, display_name="user **command-user**",
) )
response = await handle_chattiness(context, value=80)
db_session.flush() db_session.flush()
assert "Updated" in response.content assert "Updated" in response.content
assert user_model.chattiness_threshold == 80 assert user_model.chattiness_threshold == 80
def test_handle_command_chattiness_invalid_value(db_session, interaction): @pytest.mark.asyncio
with pytest.raises(CommandError): async def test_handle_command_chattiness_invalid_value(db_session, interaction):
run_command( user_model = DiscordUser(id=interaction.user.id, username="command-user")
db_session, db_session.add(user_model)
interaction, db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=user_model,
scope="user", scope="user",
handler=handle_chattiness, target=user_model,
value=150, display_name="user **command-user**",
) )
with pytest.raises(CommandError):
await handle_chattiness(context, value=150)
def test_handle_command_ignore_toggle(db_session, interaction, guild):
channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id) @pytest.mark.asyncio
async def test_handle_command_ignore_toggle(db_session, interaction, guild):
# Create the server first to satisfy FK constraint
server = DiscordServer(id=guild.id, name="Test Guild")
db_session.add(server)
channel = DiscordChannel(
id=interaction.channel.id,
name="general",
channel_type="text",
server_id=guild.id,
)
db_session.add(channel) db_session.add(channel)
db_session.commit() db_session.commit()
response = run_command( context = CommandContext(
db_session, session=db_session,
interaction, interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="channel", scope="channel",
handler=handle_ignore, target=channel,
ignore_enabled=True, display_name="channel **#general**",
) )
response = await handle_ignore(context, ignore_enabled=True)
db_session.flush() db_session.flush()
assert "no longer" not in response.content assert "no longer" not in response.content
assert channel.ignore_messages is True assert channel.ignore_messages is True
def test_handle_command_summary_missing(db_session, interaction): @pytest.mark.asyncio
response = run_command( async def test_handle_command_summary_missing(db_session, interaction):
db_session, user_model = DiscordUser(id=interaction.user.id, username="command-user")
interaction, db_session.add(user_model)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=user_model,
scope="user", scope="user",
handler=handle_summary, target=user_model,
display_name="user **command-user**",
) )
response = await handle_summary(context)
assert "No summary" in response.content assert "No summary" in response.content
@pytest.mark.asyncio
async def test_respond_sends_message_without_file():
interaction = MagicMock(spec=discord.Interaction)
interaction.response.send_message = AsyncMock()
await respond(interaction, "hello world", ephemeral=False)
interaction.response.send_message.assert_awaited_once_with(
"hello world", ephemeral=False
)
@pytest.mark.asyncio
async def test_respond_sends_file_when_content_too_large():
interaction = MagicMock(spec=discord.Interaction)
interaction.response.send_message = AsyncMock()
oversized = "x" * 2000
with patch("memory.discord.commands.discord.File") as mock_file:
file_instance = MagicMock()
mock_file.return_value = file_instance
await respond(interaction, oversized)
interaction.response.send_message.assert_awaited_once_with(
"Response too large, sending as file:",
file=file_instance,
ephemeral=True,
)
@patch("memory.discord.commands._ensure_channel")
@patch("memory.discord.commands.ensure_server")
@patch("memory.discord.commands.ensure_user")
@patch("memory.discord.commands.make_session")
def test_with_object_context_uses_ensured_objects(
mock_make_session,
mock_ensure_user,
mock_ensure_server,
mock_ensure_channel,
interaction,
guild,
text_channel,
discord_user,
):
mock_session = MagicMock()
@contextmanager
def session_cm():
yield mock_session
mock_make_session.return_value = session_cm()
bot_model = MagicMock(name="bot_model")
user_model = MagicMock(name="user_model")
server_model = MagicMock(name="server_model")
channel_model = MagicMock(name="channel_model")
mock_ensure_user.side_effect = [bot_model, user_model]
mock_ensure_server.return_value = server_model
mock_ensure_channel.return_value = channel_model
handler_objects = {}
def handler(objects):
handler_objects["objects"] = objects
return "done"
bot_client = SimpleNamespace(user=MagicMock())
override_user = MagicMock(spec=discord.User)
result = with_object_context(bot_client, interaction, handler, override_user)
assert result == "done"
objects = handler_objects["objects"]
assert objects.bot is bot_model
assert objects.server is server_model
assert objects.channel is channel_model
assert objects.user is user_model
mock_ensure_user.assert_any_call(mock_session, bot_client.user)
mock_ensure_user.assert_any_call(mock_session, override_user)
mock_ensure_server.assert_called_once_with(mock_session, guild)
mock_ensure_channel.assert_called_once_with(
mock_session, text_channel, guild.id
)
@pytest.mark.asyncio
@patch("memory.discord.commands.run_mcp_server_command", new_callable=AsyncMock)
async def test_handle_mcp_servers_returns_response(mock_run_mcp, interaction):
mock_run_mcp.return_value = "Listed servers"
server_model = DiscordServer(id=interaction.guild.id, name="Guild")
context = CommandContext(
session=MagicMock(),
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server_model,
display_name="server **Guild**",
)
interaction.client = SimpleNamespace(user=MagicMock(spec=discord.User))
response = await handle_mcp_servers(
context, action="list", url=None
)
assert response.content == "Listed servers"
mock_run_mcp.assert_awaited_once_with(
interaction.client.user, "list", None, "DiscordServer", server_model.id
)
@pytest.mark.asyncio
@patch("memory.discord.commands.run_mcp_server_command", new_callable=AsyncMock)
async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
mock_run_mcp.side_effect = RuntimeError("boom")
server_model = DiscordServer(id=interaction.guild.id, name="Guild")
context = CommandContext(
session=MagicMock(),
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server_model,
display_name="server **Guild**",
)
interaction.client = SimpleNamespace(user=MagicMock(spec=discord.User))
with pytest.raises(CommandError) as exc:
await handle_mcp_servers(context, action="list", url=None)
assert "Error: boom" in str(exc.value)

View File

@ -0,0 +1,590 @@
"""Tests for Discord MCP server management."""
import json
from unittest.mock import AsyncMock, Mock, patch
import aiohttp
import discord
import pytest
from memory.common.db.models import MCPServer, MCPServerAssignment
from memory.discord.mcp import (
call_mcp_server,
find_mcp_server,
handle_mcp_add,
handle_mcp_connect,
handle_mcp_delete,
handle_mcp_list,
handle_mcp_tools,
run_mcp_server_command,
)
# Helper class for async iteration
class AsyncIterator:
"""Helper to create an async iterator for mocking aiohttp response content."""
def __init__(self, items):
self.items = items
self.index = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.index >= len(self.items):
raise StopAsyncIteration
item = self.items[self.index]
self.index += 1
return item
@pytest.fixture
def mcp_server(db_session) -> MCPServer:
"""Create a test MCP server."""
server = MCPServer(
name="Test MCP Server",
mcp_server_url="https://mcp.example.com",
client_id="test_client_id",
access_token="test_access_token",
available_tools=["tool1", "tool2"],
)
db_session.add(server)
db_session.commit()
return server
@pytest.fixture
def mcp_assignment(db_session, mcp_server: MCPServer) -> MCPServerAssignment:
"""Create a test MCP server assignment."""
assignment = MCPServerAssignment(
mcp_server_id=mcp_server.id,
entity_type="DiscordUser",
entity_id=123456,
)
db_session.add(assignment)
db_session.commit()
return assignment
@pytest.fixture
def mock_bot_user() -> discord.User:
"""Create a mock Discord bot user."""
user = Mock(spec=discord.User)
user.name = "TestBot"
user.id = 999
return user
def test_find_mcp_server_exists(
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
):
"""Test finding an existing MCP server."""
result = find_mcp_server(
db_session,
entity_type="DiscordUser",
entity_id=123456,
url="https://mcp.example.com",
)
assert result is not None
assert result.id == mcp_server.id
assert result.mcp_server_url == "https://mcp.example.com"
def test_find_mcp_server_not_found(db_session):
"""Test finding a non-existent MCP server."""
result = find_mcp_server(
db_session,
entity_type="DiscordUser",
entity_id=999999,
url="https://nonexistent.com",
)
assert result is None
def test_find_mcp_server_wrong_entity(
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
):
"""Test finding MCP server with wrong entity type."""
result = find_mcp_server(
db_session,
entity_type="DiscordChannel", # Wrong entity type
entity_id=123456,
url="https://mcp.example.com",
)
assert result is None
@pytest.mark.asyncio
async def test_call_mcp_server_success():
"""Test calling MCP server successfully."""
mock_response_data = [
b'data: {"result": {"tools": [{"name": "test"}]}}\n',
b'data: {"status": "ok"}\n',
]
mock_response = Mock()
mock_response.status = 200
mock_response.content = AsyncIterator(mock_response_data)
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = []
async for data in call_mcp_server(
"https://mcp.example.com", "test_token", "tools/list", {}
):
results.append(data)
assert len(results) == 2
assert "result" in results[0]
assert results[0]["result"]["tools"][0]["name"] == "test"
@pytest.mark.asyncio
async def test_call_mcp_server_error():
"""Test calling MCP server with error response."""
mock_response = Mock()
mock_response.status = 500
mock_response.text = AsyncMock(return_value="Internal Server Error")
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
with pytest.raises(ValueError, match="Failed to call MCP server"):
async for _ in call_mcp_server(
"https://mcp.example.com", "test_token", "tools/list"
):
pass
@pytest.mark.asyncio
async def test_call_mcp_server_invalid_json():
"""Test calling MCP server with invalid JSON."""
mock_response_data = [
b"data: invalid json\n",
b'data: {"valid": "json"}\n',
]
mock_response = Mock()
mock_response.status = 200
mock_response.content = AsyncIterator(mock_response_data)
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = []
async for data in call_mcp_server(
"https://mcp.example.com", "test_token", "tools/list"
):
results.append(data)
# Should skip invalid JSON and only return valid one
assert len(results) == 1
assert results[0] == {"valid": "json"}
@pytest.mark.asyncio
async def test_handle_mcp_list_empty(db_session):
"""Test listing MCP servers when none exist."""
result = await handle_mcp_list("DiscordUser", 123456)
assert "You don't have any MCP servers configured yet" in result
assert "/memory_mcp_servers add" in result
@pytest.mark.asyncio
async def test_handle_mcp_list_with_servers(
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
):
"""Test listing MCP servers with existing servers."""
result = await handle_mcp_list("DiscordUser", 123456)
assert "Your MCP Servers" in result
assert "https://mcp.example.com" in result
assert "test_client_id" in result
assert "🟢" in result # Server has access token
@pytest.mark.asyncio
async def test_handle_mcp_list_disconnected_server(db_session):
"""Test listing MCP servers with disconnected server."""
server = MCPServer(
name="Disconnected Server",
mcp_server_url="https://disconnected.example.com",
client_id="client_123",
access_token=None, # No access token
)
db_session.add(server)
db_session.flush()
assignment = MCPServerAssignment(
mcp_server_id=server.id,
entity_type="DiscordUser",
entity_id=123456,
)
db_session.add(assignment)
db_session.commit()
result = await handle_mcp_list("DiscordUser", 123456)
assert "🔴" in result # Server has no access token
@pytest.mark.asyncio
async def test_handle_mcp_add_new_server(db_session, mock_bot_user):
"""Test adding a new MCP server."""
with (
patch("memory.discord.mcp.get_endpoints") as mock_get_endpoints,
patch("memory.discord.mcp.register_oauth_client") as mock_register,
patch("memory.discord.mcp.issue_challenge") as mock_challenge,
):
mock_endpoints = Mock()
mock_get_endpoints.return_value = mock_endpoints
mock_register.return_value = "new_client_id"
mock_challenge.return_value = "https://auth.example.com/authorize"
result = await handle_mcp_add(
"DiscordUser", 123456, mock_bot_user, "https://new.example.com"
)
assert "Add MCP Server" in result
assert "https://new.example.com" in result
assert "https://auth.example.com/authorize" in result
# Verify server was created
server = (
db_session.query(MCPServer)
.filter(MCPServer.mcp_server_url == "https://new.example.com")
.first()
)
assert server is not None
assert server.client_id == "new_client_id"
# Verify assignment was created
assignment = (
db_session.query(MCPServerAssignment)
.filter(
MCPServerAssignment.mcp_server_id == server.id,
MCPServerAssignment.entity_type == "DiscordUser",
MCPServerAssignment.entity_id == 123456,
)
.first()
)
assert assignment is not None
@pytest.mark.asyncio
async def test_handle_mcp_add_existing_server(
db_session,
mcp_server: MCPServer,
mcp_assignment: MCPServerAssignment,
mock_bot_user,
):
"""Test adding an MCP server that already exists."""
result = await handle_mcp_add(
"DiscordUser", 123456, mock_bot_user, "https://mcp.example.com"
)
assert "MCP Server Already Exists" in result
assert "https://mcp.example.com" in result
assert "/memory_mcp_servers connect" in result
@pytest.mark.asyncio
async def test_handle_mcp_add_no_bot_user(db_session):
"""Test adding MCP server without bot user."""
with pytest.raises(ValueError, match="Bot user is required"):
await handle_mcp_add("DiscordUser", 123456, None, "https://example.com")
@pytest.mark.asyncio
async def test_handle_mcp_delete_existing(
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
):
"""Test deleting an existing MCP server assignment."""
# Store IDs before deletion
assignment_id = mcp_assignment.id
server_id = mcp_server.id
result = await handle_mcp_delete("DiscordUser", 123456, "https://mcp.example.com")
assert "Delete MCP Server" in result
assert "https://mcp.example.com" in result
assert "has been removed" in result
# Verify assignment was deleted
assignment = (
db_session.query(MCPServerAssignment)
.filter(MCPServerAssignment.id == assignment_id)
.first()
)
assert assignment is None
# Verify server was also deleted (no other assignments)
server = db_session.query(MCPServer).filter(MCPServer.id == server_id).first()
assert server is None
@pytest.mark.asyncio
async def test_handle_mcp_delete_not_found(db_session):
"""Test deleting a non-existent MCP server."""
result = await handle_mcp_delete("DiscordUser", 123456, "https://nonexistent.com")
assert "MCP Server Not Found" in result
assert "https://nonexistent.com" in result
@pytest.mark.asyncio
async def test_handle_mcp_delete_with_other_assignments(db_session):
"""Test deleting MCP server with multiple assignments."""
server = MCPServer(
name="Shared Server",
mcp_server_url="https://shared.example.com",
client_id="shared_client",
)
db_session.add(server)
db_session.flush()
assignment1 = MCPServerAssignment(
mcp_server_id=server.id,
entity_type="DiscordUser",
entity_id=111,
)
assignment2 = MCPServerAssignment(
mcp_server_id=server.id,
entity_type="DiscordUser",
entity_id=222,
)
db_session.add_all([assignment1, assignment2])
db_session.commit()
# Delete one assignment
result = await handle_mcp_delete("DiscordUser", 111, "https://shared.example.com")
assert "has been removed" in result
# Verify only one assignment was deleted
remaining = (
db_session.query(MCPServerAssignment)
.filter(MCPServerAssignment.mcp_server_id == server.id)
.count()
)
assert remaining == 1
# Verify server still exists
server_check = db_session.query(MCPServer).filter(MCPServer.id == server.id).first()
assert server_check is not None
@pytest.mark.asyncio
async def test_handle_mcp_connect_existing(
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
):
"""Test reconnecting to an existing MCP server."""
with (
patch("memory.discord.mcp.get_endpoints") as mock_get_endpoints,
patch("memory.discord.mcp.issue_challenge") as mock_challenge,
):
mock_endpoints = Mock()
mock_get_endpoints.return_value = mock_endpoints
mock_challenge.return_value = "https://auth.example.com/authorize?state=new"
result = await handle_mcp_connect(
"DiscordUser", 123456, "https://mcp.example.com"
)
assert "Reconnect to MCP Server" in result
assert "https://mcp.example.com" in result
assert "https://auth.example.com/authorize?state=new" in result
@pytest.mark.asyncio
async def test_handle_mcp_connect_not_found(db_session):
"""Test reconnecting to a non-existent MCP server."""
with pytest.raises(ValueError, match="MCP Server Not Found"):
await handle_mcp_connect("DiscordUser", 123456, "https://nonexistent.com")
@pytest.mark.asyncio
async def test_handle_mcp_tools_success(
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
):
"""Test listing tools from an MCP server."""
mock_response_data = [
b'data: {"result": {"tools": [{"name": "search", "description": "Search tool"}]}}\n',
]
mock_response = Mock()
mock_response.status = 200
mock_response.content = AsyncIterator(mock_response_data)
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
result = await handle_mcp_tools(
"DiscordUser", 123456, "https://mcp.example.com"
)
assert "MCP Server Tools" in result
assert "https://mcp.example.com" in result
assert "search" in result
assert "Search tool" in result
assert "Found 1 tool(s)" in result
@pytest.mark.asyncio
async def test_handle_mcp_tools_no_tools(
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
):
"""Test listing tools when server has no tools."""
mock_response_data = [
b'data: {"result": {"tools": []}}\n',
]
mock_response = Mock()
mock_response.status = 200
mock_response.content = AsyncIterator(mock_response_data)
mock_post = AsyncMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
result = await handle_mcp_tools(
"DiscordUser", 123456, "https://mcp.example.com"
)
assert "No tools available" in result
@pytest.mark.asyncio
async def test_handle_mcp_tools_server_not_found(db_session):
"""Test listing tools for a non-existent server."""
with pytest.raises(ValueError, match="MCP Server Not Found"):
await handle_mcp_tools("DiscordUser", 123456, "https://nonexistent.com")
@pytest.mark.asyncio
async def test_handle_mcp_tools_not_authorized(db_session):
"""Test listing tools when not authorized."""
server = MCPServer(
name="Unauthorized Server",
mcp_server_url="https://unauthorized.example.com",
client_id="client_123",
access_token=None, # No access token
)
db_session.add(server)
db_session.flush()
assignment = MCPServerAssignment(
mcp_server_id=server.id,
entity_type="DiscordUser",
entity_id=123456,
)
db_session.add(assignment)
db_session.commit()
with pytest.raises(ValueError, match="Not Authorized"):
await handle_mcp_tools(
"DiscordUser", 123456, "https://unauthorized.example.com"
)
@pytest.mark.asyncio
async def test_handle_mcp_tools_connection_error(
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
):
"""Test listing tools with connection error."""
mock_post = AsyncMock()
mock_post.__aenter__.side_effect = aiohttp.ClientError("Connection failed")
mock_post.__aexit__.return_value = None
mock_session = Mock()
mock_session.post = Mock(return_value=mock_post)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__.return_value = mock_session
mock_session_ctx.__aexit__.return_value = None
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
with pytest.raises(ValueError, match="Connection failed"):
await handle_mcp_tools("DiscordUser", 123456, "https://mcp.example.com")
@pytest.mark.asyncio
async def test_run_mcp_server_command_list(db_session, mock_bot_user):
"""Test run_mcp_server_command with list action."""
result = await run_mcp_server_command(
mock_bot_user, "list", None, "DiscordUser", 123456
)
assert "Your MCP Servers" in result
@pytest.mark.asyncio
async def test_run_mcp_server_command_invalid_action(mock_bot_user):
"""Test run_mcp_server_command with invalid action."""
with pytest.raises(ValueError, match="Invalid action"):
await run_mcp_server_command(
mock_bot_user, "invalid", None, "DiscordUser", 123456
)
@pytest.mark.asyncio
async def test_run_mcp_server_command_missing_url(mock_bot_user):
"""Test run_mcp_server_command with missing URL for non-list action."""
with pytest.raises(ValueError, match="URL is required"):
await run_mcp_server_command(mock_bot_user, "add", None, "DiscordUser", 123456)
@pytest.mark.asyncio
async def test_run_mcp_server_command_no_bot_user():
"""Test run_mcp_server_command without bot user."""
with pytest.raises(ValueError, match="Bot user is required"):
await run_mcp_server_command(None, "list", None, "DiscordUser", 123456)

View File

@ -1,5 +1,8 @@
"""Tests for Discord message helper functions.""" """Tests for Discord message helper functions."""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest import pytest
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from memory.discord.messages import ( from memory.discord.messages import (
@ -9,6 +12,7 @@ from memory.discord.messages import (
upsert_scheduled_message, upsert_scheduled_message,
previous_messages, previous_messages,
comm_channel_prompt, comm_channel_prompt,
call_llm,
) )
from memory.common.db.models import ( from memory.common.db.models import (
DiscordUser, DiscordUser,
@ -18,6 +22,7 @@ from memory.common.db.models import (
HumanUser, HumanUser,
ScheduledLLMCall, ScheduledLLMCall,
) )
from memory.common.llms.tools import MCPServer as MCPServerDefinition
@pytest.fixture @pytest.fixture
@ -411,3 +416,107 @@ def test_comm_channel_prompt_includes_user_notes(
assert "user_notes" in result.lower() assert "user_notes" in result.lower()
assert "testuser" in result # username should appear assert "testuser" in result # username should appear
@patch("memory.discord.messages.create_provider")
@patch("memory.discord.messages.previous_messages")
@patch("memory.common.llms.tools.discord.make_discord_tools")
@patch("memory.common.llms.tools.base.WebSearchTool")
def test_call_llm_includes_web_search_and_mcp_servers(
mock_web_search,
mock_make_tools,
mock_prev_messages,
mock_create_provider,
):
provider = MagicMock()
provider.usage_tracker.is_rate_limited.return_value = False
provider.as_messages.return_value = ["converted"]
provider.run_with_tools.return_value = SimpleNamespace(response="llm-output")
mock_create_provider.return_value = provider
mock_prev_messages.return_value = [SimpleNamespace(as_content=lambda: "prev")]
existing_tool = MagicMock(name="existing_tool")
mock_make_tools.return_value = {"existing": existing_tool}
web_tool_instance = MagicMock(name="web_tool")
mock_web_search.return_value = web_tool_instance
bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
from_user = SimpleNamespace(id=123)
mcp_model = SimpleNamespace(
name="Server",
mcp_server_url="https://mcp.example.com",
access_token="token123",
)
result = call_llm(
session=MagicMock(),
bot_user=bot_user,
from_user=from_user,
channel=None,
model="gpt-test",
messages=["hi"],
mcp_servers=[mcp_model],
)
assert result == "llm-output"
kwargs = provider.run_with_tools.call_args.kwargs
tools = kwargs["tools"]
assert tools["existing"] is existing_tool
assert tools["web_search"] is web_tool_instance
mcp_servers = kwargs["mcp_servers"]
assert mcp_servers == [
MCPServerDefinition(
name="Server", url="https://mcp.example.com", token="token123"
)
]
@patch("memory.discord.messages.create_provider")
@patch("memory.discord.messages.previous_messages")
@patch("memory.common.llms.tools.discord.make_discord_tools")
@patch("memory.common.llms.tools.base.WebSearchTool")
def test_call_llm_filters_disallowed_tools(
mock_web_search,
mock_make_tools,
mock_prev_messages,
mock_create_provider,
):
provider = MagicMock()
provider.usage_tracker.is_rate_limited.return_value = False
provider.as_messages.return_value = ["converted"]
provider.run_with_tools.return_value = SimpleNamespace(response="filtered-output")
mock_create_provider.return_value = provider
mock_prev_messages.return_value = []
allowed_tool = MagicMock(name="allowed")
blocked_tool = MagicMock(name="blocked")
mock_make_tools.return_value = {
"allowed": allowed_tool,
"blocked": blocked_tool,
}
mock_web_search.return_value = MagicMock(name="web_tool")
bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
from_user = SimpleNamespace(id=1)
call_llm(
session=MagicMock(),
bot_user=bot_user,
from_user=from_user,
channel=None,
model="gpt-test",
messages=[],
allowed_tools={"allowed"},
mcp_servers=None,
)
tools = provider.run_with_tools.call_args.kwargs["tools"]
assert "allowed" in tools
assert "blocked" not in tools
assert "web_search" not in tools

View File

@ -0,0 +1,41 @@
"""Tests for Discord setup CLI utilities."""
from unittest.mock import MagicMock, patch
from click.testing import CliRunner
from tools.discord_setup import generate_bot_invite_url, make_invite
def test_make_invite_generates_expected_url():
result = make_invite(123456789)
assert (
result
== "https://discord.com/oauth2/authorize?client_id=123456789&scope=bot&permissions=3088"
)
@patch("tools.discord_setup.requests.get")
def test_generate_bot_invite_url_outputs_link(mock_get):
response = MagicMock()
response.raise_for_status.return_value = None
response.json.return_value = {"id": "987654321"}
mock_get.return_value = response
runner = CliRunner()
result = runner.invoke(generate_bot_invite_url, ["--bot-token", "abc.def"])
assert result.exit_code == 0
assert "Bot invite URL" in result.output
assert "987654321" in result.output
@patch("tools.discord_setup.requests.get", side_effect=Exception("api down"))
def test_generate_bot_invite_url_handles_errors(mock_get):
runner = CliRunner()
result = runner.invoke(generate_bot_invite_url, ["--bot-token", "token"])
assert result.exit_code != 0
assert isinstance(result.exception, ValueError)
assert "Could not get bot info" in str(result.exception)