Add proactive check-in functionality for Discord

- Add proactive_cron, proactive_prompt, last_proactive_at fields to Discord models
- Add /proactive command handler for configuring check-in schedules
- Add evaluate_proactive_checkins task (runs every minute via celery beat)
- Add execute_proactive_checkin task that evaluates interest and sends messages
- Smart bot selection finds the correct bot for each server
- Channel selection defaults to "general" text channel for servers
- Add database migration for new fields
- Add comprehensive tests for commands and tasks

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Daniel O'Connell 2025-12-25 09:21:30 +01:00
parent 9088997295
commit a238ca6329
12 changed files with 1274 additions and 4 deletions

View File

@ -0,0 +1,45 @@
"""Add proactive check-in fields to Discord entities
Revision ID: e1f2a3b4c5d6
Revises: d0e1f2a3b4c5
Create Date: 2025-12-24 16:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "e1f2a3b4c5d6"
down_revision: Union[str, None] = "d0e1f2a3b4c5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add proactive fields to all MessageProcessor tables
for table in ["discord_servers", "discord_channels", "discord_users"]:
op.add_column(
table,
sa.Column("proactive_cron", sa.Text(), nullable=True),
)
op.add_column(
table,
sa.Column("proactive_prompt", sa.Text(), nullable=True),
)
op.add_column(
table,
sa.Column(
"last_proactive_at", sa.DateTime(timezone=True), nullable=True
),
)
def downgrade() -> None:
for table in ["discord_servers", "discord_channels", "discord_users"]:
op.drop_column(table, "last_proactive_at")
op.drop_column(table, "proactive_prompt")
op.drop_column(table, "proactive_cron")

View File

@ -10,5 +10,6 @@ openai==2.3.0
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
httpx>=0.28.1
celery[redis,sqs]==5.3.6
croniter==2.0.1
cryptography==43.0.0
bcrypt==4.1.2

View File

@ -1,7 +1,9 @@
pytest==7.4.4
pytest-cov==4.1.0
pytest-asyncio==0.23.0
black==23.12.1
mypy==1.8.0
isort==5.13.2
isort==5.13.2
testcontainers[qdrant]==4.10.0
click==8.1.7
click==8.1.7
croniter==2.0.1

View File

@ -17,6 +17,7 @@ DISCORD_ROOT = "memory.workers.tasks.discord"
BACKUP_ROOT = "memory.workers.tasks.backup"
GITHUB_ROOT = "memory.workers.tasks.github"
PEOPLE_ROOT = "memory.workers.tasks.people"
PROACTIVE_ROOT = "memory.workers.tasks.proactive"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
@ -73,6 +74,10 @@ SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person"
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file"
# Proactive check-in tasks
EVALUATE_PROACTIVE_CHECKINS = f"{PROACTIVE_ROOT}.evaluate_proactive_checkins"
EXECUTE_PROACTIVE_CHECKIN = f"{PROACTIVE_ROOT}.execute_proactive_checkin"
def get_broker_url() -> str:
protocol = settings.CELERY_BROKER_TYPE
@ -130,12 +135,17 @@ app.conf.update(
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
f"{PROACTIVE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
},
beat_schedule={
"sync-github-repos-hourly": {
"task": SYNC_ALL_GITHUB_REPOS,
"schedule": crontab(minute=0), # Every hour at :00
},
"evaluate-proactive-checkins": {
"task": EVALUATE_PROACTIVE_CHECKINS,
"schedule": crontab(), # Every minute
},
},
)

View File

@ -63,13 +63,29 @@ class MessageProcessor:
doc=textwrap.dedent(
"""
A summary of this processor, made by and for AI systems.
The idea here is that AI systems can use this summary to keep notes on the given processor.
These should automatically be injected into the context of the messages that are processed by this processor.
These should automatically be injected into the context of the messages that are processed by this processor.
"""
),
)
proactive_cron = Column(
Text,
nullable=True,
doc="Cron schedule for proactive check-ins (e.g., '0 9 * * *' for 9am daily). None = disabled.",
)
proactive_prompt = Column(
Text,
nullable=True,
doc="Custom instructions for proactive check-ins.",
)
last_proactive_at = Column(
DateTime(timezone=True),
nullable=True,
doc="When the last proactive check-in was sent.",
)
@property
def entity_type(self) -> str:
return self.__class__.__tablename__[8:-1] # type: ignore

View File

@ -132,6 +132,7 @@ CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
PROACTIVE_CHECKIN_INTERVAL = int(os.getenv("PROACTIVE_CHECKIN_INTERVAL", 60))
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))

View File

@ -167,6 +167,25 @@ def _create_scope_group(
url=url and url.strip(),
)
# Proactive command
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
@discord.app_commands.describe(
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
prompt="Optional custom instructions for check-ins",
)
async def proactive_cmd(
interaction: discord.Interaction,
cron: str | None = None,
prompt: str | None = None,
):
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_proactive,
cron=cron and cron.strip(),
prompt=prompt,
)
return group
@ -265,6 +284,28 @@ def _create_user_scope_group(
url=url and url.strip(),
)
# Proactive command
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
@discord.app_commands.describe(
user="Target user",
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
prompt="Optional custom instructions for check-ins",
)
async def proactive_cmd(
interaction: discord.Interaction,
user: discord.User,
cron: str | None = None,
prompt: str | None = None,
):
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_proactive,
target_user=user,
cron=cron and cron.strip(),
prompt=prompt,
)
return group
@ -663,3 +704,68 @@ async def handle_mcp_servers(
except Exception as exc:
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
raise CommandError(f"Error: {exc}") from exc
async def handle_proactive(
context: CommandContext,
*,
cron: str | None = None,
prompt: str | None = None,
) -> CommandResponse:
"""Handle proactive check-in configuration."""
from croniter import croniter
model = context.target
# If no arguments, show current settings
if cron is None and prompt is None:
current_cron = getattr(model, "proactive_cron", None)
current_prompt = getattr(model, "proactive_prompt", None)
if not current_cron:
return CommandResponse(
content=f"Proactive check-ins are disabled for {context.display_name}."
)
lines = [f"Proactive check-ins for {context.display_name}:"]
lines.append(f" Schedule: `{current_cron}`")
if current_prompt:
lines.append(f" Prompt: {current_prompt}")
return CommandResponse(content="\n".join(lines))
# Handle cron setting
if cron is not None:
if cron.lower() == "off":
setattr(model, "proactive_cron", None)
return CommandResponse(
content=f"Proactive check-ins disabled for {context.display_name}."
)
# Validate cron expression
try:
croniter(cron)
except (ValueError, KeyError) as e:
raise CommandError(
f"Invalid cron expression: {cron}\n"
"Examples:\n"
" `0 9 * * *` - 9am daily\n"
" `0 9,17 * * 1-5` - 9am and 5pm weekdays\n"
" `0 */4 * * *` - every 4 hours"
) from e
setattr(model, "proactive_cron", cron)
# Handle prompt setting
if prompt is not None:
setattr(model, "proactive_prompt", prompt or None)
# Build response
current_cron = getattr(model, "proactive_cron", None)
current_prompt = getattr(model, "proactive_prompt", None)
lines = [f"Updated proactive settings for {context.display_name}:"]
lines.append(f" Schedule: `{current_cron}`")
if current_prompt:
lines.append(f" Prompt: {current_prompt}")
return CommandResponse(content="\n".join(lines))

View File

@ -11,6 +11,7 @@ from memory.common.celery_app import (
SYNC_LESSWRONG,
RUN_SCHEDULED_CALLS,
BACKUP_ALL,
EVALUATE_PROACTIVE_CHECKINS,
)
logger = logging.getLogger(__name__)
@ -53,4 +54,8 @@ app.conf.beat_schedule = {
"task": BACKUP_ALL,
"schedule": settings.S3_BACKUP_INTERVAL,
},
"evaluate-proactive-checkins": {
"task": EVALUATE_PROACTIVE_CHECKINS,
"schedule": settings.PROACTIVE_CHECKIN_INTERVAL,
},
}

View File

@ -15,6 +15,7 @@ from memory.workers.tasks import (
notes,
observations,
people,
proactive,
scheduled_calls,
) # noqa
@ -31,5 +32,6 @@ __all__ = [
"notes",
"observations",
"people",
"proactive",
"scheduled_calls",
]

View File

@ -0,0 +1,341 @@
"""
Celery tasks for proactive Discord check-ins.
"""
import logging
import re
import textwrap
from datetime import datetime, timezone
from typing import Any, Literal, cast
from croniter import croniter
from sqlalchemy import or_
from memory.common import settings
from memory.common.celery_app import app
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
from memory.workers.tasks.content_processing import safe_task_execution
logger = logging.getLogger(__name__)
EVALUATE_PROACTIVE_CHECKINS = "memory.workers.tasks.proactive.evaluate_proactive_checkins"
EXECUTE_PROACTIVE_CHECKIN = "memory.workers.tasks.proactive.execute_proactive_checkin"
EntityType = Literal["user", "channel", "server"]
def is_cron_due(cron_expr: str, last_run: datetime | None, now: datetime) -> bool:
"""Check if a cron expression is due to run now.
Uses croniter to determine if the current time falls within the cron's schedule
and enough time has passed since the last run.
"""
try:
cron = croniter(cron_expr, now)
# Get the previous scheduled time from now
prev_run = cron.get_prev(datetime)
# Get the one before that to determine the interval
cron.get_prev(datetime)
prev_prev_run = cron.get_current(datetime)
# If we haven't run since the last scheduled time, we should run
if last_run is None:
# Never run before - check if current time is within a minute of prev_run
time_since_scheduled = (now - prev_run).total_seconds()
return time_since_scheduled < 120 # Within 2 minutes of scheduled time
# Make sure last_run is timezone aware
if last_run.tzinfo is None:
last_run = last_run.replace(tzinfo=timezone.utc)
# We should run if last_run is before the previous scheduled time
return last_run < prev_run
except Exception as e:
logger.warning(f"Invalid cron expression '{cron_expr}': {e}")
return False
def get_bot_for_entity(
session, entity_type: EntityType, entity_id: int
) -> DiscordUser | None:
"""Get the bot user associated with an entity."""
from memory.common.db.models import DiscordBotUser, DiscordMessage
from sqlalchemy.orm import joinedload
# For servers, find a bot that has sent messages in that server
if entity_type == "server":
# Find bots that have interacted with this server
bot_users = (
session.query(DiscordUser)
.options(joinedload(DiscordUser.system_user))
.join(DiscordMessage, DiscordMessage.from_id == DiscordUser.id)
.filter(
DiscordMessage.server_id == entity_id,
DiscordUser.system_user_id.isnot(None),
)
.distinct()
.all()
)
# Find one that's actually a bot
for user in bot_users:
if user.system_user and user.system_user.user_type == "discord_bot":
return user
# For channels, check the server the channel belongs to
if entity_type == "channel":
channel = session.get(DiscordChannel, entity_id)
if channel and channel.server_id:
return get_bot_for_entity(session, "server", channel.server_id)
# Fallback: use first available bot
bot = (
session.query(DiscordBotUser)
.options(joinedload(DiscordBotUser.discord_users).joinedload(DiscordUser.system_user))
.first()
)
if bot and bot.discord_users:
return bot.discord_users[0]
return None
def get_target_user_for_entity(
session, entity_type: EntityType, entity_id: int
) -> DiscordUser | None:
"""Get the target user for sending a proactive message."""
if entity_type == "user":
return session.get(DiscordUser, entity_id)
# For channels and servers, we don't have a specific target user
return None
def get_channel_for_entity(
session, entity_type: EntityType, entity_id: int
) -> DiscordChannel | None:
"""Get the channel for sending a proactive message."""
if entity_type == "channel":
return session.get(DiscordChannel, entity_id)
if entity_type == "server":
# For servers, find the first text channel (prefer "general")
channels = (
session.query(DiscordChannel)
.filter(
DiscordChannel.server_id == entity_id,
DiscordChannel.channel_type == "text",
)
.all()
)
if not channels:
return None
# Prefer a channel named "general" if it exists
for channel in channels:
if channel.name and "general" in channel.name.lower():
return channel
return channels[0]
# For users, we use DMs (no channel)
return None
@app.task(name=EVALUATE_PROACTIVE_CHECKINS)
@safe_task_execution
def evaluate_proactive_checkins() -> dict[str, Any]:
"""
Evaluate which entities need proactive check-ins.
This task runs every minute and checks all entities with proactive_cron set
to see if they're due for a check-in.
"""
now = datetime.now(timezone.utc)
dispatched = []
with make_session() as session:
# Query all entities with proactive_cron set
for model, entity_type in [
(DiscordUser, "user"),
(DiscordChannel, "channel"),
(DiscordServer, "server"),
]:
entities = (
session.query(model)
.filter(model.proactive_cron.isnot(None))
.all()
)
for entity in entities:
cron_expr = cast(str, entity.proactive_cron)
last_run = entity.last_proactive_at
if is_cron_due(cron_expr, last_run, now):
logger.info(
f"Proactive check-in due for {entity_type} {entity.id}"
)
execute_proactive_checkin.delay(entity_type, entity.id)
dispatched.append({"type": entity_type, "id": entity.id})
return {
"evaluated_at": now.isoformat(),
"dispatched": dispatched,
"count": len(dispatched),
}
@app.task(name=EXECUTE_PROACTIVE_CHECKIN)
@safe_task_execution
def execute_proactive_checkin(entity_type: EntityType, entity_id: int) -> dict[str, Any]:
"""
Execute a proactive check-in for a specific entity.
This evaluates whether the bot should reach out and, if so, generates
and sends a check-in message.
"""
logger.info(f"Executing proactive check-in for {entity_type} {entity_id}")
with make_session() as session:
# Get the entity
model_class = {
"user": DiscordUser,
"channel": DiscordChannel,
"server": DiscordServer,
}[entity_type]
entity = session.get(model_class, entity_id)
if not entity:
return {"error": f"{entity_type} {entity_id} not found"}
# Get the bot user
bot_user = get_bot_for_entity(session, entity_type, entity_id)
if not bot_user:
return {"error": "No bot user found"}
# Get target user and channel
target_user = get_target_user_for_entity(session, entity_type, entity_id)
channel = get_channel_for_entity(session, entity_type, entity_id)
if not target_user and not channel:
return {"error": "No target user or channel for proactive check-in"}
# Get chattiness threshold
chattiness = entity.chattiness_threshold or 90
# Build the evaluation prompt
proactive_prompt = entity.proactive_prompt or ""
eval_prompt = textwrap.dedent("""
You are considering whether to proactively reach out to check in.
{proactive_prompt}
Based on your notes and the context of previous conversations:
1. Is there anything worth checking in about?
2. Has enough happened or enough time passed to warrant a check-in?
3. Would reaching out now be welcome or intrusive?
Please return a number between 0 and 100 indicating how strongly you want to check in
(0 = definitely not, 100 = definitely yes).
<response>
<number>50</number>
<reason>Your reasoning here</reason>
</response>
""").format(proactive_prompt=proactive_prompt)
# Build context
system_prompt = comm_channel_prompt(
session, bot_user, target_user, channel
)
# First, evaluate whether we should check in
eval_response = call_llm(
session,
bot_user=bot_user,
from_user=target_user,
channel=channel,
model=settings.SUMMARIZER_MODEL,
system_prompt=system_prompt,
messages=[eval_prompt],
allowed_tools=[
"update_channel_summary",
"update_user_summary",
"update_server_summary",
],
)
if not eval_response:
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {"status": "no_eval_response", "entity_type": entity_type, "entity_id": entity_id}
# Parse the interest score
match = re.search(r"<number>(\d+)</number>", eval_response)
if not match:
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {"status": "no_score_in_response", "entity_type": entity_type, "entity_id": entity_id}
interest_score = int(match.group(1))
threshold = 100 - chattiness
logger.info(
f"Proactive check-in eval: interest={interest_score}, threshold={threshold}, chattiness={chattiness}"
)
if interest_score < threshold:
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {
"status": "below_threshold",
"interest": interest_score,
"threshold": threshold,
"entity_type": entity_type,
"entity_id": entity_id,
}
# Generate the actual check-in message
checkin_prompt = textwrap.dedent("""
You've decided to proactively check in. Generate a natural, friendly check-in message.
{proactive_prompt}
Keep it brief and genuine. Don't be overly formal or robotic.
Reference specific things from your notes if relevant.
""").format(proactive_prompt=proactive_prompt)
response = call_llm(
session,
bot_user=bot_user,
from_user=target_user,
channel=channel,
model=settings.DISCORD_MODEL,
system_prompt=system_prompt,
messages=[checkin_prompt],
)
if not response:
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {"status": "no_message_generated", "entity_type": entity_type, "entity_id": entity_id}
# Send the message
bot_id = bot_user.system_user.id if bot_user.system_user else None
if not bot_id:
return {"error": "No system user for bot"}
success = send_discord_response(
bot_id=bot_id,
response=response,
channel_id=channel.id if channel else None,
user_identifier=target_user.username if target_user else None,
)
# Update last_proactive_at
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {
"status": "sent" if success else "send_failed",
"interest": interest_score,
"entity_type": entity_type,
"entity_id": entity_id,
"response_preview": response[:100] + "..." if len(response) > 100 else response,
}

View File

@ -15,6 +15,7 @@ from memory.discord.commands import (
handle_chattiness,
handle_ignore,
handle_summary,
handle_proactive,
respond,
with_object_context,
handle_mcp_servers,
@ -377,3 +378,207 @@ async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
await handle_mcp_servers(context, action="list", url=None)
assert "Error: boom" in str(exc.value)
# ============================================================================
# Tests for handle_proactive
# ============================================================================
@pytest.mark.asyncio
async def test_handle_proactive_show_disabled(db_session, interaction, guild):
"""Test showing proactive settings when disabled."""
server = DiscordServer(id=guild.id, name="Guild", proactive_cron=None)
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context)
assert "disabled" in response.content.lower()
@pytest.mark.asyncio
async def test_handle_proactive_show_enabled(db_session, interaction, guild):
"""Test showing proactive settings when enabled."""
server = DiscordServer(
id=guild.id,
name="Guild",
proactive_cron="0 9 * * *",
proactive_prompt="Check on projects",
)
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context)
assert "0 9 * * *" in response.content
assert "Check on projects" in response.content
@pytest.mark.asyncio
async def test_handle_proactive_set_cron(db_session, interaction, guild):
"""Test setting proactive cron schedule."""
server = DiscordServer(id=guild.id, name="Guild")
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context, cron="0 9 * * *")
assert "Updated" in response.content
assert "0 9 * * *" in response.content
assert server.proactive_cron == "0 9 * * *"
@pytest.mark.asyncio
async def test_handle_proactive_set_prompt(db_session, interaction, guild):
"""Test setting proactive prompt."""
server = DiscordServer(id=guild.id, name="Guild", proactive_cron="0 9 * * *")
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context, prompt="Focus on daily standups")
assert "Updated" in response.content
assert server.proactive_prompt == "Focus on daily standups"
@pytest.mark.asyncio
async def test_handle_proactive_disable(db_session, interaction, guild):
"""Test disabling proactive check-ins."""
server = DiscordServer(
id=guild.id,
name="Guild",
proactive_cron="0 9 * * *",
proactive_prompt="Some prompt",
)
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context, cron="off")
assert "disabled" in response.content.lower()
assert server.proactive_cron is None
@pytest.mark.asyncio
async def test_handle_proactive_invalid_cron(db_session, interaction, guild):
"""Test error on invalid cron expression."""
server = DiscordServer(id=guild.id, name="Guild")
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
with pytest.raises(CommandError) as exc:
await handle_proactive(context, cron="not a valid cron")
assert "Invalid cron expression" in str(exc.value)
@pytest.mark.asyncio
async def test_handle_proactive_user_scope(db_session, interaction, discord_user):
"""Test proactive settings for user scope."""
user_model = DiscordUser(
id=discord_user.id, username="testuser", proactive_cron=None
)
db_session.add(user_model)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="me",
target=user_model,
display_name="you (**testuser**)",
)
response = await handle_proactive(context, cron="0 9,17 * * 1-5")
assert "Updated" in response.content
assert user_model.proactive_cron == "0 9,17 * * 1-5"
@pytest.mark.asyncio
async def test_handle_proactive_channel_scope(
db_session, interaction, guild, text_channel
):
"""Test proactive settings for channel scope."""
server = DiscordServer(id=guild.id, name="Guild")
db_session.add(server)
db_session.flush()
channel_model = DiscordChannel(
id=text_channel.id,
name="general",
channel_type="text",
server_id=guild.id,
)
db_session.add(channel_model)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="channel",
target=channel_model,
display_name="channel **#general**",
)
response = await handle_proactive(context, cron="0 12 * * *")
assert "Updated" in response.content
assert channel_model.proactive_cron == "0 12 * * *"

View File

@ -0,0 +1,536 @@
"""Tests for proactive check-in tasks."""
import pytest
from datetime import datetime, timezone, timedelta
from unittest.mock import Mock, patch, MagicMock
from memory.common.db.models import (
DiscordBotUser,
DiscordUser,
DiscordChannel,
DiscordServer,
)
from memory.workers.tasks import proactive
from memory.workers.tasks.proactive import is_cron_due
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def bot_user(db_session):
"""Create a bot user for testing."""
bot_discord_user = DiscordUser(
id=999999999,
username="testbot",
)
db_session.add(bot_discord_user)
db_session.flush()
user = DiscordBotUser.create_with_api_key(
discord_users=[bot_discord_user],
name="testbot",
email="bot@example.com",
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def target_user(db_session):
"""Create a target Discord user for testing."""
discord_user = DiscordUser(
id=123456789,
username="targetuser",
proactive_cron="0 9 * * *", # 9am daily
chattiness_threshold=50,
)
db_session.add(discord_user)
db_session.commit()
return discord_user
@pytest.fixture
def target_user_no_cron(db_session):
"""Create a target Discord user without proactive cron."""
discord_user = DiscordUser(
id=123456790,
username="nocronuser",
proactive_cron=None,
)
db_session.add(discord_user)
db_session.commit()
return discord_user
@pytest.fixture
def target_server(db_session):
"""Create a target Discord server for testing."""
server = DiscordServer(
id=987654321,
name="Test Server",
proactive_cron="0 */4 * * *", # Every 4 hours
chattiness_threshold=30,
)
db_session.add(server)
db_session.commit()
return server
@pytest.fixture
def target_channel(db_session, target_server):
"""Create a target Discord channel for testing."""
channel = DiscordChannel(
id=111222333,
name="test-channel",
channel_type="text",
server_id=target_server.id,
proactive_cron="0 12 * * 1-5", # Noon on weekdays
chattiness_threshold=70,
)
db_session.add(channel)
db_session.commit()
return channel
# ============================================================================
# Tests for is_cron_due helper
# ============================================================================
@pytest.mark.parametrize(
"cron_expr,now,last_run,expected",
[
# Cron is due when never run before and time matches
(
"0 9 * * *",
datetime(2025, 12, 24, 9, 0, 30, tzinfo=timezone.utc),
None,
True,
),
# Cron is due when last run was before the scheduled time
(
"0 9 * * *",
datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc),
datetime(2025, 12, 23, 9, 0, 0, tzinfo=timezone.utc),
True,
),
# Cron is NOT due when already run this period
(
"0 9 * * *",
datetime(2025, 12, 24, 9, 30, 0, tzinfo=timezone.utc),
datetime(2025, 12, 24, 9, 5, 0, tzinfo=timezone.utc),
False,
),
# Cron is NOT due when current time is before scheduled time
(
"0 9 * * *",
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
None,
False,
),
# Hourly cron schedule
(
"0 * * * *",
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
datetime(2025, 12, 24, 11, 0, 0, tzinfo=timezone.utc),
True,
),
# Every 4 hours cron schedule
(
"0 */4 * * *",
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
True,
),
],
ids=[
"due_never_run",
"due_last_run_before_schedule",
"not_due_already_run",
"not_due_too_early",
"due_hourly",
"due_every_4_hours",
],
)
def test_is_cron_due(cron_expr, now, last_run, expected):
"""Test is_cron_due with various scenarios."""
assert is_cron_due(cron_expr, last_run, now) is expected
def test_is_cron_due_invalid_expression():
"""Test invalid cron expression returns False."""
now = datetime(2025, 12, 24, 9, 0, 0, tzinfo=timezone.utc)
assert is_cron_due("invalid cron", None, now) is False
def test_is_cron_due_with_naive_last_run():
"""Test cron handles naive datetime for last_run."""
now = datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc)
cron_expr = "0 9 * * *"
last_run = datetime(2025, 12, 23, 9, 0, 0) # Naive datetime
assert is_cron_due(cron_expr, last_run, now) is True
# ============================================================================
# Tests for evaluate_proactive_checkins task
# ============================================================================
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
@patch("memory.workers.tasks.proactive.is_cron_due")
@patch("memory.workers.tasks.proactive.make_session")
def test_evaluate_proactive_checkins_dispatches_due(
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
):
"""Test that due check-ins are dispatched."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
mock_is_cron_due.return_value = True
result = proactive.evaluate_proactive_checkins()
assert result["count"] >= 1
mock_execute.delay.assert_called()
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
@patch("memory.workers.tasks.proactive.is_cron_due")
@patch("memory.workers.tasks.proactive.make_session")
def test_evaluate_proactive_checkins_skips_not_due(
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
):
"""Test that not-due check-ins are not dispatched."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
mock_is_cron_due.return_value = False
result = proactive.evaluate_proactive_checkins()
assert result["count"] == 0
mock_execute.delay.assert_not_called()
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
@patch("memory.workers.tasks.proactive.make_session")
def test_evaluate_proactive_checkins_skips_no_cron(
mock_make_session, mock_execute, db_session, target_user_no_cron
):
"""Test that entities without proactive_cron are skipped."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
result = proactive.evaluate_proactive_checkins()
for call in mock_execute.delay.call_args_list:
entity_type, entity_id = call[0]
assert entity_id != target_user_no_cron.id
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
@patch("memory.workers.tasks.proactive.is_cron_due")
@patch("memory.workers.tasks.proactive.make_session")
def test_evaluate_proactive_checkins_multiple_entity_types(
mock_make_session,
mock_is_cron_due,
mock_execute,
db_session,
target_user,
target_server,
target_channel,
):
"""Test that check-ins are dispatched for users, channels, and servers."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
mock_is_cron_due.return_value = True
result = proactive.evaluate_proactive_checkins()
assert result["count"] == 3
dispatched_types = {d["type"] for d in result["dispatched"]}
assert "user" in dispatched_types
assert "channel" in dispatched_types
assert "server" in dispatched_types
# ============================================================================
# Tests for execute_proactive_checkin task
# ============================================================================
@patch("memory.workers.tasks.proactive.send_discord_response")
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_sends_when_above_threshold(
mock_make_session,
mock_get_bot,
mock_call_llm,
mock_send,
db_session,
target_user,
bot_user,
):
"""Test check-in is sent when interest exceeds threshold."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.side_effect = [
"<response><number>80</number><reason>Should check in</reason></response>",
"Hey! Just checking in - how are things going?",
]
mock_send.return_value = True
result = proactive.execute_proactive_checkin("user", target_user.id)
assert result["status"] == "sent"
assert result["interest"] == 80
mock_send.assert_called_once()
db_session.refresh(target_user)
assert target_user.last_proactive_at is not None
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_skips_below_threshold(
mock_make_session,
mock_get_bot,
mock_call_llm,
db_session,
target_user,
bot_user,
):
"""Test check-in is skipped when interest is below threshold."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.return_value = (
"<response><number>30</number><reason>Not much to say</reason></response>"
)
result = proactive.execute_proactive_checkin("user", target_user.id)
assert result["status"] == "below_threshold"
assert result["interest"] == 30
assert result["threshold"] == 50
@pytest.mark.parametrize(
"llm_response,expected_status",
[
(None, "no_eval_response"),
("I'm not sure what to say.", "no_score_in_response"),
],
ids=["no_response", "malformed_response"],
)
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_handles_bad_llm_response(
mock_make_session,
mock_get_bot,
mock_call_llm,
llm_response,
expected_status,
db_session,
target_user,
bot_user,
):
"""Test handling of missing or malformed LLM responses."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.return_value = llm_response
result = proactive.execute_proactive_checkin("user", target_user.id)
assert result["status"] == expected_status
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_nonexistent_entity(mock_make_session, db_session):
"""Test handling when entity doesn't exist."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
result = proactive.execute_proactive_checkin("user", 999999)
assert "error" in result
assert "not found" in result["error"]
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_no_bot_user(
mock_make_session, mock_get_bot, db_session, target_user
):
"""Test handling when no bot user is found."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
mock_get_bot.return_value = None
result = proactive.execute_proactive_checkin("user", target_user.id)
assert "error" in result
assert "No bot user" in result["error"]
@patch("memory.workers.tasks.proactive.send_discord_response")
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_uses_proactive_prompt(
mock_make_session,
mock_get_bot,
mock_call_llm,
mock_send,
db_session,
bot_user,
):
"""Test that proactive_prompt is included in the evaluation."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
user_with_prompt = DiscordUser(
id=555666777,
username="promptuser",
proactive_cron="0 9 * * *",
proactive_prompt="Focus on their coding projects",
chattiness_threshold=50,
)
db_session.add(user_with_prompt)
db_session.commit()
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.side_effect = [
"<response><number>80</number><reason>Check on projects</reason></response>",
"How are your coding projects coming along?",
]
mock_send.return_value = True
result = proactive.execute_proactive_checkin("user", user_with_prompt.id)
assert result["status"] == "sent"
call_args = mock_call_llm.call_args_list[0]
messages_arg = call_args.kwargs.get("messages") or call_args[1].get("messages")
assert any("Focus on their coding projects" in str(m) for m in messages_arg)
@patch("memory.workers.tasks.proactive.send_discord_response")
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_channel(
mock_make_session,
mock_get_bot,
mock_call_llm,
mock_send,
db_session,
target_channel,
bot_user,
):
"""Test check-in to a channel."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.side_effect = [
"<response><number>50</number><reason>Check channel</reason></response>",
"Good morning everyone!",
]
mock_send.return_value = True
result = proactive.execute_proactive_checkin("channel", target_channel.id)
assert result["status"] == "sent"
assert result["entity_type"] == "channel"
send_call = mock_send.call_args
assert send_call.kwargs.get("channel_id") == target_channel.id
@patch("memory.workers.tasks.proactive.send_discord_response")
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_updates_last_proactive_at(
mock_make_session,
mock_get_bot,
mock_call_llm,
mock_send,
db_session,
target_user,
bot_user,
):
"""Test that last_proactive_at is updated after successful check-in."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.side_effect = [
"<response><number>80</number><reason>Check in</reason></response>",
"Hey there!",
]
mock_send.return_value = True
before_time = datetime.now(timezone.utc)
proactive.execute_proactive_checkin("user", target_user.id)
after_time = datetime.now(timezone.utc)
db_session.refresh(target_user)
assert target_user.last_proactive_at is not None
assert before_time <= target_user.last_proactive_at <= after_time
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_updates_last_proactive_at_on_skip(
mock_make_session,
mock_get_bot,
mock_call_llm,
db_session,
target_user,
bot_user,
):
"""Test that last_proactive_at is updated even when check-in is skipped."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.return_value = (
"<response><number>10</number><reason>Nothing to say</reason></response>"
)
proactive.execute_proactive_checkin("user", target_user.id)
db_session.refresh(target_user)
assert target_user.last_proactive_at is not None