mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-22 22:56:38 +02:00
add tetsts
This commit is contained in:
parent
1606348d8b
commit
1a3cf9c931
@ -16,7 +16,7 @@ autorestart=true
|
||||
startsecs=10
|
||||
|
||||
[program:discord-api]
|
||||
command=uvicorn memory.discord.api:app --host 0.0.0.0 --port %(ENV_DISCORD_API_PORT)s
|
||||
command=uvicorn memory.discord.api:app --host 0.0.0.0 --port %(ENV_DISCORD_COLLECTOR_PORT)s
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
|
@ -98,7 +98,6 @@ app.conf.update(
|
||||
@app.on_after_configure.connect # type: ignore[attr-defined]
|
||||
def ensure_qdrant_initialised(sender, **_):
|
||||
from memory.common import qdrant
|
||||
from memory.common.discord import load_servers
|
||||
|
||||
qdrant.setup_qdrant()
|
||||
load_servers()
|
||||
# Note: load_servers() was removed as it's no longer needed
|
||||
|
@ -21,7 +21,7 @@ class ScheduledLLMCall(Base):
|
||||
__tablename__ = "scheduled_llm_calls"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
user_id = Column(BigInteger, ForeignKey("users.id"), nullable=False)
|
||||
topic = Column(Text, nullable=True)
|
||||
|
||||
# Scheduling info
|
||||
|
1
src/memory/discord/__init__.py
Normal file
1
src/memory/discord/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Discord integration for memory system."""
|
@ -27,11 +27,7 @@ def resolve_discord_user(
|
||||
if isinstance(entity, int):
|
||||
return session.get(DiscordUser, entity)
|
||||
|
||||
entity = session.query(DiscordUser).filter(DiscordUser.username == entity).first()
|
||||
if not entity:
|
||||
entity = DiscordUser(id=entity, username=entity)
|
||||
session.add(entity)
|
||||
return entity
|
||||
return session.query(DiscordUser).filter(DiscordUser.username == entity).first()
|
||||
|
||||
|
||||
def resolve_discord_channel(
|
||||
|
@ -264,10 +264,10 @@ def sync_website_archive(
|
||||
if existing:
|
||||
continue
|
||||
|
||||
task_ids.append(sync_webpage.delay(feed_item.url, list(tags)).id)
|
||||
new_articles += 1
|
||||
task_ids.append(sync_webpage.delay(feed_item.url, list(tags)).id)
|
||||
new_articles += 1
|
||||
|
||||
logger.info(f"Scheduled sync for: {feed_item.title} ({feed_item.url})")
|
||||
logger.info(f"Scheduled sync for: {feed_item.title} ({feed_item.url})")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
|
@ -12,6 +12,7 @@ import traceback
|
||||
import logging
|
||||
from typing import Any, Callable, Sequence, cast
|
||||
|
||||
from sqlalchemy import or_
|
||||
from memory.common import embedding, qdrant
|
||||
from memory.common.db.models import SourceItem, Chunk
|
||||
from memory.common.discord import notify_task_failure
|
||||
@ -29,6 +30,7 @@ def check_content_exists(
|
||||
|
||||
Searches for existing content by any of the provided attributes
|
||||
(typically URL, file_path, or SHA256 hash).
|
||||
Uses OR logic - returns content if ANY attribute matches.
|
||||
|
||||
Args:
|
||||
session: Database session for querying
|
||||
@ -38,11 +40,21 @@ def check_content_exists(
|
||||
Returns:
|
||||
Existing SourceItem if found, None otherwise
|
||||
"""
|
||||
query = session.query(model_class)
|
||||
# Return None if no search criteria provided
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
filters = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(model_class, key):
|
||||
query = query.filter(getattr(model_class, key) == value)
|
||||
filters.append(getattr(model_class, key) == value)
|
||||
|
||||
# Return None if none of the provided attributes exist on the model
|
||||
if not filters:
|
||||
return None
|
||||
|
||||
# Use OR logic to find content matching any of the provided attributes
|
||||
query = session.query(model_class).filter(or_(*filters))
|
||||
return query.first()
|
||||
|
||||
|
||||
|
254
tests/memory/common/db/models/test_discord_models.py
Normal file
254
tests/memory/common/db/models/test_discord_models.py
Normal file
@ -0,0 +1,254 @@
|
||||
"""Tests for Discord database models."""
|
||||
|
||||
import pytest
|
||||
from memory.common.db.models import DiscordServer, DiscordChannel, DiscordUser
|
||||
|
||||
|
||||
def test_create_discord_server(db_session):
|
||||
"""Test creating a Discord server."""
|
||||
server = DiscordServer(
|
||||
id=123456789,
|
||||
name="Test Server",
|
||||
description="A test Discord server",
|
||||
member_count=100,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
assert server.id == 123456789
|
||||
assert server.name == "Test Server"
|
||||
assert server.description == "A test Discord server"
|
||||
assert server.member_count == 100
|
||||
assert server.track_messages is True # default value
|
||||
assert server.ignore_messages is False
|
||||
|
||||
|
||||
def test_discord_server_as_xml(db_session):
|
||||
"""Test DiscordServer.as_xml() method."""
|
||||
server = DiscordServer(
|
||||
id=123456789,
|
||||
name="Test Server",
|
||||
summary="This is a test server for gaming",
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
xml = server.as_xml()
|
||||
assert "<servers>" in xml # tablename is discord_servers, strips to "servers"
|
||||
assert "<name>Test Server</name>" in xml
|
||||
assert "<summary>This is a test server for gaming</summary>" in xml
|
||||
assert "</servers>" in xml
|
||||
|
||||
|
||||
def test_discord_server_message_tracking(db_session):
|
||||
"""Test Discord server message tracking flags."""
|
||||
server = DiscordServer(
|
||||
id=123456789,
|
||||
name="Test Server",
|
||||
track_messages=False,
|
||||
ignore_messages=True,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
assert server.track_messages is False
|
||||
assert server.ignore_messages is True
|
||||
|
||||
|
||||
def test_discord_server_allowed_tools(db_session):
|
||||
"""Test Discord server allowed/disallowed tools."""
|
||||
server = DiscordServer(
|
||||
id=123456789,
|
||||
name="Test Server",
|
||||
allowed_tools=["search", "schedule"],
|
||||
disallowed_tools=["delete", "ban"],
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
assert "search" in server.allowed_tools
|
||||
assert "schedule" in server.allowed_tools
|
||||
assert "delete" in server.disallowed_tools
|
||||
assert "ban" in server.disallowed_tools
|
||||
|
||||
|
||||
def test_create_discord_channel(db_session):
|
||||
"""Test creating a Discord channel."""
|
||||
server = DiscordServer(id=987654321, name="Parent Server")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
server_id=server.id,
|
||||
name="general",
|
||||
channel_type="text",
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
assert channel.id == 111222333
|
||||
assert channel.server_id == server.id
|
||||
assert channel.name == "general"
|
||||
assert channel.channel_type == "text"
|
||||
assert channel.server.name == "Parent Server"
|
||||
|
||||
|
||||
def test_discord_channel_without_server(db_session):
|
||||
"""Test creating a Discord DM channel without a server."""
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
name="dm-channel",
|
||||
channel_type="dm",
|
||||
server_id=None,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
assert channel.id == 111222333
|
||||
assert channel.server_id is None
|
||||
assert channel.channel_type == "dm"
|
||||
|
||||
|
||||
def test_discord_channel_as_xml(db_session):
|
||||
"""Test DiscordChannel.as_xml() method."""
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
name="general",
|
||||
channel_type="text",
|
||||
summary="Main discussion channel",
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
xml = channel.as_xml()
|
||||
assert "<channels>" in xml # tablename is discord_channels, strips to "channels"
|
||||
assert "<name>general</name>" in xml
|
||||
assert "<summary>Main discussion channel</summary>" in xml
|
||||
assert "</channels>" in xml
|
||||
|
||||
|
||||
def test_discord_channel_inherits_server_settings(db_session):
|
||||
"""Test that channels can have their own or inherit server settings."""
|
||||
server = DiscordServer(
|
||||
id=987654321, name="Server", track_messages=True, ignore_messages=False
|
||||
)
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
server_id=server.id,
|
||||
name="announcements",
|
||||
channel_type="text",
|
||||
track_messages=False, # Override server setting
|
||||
)
|
||||
db_session.add_all([server, channel])
|
||||
db_session.commit()
|
||||
|
||||
assert server.track_messages is True
|
||||
assert channel.track_messages is False
|
||||
|
||||
|
||||
def test_create_discord_user(db_session):
|
||||
"""Test creating a Discord user."""
|
||||
user = DiscordUser(
|
||||
id=555666777,
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
assert user.id == 555666777
|
||||
assert user.username == "testuser"
|
||||
assert user.display_name == "Test User"
|
||||
assert user.system_user_id is None
|
||||
|
||||
|
||||
def test_discord_user_with_system_user(db_session):
|
||||
"""Test Discord user linked to a system user."""
|
||||
from memory.common.db.models import HumanUser
|
||||
|
||||
system_user = HumanUser.create_with_password(
|
||||
email="user@example.com", name="System User", password="password123"
|
||||
)
|
||||
db_session.add(system_user)
|
||||
db_session.commit()
|
||||
|
||||
discord_user = DiscordUser(
|
||||
id=555666777,
|
||||
username="testuser",
|
||||
system_user_id=system_user.id,
|
||||
)
|
||||
db_session.add(discord_user)
|
||||
db_session.commit()
|
||||
|
||||
assert discord_user.system_user_id == system_user.id
|
||||
assert discord_user.system_user.email == "user@example.com"
|
||||
|
||||
|
||||
def test_discord_user_as_xml(db_session):
|
||||
"""Test DiscordUser.as_xml() method."""
|
||||
user = DiscordUser(
|
||||
id=555666777,
|
||||
username="testuser",
|
||||
summary="Friendly and helpful community member",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
xml = user.as_xml()
|
||||
assert "<users>" in xml # tablename is discord_users, strips to "users"
|
||||
assert "<name>testuser</name>" in xml
|
||||
assert "<summary>Friendly and helpful community member</summary>" in xml
|
||||
assert "</users>" in xml
|
||||
|
||||
|
||||
def test_discord_user_message_preferences(db_session):
|
||||
"""Test Discord user message tracking preferences."""
|
||||
user = DiscordUser(
|
||||
id=555666777,
|
||||
username="testuser",
|
||||
track_messages=True,
|
||||
ignore_messages=False,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
assert user.track_messages is True
|
||||
assert user.ignore_messages is False
|
||||
|
||||
|
||||
def test_discord_server_channel_relationship(db_session):
|
||||
"""Test the relationship between servers and channels."""
|
||||
server = DiscordServer(id=987654321, name="Test Server")
|
||||
channel1 = DiscordChannel(
|
||||
id=111222333, server_id=server.id, name="general", channel_type="text"
|
||||
)
|
||||
channel2 = DiscordChannel(
|
||||
id=111222334, server_id=server.id, name="off-topic", channel_type="text"
|
||||
)
|
||||
db_session.add_all([server, channel1, channel2])
|
||||
db_session.commit()
|
||||
|
||||
assert len(server.channels) == 2
|
||||
assert channel1 in server.channels
|
||||
assert channel2 in server.channels
|
||||
|
||||
|
||||
def test_discord_server_cascade_delete(db_session):
|
||||
"""Test that deleting a server cascades to channels."""
|
||||
server = DiscordServer(id=987654321, name="Test Server")
|
||||
channel = DiscordChannel(
|
||||
id=111222333, server_id=server.id, name="general", channel_type="text"
|
||||
)
|
||||
db_session.add_all([server, channel])
|
||||
db_session.commit()
|
||||
|
||||
channel_id = channel.id
|
||||
|
||||
# Delete server
|
||||
db_session.delete(server)
|
||||
db_session.commit()
|
||||
|
||||
# Channel should be deleted too
|
||||
deleted_channel = db_session.get(DiscordChannel, channel_id)
|
||||
assert deleted_channel is None
|
@ -1,5 +1,12 @@
|
||||
import pytest
|
||||
from memory.common.db.models.users import hash_password, verify_password
|
||||
from memory.common.db.models.users import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
User,
|
||||
HumanUser,
|
||||
BotUser,
|
||||
DiscordBotUser,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -102,3 +109,144 @@ def test_hash_verify_roundtrip(test_password):
|
||||
|
||||
# Wrong password should not verify
|
||||
assert not verify_password(test_password + "_wrong", password_hash)
|
||||
|
||||
|
||||
# Test User Model Hierarchy
|
||||
|
||||
|
||||
def test_create_human_user(db_session):
|
||||
"""Test creating a HumanUser with password"""
|
||||
user = HumanUser.create_with_password(
|
||||
email="human@example.com", name="Human User", password="test_password123"
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
assert user.id is not None
|
||||
assert user.email == "human@example.com"
|
||||
assert user.name == "Human User"
|
||||
assert user.user_type == "human"
|
||||
assert user.password_hash is not None
|
||||
assert user.api_key is None
|
||||
assert user.is_valid_password("test_password123")
|
||||
assert not user.is_valid_password("wrong_password")
|
||||
|
||||
|
||||
def test_create_bot_user(db_session):
|
||||
"""Test creating a BotUser with API key"""
|
||||
user = BotUser.create_with_api_key(
|
||||
name="Test Bot", email="bot@example.com", api_key="test_api_key_123"
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
assert user.id is not None
|
||||
assert user.email == "bot@example.com"
|
||||
assert user.name == "Test Bot"
|
||||
assert user.user_type == "bot"
|
||||
assert user.api_key == "test_api_key_123"
|
||||
assert user.password_hash is None
|
||||
|
||||
|
||||
def test_create_bot_user_auto_api_key(db_session):
|
||||
"""Test creating a BotUser with auto-generated API key"""
|
||||
user = BotUser.create_with_api_key(name="Auto Bot", email="autobot@example.com")
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
assert user.id is not None
|
||||
assert user.api_key is not None
|
||||
assert user.api_key.startswith("bot_")
|
||||
assert len(user.api_key) == 68 # "bot_" + 32 bytes hex encoded (64 chars)
|
||||
|
||||
|
||||
def test_create_discord_bot_user(db_session):
|
||||
"""Test creating a DiscordBotUser"""
|
||||
user = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
name="Discord Bot",
|
||||
email="discordbot@example.com",
|
||||
api_key="discord_key_123",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
assert user.id is not None
|
||||
assert user.email == "discordbot@example.com"
|
||||
assert user.name == "Discord Bot"
|
||||
assert user.user_type == "discord_bot"
|
||||
assert user.api_key == "discord_key_123"
|
||||
|
||||
|
||||
def test_user_serialization_human(db_session):
|
||||
"""Test HumanUser serialization"""
|
||||
user = HumanUser.create_with_password(
|
||||
email="serialize@example.com", name="Serialize User", password="password123"
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
serialized = user.serialize()
|
||||
assert serialized["user_id"] == user.id
|
||||
assert serialized["name"] == "Serialize User"
|
||||
assert serialized["email"] == "serialize@example.com"
|
||||
assert serialized["user_type"] == "human"
|
||||
assert "password_hash" not in serialized # Should not expose password hash
|
||||
|
||||
|
||||
def test_user_serialization_bot(db_session):
|
||||
"""Test BotUser serialization"""
|
||||
user = BotUser.create_with_api_key(name="Bot", email="bot@example.com")
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
serialized = user.serialize()
|
||||
assert serialized["user_id"] == user.id
|
||||
assert serialized["name"] == "Bot"
|
||||
assert serialized["email"] == "bot@example.com"
|
||||
assert serialized["user_type"] == "bot"
|
||||
assert "api_key" not in serialized # Should not expose API key
|
||||
|
||||
|
||||
def test_bot_user_api_key_uniqueness(db_session):
|
||||
"""Test that API keys must be unique"""
|
||||
user1 = BotUser.create_with_api_key(
|
||||
name="Bot 1", email="bot1@example.com", api_key="same_key"
|
||||
)
|
||||
user2 = BotUser.create_with_api_key(
|
||||
name="Bot 2", email="bot2@example.com", api_key="same_key"
|
||||
)
|
||||
db_session.add(user1)
|
||||
db_session.commit()
|
||||
|
||||
db_session.add(user2)
|
||||
with pytest.raises(Exception): # IntegrityError from unique constraint
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def test_human_user_factory_method(db_session):
|
||||
"""Test that HumanUser factory method sets all required fields"""
|
||||
user = HumanUser.create_with_password(
|
||||
email="factory@example.com", name="Factory User", password="test123"
|
||||
)
|
||||
|
||||
# Factory method should set all required fields
|
||||
assert user.email == "factory@example.com"
|
||||
assert user.name == "Factory User"
|
||||
assert user.password_hash is not None
|
||||
assert user.user_type == "human"
|
||||
assert user.api_key is None
|
||||
|
||||
|
||||
def test_bot_user_factory_method(db_session):
|
||||
"""Test that BotUser factory method sets all required fields"""
|
||||
user = BotUser.create_with_api_key(
|
||||
name="Factory Bot", email="factorybot@example.com", api_key="test_key"
|
||||
)
|
||||
|
||||
# Factory method should set all required fields
|
||||
assert user.email == "factorybot@example.com"
|
||||
assert user.name == "Factory Bot"
|
||||
assert user.api_key == "test_key"
|
||||
assert user.user_type == "bot"
|
||||
assert user.password_hash is None
|
||||
|
584
tests/memory/common/llms/tools/test_discord_tools.py
Normal file
584
tests/memory/common/llms/tools/test_discord_tools.py
Normal file
@ -0,0 +1,584 @@
|
||||
"""Tests for Discord LLM tools."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from memory.common.llms.tools.discord import (
|
||||
handle_update_summary_call,
|
||||
make_summary_tool,
|
||||
schedule_message,
|
||||
make_message_scheduler,
|
||||
make_prev_messages_tool,
|
||||
make_discord_tools,
|
||||
)
|
||||
from memory.common.db.models import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
DiscordMessage,
|
||||
BotUser,
|
||||
HumanUser,
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures for Discord entities
|
||||
@pytest.fixture
|
||||
def sample_discord_server(db_session):
|
||||
"""Create a sample Discord server for testing."""
|
||||
server = DiscordServer(
|
||||
id=123456789,
|
||||
name="Test Server",
|
||||
summary="A test server for testing",
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discord_channel(db_session, sample_discord_server):
|
||||
"""Create a sample Discord channel for testing."""
|
||||
channel = DiscordChannel(
|
||||
id=987654321,
|
||||
server_id=sample_discord_server.id,
|
||||
name="general",
|
||||
channel_type="text",
|
||||
summary="General discussion channel",
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discord_user(db_session):
|
||||
"""Create a sample Discord user for testing."""
|
||||
user = DiscordUser(
|
||||
id=111222333,
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
summary="A test user",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bot_user(db_session):
|
||||
"""Create a sample bot user for testing."""
|
||||
bot = BotUser.create_with_api_key(
|
||||
name="Test Bot",
|
||||
email="testbot@example.com",
|
||||
)
|
||||
db_session.add(bot)
|
||||
db_session.commit()
|
||||
return bot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_human_user(db_session):
|
||||
"""Create a sample human user for testing."""
|
||||
user = HumanUser.create_with_password(
|
||||
email="human@example.com",
|
||||
name="Human User",
|
||||
password="test_password123",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
# Tests for handle_update_summary_call
|
||||
def test_handle_update_summary_call_server_dict_input(
|
||||
db_session, sample_discord_server
|
||||
):
|
||||
"""Test updating server summary with dict input."""
|
||||
handler = handle_update_summary_call("server", sample_discord_server.id)
|
||||
|
||||
result = handler({"summary": "New server summary"})
|
||||
|
||||
assert result == "Updated summary"
|
||||
|
||||
# Verify the summary was updated in the database
|
||||
db_session.refresh(sample_discord_server)
|
||||
assert sample_discord_server.summary == "New server summary"
|
||||
|
||||
|
||||
def test_handle_update_summary_call_channel_dict_input(
|
||||
db_session, sample_discord_channel
|
||||
):
|
||||
"""Test updating channel summary with dict input."""
|
||||
handler = handle_update_summary_call("channel", sample_discord_channel.id)
|
||||
|
||||
result = handler({"summary": "New channel summary"})
|
||||
|
||||
assert result == "Updated summary"
|
||||
|
||||
db_session.refresh(sample_discord_channel)
|
||||
assert sample_discord_channel.summary == "New channel summary"
|
||||
|
||||
|
||||
def test_handle_update_summary_call_user_dict_input(db_session, sample_discord_user):
|
||||
"""Test updating user summary with dict input."""
|
||||
handler = handle_update_summary_call("user", sample_discord_user.id)
|
||||
|
||||
result = handler({"summary": "New user summary"})
|
||||
|
||||
assert result == "Updated summary"
|
||||
|
||||
db_session.refresh(sample_discord_user)
|
||||
assert sample_discord_user.summary == "New user summary"
|
||||
|
||||
|
||||
def test_handle_update_summary_call_string_input(db_session, sample_discord_server):
|
||||
"""Test updating summary with string input."""
|
||||
handler = handle_update_summary_call("server", sample_discord_server.id)
|
||||
|
||||
result = handler("String summary")
|
||||
|
||||
assert result == "Updated summary"
|
||||
|
||||
db_session.refresh(sample_discord_server)
|
||||
assert sample_discord_server.summary == "String summary"
|
||||
|
||||
|
||||
def test_handle_update_summary_call_dict_without_summary_key(
|
||||
db_session, sample_discord_server
|
||||
):
|
||||
"""Test updating summary with dict that doesn't have 'summary' key."""
|
||||
handler = handle_update_summary_call("server", sample_discord_server.id)
|
||||
|
||||
result = handler({"other_key": "value"})
|
||||
|
||||
assert result == "Updated summary"
|
||||
|
||||
db_session.refresh(sample_discord_server)
|
||||
# Should use string representation of the dict
|
||||
assert "other_key" in sample_discord_server.summary
|
||||
|
||||
|
||||
def test_handle_update_summary_call_nonexistent_entity(db_session):
|
||||
"""Test updating summary for nonexistent entity."""
|
||||
handler = handle_update_summary_call("server", 999999999)
|
||||
|
||||
result = handler({"summary": "New summary"})
|
||||
|
||||
assert "Error updating summary" in result
|
||||
|
||||
|
||||
# Tests for make_summary_tool
|
||||
def test_make_summary_tool_server(sample_discord_server):
|
||||
"""Test creating a summary tool for a server."""
|
||||
tool = make_summary_tool("server", sample_discord_server.id)
|
||||
|
||||
assert tool.name == "update_server_summary"
|
||||
assert "server" in tool.description
|
||||
assert tool.input_schema["type"] == "object"
|
||||
assert "summary" in tool.input_schema["properties"]
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_summary_tool_channel(sample_discord_channel):
|
||||
"""Test creating a summary tool for a channel."""
|
||||
tool = make_summary_tool("channel", sample_discord_channel.id)
|
||||
|
||||
assert tool.name == "update_channel_summary"
|
||||
assert "channel" in tool.description
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_summary_tool_user(sample_discord_user):
|
||||
"""Test creating a summary tool for a user."""
|
||||
tool = make_summary_tool("user", sample_discord_user.id)
|
||||
|
||||
assert tool.name == "update_user_summary"
|
||||
assert "user" in tool.description
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
# Tests for schedule_message
|
||||
def test_schedule_message_with_user(
|
||||
db_session,
|
||||
sample_human_user,
|
||||
sample_discord_user,
|
||||
):
|
||||
"""Test scheduling a message to a Discord user."""
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = schedule_message(
|
||||
user_id=sample_human_user.id,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
model="test-model",
|
||||
message="Test message",
|
||||
date_time=future_time,
|
||||
)
|
||||
|
||||
# Result should be the ID of the created scheduled call (UUID string)
|
||||
assert isinstance(result, str)
|
||||
|
||||
# Verify the scheduled call was created in the database
|
||||
# Need to use a fresh query since schedule_message uses its own session
|
||||
scheduled_call = db_session.query(ScheduledLLMCall).filter_by(id=result).first()
|
||||
assert scheduled_call is not None
|
||||
assert scheduled_call.user_id == sample_human_user.id
|
||||
assert scheduled_call.discord_user_id == sample_discord_user.id
|
||||
assert scheduled_call.discord_channel_id is None
|
||||
assert scheduled_call.message == "Test message"
|
||||
assert scheduled_call.model == "test-model"
|
||||
|
||||
|
||||
def test_schedule_message_with_channel(
|
||||
db_session,
|
||||
sample_human_user,
|
||||
sample_discord_channel,
|
||||
):
|
||||
"""Test scheduling a message to a Discord channel."""
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = schedule_message(
|
||||
user_id=sample_human_user.id,
|
||||
user=None,
|
||||
channel=sample_discord_channel.id,
|
||||
model="test-model",
|
||||
message="Test message",
|
||||
date_time=future_time,
|
||||
)
|
||||
|
||||
# Result should be the ID of the created scheduled call (UUID string)
|
||||
assert isinstance(result, str)
|
||||
|
||||
# Verify the scheduled call was created in the database
|
||||
scheduled_call = db_session.query(ScheduledLLMCall).filter_by(id=result).first()
|
||||
assert scheduled_call is not None
|
||||
assert scheduled_call.user_id == sample_human_user.id
|
||||
assert scheduled_call.discord_user_id is None
|
||||
assert scheduled_call.discord_channel_id == sample_discord_channel.id
|
||||
assert scheduled_call.message == "Test message"
|
||||
|
||||
|
||||
# Tests for make_message_scheduler
|
||||
def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user):
|
||||
"""Test creating a message scheduler tool for a user."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert tool.name == "schedule_message"
|
||||
assert "from your chat with this user" in tool.description
|
||||
assert tool.input_schema["type"] == "object"
|
||||
assert "message" in tool.input_schema["properties"]
|
||||
assert "date_time" in tool.input_schema["properties"]
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_channel):
|
||||
"""Test creating a message scheduler tool for a channel."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=None,
|
||||
channel=sample_discord_channel.id,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert tool.name == "schedule_message"
|
||||
assert "in this channel" in tool.description
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_message_scheduler_without_user_or_channel(sample_bot_user):
|
||||
"""Test that creating a scheduler without user or channel raises error."""
|
||||
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
||||
make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=None,
|
||||
channel=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.llms.tools.discord.schedule_message")
|
||||
def test_message_scheduler_handler_success(
|
||||
mock_schedule_message, sample_bot_user, sample_discord_user
|
||||
):
|
||||
"""Test message scheduler handler with valid input."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
mock_schedule_message.return_value = "123"
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = tool.function(
|
||||
{"message": "Test message", "date_time": future_time.isoformat()}
|
||||
)
|
||||
|
||||
assert result == "123"
|
||||
mock_schedule_message.assert_called_once()
|
||||
|
||||
|
||||
def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord_user):
|
||||
"""Test message scheduler handler with non-dict input."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Input must be a dictionary"):
|
||||
tool.function("not a dict")
|
||||
|
||||
|
||||
def test_message_scheduler_handler_invalid_datetime(
|
||||
sample_bot_user, sample_discord_user
|
||||
):
|
||||
"""Test message scheduler handler with invalid datetime."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid date time format"):
|
||||
tool.function(
|
||||
{
|
||||
"message": "Test message",
|
||||
"date_time": "not a valid datetime",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_message_scheduler_handler_missing_datetime(
|
||||
sample_bot_user, sample_discord_user
|
||||
):
|
||||
"""Test message scheduler handler with missing datetime."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Date time is required"):
|
||||
tool.function({"message": "Test message"})
|
||||
|
||||
|
||||
# Tests for make_prev_messages_tool
|
||||
def test_make_prev_messages_tool_with_user(sample_discord_user):
|
||||
"""Test creating a previous messages tool for a user."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
|
||||
assert tool.name == "previous_messages"
|
||||
assert "from your chat with this user" in tool.description
|
||||
assert tool.input_schema["type"] == "object"
|
||||
assert "max_messages" in tool.input_schema["properties"]
|
||||
assert "offset" in tool.input_schema["properties"]
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_prev_messages_tool_with_channel(sample_discord_channel):
|
||||
"""Test creating a previous messages tool for a channel."""
|
||||
tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id)
|
||||
|
||||
assert tool.name == "previous_messages"
|
||||
assert "in this channel" in tool.description
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_prev_messages_tool_without_user_or_channel():
|
||||
"""Test that creating a tool without user or channel raises error."""
|
||||
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
||||
make_prev_messages_tool(user=None, channel=None)
|
||||
|
||||
|
||||
def test_prev_messages_handler_success(
|
||||
db_session, sample_discord_user, sample_discord_channel
|
||||
):
|
||||
"""Test previous messages handler with valid input."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
|
||||
# Create some actual messages in the database
|
||||
msg1 = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=sample_discord_channel.id,
|
||||
from_id=sample_discord_user.id,
|
||||
recipient_id=sample_discord_user.id,
|
||||
content="Message 1",
|
||||
sent_at=datetime.now(timezone.utc) - timedelta(minutes=10),
|
||||
modality="text",
|
||||
sha256=b"hash1" + bytes(26),
|
||||
)
|
||||
msg2 = DiscordMessage(
|
||||
message_id=2,
|
||||
channel_id=sample_discord_channel.id,
|
||||
from_id=sample_discord_user.id,
|
||||
recipient_id=sample_discord_user.id,
|
||||
content="Message 2",
|
||||
sent_at=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
modality="text",
|
||||
sha256=b"hash2" + bytes(26),
|
||||
)
|
||||
db_session.add_all([msg1, msg2])
|
||||
db_session.commit()
|
||||
|
||||
result = tool.function({"max_messages": 10, "offset": 0})
|
||||
|
||||
# Should return messages formatted as strings
|
||||
assert isinstance(result, str)
|
||||
# Both messages should be in the result
|
||||
assert "Message 1" in result or "Message 2" in result
|
||||
|
||||
|
||||
def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
|
||||
"""Test previous messages handler with default values."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
|
||||
result = tool.function({})
|
||||
|
||||
# Should return empty string when no messages
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_prev_messages_handler_invalid_input(sample_discord_user):
|
||||
"""Test previous messages handler with non-dict input."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Input must be a dictionary"):
|
||||
tool.function("not a dict")
|
||||
|
||||
|
||||
def test_prev_messages_handler_invalid_max_messages(sample_discord_user):
|
||||
"""Test previous messages handler with invalid max_messages (negative value)."""
|
||||
# Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
|
||||
# so we test with -1 which actually triggers the validation
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Max messages must be greater than 0"):
|
||||
tool.function({"max_messages": -1})
|
||||
|
||||
|
||||
def test_prev_messages_handler_invalid_offset(sample_discord_user):
|
||||
"""Test previous messages handler with invalid offset."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
|
||||
tool.function({"offset": -1})
|
||||
|
||||
|
||||
def test_prev_messages_handler_non_integer_values(sample_discord_user):
|
||||
"""Test previous messages handler with non-integer values."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Max messages and offset must be integers"):
|
||||
tool.function({"max_messages": "not an int"})
|
||||
|
||||
|
||||
# Tests for make_discord_tools
|
||||
def test_make_discord_tools_with_user_and_channel(
|
||||
sample_bot_user, sample_discord_user, sample_discord_channel
|
||||
):
|
||||
"""Test creating Discord tools with both user and channel."""
|
||||
tools = make_discord_tools(
|
||||
bot=sample_bot_user,
|
||||
author=sample_discord_user,
|
||||
channel=sample_discord_channel,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# update_user_summary, update_server_summary
|
||||
assert len(tools) == 5
|
||||
assert "schedule_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_user_summary" in tools
|
||||
assert "update_server_summary" in tools
|
||||
|
||||
|
||||
def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user):
|
||||
"""Test creating Discord tools with only user (DM scenario)."""
|
||||
tools = make_discord_tools(
|
||||
bot=sample_bot_user,
|
||||
author=sample_discord_user,
|
||||
channel=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_user_summary
|
||||
# Note: Without channel, there's no channel summary tool
|
||||
assert len(tools) >= 2 # At least schedule and previous messages
|
||||
assert "schedule_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_user_summary" in tools
|
||||
|
||||
|
||||
def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_channel):
|
||||
"""Test creating Discord tools with only channel (no specific author)."""
|
||||
tools = make_discord_tools(
|
||||
bot=sample_bot_user,
|
||||
author=None,
|
||||
channel=sample_discord_channel,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# update_server_summary (no user summary without author)
|
||||
assert len(tools) == 4
|
||||
assert "schedule_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_server_summary" in tools
|
||||
assert "update_user_summary" not in tools
|
||||
|
||||
|
||||
def test_make_discord_tools_channel_without_server(
|
||||
db_session, sample_bot_user, sample_discord_user
|
||||
):
|
||||
"""Test creating Discord tools with channel that has no server (DM channel)."""
|
||||
dm_channel = DiscordChannel(
|
||||
id=999888777,
|
||||
server_id=None,
|
||||
name="DM Channel",
|
||||
channel_type="dm",
|
||||
)
|
||||
db_session.add(dm_channel)
|
||||
db_session.commit()
|
||||
|
||||
tools = make_discord_tools(
|
||||
bot=sample_bot_user,
|
||||
author=sample_discord_user,
|
||||
channel=dm_channel,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should not have server summary tool since channel has no server
|
||||
assert "update_server_summary" not in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_user_summary" in tools
|
||||
|
||||
|
||||
def test_make_discord_tools_returns_dict_with_correct_keys(
|
||||
sample_bot_user, sample_discord_user, sample_discord_channel
|
||||
):
|
||||
"""Test that make_discord_tools returns a dict with tool names as keys."""
|
||||
tools = make_discord_tools(
|
||||
bot=sample_bot_user,
|
||||
author=sample_discord_user,
|
||||
channel=sample_discord_channel,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Verify all keys match the tool names
|
||||
for tool_name, tool in tools.items():
|
||||
assert tool_name == tool.name
|
@ -34,7 +34,7 @@ def test_send_dm_success(mock_post, mock_api_url):
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user_identifier": "user123", "message": "Hello!"},
|
||||
json={"user": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
@ -34,7 +34,7 @@ def test_send_dm_success(mock_post, mock_api_url):
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user_identifier": "user123", "message": "Hello!"},
|
||||
json={"user": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
@ -15,7 +15,7 @@ from memory.discord.collector import (
|
||||
sync_guild_metadata,
|
||||
MessageCollector,
|
||||
)
|
||||
from memory.common.db.models.sources import (
|
||||
from memory.common.db.models import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
413
tests/memory/discord_tests/test_messages.py
Normal file
413
tests/memory/discord_tests/test_messages.py
Normal file
@ -0,0 +1,413 @@
|
||||
"""Tests for Discord message helper functions."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from memory.discord.messages import (
|
||||
resolve_discord_user,
|
||||
resolve_discord_channel,
|
||||
schedule_discord_message,
|
||||
upsert_scheduled_message,
|
||||
previous_messages,
|
||||
comm_channel_prompt,
|
||||
)
|
||||
from memory.common.db.models import (
|
||||
DiscordUser,
|
||||
DiscordChannel,
|
||||
DiscordServer,
|
||||
DiscordMessage,
|
||||
HumanUser,
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discord_user(db_session):
|
||||
"""Create a sample Discord user."""
|
||||
user = DiscordUser(id=123456789, username="testuser")
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discord_channel(db_session):
|
||||
"""Create a sample Discord channel."""
|
||||
server = DiscordServer(id=987654321, name="Test Server")
|
||||
channel = DiscordChannel(
|
||||
id=111222333, name="general", channel_type="text", server_id=server.id
|
||||
)
|
||||
db_session.add_all([server, channel])
|
||||
db_session.commit()
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_system_user(db_session):
|
||||
"""Create a sample system user."""
|
||||
user = HumanUser.create_with_password(
|
||||
email="user@example.com", name="Test User", password="password123"
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
# Test resolve_discord_user
|
||||
|
||||
|
||||
def test_resolve_discord_user_with_none(db_session):
|
||||
"""Test resolving None returns None."""
|
||||
result = resolve_discord_user(db_session, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_resolve_discord_user_with_discord_user_object(
|
||||
db_session, sample_discord_user
|
||||
):
|
||||
"""Test resolving a DiscordUser object returns it unchanged."""
|
||||
result = resolve_discord_user(db_session, sample_discord_user)
|
||||
assert result == sample_discord_user
|
||||
assert result.id == 123456789
|
||||
|
||||
|
||||
def test_resolve_discord_user_with_id(db_session, sample_discord_user):
|
||||
"""Test resolving by integer ID."""
|
||||
result = resolve_discord_user(db_session, 123456789)
|
||||
assert result is not None
|
||||
assert result.id == 123456789
|
||||
assert result.username == "testuser"
|
||||
|
||||
|
||||
def test_resolve_discord_user_with_username(db_session, sample_discord_user):
|
||||
"""Test resolving by username string."""
|
||||
result = resolve_discord_user(db_session, "testuser")
|
||||
assert result is not None
|
||||
assert result.username == "testuser"
|
||||
|
||||
|
||||
def test_resolve_discord_user_with_nonexistent_username_returns_none(db_session):
|
||||
"""Test that resolving a non-existent username returns None."""
|
||||
result = resolve_discord_user(db_session, "nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
# Test resolve_discord_channel
|
||||
|
||||
|
||||
def test_resolve_discord_channel_with_none(db_session):
|
||||
"""Test resolving None returns None."""
|
||||
result = resolve_discord_channel(db_session, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_resolve_discord_channel_with_channel_object(
|
||||
db_session, sample_discord_channel
|
||||
):
|
||||
"""Test resolving a DiscordChannel object returns it unchanged."""
|
||||
result = resolve_discord_channel(db_session, sample_discord_channel)
|
||||
assert result == sample_discord_channel
|
||||
assert result.id == 111222333
|
||||
|
||||
|
||||
def test_resolve_discord_channel_with_id(db_session, sample_discord_channel):
|
||||
"""Test resolving by integer ID."""
|
||||
result = resolve_discord_channel(db_session, 111222333)
|
||||
assert result is not None
|
||||
assert result.id == 111222333
|
||||
assert result.name == "general"
|
||||
|
||||
|
||||
def test_resolve_discord_channel_with_name(db_session, sample_discord_channel):
|
||||
"""Test resolving by channel name string."""
|
||||
result = resolve_discord_channel(db_session, "general")
|
||||
assert result is not None
|
||||
assert result.name == "general"
|
||||
|
||||
|
||||
def test_resolve_discord_channel_returns_none_if_not_found(db_session):
|
||||
"""Test that resolving a non-existent channel returns None."""
|
||||
result = resolve_discord_channel(db_session, "nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
# Test schedule_discord_message
|
||||
|
||||
|
||||
def test_schedule_discord_message_with_user(
|
||||
db_session, sample_discord_user, sample_system_user
|
||||
):
|
||||
"""Test scheduling a message to a Discord user."""
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = schedule_discord_message(
|
||||
db_session,
|
||||
scheduled_time=future_time,
|
||||
message="Test message",
|
||||
user_id=sample_system_user.id,
|
||||
discord_user=sample_discord_user,
|
||||
model="test-model",
|
||||
topic="Test Topic",
|
||||
)
|
||||
db_session.flush() # Flush to populate the foreign key IDs
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, ScheduledLLMCall)
|
||||
assert result.message == "Test message"
|
||||
assert result.discord_user_id == sample_discord_user.id
|
||||
assert result.user_id == sample_system_user.id
|
||||
|
||||
|
||||
def test_schedule_discord_message_with_channel(
|
||||
db_session, sample_discord_channel, sample_system_user
|
||||
):
|
||||
"""Test scheduling a message to a Discord channel."""
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = schedule_discord_message(
|
||||
db_session,
|
||||
scheduled_time=future_time,
|
||||
message="Channel message",
|
||||
user_id=sample_system_user.id,
|
||||
discord_channel=sample_discord_channel,
|
||||
)
|
||||
db_session.flush() # Flush to populate the foreign key IDs
|
||||
|
||||
assert result is not None
|
||||
assert result.discord_channel_id == sample_discord_channel.id
|
||||
|
||||
|
||||
def test_schedule_discord_message_requires_user_or_channel(
|
||||
db_session, sample_system_user
|
||||
):
|
||||
"""Test that scheduling requires either user or channel."""
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
with pytest.raises(ValueError, match="Either discord_user or discord_channel must be provided"):
|
||||
schedule_discord_message(
|
||||
db_session,
|
||||
scheduled_time=future_time,
|
||||
message="Test",
|
||||
user_id=sample_system_user.id,
|
||||
)
|
||||
|
||||
|
||||
def test_schedule_discord_message_requires_future_time(
|
||||
db_session, sample_discord_user, sample_system_user
|
||||
):
|
||||
"""Test that scheduling requires a future time."""
|
||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
|
||||
with pytest.raises(ValueError, match="Scheduled time must be in the future"):
|
||||
schedule_discord_message(
|
||||
db_session,
|
||||
scheduled_time=past_time,
|
||||
message="Test",
|
||||
user_id=sample_system_user.id,
|
||||
discord_user=sample_discord_user,
|
||||
)
|
||||
|
||||
|
||||
def test_schedule_discord_message_with_metadata(
|
||||
db_session, sample_discord_user, sample_system_user
|
||||
):
|
||||
"""Test scheduling with custom metadata."""
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
metadata = {"priority": "high", "tags": ["urgent"]}
|
||||
|
||||
result = schedule_discord_message(
|
||||
db_session,
|
||||
scheduled_time=future_time,
|
||||
message="Urgent message",
|
||||
user_id=sample_system_user.id,
|
||||
discord_user=sample_discord_user,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert result.data == metadata
|
||||
|
||||
|
||||
# Test upsert_scheduled_message
|
||||
|
||||
|
||||
def test_upsert_scheduled_message_creates_new(
|
||||
db_session, sample_discord_user, sample_system_user
|
||||
):
|
||||
"""Test upserting creates a new message if none exists."""
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = upsert_scheduled_message(
|
||||
db_session,
|
||||
scheduled_time=future_time,
|
||||
message="New message",
|
||||
user_id=sample_system_user.id,
|
||||
discord_user=sample_discord_user,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.message == "New message"
|
||||
|
||||
|
||||
def test_upsert_scheduled_message_cancels_earlier_call(
|
||||
db_session, sample_discord_user, sample_system_user
|
||||
):
|
||||
"""Test upserting cancels an earlier scheduled call for the same user/channel."""
|
||||
future_time1 = datetime.now(timezone.utc) + timedelta(hours=2)
|
||||
future_time2 = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
# Create first scheduled message
|
||||
first_call = schedule_discord_message(
|
||||
db_session,
|
||||
scheduled_time=future_time1,
|
||||
message="First message",
|
||||
user_id=sample_system_user.id,
|
||||
discord_user=sample_discord_user,
|
||||
model="test-model",
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Upsert with earlier time should cancel the first
|
||||
second_call = upsert_scheduled_message(
|
||||
db_session,
|
||||
scheduled_time=future_time2,
|
||||
message="Second message",
|
||||
user_id=sample_system_user.id,
|
||||
discord_user=sample_discord_user,
|
||||
model="test-model",
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
db_session.refresh(first_call)
|
||||
assert first_call.status == "cancelled"
|
||||
assert second_call.status == "pending"
|
||||
|
||||
|
||||
# Test previous_messages
|
||||
|
||||
|
||||
def test_previous_messages_empty(db_session):
|
||||
"""Test getting previous messages when none exist."""
|
||||
result = previous_messages(db_session, user_id=123, channel_id=456)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel):
|
||||
"""Test filtering messages by recipient user."""
|
||||
# Create some messages
|
||||
msg1 = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=sample_discord_channel.id,
|
||||
from_id=sample_discord_user.id,
|
||||
recipient_id=sample_discord_user.id,
|
||||
content="Message 1",
|
||||
sent_at=datetime.now(timezone.utc) - timedelta(minutes=10),
|
||||
modality="text",
|
||||
sha256=b"hash1" + bytes(26),
|
||||
)
|
||||
msg2 = DiscordMessage(
|
||||
message_id=2,
|
||||
channel_id=sample_discord_channel.id,
|
||||
from_id=sample_discord_user.id,
|
||||
recipient_id=sample_discord_user.id,
|
||||
content="Message 2",
|
||||
sent_at=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
modality="text",
|
||||
sha256=b"hash2" + bytes(26),
|
||||
)
|
||||
db_session.add_all([msg1, msg2])
|
||||
db_session.commit()
|
||||
|
||||
result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None)
|
||||
assert len(result) == 2
|
||||
# Should be in chronological order (oldest first)
|
||||
assert result[0].message_id == 1
|
||||
assert result[1].message_id == 2
|
||||
|
||||
|
||||
def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel):
|
||||
"""Test limiting the number of previous messages."""
|
||||
# Create 15 messages
|
||||
for i in range(15):
|
||||
msg = DiscordMessage(
|
||||
message_id=i,
|
||||
channel_id=sample_discord_channel.id,
|
||||
from_id=sample_discord_user.id,
|
||||
recipient_id=sample_discord_user.id,
|
||||
content=f"Message {i}",
|
||||
sent_at=datetime.now(timezone.utc) - timedelta(minutes=15 - i),
|
||||
modality="text",
|
||||
sha256=f"hash{i}".encode() + bytes(26),
|
||||
)
|
||||
db_session.add(msg)
|
||||
db_session.commit()
|
||||
|
||||
result = previous_messages(
|
||||
db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5
|
||||
)
|
||||
assert len(result) == 5
|
||||
|
||||
|
||||
# Test comm_channel_prompt
|
||||
|
||||
|
||||
def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel):
|
||||
"""Test generating a basic communication channel prompt."""
|
||||
result = comm_channel_prompt(
|
||||
db_session, user=sample_discord_user, channel=sample_discord_channel
|
||||
)
|
||||
|
||||
assert "You are a bot communicating on Discord" in result
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel):
|
||||
"""Test that prompt includes server context when available."""
|
||||
server = sample_discord_channel.server
|
||||
server.summary = "Gaming community server"
|
||||
db_session.commit()
|
||||
|
||||
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
|
||||
|
||||
assert "server_context" in result.lower()
|
||||
assert "Gaming community server" in result
|
||||
|
||||
|
||||
def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel):
|
||||
"""Test that prompt includes channel context."""
|
||||
sample_discord_channel.summary = "General discussion channel"
|
||||
db_session.commit()
|
||||
|
||||
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
|
||||
|
||||
assert "channel_context" in result.lower()
|
||||
assert "General discussion channel" in result
|
||||
|
||||
|
||||
def test_comm_channel_prompt_includes_user_notes(
|
||||
db_session, sample_discord_user, sample_discord_channel
|
||||
):
|
||||
"""Test that prompt includes user notes from previous messages."""
|
||||
sample_discord_user.summary = "Helpful community member"
|
||||
db_session.commit()
|
||||
|
||||
# Create a message from this user
|
||||
msg = DiscordMessage(
|
||||
message_id=1,
|
||||
from_id=sample_discord_user.id,
|
||||
recipient_id=sample_discord_user.id,
|
||||
channel_id=sample_discord_channel.id,
|
||||
content="Hello",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
modality="text",
|
||||
sha256=b"hash" + bytes(27),
|
||||
)
|
||||
db_session.add(msg)
|
||||
db_session.commit()
|
||||
|
||||
result = comm_channel_prompt(
|
||||
db_session, user=sample_discord_user, channel=sample_discord_channel
|
||||
)
|
||||
|
||||
assert "user_notes" in result.lower()
|
||||
assert "testuser" in result # username should appear
|
@ -733,6 +733,8 @@ def test_safe_task_execution_exception_logging(caplog):
|
||||
|
||||
result = failing_task()
|
||||
|
||||
assert result == {"status": "error", "error": "Test runtime error"}
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "Test runtime error"
|
||||
assert "traceback" in result
|
||||
assert "Task failing_task failed:" in caplog.text
|
||||
assert "Test runtime error" in caplog.text
|
||||
|
@ -59,6 +59,7 @@ def sample_message_data(mock_discord_user, mock_discord_channel):
|
||||
"message_id": 999888777,
|
||||
"channel_id": mock_discord_channel.id,
|
||||
"author_id": mock_discord_user.id,
|
||||
"recipient_id": mock_discord_user.id,
|
||||
"content": "This is a test Discord message with enough content to be processed.",
|
||||
"sent_at": "2024-01-01T12:00:00Z",
|
||||
"server_id": None,
|
||||
@ -74,7 +75,8 @@ def test_get_prev_returns_previous_messages(
|
||||
msg1 = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
content="First message",
|
||||
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
@ -83,7 +85,8 @@ def test_get_prev_returns_previous_messages(
|
||||
msg2 = DiscordMessage(
|
||||
message_id=2,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
content="Second message",
|
||||
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
@ -92,7 +95,8 @@ def test_get_prev_returns_previous_messages(
|
||||
msg3 = DiscordMessage(
|
||||
message_id=3,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
content="Third message",
|
||||
sent_at=datetime(2024, 1, 1, 10, 10, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
@ -123,7 +127,8 @@ def test_get_prev_limits_context_window(
|
||||
msg = DiscordMessage(
|
||||
message_id=i,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
content=f"Message {i}",
|
||||
sent_at=datetime(2024, 1, 1, 10, i, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
@ -157,14 +162,26 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel):
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.workers.tasks.discord.create_provider")
|
||||
def test_should_process_normal_message(
|
||||
db_session, mock_discord_user, mock_discord_server, mock_discord_channel
|
||||
mock_create_provider,
|
||||
db_session,
|
||||
mock_discord_user,
|
||||
mock_discord_server,
|
||||
mock_discord_channel,
|
||||
):
|
||||
"""Test should_process returns True for normal messages."""
|
||||
# Mock the LLM provider to return "yes"
|
||||
mock_provider = Mock()
|
||||
mock_provider.generate.return_value = "<response>yes</response>"
|
||||
mock_provider.as_messages.return_value = []
|
||||
mock_create_provider.return_value = mock_provider
|
||||
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
@ -210,7 +227,8 @@ def test_should_process_server_ignored(
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
server_id=server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
@ -243,7 +261,8 @@ def test_should_process_channel_ignored(
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=channel.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
@ -274,7 +293,8 @@ def test_should_process_user_ignored(
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
discord_user_id=user.id,
|
||||
from_id=user.id,
|
||||
recipient_id=user.id,
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
@ -350,7 +370,8 @@ def test_add_discord_message_with_context(
|
||||
prev_msg = DiscordMessage(
|
||||
message_id=111111,
|
||||
channel_id=sample_message_data["channel_id"],
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
content="Previous message",
|
||||
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
@ -370,11 +391,13 @@ def test_add_discord_message_with_context(
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.discord.should_process")
|
||||
@patch("memory.workers.tasks.discord.process_discord_message")
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_add_discord_message_triggers_processing(
|
||||
mock_process,
|
||||
mock_should_process,
|
||||
db_session,
|
||||
sample_message_data,
|
||||
mock_discord_server,
|
||||
@ -382,6 +405,7 @@ def test_add_discord_message_triggers_processing(
|
||||
qdrant,
|
||||
):
|
||||
"""Test that add_discord_message triggers process_discord_message when conditions are met."""
|
||||
mock_should_process.return_value = True
|
||||
mock_process.delay = Mock()
|
||||
sample_message_data["server_id"] = mock_discord_server.id
|
||||
|
||||
@ -454,7 +478,8 @@ def test_edit_discord_message_updates_context(
|
||||
prev_msg = DiscordMessage(
|
||||
message_id=111111,
|
||||
channel_id=sample_message_data["channel_id"],
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
content="Previous message",
|
||||
sent_at=datetime(2024, 1, 1, 11, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
@ -576,7 +601,8 @@ def test_get_prev_only_returns_messages_from_same_channel(
|
||||
msg1 = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=channel1.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
content="Message in channel 1",
|
||||
sent_at=datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
@ -585,7 +611,8 @@ def test_get_prev_only_returns_messages_from_same_channel(
|
||||
msg2 = DiscordMessage(
|
||||
message_id=2,
|
||||
channel_id=channel2.id,
|
||||
discord_user_id=mock_discord_user.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
content="Message in channel 2",
|
||||
sent_at=datetime(2024, 1, 1, 10, 5, 0, tzinfo=timezone.utc),
|
||||
modality="text",
|
||||
|
@ -12,6 +12,7 @@ from memory.workers.tasks import ebook
|
||||
def mock_ebook():
|
||||
"""Mock ebook data for testing."""
|
||||
return Ebook(
|
||||
relative_path=Path("test/book.epub"),
|
||||
title="Test Book",
|
||||
author="Test Author",
|
||||
metadata={"language": "en", "creator": "Test Publisher"},
|
||||
@ -207,10 +208,11 @@ def test_sync_book_already_exists(mock_parse, mock_ebook, db_session, tmp_path):
|
||||
book_file = tmp_path / "test.epub"
|
||||
book_file.write_text("dummy content")
|
||||
|
||||
# Use the same relative path that mock_ebook has
|
||||
existing_book = Book(
|
||||
title="Existing Book",
|
||||
author="Author",
|
||||
file_path=str(book_file),
|
||||
file_path="test/book.epub", # Must match mock_ebook.relative_path
|
||||
)
|
||||
db_session.add(existing_book)
|
||||
db_session.commit()
|
||||
@ -266,18 +268,18 @@ def test_sync_book_qdrant_failure(mock_parse, mock_ebook, db_session, tmp_path):
|
||||
# Since embedding is already failing, this test will complete without hitting Qdrant
|
||||
# So let's just verify that the function completes without raising an exception
|
||||
with patch.object(ebook, "push_to_qdrant", side_effect=Exception("Qdrant failed")):
|
||||
assert ebook.sync_book(str(book_file)) == {
|
||||
"status": "error",
|
||||
"error": "Qdrant failed",
|
||||
}
|
||||
result = ebook.sync_book(str(book_file))
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "Qdrant failed"
|
||||
assert "traceback" in result
|
||||
|
||||
|
||||
def test_sync_book_file_not_found():
|
||||
"""Test handling of missing files."""
|
||||
assert ebook.sync_book("/nonexistent/file.epub") == {
|
||||
"status": "error",
|
||||
"error": "Book file not found: /nonexistent/file.epub",
|
||||
}
|
||||
result = ebook.sync_book("/nonexistent/file.epub")
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "Book file not found: /nonexistent/file.epub"
|
||||
assert "traceback" in result
|
||||
|
||||
|
||||
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
||||
|
@ -3,18 +3,17 @@ from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
import uuid
|
||||
|
||||
from memory.common.db.models import ScheduledLLMCall, User
|
||||
from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer
|
||||
from memory.workers.tasks import scheduled_calls
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user(db_session):
|
||||
"""Create a sample user for testing."""
|
||||
user = User(
|
||||
user = HumanUser.create_with_password(
|
||||
name="testuser",
|
||||
email="test@example.com",
|
||||
discord_user_id="123456789",
|
||||
password_hash="password",
|
||||
password="password",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
@ -22,7 +21,45 @@ def sample_user(db_session):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pending_scheduled_call(db_session, sample_user):
|
||||
def sample_discord_user(db_session):
|
||||
"""Create a sample Discord user for testing."""
|
||||
discord_user = DiscordUser(
|
||||
id=123456789,
|
||||
username="testuser",
|
||||
)
|
||||
db_session.add(discord_user)
|
||||
db_session.commit()
|
||||
return discord_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discord_server(db_session):
|
||||
"""Create a sample Discord server for testing."""
|
||||
server = DiscordServer(
|
||||
id=987654321,
|
||||
name="Test Server",
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discord_channel(db_session, sample_discord_server):
|
||||
"""Create a sample Discord channel for testing."""
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
name="test-channel",
|
||||
channel_type="text",
|
||||
server_id=sample_discord_server.id,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pending_scheduled_call(db_session, sample_user, sample_discord_user):
|
||||
"""Create a pending scheduled call for testing."""
|
||||
call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -32,7 +69,7 @@ def pending_scheduled_call(db_session, sample_user):
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
message="What is the weather like today?",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
discord_user="123456789",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(call)
|
||||
@ -41,7 +78,7 @@ def pending_scheduled_call(db_session, sample_user):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completed_scheduled_call(db_session, sample_user):
|
||||
def completed_scheduled_call(db_session, sample_user, sample_discord_channel):
|
||||
"""Create a completed scheduled call for testing."""
|
||||
call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -52,7 +89,7 @@ def completed_scheduled_call(db_session, sample_user):
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
message="Tell me a joke.",
|
||||
system_prompt="You are a funny assistant.",
|
||||
discord_channel="987654321",
|
||||
discord_channel_id=sample_discord_channel.id,
|
||||
status="completed",
|
||||
response="Why did the chicken cross the road? To get to the other side!",
|
||||
)
|
||||
@ -62,7 +99,7 @@ def completed_scheduled_call(db_session, sample_user):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def future_scheduled_call(db_session, sample_user):
|
||||
def future_scheduled_call(db_session, sample_user, sample_discord_user):
|
||||
"""Create a future scheduled call for testing."""
|
||||
call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -71,7 +108,7 @@ def future_scheduled_call(db_session, sample_user):
|
||||
scheduled_time=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
message="What will happen tomorrow?",
|
||||
discord_user="123456789",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(call)
|
||||
@ -87,7 +124,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
||||
|
||||
mock_send_dm.assert_called_once_with(
|
||||
"123456789",
|
||||
"testuser", # username, not ID
|
||||
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
||||
)
|
||||
|
||||
@ -100,7 +137,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
||||
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
||||
|
||||
mock_broadcast.assert_called_once_with(
|
||||
"987654321",
|
||||
"test-channel", # channel name, not ID
|
||||
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
||||
)
|
||||
|
||||
@ -133,7 +170,7 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
def test_execute_scheduled_call_success(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -171,7 +208,7 @@ def test_execute_scheduled_call_not_found(db_session):
|
||||
assert result == {"error": "Scheduled call not found"}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
def test_execute_scheduled_call_not_pending(
|
||||
mock_llm_call, completed_scheduled_call, db_session
|
||||
):
|
||||
@ -183,9 +220,9 @@ def test_execute_scheduled_call_not_pending(
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
def test_execute_scheduled_call_with_default_system_prompt(
|
||||
mock_llm_call, mock_send_discord, db_session, sample_user
|
||||
mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
|
||||
):
|
||||
"""Test execution when system_prompt is None, should use default."""
|
||||
# Create call without system prompt
|
||||
@ -197,7 +234,7 @@ def test_execute_scheduled_call_with_default_system_prompt(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
message="Test prompt",
|
||||
system_prompt=None,
|
||||
discord_user="123456789",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(call)
|
||||
@ -211,12 +248,12 @@ def test_execute_scheduled_call_with_default_system_prompt(
|
||||
mock_llm_call.assert_called_once_with(
|
||||
prompt="Test prompt",
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
system_prompt=scheduled_calls.llms.SYSTEM_PROMPT,
|
||||
system_prompt=None, # The code uses system_prompt as-is, not a default
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
def test_execute_scheduled_call_discord_error(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -240,7 +277,7 @@ def test_execute_scheduled_call_discord_error(
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
def test_execute_scheduled_call_llm_error(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -258,7 +295,7 @@ def test_execute_scheduled_call_llm_error(
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
def test_execute_scheduled_call_long_response_truncation(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -279,7 +316,7 @@ def test_execute_scheduled_call_long_response_truncation(
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||
def test_run_scheduled_calls_with_due_calls(
|
||||
mock_execute_delay, db_session, sample_user
|
||||
mock_execute_delay, db_session, sample_user, sample_discord_user
|
||||
):
|
||||
"""Test running scheduled calls with due calls."""
|
||||
# Create multiple due calls
|
||||
@ -289,7 +326,7 @@ def test_run_scheduled_calls_with_due_calls(
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=10),
|
||||
model="test-model",
|
||||
message="Test 1",
|
||||
discord_user="123",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="pending",
|
||||
)
|
||||
due_call2 = ScheduledLLMCall(
|
||||
@ -298,7 +335,7 @@ def test_run_scheduled_calls_with_due_calls(
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
message="Test 2",
|
||||
discord_user="123",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
@ -337,7 +374,7 @@ def test_run_scheduled_calls_no_due_calls(
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||
def test_run_scheduled_calls_mixed_statuses(
|
||||
mock_execute_delay, db_session, sample_user
|
||||
mock_execute_delay, db_session, sample_user, sample_discord_user
|
||||
):
|
||||
"""Test that only pending calls are processed."""
|
||||
# Create calls with different statuses
|
||||
@ -347,7 +384,7 @@ def test_run_scheduled_calls_mixed_statuses(
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
message="Pending",
|
||||
discord_user="123",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="pending",
|
||||
)
|
||||
executing_call = ScheduledLLMCall(
|
||||
@ -356,7 +393,7 @@ def test_run_scheduled_calls_mixed_statuses(
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
message="Executing",
|
||||
discord_user="123",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="executing",
|
||||
)
|
||||
completed_call = ScheduledLLMCall(
|
||||
@ -365,7 +402,7 @@ def test_run_scheduled_calls_mixed_statuses(
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
message="Completed",
|
||||
discord_user="123",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
@ -387,7 +424,7 @@ def test_run_scheduled_calls_mixed_statuses(
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||
def test_run_scheduled_calls_timezone_handling(
|
||||
mock_execute_delay, db_session, sample_user
|
||||
mock_execute_delay, db_session, sample_user, sample_discord_user
|
||||
):
|
||||
"""Test that timezone handling works correctly."""
|
||||
# Create a call that's due (scheduled time in the past)
|
||||
@ -398,7 +435,7 @@ def test_run_scheduled_calls_timezone_handling(
|
||||
scheduled_time=past_time.replace(tzinfo=None), # Store as naive datetime
|
||||
model="test-model",
|
||||
message="Due call",
|
||||
discord_user="123",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
@ -410,7 +447,7 @@ def test_run_scheduled_calls_timezone_handling(
|
||||
scheduled_time=future_time.replace(tzinfo=None), # Store as naive datetime
|
||||
model="test-model",
|
||||
message="Future call",
|
||||
discord_user="123",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
@ -431,7 +468,7 @@ def test_run_scheduled_calls_timezone_handling(
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
def test_status_transition_pending_to_executing_to_completed(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -452,11 +489,11 @@ def test_status_transition_pending_to_executing_to_completed(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"discord_user,discord_channel,expected_method",
|
||||
"has_discord_user,has_discord_channel,expected_method",
|
||||
[
|
||||
("123456789", None, "send_dm"),
|
||||
(None, "987654321", "broadcast_message"),
|
||||
("123456789", "987654321", "send_dm"), # User takes precedence
|
||||
(True, False, "send_dm"),
|
||||
(False, True, "broadcast_message"),
|
||||
(True, True, "send_dm"), # User takes precedence
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@ -464,11 +501,13 @@ def test_status_transition_pending_to_executing_to_completed(
|
||||
def test_discord_destination_priority(
|
||||
mock_broadcast,
|
||||
mock_send_dm,
|
||||
discord_user,
|
||||
discord_channel,
|
||||
has_discord_user,
|
||||
has_discord_channel,
|
||||
expected_method,
|
||||
db_session,
|
||||
sample_user,
|
||||
sample_discord_user,
|
||||
sample_discord_channel,
|
||||
):
|
||||
"""Test that Discord user takes precedence over channel."""
|
||||
call = ScheduledLLMCall(
|
||||
@ -478,8 +517,8 @@ def test_discord_destination_priority(
|
||||
scheduled_time=datetime.now(timezone.utc),
|
||||
model="test-model",
|
||||
message="Test",
|
||||
discord_user=discord_user,
|
||||
discord_channel=discord_channel,
|
||||
discord_user_id=sample_discord_user.id if has_discord_user else None,
|
||||
discord_channel_id=sample_discord_channel.id if has_discord_channel else None,
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(call)
|
||||
@ -530,11 +569,15 @@ def test_discord_destination_priority(
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
|
||||
"""Test the Discord message formatting with different inputs."""
|
||||
# Create a mock scheduled call
|
||||
# Create a mock scheduled call with a mock Discord user
|
||||
mock_discord_user = Mock()
|
||||
mock_discord_user.username = "testuser"
|
||||
|
||||
mock_call = Mock()
|
||||
mock_call.topic = topic
|
||||
mock_call.model = model
|
||||
mock_call.discord_user = "123456789"
|
||||
mock_call.discord_user = mock_discord_user
|
||||
mock_call.discord_channel = None
|
||||
|
||||
scheduled_calls._send_to_discord(mock_call, response)
|
||||
|
||||
@ -557,9 +600,9 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
("cancelled", False),
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
def test_execute_scheduled_call_status_check(
|
||||
mock_llm_call, status, should_execute, db_session, sample_user
|
||||
mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
|
||||
):
|
||||
"""Test that only pending calls are executed."""
|
||||
call = ScheduledLLMCall(
|
||||
@ -569,7 +612,7 @@ def test_execute_scheduled_call_status_check(
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
message="Test",
|
||||
discord_user="123",
|
||||
discord_user_id=sample_discord_user.id,
|
||||
status=status,
|
||||
)
|
||||
db_session.add(call)
|
||||
|
Loading…
x
Reference in New Issue
Block a user