diff --git a/docker/ingest_hub/supervisor.conf b/docker/ingest_hub/supervisor.conf
index 0f53377..af5e8c9 100644
--- a/docker/ingest_hub/supervisor.conf
+++ b/docker/ingest_hub/supervisor.conf
@@ -16,7 +16,7 @@ autorestart=true
startsecs=10
[program:discord-api]
-command=uvicorn memory.discord.api:app --host 0.0.0.0 --port %(ENV_DISCORD_API_PORT)s
+command=uvicorn memory.discord.api:app --host 0.0.0.0 --port %(ENV_DISCORD_COLLECTOR_PORT)s
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py
index cc930da..a415c19 100644
--- a/src/memory/common/celery_app.py
+++ b/src/memory/common/celery_app.py
@@ -98,7 +98,6 @@ app.conf.update(
@app.on_after_configure.connect # type: ignore[attr-defined]
def ensure_qdrant_initialised(sender, **_):
from memory.common import qdrant
- from memory.common.discord import load_servers
qdrant.setup_qdrant()
- load_servers()
+ # Note: load_servers() was removed as it's no longer needed
diff --git a/src/memory/common/db/models/scheduled_calls.py b/src/memory/common/db/models/scheduled_calls.py
index 75d2184..d2114b6 100644
--- a/src/memory/common/db/models/scheduled_calls.py
+++ b/src/memory/common/db/models/scheduled_calls.py
@@ -21,7 +21,7 @@ class ScheduledLLMCall(Base):
__tablename__ = "scheduled_llm_calls"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
- user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
+ user_id = Column(BigInteger, ForeignKey("users.id"), nullable=False)
topic = Column(Text, nullable=True)
# Scheduling info
diff --git a/src/memory/discord/__init__.py b/src/memory/discord/__init__.py
new file mode 100644
index 0000000..d46a1c4
--- /dev/null
+++ b/src/memory/discord/__init__.py
@@ -0,0 +1 @@
+"""Discord integration for memory system."""
diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py
index d64d02e..b6f4cfa 100644
--- a/src/memory/discord/messages.py
+++ b/src/memory/discord/messages.py
@@ -27,11 +27,7 @@ def resolve_discord_user(
if isinstance(entity, int):
return session.get(DiscordUser, entity)
- entity = session.query(DiscordUser).filter(DiscordUser.username == entity).first()
- if not entity:
- entity = DiscordUser(id=entity, username=entity)
- session.add(entity)
- return entity
+ return session.query(DiscordUser).filter(DiscordUser.username == entity).first()
def resolve_discord_channel(
diff --git a/src/memory/workers/tasks/blogs.py b/src/memory/workers/tasks/blogs.py
index ca0ca14..34ff867 100644
--- a/src/memory/workers/tasks/blogs.py
+++ b/src/memory/workers/tasks/blogs.py
@@ -264,10 +264,10 @@ def sync_website_archive(
if existing:
continue
- task_ids.append(sync_webpage.delay(feed_item.url, list(tags)).id)
- new_articles += 1
+ task_ids.append(sync_webpage.delay(feed_item.url, list(tags)).id)
+ new_articles += 1
- logger.info(f"Scheduled sync for: {feed_item.title} ({feed_item.url})")
+ logger.info(f"Scheduled sync for: {feed_item.title} ({feed_item.url})")
return {
"status": "completed",
diff --git a/src/memory/workers/tasks/content_processing.py b/src/memory/workers/tasks/content_processing.py
index 34dbb41..d1584e6 100644
--- a/src/memory/workers/tasks/content_processing.py
+++ b/src/memory/workers/tasks/content_processing.py
@@ -12,6 +12,7 @@ import traceback
import logging
from typing import Any, Callable, Sequence, cast
+from sqlalchemy import or_
from memory.common import embedding, qdrant
from memory.common.db.models import SourceItem, Chunk
from memory.common.discord import notify_task_failure
@@ -29,6 +30,7 @@ def check_content_exists(
Searches for existing content by any of the provided attributes
(typically URL, file_path, or SHA256 hash).
+ Uses OR logic - returns content if ANY attribute matches.
Args:
session: Database session for querying
@@ -38,11 +40,21 @@ def check_content_exists(
Returns:
Existing SourceItem if found, None otherwise
"""
- query = session.query(model_class)
+ # Return None if no search criteria provided
+ if not kwargs:
+ return None
+
+ filters = []
for key, value in kwargs.items():
if hasattr(model_class, key):
- query = query.filter(getattr(model_class, key) == value)
+ filters.append(getattr(model_class, key) == value)
+ # Return None if none of the provided attributes exist on the model
+ if not filters:
+ return None
+
+ # Use OR logic to find content matching any of the provided attributes
+ query = session.query(model_class).filter(or_(*filters))
return query.first()
diff --git a/tests/memory/common/db/models/test_discord_models.py b/tests/memory/common/db/models/test_discord_models.py
new file mode 100644
index 0000000..bd83bc6
--- /dev/null
+++ b/tests/memory/common/db/models/test_discord_models.py
@@ -0,0 +1,254 @@
+"""Tests for Discord database models."""
+
+import pytest
+from memory.common.db.models import DiscordServer, DiscordChannel, DiscordUser
+
+
+def test_create_discord_server(db_session):
+ """Test creating a Discord server."""
+ server = DiscordServer(
+ id=123456789,
+ name="Test Server",
+ description="A test Discord server",
+ member_count=100,
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ assert server.id == 123456789
+ assert server.name == "Test Server"
+ assert server.description == "A test Discord server"
+ assert server.member_count == 100
+ assert server.track_messages is True # default value
+ assert server.ignore_messages is False
+
+
+def test_discord_server_as_xml(db_session):
+ """Test DiscordServer.as_xml() method."""
+ server = DiscordServer(
+ id=123456789,
+ name="Test Server",
+ summary="This is a test server for gaming",
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ xml = server.as_xml()
+ assert "" in xml # tablename is discord_servers, strips to "servers"
+ assert "Test Server" in xml
+ assert "This is a test server for gaming" in xml
+ assert "" in xml
+
+
+def test_discord_server_message_tracking(db_session):
+ """Test Discord server message tracking flags."""
+ server = DiscordServer(
+ id=123456789,
+ name="Test Server",
+ track_messages=False,
+ ignore_messages=True,
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ assert server.track_messages is False
+ assert server.ignore_messages is True
+
+
+def test_discord_server_allowed_tools(db_session):
+ """Test Discord server allowed/disallowed tools."""
+ server = DiscordServer(
+ id=123456789,
+ name="Test Server",
+ allowed_tools=["search", "schedule"],
+ disallowed_tools=["delete", "ban"],
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ assert "search" in server.allowed_tools
+ assert "schedule" in server.allowed_tools
+ assert "delete" in server.disallowed_tools
+ assert "ban" in server.disallowed_tools
+
+
+def test_create_discord_channel(db_session):
+ """Test creating a Discord channel."""
+ server = DiscordServer(id=987654321, name="Parent Server")
+ db_session.add(server)
+ db_session.commit()
+
+ channel = DiscordChannel(
+ id=111222333,
+ server_id=server.id,
+ name="general",
+ channel_type="text",
+ )
+ db_session.add(channel)
+ db_session.commit()
+
+ assert channel.id == 111222333
+ assert channel.server_id == server.id
+ assert channel.name == "general"
+ assert channel.channel_type == "text"
+ assert channel.server.name == "Parent Server"
+
+
+def test_discord_channel_without_server(db_session):
+ """Test creating a Discord DM channel without a server."""
+ channel = DiscordChannel(
+ id=111222333,
+ name="dm-channel",
+ channel_type="dm",
+ server_id=None,
+ )
+ db_session.add(channel)
+ db_session.commit()
+
+ assert channel.id == 111222333
+ assert channel.server_id is None
+ assert channel.channel_type == "dm"
+
+
+def test_discord_channel_as_xml(db_session):
+ """Test DiscordChannel.as_xml() method."""
+ channel = DiscordChannel(
+ id=111222333,
+ name="general",
+ channel_type="text",
+ summary="Main discussion channel",
+ )
+ db_session.add(channel)
+ db_session.commit()
+
+ xml = channel.as_xml()
+ assert "" in xml # tablename is discord_channels, strips to "channels"
+ assert "general" in xml
+ assert "Main discussion channel" in xml
+ assert "" in xml
+
+
+def test_discord_channel_inherits_server_settings(db_session):
+ """Test that channels can have their own or inherit server settings."""
+ server = DiscordServer(
+ id=987654321, name="Server", track_messages=True, ignore_messages=False
+ )
+ channel = DiscordChannel(
+ id=111222333,
+ server_id=server.id,
+ name="announcements",
+ channel_type="text",
+ track_messages=False, # Override server setting
+ )
+ db_session.add_all([server, channel])
+ db_session.commit()
+
+ assert server.track_messages is True
+ assert channel.track_messages is False
+
+
+def test_create_discord_user(db_session):
+ """Test creating a Discord user."""
+ user = DiscordUser(
+ id=555666777,
+ username="testuser",
+ display_name="Test User",
+ )
+ db_session.add(user)
+ db_session.commit()
+
+ assert user.id == 555666777
+ assert user.username == "testuser"
+ assert user.display_name == "Test User"
+ assert user.system_user_id is None
+
+
+def test_discord_user_with_system_user(db_session):
+ """Test Discord user linked to a system user."""
+ from memory.common.db.models import HumanUser
+
+ system_user = HumanUser.create_with_password(
+ email="user@example.com", name="System User", password="password123"
+ )
+ db_session.add(system_user)
+ db_session.commit()
+
+ discord_user = DiscordUser(
+ id=555666777,
+ username="testuser",
+ system_user_id=system_user.id,
+ )
+ db_session.add(discord_user)
+ db_session.commit()
+
+ assert discord_user.system_user_id == system_user.id
+ assert discord_user.system_user.email == "user@example.com"
+
+
+def test_discord_user_as_xml(db_session):
+ """Test DiscordUser.as_xml() method."""
+ user = DiscordUser(
+ id=555666777,
+ username="testuser",
+ summary="Friendly and helpful community member",
+ )
+ db_session.add(user)
+ db_session.commit()
+
+ xml = user.as_xml()
+ assert "" in xml # tablename is discord_users, strips to "users"
+ assert "testuser" in xml
+ assert "Friendly and helpful community member" in xml
+ assert "" in xml
+
+
+def test_discord_user_message_preferences(db_session):
+ """Test Discord user message tracking preferences."""
+ user = DiscordUser(
+ id=555666777,
+ username="testuser",
+ track_messages=True,
+ ignore_messages=False,
+ )
+ db_session.add(user)
+ db_session.commit()
+
+ assert user.track_messages is True
+ assert user.ignore_messages is False
+
+
+def test_discord_server_channel_relationship(db_session):
+ """Test the relationship between servers and channels."""
+ server = DiscordServer(id=987654321, name="Test Server")
+ channel1 = DiscordChannel(
+ id=111222333, server_id=server.id, name="general", channel_type="text"
+ )
+ channel2 = DiscordChannel(
+ id=111222334, server_id=server.id, name="off-topic", channel_type="text"
+ )
+ db_session.add_all([server, channel1, channel2])
+ db_session.commit()
+
+ assert len(server.channels) == 2
+ assert channel1 in server.channels
+ assert channel2 in server.channels
+
+
+def test_discord_server_cascade_delete(db_session):
+ """Test that deleting a server cascades to channels."""
+ server = DiscordServer(id=987654321, name="Test Server")
+ channel = DiscordChannel(
+ id=111222333, server_id=server.id, name="general", channel_type="text"
+ )
+ db_session.add_all([server, channel])
+ db_session.commit()
+
+ channel_id = channel.id
+
+ # Delete server
+ db_session.delete(server)
+ db_session.commit()
+
+ # Channel should be deleted too
+ deleted_channel = db_session.get(DiscordChannel, channel_id)
+ assert deleted_channel is None
diff --git a/tests/memory/common/db/models/test_users.py b/tests/memory/common/db/models/test_users.py
index 56b0473..dd07bf8 100644
--- a/tests/memory/common/db/models/test_users.py
+++ b/tests/memory/common/db/models/test_users.py
@@ -1,5 +1,12 @@
import pytest
-from memory.common.db.models.users import hash_password, verify_password
+from memory.common.db.models.users import (
+ hash_password,
+ verify_password,
+ User,
+ HumanUser,
+ BotUser,
+ DiscordBotUser,
+)
@pytest.mark.parametrize(
@@ -102,3 +109,144 @@ def test_hash_verify_roundtrip(test_password):
# Wrong password should not verify
assert not verify_password(test_password + "_wrong", password_hash)
+
+
+# Test User Model Hierarchy
+
+
+def test_create_human_user(db_session):
+ """Test creating a HumanUser with password"""
+ user = HumanUser.create_with_password(
+ email="human@example.com", name="Human User", password="test_password123"
+ )
+ db_session.add(user)
+ db_session.commit()
+
+ assert user.id is not None
+ assert user.email == "human@example.com"
+ assert user.name == "Human User"
+ assert user.user_type == "human"
+ assert user.password_hash is not None
+ assert user.api_key is None
+ assert user.is_valid_password("test_password123")
+ assert not user.is_valid_password("wrong_password")
+
+
+def test_create_bot_user(db_session):
+ """Test creating a BotUser with API key"""
+ user = BotUser.create_with_api_key(
+ name="Test Bot", email="bot@example.com", api_key="test_api_key_123"
+ )
+ db_session.add(user)
+ db_session.commit()
+
+ assert user.id is not None
+ assert user.email == "bot@example.com"
+ assert user.name == "Test Bot"
+ assert user.user_type == "bot"
+ assert user.api_key == "test_api_key_123"
+ assert user.password_hash is None
+
+
+def test_create_bot_user_auto_api_key(db_session):
+ """Test creating a BotUser with auto-generated API key"""
+ user = BotUser.create_with_api_key(name="Auto Bot", email="autobot@example.com")
+ db_session.add(user)
+ db_session.commit()
+
+ assert user.id is not None
+ assert user.api_key is not None
+ assert user.api_key.startswith("bot_")
+ assert len(user.api_key) == 68 # "bot_" + 32 bytes hex encoded (64 chars)
+
+
+def test_create_discord_bot_user(db_session):
+ """Test creating a DiscordBotUser"""
+ user = DiscordBotUser.create_with_api_key(
+ discord_users=[],
+ name="Discord Bot",
+ email="discordbot@example.com",
+ api_key="discord_key_123",
+ )
+ db_session.add(user)
+ db_session.commit()
+
+ assert user.id is not None
+ assert user.email == "discordbot@example.com"
+ assert user.name == "Discord Bot"
+ assert user.user_type == "discord_bot"
+ assert user.api_key == "discord_key_123"
+
+
+def test_user_serialization_human(db_session):
+ """Test HumanUser serialization"""
+ user = HumanUser.create_with_password(
+ email="serialize@example.com", name="Serialize User", password="password123"
+ )
+ db_session.add(user)
+ db_session.commit()
+
+ serialized = user.serialize()
+ assert serialized["user_id"] == user.id
+ assert serialized["name"] == "Serialize User"
+ assert serialized["email"] == "serialize@example.com"
+ assert serialized["user_type"] == "human"
+ assert "password_hash" not in serialized # Should not expose password hash
+
+
+def test_user_serialization_bot(db_session):
+ """Test BotUser serialization"""
+ user = BotUser.create_with_api_key(name="Bot", email="bot@example.com")
+ db_session.add(user)
+ db_session.commit()
+
+ serialized = user.serialize()
+ assert serialized["user_id"] == user.id
+ assert serialized["name"] == "Bot"
+ assert serialized["email"] == "bot@example.com"
+ assert serialized["user_type"] == "bot"
+ assert "api_key" not in serialized # Should not expose API key
+
+
+def test_bot_user_api_key_uniqueness(db_session):
+ """Test that API keys must be unique"""
+ user1 = BotUser.create_with_api_key(
+ name="Bot 1", email="bot1@example.com", api_key="same_key"
+ )
+ user2 = BotUser.create_with_api_key(
+ name="Bot 2", email="bot2@example.com", api_key="same_key"
+ )
+ db_session.add(user1)
+ db_session.commit()
+
+ db_session.add(user2)
+ with pytest.raises(Exception): # IntegrityError from unique constraint
+ db_session.commit()
+
+
+def test_human_user_factory_method(db_session):
+ """Test that HumanUser factory method sets all required fields"""
+ user = HumanUser.create_with_password(
+ email="factory@example.com", name="Factory User", password="test123"
+ )
+
+ # Factory method should set all required fields
+ assert user.email == "factory@example.com"
+ assert user.name == "Factory User"
+ assert user.password_hash is not None
+ assert user.user_type == "human"
+ assert user.api_key is None
+
+
+def test_bot_user_factory_method(db_session):
+ """Test that BotUser factory method sets all required fields"""
+ user = BotUser.create_with_api_key(
+ name="Factory Bot", email="factorybot@example.com", api_key="test_key"
+ )
+
+ # Factory method should set all required fields
+ assert user.email == "factorybot@example.com"
+ assert user.name == "Factory Bot"
+ assert user.api_key == "test_key"
+ assert user.user_type == "bot"
+ assert user.password_hash is None
diff --git a/tests/memory/common/llms/tools/test_discord_tools.py b/tests/memory/common/llms/tools/test_discord_tools.py
new file mode 100644
index 0000000..1a30cdf
--- /dev/null
+++ b/tests/memory/common/llms/tools/test_discord_tools.py
@@ -0,0 +1,584 @@
+"""Tests for Discord LLM tools."""
+
+import pytest
+from datetime import datetime, timezone, timedelta
+from unittest.mock import Mock, patch
+
+from memory.common.llms.tools.discord import (
+ handle_update_summary_call,
+ make_summary_tool,
+ schedule_message,
+ make_message_scheduler,
+ make_prev_messages_tool,
+ make_discord_tools,
+)
+from memory.common.db.models import (
+ DiscordServer,
+ DiscordChannel,
+ DiscordUser,
+ DiscordMessage,
+ BotUser,
+ HumanUser,
+ ScheduledLLMCall,
+)
+
+
+# Fixtures for Discord entities
+@pytest.fixture
+def sample_discord_server(db_session):
+ """Create a sample Discord server for testing."""
+ server = DiscordServer(
+ id=123456789,
+ name="Test Server",
+ summary="A test server for testing",
+ )
+ db_session.add(server)
+ db_session.commit()
+ return server
+
+
+@pytest.fixture
+def sample_discord_channel(db_session, sample_discord_server):
+ """Create a sample Discord channel for testing."""
+ channel = DiscordChannel(
+ id=987654321,
+ server_id=sample_discord_server.id,
+ name="general",
+ channel_type="text",
+ summary="General discussion channel",
+ )
+ db_session.add(channel)
+ db_session.commit()
+ return channel
+
+
+@pytest.fixture
+def sample_discord_user(db_session):
+ """Create a sample Discord user for testing."""
+ user = DiscordUser(
+ id=111222333,
+ username="testuser",
+ display_name="Test User",
+ summary="A test user",
+ )
+ db_session.add(user)
+ db_session.commit()
+ return user
+
+
+@pytest.fixture
+def sample_bot_user(db_session):
+ """Create a sample bot user for testing."""
+ bot = BotUser.create_with_api_key(
+ name="Test Bot",
+ email="testbot@example.com",
+ )
+ db_session.add(bot)
+ db_session.commit()
+ return bot
+
+
+@pytest.fixture
+def sample_human_user(db_session):
+ """Create a sample human user for testing."""
+ user = HumanUser.create_with_password(
+ email="human@example.com",
+ name="Human User",
+ password="test_password123",
+ )
+ db_session.add(user)
+ db_session.commit()
+ return user
+
+
+# Tests for handle_update_summary_call
+def test_handle_update_summary_call_server_dict_input(
+ db_session, sample_discord_server
+):
+ """Test updating server summary with dict input."""
+ handler = handle_update_summary_call("server", sample_discord_server.id)
+
+ result = handler({"summary": "New server summary"})
+
+ assert result == "Updated summary"
+
+ # Verify the summary was updated in the database
+ db_session.refresh(sample_discord_server)
+ assert sample_discord_server.summary == "New server summary"
+
+
+def test_handle_update_summary_call_channel_dict_input(
+ db_session, sample_discord_channel
+):
+ """Test updating channel summary with dict input."""
+ handler = handle_update_summary_call("channel", sample_discord_channel.id)
+
+ result = handler({"summary": "New channel summary"})
+
+ assert result == "Updated summary"
+
+ db_session.refresh(sample_discord_channel)
+ assert sample_discord_channel.summary == "New channel summary"
+
+
+def test_handle_update_summary_call_user_dict_input(db_session, sample_discord_user):
+ """Test updating user summary with dict input."""
+ handler = handle_update_summary_call("user", sample_discord_user.id)
+
+ result = handler({"summary": "New user summary"})
+
+ assert result == "Updated summary"
+
+ db_session.refresh(sample_discord_user)
+ assert sample_discord_user.summary == "New user summary"
+
+
+def test_handle_update_summary_call_string_input(db_session, sample_discord_server):
+ """Test updating summary with string input."""
+ handler = handle_update_summary_call("server", sample_discord_server.id)
+
+ result = handler("String summary")
+
+ assert result == "Updated summary"
+
+ db_session.refresh(sample_discord_server)
+ assert sample_discord_server.summary == "String summary"
+
+
+def test_handle_update_summary_call_dict_without_summary_key(
+ db_session, sample_discord_server
+):
+ """Test updating summary with dict that doesn't have 'summary' key."""
+ handler = handle_update_summary_call("server", sample_discord_server.id)
+
+ result = handler({"other_key": "value"})
+
+ assert result == "Updated summary"
+
+ db_session.refresh(sample_discord_server)
+ # Should use string representation of the dict
+ assert "other_key" in sample_discord_server.summary
+
+
+def test_handle_update_summary_call_nonexistent_entity(db_session):
+ """Test updating summary for nonexistent entity."""
+ handler = handle_update_summary_call("server", 999999999)
+
+ result = handler({"summary": "New summary"})
+
+ assert "Error updating summary" in result
+
+
+# Tests for make_summary_tool
+def test_make_summary_tool_server(sample_discord_server):
+ """Test creating a summary tool for a server."""
+ tool = make_summary_tool("server", sample_discord_server.id)
+
+ assert tool.name == "update_server_summary"
+ assert "server" in tool.description
+ assert tool.input_schema["type"] == "object"
+ assert "summary" in tool.input_schema["properties"]
+ assert callable(tool.function)
+
+
+def test_make_summary_tool_channel(sample_discord_channel):
+ """Test creating a summary tool for a channel."""
+ tool = make_summary_tool("channel", sample_discord_channel.id)
+
+ assert tool.name == "update_channel_summary"
+ assert "channel" in tool.description
+ assert callable(tool.function)
+
+
+def test_make_summary_tool_user(sample_discord_user):
+ """Test creating a summary tool for a user."""
+ tool = make_summary_tool("user", sample_discord_user.id)
+
+ assert tool.name == "update_user_summary"
+ assert "user" in tool.description
+ assert callable(tool.function)
+
+
+# Tests for schedule_message
+def test_schedule_message_with_user(
+ db_session,
+ sample_human_user,
+ sample_discord_user,
+):
+ """Test scheduling a message to a Discord user."""
+ future_time = datetime.now(timezone.utc) + timedelta(hours=1)
+
+ result = schedule_message(
+ user_id=sample_human_user.id,
+ user=sample_discord_user.id,
+ channel=None,
+ model="test-model",
+ message="Test message",
+ date_time=future_time,
+ )
+
+ # Result should be the ID of the created scheduled call (UUID string)
+ assert isinstance(result, str)
+
+ # Verify the scheduled call was created in the database
+ # Need to use a fresh query since schedule_message uses its own session
+ scheduled_call = db_session.query(ScheduledLLMCall).filter_by(id=result).first()
+ assert scheduled_call is not None
+ assert scheduled_call.user_id == sample_human_user.id
+ assert scheduled_call.discord_user_id == sample_discord_user.id
+ assert scheduled_call.discord_channel_id is None
+ assert scheduled_call.message == "Test message"
+ assert scheduled_call.model == "test-model"
+
+
+def test_schedule_message_with_channel(
+ db_session,
+ sample_human_user,
+ sample_discord_channel,
+):
+ """Test scheduling a message to a Discord channel."""
+ future_time = datetime.now(timezone.utc) + timedelta(hours=1)
+
+ result = schedule_message(
+ user_id=sample_human_user.id,
+ user=None,
+ channel=sample_discord_channel.id,
+ model="test-model",
+ message="Test message",
+ date_time=future_time,
+ )
+
+ # Result should be the ID of the created scheduled call (UUID string)
+ assert isinstance(result, str)
+
+ # Verify the scheduled call was created in the database
+ scheduled_call = db_session.query(ScheduledLLMCall).filter_by(id=result).first()
+ assert scheduled_call is not None
+ assert scheduled_call.user_id == sample_human_user.id
+ assert scheduled_call.discord_user_id is None
+ assert scheduled_call.discord_channel_id == sample_discord_channel.id
+ assert scheduled_call.message == "Test message"
+
+
+# Tests for make_message_scheduler
+def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user):
+ """Test creating a message scheduler tool for a user."""
+ tool = make_message_scheduler(
+ bot=sample_bot_user,
+ user=sample_discord_user.id,
+ channel=None,
+ model="test-model",
+ )
+
+ assert tool.name == "schedule_message"
+ assert "from your chat with this user" in tool.description
+ assert tool.input_schema["type"] == "object"
+ assert "message" in tool.input_schema["properties"]
+ assert "date_time" in tool.input_schema["properties"]
+ assert callable(tool.function)
+
+
+def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_channel):
+ """Test creating a message scheduler tool for a channel."""
+ tool = make_message_scheduler(
+ bot=sample_bot_user,
+ user=None,
+ channel=sample_discord_channel.id,
+ model="test-model",
+ )
+
+ assert tool.name == "schedule_message"
+ assert "in this channel" in tool.description
+ assert callable(tool.function)
+
+
+def test_make_message_scheduler_without_user_or_channel(sample_bot_user):
+ """Test that creating a scheduler without user or channel raises error."""
+ with pytest.raises(ValueError, match="Either user or channel must be provided"):
+ make_message_scheduler(
+ bot=sample_bot_user,
+ user=None,
+ channel=None,
+ model="test-model",
+ )
+
+
+@patch("memory.common.llms.tools.discord.schedule_message")
+def test_message_scheduler_handler_success(
+ mock_schedule_message, sample_bot_user, sample_discord_user
+):
+ """Test message scheduler handler with valid input."""
+ tool = make_message_scheduler(
+ bot=sample_bot_user,
+ user=sample_discord_user.id,
+ channel=None,
+ model="test-model",
+ )
+
+ mock_schedule_message.return_value = "123"
+ future_time = datetime.now(timezone.utc) + timedelta(hours=1)
+
+ result = tool.function(
+ {"message": "Test message", "date_time": future_time.isoformat()}
+ )
+
+ assert result == "123"
+ mock_schedule_message.assert_called_once()
+
+
+def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord_user):
+ """Test message scheduler handler with non-dict input."""
+ tool = make_message_scheduler(
+ bot=sample_bot_user,
+ user=sample_discord_user.id,
+ channel=None,
+ model="test-model",
+ )
+
+ with pytest.raises(ValueError, match="Input must be a dictionary"):
+ tool.function("not a dict")
+
+
+def test_message_scheduler_handler_invalid_datetime(
+ sample_bot_user, sample_discord_user
+):
+ """Test message scheduler handler with invalid datetime."""
+ tool = make_message_scheduler(
+ bot=sample_bot_user,
+ user=sample_discord_user.id,
+ channel=None,
+ model="test-model",
+ )
+
+ with pytest.raises(ValueError, match="Invalid date time format"):
+ tool.function(
+ {
+ "message": "Test message",
+ "date_time": "not a valid datetime",
+ }
+ )
+
+
+def test_message_scheduler_handler_missing_datetime(
+ sample_bot_user, sample_discord_user
+):
+ """Test message scheduler handler with missing datetime."""
+ tool = make_message_scheduler(
+ bot=sample_bot_user,
+ user=sample_discord_user.id,
+ channel=None,
+ model="test-model",
+ )
+
+ with pytest.raises(ValueError, match="Date time is required"):
+ tool.function({"message": "Test message"})
+
+
+# Tests for make_prev_messages_tool
+def test_make_prev_messages_tool_with_user(sample_discord_user):
+ """Test creating a previous messages tool for a user."""
+ tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+
+ assert tool.name == "previous_messages"
+ assert "from your chat with this user" in tool.description
+ assert tool.input_schema["type"] == "object"
+ assert "max_messages" in tool.input_schema["properties"]
+ assert "offset" in tool.input_schema["properties"]
+ assert callable(tool.function)
+
+
+def test_make_prev_messages_tool_with_channel(sample_discord_channel):
+ """Test creating a previous messages tool for a channel."""
+ tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id)
+
+ assert tool.name == "previous_messages"
+ assert "in this channel" in tool.description
+ assert callable(tool.function)
+
+
+def test_make_prev_messages_tool_without_user_or_channel():
+ """Test that creating a tool without user or channel raises error."""
+ with pytest.raises(ValueError, match="Either user or channel must be provided"):
+ make_prev_messages_tool(user=None, channel=None)
+
+
+def test_prev_messages_handler_success(
+ db_session, sample_discord_user, sample_discord_channel
+):
+ """Test previous messages handler with valid input."""
+ tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+
+ # Create some actual messages in the database
+ msg1 = DiscordMessage(
+ message_id=1,
+ channel_id=sample_discord_channel.id,
+ from_id=sample_discord_user.id,
+ recipient_id=sample_discord_user.id,
+ content="Message 1",
+ sent_at=datetime.now(timezone.utc) - timedelta(minutes=10),
+ modality="text",
+ sha256=b"hash1" + bytes(26),
+ )
+ msg2 = DiscordMessage(
+ message_id=2,
+ channel_id=sample_discord_channel.id,
+ from_id=sample_discord_user.id,
+ recipient_id=sample_discord_user.id,
+ content="Message 2",
+ sent_at=datetime.now(timezone.utc) - timedelta(minutes=5),
+ modality="text",
+ sha256=b"hash2" + bytes(26),
+ )
+ db_session.add_all([msg1, msg2])
+ db_session.commit()
+
+ result = tool.function({"max_messages": 10, "offset": 0})
+
+ # Should return messages formatted as strings
+ assert isinstance(result, str)
+ # Both messages should be in the result
+ assert "Message 1" in result or "Message 2" in result
+
+
+def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
+ """Test previous messages handler with default values."""
+ tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+
+ result = tool.function({})
+
+ # Should return empty string when no messages
+ assert isinstance(result, str)
+
+
+def test_prev_messages_handler_invalid_input(sample_discord_user):
+ """Test previous messages handler with non-dict input."""
+ tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+
+ with pytest.raises(ValueError, match="Input must be a dictionary"):
+ tool.function("not a dict")
+
+
+def test_prev_messages_handler_invalid_max_messages(sample_discord_user):
+ """Test previous messages handler with invalid max_messages (negative value)."""
+ # Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
+ # so we test with -1 which actually triggers the validation
+ tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+
+ with pytest.raises(ValueError, match="Max messages must be greater than 0"):
+ tool.function({"max_messages": -1})
+
+
+def test_prev_messages_handler_invalid_offset(sample_discord_user):
+ """Test previous messages handler with invalid offset."""
+ tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+
+ with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
+ tool.function({"offset": -1})
+
+
+def test_prev_messages_handler_non_integer_values(sample_discord_user):
+ """Test previous messages handler with non-integer values."""
+ tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+
+ with pytest.raises(ValueError, match="Max messages and offset must be integers"):
+ tool.function({"max_messages": "not an int"})
+
+
+# Tests for make_discord_tools
+def test_make_discord_tools_with_user_and_channel(
+ sample_bot_user, sample_discord_user, sample_discord_channel
+):
+ """Test creating Discord tools with both user and channel."""
+ tools = make_discord_tools(
+ bot=sample_bot_user,
+ author=sample_discord_user,
+ channel=sample_discord_channel,
+ model="test-model",
+ )
+
+ # Should have: schedule_message, previous_messages, update_channel_summary,
+ # update_user_summary, update_server_summary
+ assert len(tools) == 5
+ assert "schedule_message" in tools
+ assert "previous_messages" in tools
+ assert "update_channel_summary" in tools
+ assert "update_user_summary" in tools
+ assert "update_server_summary" in tools
+
+
+def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user):
+ """Test creating Discord tools with only user (DM scenario)."""
+ tools = make_discord_tools(
+ bot=sample_bot_user,
+ author=sample_discord_user,
+ channel=None,
+ model="test-model",
+ )
+
+ # Should have: schedule_message, previous_messages, update_user_summary
+ # Note: Without channel, there's no channel summary tool
+ assert len(tools) >= 2 # At least schedule and previous messages
+ assert "schedule_message" in tools
+ assert "previous_messages" in tools
+ assert "update_user_summary" in tools
+
+
+def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_channel):
+ """Test creating Discord tools with only channel (no specific author)."""
+ tools = make_discord_tools(
+ bot=sample_bot_user,
+ author=None,
+ channel=sample_discord_channel,
+ model="test-model",
+ )
+
+ # Should have: schedule_message, previous_messages, update_channel_summary,
+ # update_server_summary (no user summary without author)
+ assert len(tools) == 4
+ assert "schedule_message" in tools
+ assert "previous_messages" in tools
+ assert "update_channel_summary" in tools
+ assert "update_server_summary" in tools
+ assert "update_user_summary" not in tools
+
+
+def test_make_discord_tools_channel_without_server(
+ db_session, sample_bot_user, sample_discord_user
+):
+ """Test creating Discord tools with channel that has no server (DM channel)."""
+ dm_channel = DiscordChannel(
+ id=999888777,
+ server_id=None,
+ name="DM Channel",
+ channel_type="dm",
+ )
+ db_session.add(dm_channel)
+ db_session.commit()
+
+ tools = make_discord_tools(
+ bot=sample_bot_user,
+ author=sample_discord_user,
+ channel=dm_channel,
+ model="test-model",
+ )
+
+ # Should not have server summary tool since channel has no server
+ assert "update_server_summary" not in tools
+ assert "update_channel_summary" in tools
+ assert "update_user_summary" in tools
+
+
+def test_make_discord_tools_returns_dict_with_correct_keys(
+ sample_bot_user, sample_discord_user, sample_discord_channel
+):
+ """Test that make_discord_tools returns a dict with tool names as keys."""
+ tools = make_discord_tools(
+ bot=sample_bot_user,
+ author=sample_discord_user,
+ channel=sample_discord_channel,
+ model="test-model",
+ )
+
+ # Verify all keys match the tool names
+ for tool_name, tool in tools.items():
+ assert tool_name == tool.name
diff --git a/tests/memory/common/test_discord.py b/tests/memory/common/test_discord.py
index 3d46a0d..274ba8e 100644
--- a/tests/memory/common/test_discord.py
+++ b/tests/memory/common/test_discord.py
@@ -34,7 +34,7 @@ def test_send_dm_success(mock_post, mock_api_url):
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_dm",
- json={"user_identifier": "user123", "message": "Hello!"},
+ json={"user": "user123", "message": "Hello!"},
timeout=10,
)
diff --git a/tests/memory/common/test_discord_integration.py b/tests/memory/common/test_discord_integration.py
index 3d46a0d..274ba8e 100644
--- a/tests/memory/common/test_discord_integration.py
+++ b/tests/memory/common/test_discord_integration.py
@@ -34,7 +34,7 @@ def test_send_dm_success(mock_post, mock_api_url):
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_dm",
- json={"user_identifier": "user123", "message": "Hello!"},
+ json={"user": "user123", "message": "Hello!"},
timeout=10,
)
diff --git a/tests/memory/discord/test_collector.py b/tests/memory/discord_tests/test_collector.py
similarity index 99%
rename from tests/memory/discord/test_collector.py
rename to tests/memory/discord_tests/test_collector.py
index b3624e8..63f7aef 100644
--- a/tests/memory/discord/test_collector.py
+++ b/tests/memory/discord_tests/test_collector.py
@@ -15,7 +15,7 @@ from memory.discord.collector import (
sync_guild_metadata,
MessageCollector,
)
-from memory.common.db.models.sources import (
+from memory.common.db.models import (
DiscordServer,
DiscordChannel,
DiscordUser,
diff --git a/tests/memory/discord_tests/test_messages.py b/tests/memory/discord_tests/test_messages.py
new file mode 100644
index 0000000..6c06662
--- /dev/null
+++ b/tests/memory/discord_tests/test_messages.py
@@ -0,0 +1,413 @@
+"""Tests for Discord message helper functions."""
+
+import pytest
+from datetime import datetime, timedelta, timezone
+from memory.discord.messages import (
+ resolve_discord_user,
+ resolve_discord_channel,
+ schedule_discord_message,
+ upsert_scheduled_message,
+ previous_messages,
+ comm_channel_prompt,
+)
+from memory.common.db.models import (
+ DiscordUser,
+ DiscordChannel,
+ DiscordServer,
+ DiscordMessage,
+ HumanUser,
+ ScheduledLLMCall,
+)
+
+
+@pytest.fixture
+def sample_discord_user(db_session):
+ """Create a sample Discord user."""
+ user = DiscordUser(id=123456789, username="testuser")
+ db_session.add(user)
+ db_session.commit()
+ return user
+
+
+@pytest.fixture
+def sample_discord_channel(db_session):
+ """Create a sample Discord channel."""
+ server = DiscordServer(id=987654321, name="Test Server")
+ channel = DiscordChannel(
+ id=111222333, name="general", channel_type="text", server_id=server.id
+ )
+ db_session.add_all([server, channel])
+ db_session.commit()
+ return channel
+
+
+@pytest.fixture
+def sample_system_user(db_session):
+ """Create a sample system user."""
+ user = HumanUser.create_with_password(
+ email="user@example.com", name="Test User", password="password123"
+ )
+ db_session.add(user)
+ db_session.commit()
+ return user
+
+
+# Test resolve_discord_user
+
+
+def test_resolve_discord_user_with_none(db_session):
+ """Test resolving None returns None."""
+ result = resolve_discord_user(db_session, None)
+ assert result is None
+
+
+def test_resolve_discord_user_with_discord_user_object(
+ db_session, sample_discord_user
+):
+ """Test resolving a DiscordUser object returns it unchanged."""
+ result = resolve_discord_user(db_session, sample_discord_user)
+ assert result == sample_discord_user
+ assert result.id == 123456789
+
+
+def test_resolve_discord_user_with_id(db_session, sample_discord_user):
+ """Test resolving by integer ID."""
+ result = resolve_discord_user(db_session, 123456789)
+ assert result is not None
+ assert result.id == 123456789
+ assert result.username == "testuser"
+
+
+def test_resolve_discord_user_with_username(db_session, sample_discord_user):
+ """Test resolving by username string."""
+ result = resolve_discord_user(db_session, "testuser")
+ assert result is not None
+ assert result.username == "testuser"
+
+
+def test_resolve_discord_user_with_nonexistent_username_returns_none(db_session):
+ """Test that resolving a non-existent username returns None."""
+ result = resolve_discord_user(db_session, "nonexistent")
+ assert result is None
+
+
+# Test resolve_discord_channel
+
+
+def test_resolve_discord_channel_with_none(db_session):
+ """Test resolving None returns None."""
+ result = resolve_discord_channel(db_session, None)
+ assert result is None
+
+
+def test_resolve_discord_channel_with_channel_object(
+ db_session, sample_discord_channel
+):
+ """Test resolving a DiscordChannel object returns it unchanged."""
+ result = resolve_discord_channel(db_session, sample_discord_channel)
+ assert result == sample_discord_channel
+ assert result.id == 111222333
+
+
+def test_resolve_discord_channel_with_id(db_session, sample_discord_channel):
+ """Test resolving by integer ID."""
+ result = resolve_discord_channel(db_session, 111222333)
+ assert result is not None
+ assert result.id == 111222333
+ assert result.name == "general"
+
+
+def test_resolve_discord_channel_with_name(db_session, sample_discord_channel):
+ """Test resolving by channel name string."""
+ result = resolve_discord_channel(db_session, "general")
+ assert result is not None
+ assert result.name == "general"
+
+
+def test_resolve_discord_channel_returns_none_if_not_found(db_session):
+ """Test that resolving a non-existent channel returns None."""
+ result = resolve_discord_channel(db_session, "nonexistent")
+ assert result is None
+
+
+# Test schedule_discord_message
+
+
+def test_schedule_discord_message_with_user(
+ db_session, sample_discord_user, sample_system_user
+):
+ """Test scheduling a message to a Discord user."""
+ future_time = datetime.now(timezone.utc) + timedelta(hours=1)
+
+ result = schedule_discord_message(
+ db_session,
+ scheduled_time=future_time,
+ message="Test message",
+ user_id=sample_system_user.id,
+ discord_user=sample_discord_user,
+ model="test-model",
+ topic="Test Topic",
+ )
+ db_session.flush() # Flush to populate the foreign key IDs
+
+ assert result is not None
+ assert isinstance(result, ScheduledLLMCall)
+ assert result.message == "Test message"
+ assert result.discord_user_id == sample_discord_user.id
+ assert result.user_id == sample_system_user.id
+
+
+def test_schedule_discord_message_with_channel(
+ db_session, sample_discord_channel, sample_system_user
+):
+ """Test scheduling a message to a Discord channel."""
+ future_time = datetime.now(timezone.utc) + timedelta(hours=1)
+
+ result = schedule_discord_message(
+ db_session,
+ scheduled_time=future_time,
+ message="Channel message",
+ user_id=sample_system_user.id,
+ discord_channel=sample_discord_channel,
+ )
+ db_session.flush() # Flush to populate the foreign key IDs
+
+ assert result is not None
+ assert result.discord_channel_id == sample_discord_channel.id
+
+
+def test_schedule_discord_message_requires_user_or_channel(
+ db_session, sample_system_user
+):
+ """Test that scheduling requires either user or channel."""
+ future_time = datetime.now(timezone.utc) + timedelta(hours=1)
+
+ with pytest.raises(ValueError, match="Either discord_user or discord_channel must be provided"):
+ schedule_discord_message(
+ db_session,
+ scheduled_time=future_time,
+ message="Test",
+ user_id=sample_system_user.id,
+ )
+
+
+def test_schedule_discord_message_requires_future_time(
+ db_session, sample_discord_user, sample_system_user
+):
+ """Test that scheduling requires a future time."""
+ past_time = datetime.now(timezone.utc) - timedelta(hours=1)
+
+ with pytest.raises(ValueError, match="Scheduled time must be in the future"):
+ schedule_discord_message(
+ db_session,
+ scheduled_time=past_time,
+ message="Test",
+ user_id=sample_system_user.id,
+ discord_user=sample_discord_user,
+ )
+
+
+def test_schedule_discord_message_with_metadata(
+ db_session, sample_discord_user, sample_system_user
+):
+ """Test scheduling with custom metadata."""
+ future_time = datetime.now(timezone.utc) + timedelta(hours=1)
+ metadata = {"priority": "high", "tags": ["urgent"]}
+
+ result = schedule_discord_message(
+ db_session,
+ scheduled_time=future_time,
+ message="Urgent message",
+ user_id=sample_system_user.id,
+ discord_user=sample_discord_user,
+ metadata=metadata,
+ )
+
+ assert result.data == metadata
+
+
+# Test upsert_scheduled_message
+
+
+def test_upsert_scheduled_message_creates_new(
+ db_session, sample_discord_user, sample_system_user
+):
+ """Test upserting creates a new message if none exists."""
+ future_time = datetime.now(timezone.utc) + timedelta(hours=1)
+
+ result = upsert_scheduled_message(
+ db_session,
+ scheduled_time=future_time,
+ message="New message",
+ user_id=sample_system_user.id,
+ discord_user=sample_discord_user,
+ model="test-model",
+ )
+
+ assert result is not None
+ assert result.message == "New message"
+
+
+def test_upsert_scheduled_message_cancels_earlier_call(
+ db_session, sample_discord_user, sample_system_user
+):
+ """Test upserting cancels an earlier scheduled call for the same user/channel."""
+ future_time1 = datetime.now(timezone.utc) + timedelta(hours=2)
+ future_time2 = datetime.now(timezone.utc) + timedelta(hours=1)
+
+ # Create first scheduled message
+ first_call = schedule_discord_message(
+ db_session,
+ scheduled_time=future_time1,
+ message="First message",
+ user_id=sample_system_user.id,
+ discord_user=sample_discord_user,
+ model="test-model",
+ )
+ db_session.commit()
+
+ # Upsert with earlier time should cancel the first
+ second_call = upsert_scheduled_message(
+ db_session,
+ scheduled_time=future_time2,
+ message="Second message",
+ user_id=sample_system_user.id,
+ discord_user=sample_discord_user,
+ model="test-model",
+ )
+ db_session.commit()
+
+ db_session.refresh(first_call)
+ assert first_call.status == "cancelled"
+ assert second_call.status == "pending"
+
+
+# Test previous_messages
+
+
+def test_previous_messages_empty(db_session):
+ """Test getting previous messages when none exist."""
+ result = previous_messages(db_session, user_id=123, channel_id=456)
+ assert result == []
+
+
+def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel):
+ """Test filtering messages by recipient user."""
+ # Create some messages
+ msg1 = DiscordMessage(
+ message_id=1,
+ channel_id=sample_discord_channel.id,
+ from_id=sample_discord_user.id,
+ recipient_id=sample_discord_user.id,
+ content="Message 1",
+ sent_at=datetime.now(timezone.utc) - timedelta(minutes=10),
+ modality="text",
+ sha256=b"hash1" + bytes(26),
+ )
+ msg2 = DiscordMessage(
+ message_id=2,
+ channel_id=sample_discord_channel.id,
+ from_id=sample_discord_user.id,
+ recipient_id=sample_discord_user.id,
+ content="Message 2",
+ sent_at=datetime.now(timezone.utc) - timedelta(minutes=5),
+ modality="text",
+ sha256=b"hash2" + bytes(26),
+ )
+ db_session.add_all([msg1, msg2])
+ db_session.commit()
+
+ result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None)
+ assert len(result) == 2
+ # Should be in chronological order (oldest first)
+ assert result[0].message_id == 1
+ assert result[1].message_id == 2
+
+
+def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel):
+ """Test limiting the number of previous messages."""
+ # Create 15 messages
+ for i in range(15):
+ msg = DiscordMessage(
+ message_id=i,
+ channel_id=sample_discord_channel.id,
+ from_id=sample_discord_user.id,
+ recipient_id=sample_discord_user.id,
+ content=f"Message {i}",
+ sent_at=datetime.now(timezone.utc) - timedelta(minutes=15 - i),
+ modality="text",
+ sha256=f"hash{i}".encode() + bytes(26),
+ )
+ db_session.add(msg)
+ db_session.commit()
+
+ result = previous_messages(
+ db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5
+ )
+ assert len(result) == 5
+
+
+# Test comm_channel_prompt
+
+
+def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel):
+ """Test generating a basic communication channel prompt."""
+ result = comm_channel_prompt(
+ db_session, user=sample_discord_user, channel=sample_discord_channel
+ )
+
+ assert "You are a bot communicating on Discord" in result
+ assert isinstance(result, str)
+ assert len(result) > 0
+
+
+def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel):
+ """Test that prompt includes server context when available."""
+ server = sample_discord_channel.server
+ server.summary = "Gaming community server"
+ db_session.commit()
+
+ result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
+
+ assert "server_context" in result.lower()
+ assert "Gaming community server" in result
+
+
+def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel):
+ """Test that prompt includes channel context."""
+ sample_discord_channel.summary = "General discussion channel"
+ db_session.commit()
+
+ result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
+
+ assert "channel_context" in result.lower()
+ assert "General discussion channel" in result
+
+
+def test_comm_channel_prompt_includes_user_notes(
+ db_session, sample_discord_user, sample_discord_channel
+):
+ """Test that prompt includes user notes from previous messages."""
+ sample_discord_user.summary = "Helpful community member"
+ db_session.commit()
+
+ # Create a message from this user
+ msg = DiscordMessage(
+ message_id=1,
+ from_id=sample_discord_user.id,
+ recipient_id=sample_discord_user.id,
+ channel_id=sample_discord_channel.id,
+ content="Hello",
+ sent_at=datetime.now(timezone.utc),
+ modality="text",
+ sha256=b"hash" + bytes(27),
+ )
+ db_session.add(msg)
+ db_session.commit()
+
+ result = comm_channel_prompt(
+ db_session, user=sample_discord_user, channel=sample_discord_channel
+ )
+
+ assert "user_notes" in result.lower()
+ assert "testuser" in result # username should appear
diff --git a/tests/memory/workers/tasks/test_content_processing.py b/tests/memory/workers/tasks/test_content_processing.py
index 0918d9c..d0310bc 100644
--- a/tests/memory/workers/tasks/test_content_processing.py
+++ b/tests/memory/workers/tasks/test_content_processing.py
@@ -733,6 +733,8 @@ def test_safe_task_execution_exception_logging(caplog):
result = failing_task()
- assert result == {"status": "error", "error": "Test runtime error"}
+ assert result["status"] == "error"
+ assert result["error"] == "Test runtime error"
+ assert "traceback" in result
assert "Task failing_task failed:" in caplog.text
assert "Test runtime error" in caplog.text
diff --git a/tests/memory/workers/tasks/test_discord_tasks.py b/tests/memory/workers/tasks/test_discord_tasks.py
index 5ce5765..61e985d 100644
--- a/tests/memory/workers/tasks/test_discord_tasks.py
+++ b/tests/memory/workers/tasks/test_discord_tasks.py
@@ -59,6 +59,7 @@ def sample_message_data(mock_discord_user, mock_discord_channel):
"message_id": 999888777,
"channel_id": mock_discord_channel.id,
"author_id": mock_discord_user.id,
+ "recipient_id": mock_discord_user.id,
"content": "This is a test Discord message with enough content to be processed.",
"sent_at": "2024-01-01T12:00:00Z",
"server_id": None,
@@ -74,7 +75,8 @@ def test_get_prev_returns_previous_messages(
msg1 = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
content="First message",
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
modality="text",
@@ -83,7 +85,8 @@ def test_get_prev_returns_previous_messages(
msg2 = DiscordMessage(
message_id=2,
channel_id=mock_discord_channel.id,
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
content="Second message",
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
modality="text",
@@ -92,7 +95,8 @@ def test_get_prev_returns_previous_messages(
msg3 = DiscordMessage(
message_id=3,
channel_id=mock_discord_channel.id,
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
content="Third message",
sent_at=datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc),
modality="text",
@@ -123,7 +127,8 @@ def test_get_prev_limits_context_window(
msg = DiscordMessage(
message_id=i,
channel_id=mock_discord_channel.id,
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
content=f"Message {i}",
sent_at=datetime(2024, 1, 1, 10, i, 0, tzinfo=timezone.utc),
modality="text",
@@ -157,14 +162,26 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel):
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
+@patch("memory.workers.tasks.discord.create_provider")
def test_should_process_normal_message(
- db_session, mock_discord_user, mock_discord_server, mock_discord_channel
+ mock_create_provider,
+ db_session,
+ mock_discord_user,
+ mock_discord_server,
+ mock_discord_channel,
):
"""Test should_process returns True for normal messages."""
+ # Mock the LLM provider to return "yes"
+ mock_provider = Mock()
+ mock_provider.generate.return_value = "yes"
+ mock_provider.as_messages.return_value = []
+ mock_create_provider.return_value = mock_provider
+
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
@@ -210,7 +227,8 @@ def test_should_process_server_ignored(
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
server_id=server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
@@ -243,7 +261,8 @@ def test_should_process_channel_ignored(
message = DiscordMessage(
message_id=1,
channel_id=channel.id,
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
@@ -274,7 +293,8 @@ def test_should_process_user_ignored(
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
- discord_user_id=user.id,
+ from_id=user.id,
+ recipient_id=user.id,
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
@@ -350,7 +370,8 @@ def test_add_discord_message_with_context(
prev_msg = DiscordMessage(
message_id=111111,
channel_id=sample_message_data["channel_id"],
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
content="Previous message",
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
modality="text",
@@ -370,11 +391,13 @@ def test_add_discord_message_with_context(
assert result["status"] == "processed"
+@patch("memory.workers.tasks.discord.should_process")
@patch("memory.workers.tasks.discord.process_discord_message")
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_add_discord_message_triggers_processing(
mock_process,
+ mock_should_process,
db_session,
sample_message_data,
mock_discord_server,
@@ -382,6 +405,7 @@ def test_add_discord_message_triggers_processing(
qdrant,
):
"""Test that add_discord_message triggers process_discord_message when conditions are met."""
+ mock_should_process.return_value = True
mock_process.delay = Mock()
sample_message_data["server_id"] = mock_discord_server.id
@@ -454,7 +478,8 @@ def test_edit_discord_message_updates_context(
prev_msg = DiscordMessage(
message_id=111111,
channel_id=sample_message_data["channel_id"],
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
content="Previous message",
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
modality="text",
@@ -576,7 +601,8 @@ def test_get_prev_only_returns_messages_from_same_channel(
msg1 = DiscordMessage(
message_id=1,
channel_id=channel1.id,
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
content="Message in channel 1",
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
modality="text",
@@ -585,7 +611,8 @@ def test_get_prev_only_returns_messages_from_same_channel(
msg2 = DiscordMessage(
message_id=2,
channel_id=channel2.id,
- discord_user_id=mock_discord_user.id,
+ from_id=mock_discord_user.id,
+ recipient_id=mock_discord_user.id,
content="Message in channel 2",
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
modality="text",
diff --git a/tests/memory/workers/tasks/test_ebook_tasks.py b/tests/memory/workers/tasks/test_ebook_tasks.py
index 390114e..9daa860 100644
--- a/tests/memory/workers/tasks/test_ebook_tasks.py
+++ b/tests/memory/workers/tasks/test_ebook_tasks.py
@@ -12,6 +12,7 @@ from memory.workers.tasks import ebook
def mock_ebook():
"""Mock ebook data for testing."""
return Ebook(
+ relative_path=Path("test/book.epub"),
title="Test Book",
author="Test Author",
metadata={"language": "en", "creator": "Test Publisher"},
@@ -207,10 +208,11 @@ def test_sync_book_already_exists(mock_parse, mock_ebook, db_session, tmp_path):
book_file = tmp_path / "test.epub"
book_file.write_text("dummy content")
+ # Use the same relative path that mock_ebook has
existing_book = Book(
title="Existing Book",
author="Author",
- file_path=str(book_file),
+ file_path="test/book.epub", # Must match mock_ebook.relative_path
)
db_session.add(existing_book)
db_session.commit()
@@ -266,18 +268,18 @@ def test_sync_book_qdrant_failure(mock_parse, mock_ebook, db_session, tmp_path):
# Since embedding is already failing, this test will complete without hitting Qdrant
# So let's just verify that the function completes without raising an exception
with patch.object(ebook, "push_to_qdrant", side_effect=Exception("Qdrant failed")):
- assert ebook.sync_book(str(book_file)) == {
- "status": "error",
- "error": "Qdrant failed",
- }
+ result = ebook.sync_book(str(book_file))
+ assert result["status"] == "error"
+ assert result["error"] == "Qdrant failed"
+ assert "traceback" in result
def test_sync_book_file_not_found():
"""Test handling of missing files."""
- assert ebook.sync_book("/nonexistent/file.epub") == {
- "status": "error",
- "error": "Book file not found: /nonexistent/file.epub",
- }
+ result = ebook.sync_book("/nonexistent/file.epub")
+ assert result["status"] == "error"
+ assert result["error"] == "Book file not found: /nonexistent/file.epub"
+ assert "traceback" in result
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
diff --git a/tests/memory/workers/tasks/test_scheduled_calls.py b/tests/memory/workers/tasks/test_scheduled_calls.py
index 8b9be2a..decf6a6 100644
--- a/tests/memory/workers/tasks/test_scheduled_calls.py
+++ b/tests/memory/workers/tasks/test_scheduled_calls.py
@@ -3,18 +3,17 @@ from datetime import datetime, timezone, timedelta
from unittest.mock import Mock, patch
import uuid
-from memory.common.db.models import ScheduledLLMCall, User
+from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer
from memory.workers.tasks import scheduled_calls
@pytest.fixture
def sample_user(db_session):
"""Create a sample user for testing."""
- user = User(
+ user = HumanUser.create_with_password(
name="testuser",
email="test@example.com",
- discord_user_id="123456789",
- password_hash="password",
+ password="password",
)
db_session.add(user)
db_session.commit()
@@ -22,7 +21,45 @@ def sample_user(db_session):
@pytest.fixture
-def pending_scheduled_call(db_session, sample_user):
+def sample_discord_user(db_session):
+ """Create a sample Discord user for testing."""
+ discord_user = DiscordUser(
+ id=123456789,
+ username="testuser",
+ )
+ db_session.add(discord_user)
+ db_session.commit()
+ return discord_user
+
+
+@pytest.fixture
+def sample_discord_server(db_session):
+ """Create a sample Discord server for testing."""
+ server = DiscordServer(
+ id=987654321,
+ name="Test Server",
+ )
+ db_session.add(server)
+ db_session.commit()
+ return server
+
+
+@pytest.fixture
+def sample_discord_channel(db_session, sample_discord_server):
+ """Create a sample Discord channel for testing."""
+ channel = DiscordChannel(
+ id=111222333,
+ name="test-channel",
+ channel_type="text",
+ server_id=sample_discord_server.id,
+ )
+ db_session.add(channel)
+ db_session.commit()
+ return channel
+
+
+@pytest.fixture
+def pending_scheduled_call(db_session, sample_user, sample_discord_user):
"""Create a pending scheduled call for testing."""
call = ScheduledLLMCall(
id=str(uuid.uuid4()),
@@ -32,7 +69,7 @@ def pending_scheduled_call(db_session, sample_user):
model="anthropic/claude-3-5-sonnet-20241022",
message="What is the weather like today?",
system_prompt="You are a helpful assistant.",
- discord_user="123456789",
+ discord_user_id=sample_discord_user.id,
status="pending",
)
db_session.add(call)
@@ -41,7 +78,7 @@ def pending_scheduled_call(db_session, sample_user):
@pytest.fixture
-def completed_scheduled_call(db_session, sample_user):
+def completed_scheduled_call(db_session, sample_user, sample_discord_channel):
"""Create a completed scheduled call for testing."""
call = ScheduledLLMCall(
id=str(uuid.uuid4()),
@@ -52,7 +89,7 @@ def completed_scheduled_call(db_session, sample_user):
model="anthropic/claude-3-5-sonnet-20241022",
message="Tell me a joke.",
system_prompt="You are a funny assistant.",
- discord_channel="987654321",
+ discord_channel_id=sample_discord_channel.id,
status="completed",
response="Why did the chicken cross the road? To get to the other side!",
)
@@ -62,7 +99,7 @@ def completed_scheduled_call(db_session, sample_user):
@pytest.fixture
-def future_scheduled_call(db_session, sample_user):
+def future_scheduled_call(db_session, sample_user, sample_discord_user):
"""Create a future scheduled call for testing."""
call = ScheduledLLMCall(
id=str(uuid.uuid4()),
@@ -71,7 +108,7 @@ def future_scheduled_call(db_session, sample_user):
scheduled_time=datetime.now(timezone.utc) + timedelta(hours=1),
model="anthropic/claude-3-5-sonnet-20241022",
message="What will happen tomorrow?",
- discord_user="123456789",
+ discord_user_id=sample_discord_user.id,
status="pending",
)
db_session.add(call)
@@ -87,7 +124,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
scheduled_calls._send_to_discord(pending_scheduled_call, response)
mock_send_dm.assert_called_once_with(
- "123456789",
+ "testuser", # username, not ID
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
)
@@ -100,7 +137,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
scheduled_calls._send_to_discord(completed_scheduled_call, response)
mock_broadcast.assert_called_once_with(
- "987654321",
+ "test-channel", # channel name, not ID
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
)
@@ -133,7 +170,7 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.call")
+@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
def test_execute_scheduled_call_success(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -171,7 +208,7 @@ def test_execute_scheduled_call_not_found(db_session):
assert result == {"error": "Scheduled call not found"}
-@patch("memory.workers.tasks.scheduled_calls.llms.call")
+@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
def test_execute_scheduled_call_not_pending(
mock_llm_call, completed_scheduled_call, db_session
):
@@ -183,9 +220,9 @@ def test_execute_scheduled_call_not_pending(
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.call")
+@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
def test_execute_scheduled_call_with_default_system_prompt(
- mock_llm_call, mock_send_discord, db_session, sample_user
+ mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
):
"""Test execution when system_prompt is None, should use default."""
# Create call without system prompt
@@ -197,7 +234,7 @@ def test_execute_scheduled_call_with_default_system_prompt(
model="anthropic/claude-3-5-sonnet-20241022",
message="Test prompt",
system_prompt=None,
- discord_user="123456789",
+ discord_user_id=sample_discord_user.id,
status="pending",
)
db_session.add(call)
@@ -211,12 +248,12 @@ def test_execute_scheduled_call_with_default_system_prompt(
mock_llm_call.assert_called_once_with(
prompt="Test prompt",
model="anthropic/claude-3-5-sonnet-20241022",
- system_prompt=scheduled_calls.llms.SYSTEM_PROMPT,
+ system_prompt=None, # The code uses system_prompt as-is, not a default
)
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.call")
+@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
def test_execute_scheduled_call_discord_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -240,7 +277,7 @@ def test_execute_scheduled_call_discord_error(
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.call")
+@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
def test_execute_scheduled_call_llm_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -258,7 +295,7 @@ def test_execute_scheduled_call_llm_error(
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.call")
+@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
def test_execute_scheduled_call_long_response_truncation(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -279,7 +316,7 @@ def test_execute_scheduled_call_long_response_truncation(
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
def test_run_scheduled_calls_with_due_calls(
- mock_execute_delay, db_session, sample_user
+ mock_execute_delay, db_session, sample_user, sample_discord_user
):
"""Test running scheduled calls with due calls."""
# Create multiple due calls
@@ -289,7 +326,7 @@ def test_run_scheduled_calls_with_due_calls(
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=10),
model="test-model",
message="Test 1",
- discord_user="123",
+ discord_user_id=sample_discord_user.id,
status="pending",
)
due_call2 = ScheduledLLMCall(
@@ -298,7 +335,7 @@ def test_run_scheduled_calls_with_due_calls(
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
message="Test 2",
- discord_user="123",
+ discord_user_id=sample_discord_user.id,
status="pending",
)
@@ -337,7 +374,7 @@ def test_run_scheduled_calls_no_due_calls(
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
def test_run_scheduled_calls_mixed_statuses(
- mock_execute_delay, db_session, sample_user
+ mock_execute_delay, db_session, sample_user, sample_discord_user
):
"""Test that only pending calls are processed."""
# Create calls with different statuses
@@ -347,7 +384,7 @@ def test_run_scheduled_calls_mixed_statuses(
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
message="Pending",
- discord_user="123",
+ discord_user_id=sample_discord_user.id,
status="pending",
)
executing_call = ScheduledLLMCall(
@@ -356,7 +393,7 @@ def test_run_scheduled_calls_mixed_statuses(
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
message="Executing",
- discord_user="123",
+ discord_user_id=sample_discord_user.id,
status="executing",
)
completed_call = ScheduledLLMCall(
@@ -365,7 +402,7 @@ def test_run_scheduled_calls_mixed_statuses(
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
message="Completed",
- discord_user="123",
+ discord_user_id=sample_discord_user.id,
status="completed",
)
@@ -387,7 +424,7 @@ def test_run_scheduled_calls_mixed_statuses(
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
def test_run_scheduled_calls_timezone_handling(
- mock_execute_delay, db_session, sample_user
+ mock_execute_delay, db_session, sample_user, sample_discord_user
):
"""Test that timezone handling works correctly."""
# Create a call that's due (scheduled time in the past)
@@ -398,7 +435,7 @@ def test_run_scheduled_calls_timezone_handling(
scheduled_time=past_time.replace(tzinfo=None), # Store as naive datetime
model="test-model",
message="Due call",
- discord_user="123",
+ discord_user_id=sample_discord_user.id,
status="pending",
)
@@ -410,7 +447,7 @@ def test_run_scheduled_calls_timezone_handling(
scheduled_time=future_time.replace(tzinfo=None), # Store as naive datetime
model="test-model",
message="Future call",
- discord_user="123",
+ discord_user_id=sample_discord_user.id,
status="pending",
)
@@ -431,7 +468,7 @@ def test_run_scheduled_calls_timezone_handling(
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.call")
+@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
def test_status_transition_pending_to_executing_to_completed(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -452,11 +489,11 @@ def test_status_transition_pending_to_executing_to_completed(
@pytest.mark.parametrize(
- "discord_user,discord_channel,expected_method",
+ "has_discord_user,has_discord_channel,expected_method",
[
- ("123456789", None, "send_dm"),
- (None, "987654321", "broadcast_message"),
- ("123456789", "987654321", "send_dm"), # User takes precedence
+ (True, False, "send_dm"),
+ (False, True, "broadcast_message"),
+ (True, True, "send_dm"), # User takes precedence
],
)
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
@@ -464,11 +501,13 @@ def test_status_transition_pending_to_executing_to_completed(
def test_discord_destination_priority(
mock_broadcast,
mock_send_dm,
- discord_user,
- discord_channel,
+ has_discord_user,
+ has_discord_channel,
expected_method,
db_session,
sample_user,
+ sample_discord_user,
+ sample_discord_channel,
):
"""Test that Discord user takes precedence over channel."""
call = ScheduledLLMCall(
@@ -478,8 +517,8 @@ def test_discord_destination_priority(
scheduled_time=datetime.now(timezone.utc),
model="test-model",
message="Test",
- discord_user=discord_user,
- discord_channel=discord_channel,
+ discord_user_id=sample_discord_user.id if has_discord_user else None,
+ discord_channel_id=sample_discord_channel.id if has_discord_channel else None,
status="pending",
)
db_session.add(call)
@@ -530,11 +569,15 @@ def test_discord_destination_priority(
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
"""Test the Discord message formatting with different inputs."""
- # Create a mock scheduled call
+ # Create a mock scheduled call with a mock Discord user
+ mock_discord_user = Mock()
+ mock_discord_user.username = "testuser"
+
mock_call = Mock()
mock_call.topic = topic
mock_call.model = model
- mock_call.discord_user = "123456789"
+ mock_call.discord_user = mock_discord_user
+ mock_call.discord_channel = None
scheduled_calls._send_to_discord(mock_call, response)
@@ -557,9 +600,9 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
("cancelled", False),
],
)
-@patch("memory.workers.tasks.scheduled_calls.llms.call")
+@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
def test_execute_scheduled_call_status_check(
- mock_llm_call, status, should_execute, db_session, sample_user
+ mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
):
"""Test that only pending calls are executed."""
call = ScheduledLLMCall(
@@ -569,7 +612,7 @@ def test_execute_scheduled_call_status_check(
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
message="Test",
- discord_user="123",
+ discord_user_id=sample_discord_user.id,
status=status,
)
db_session.add(call)