diff --git a/db/migrations/versions/20251224_160000_add_proactive_checkins.py b/db/migrations/versions/20251224_160000_add_proactive_checkins.py new file mode 100644 index 0000000..e8a8a5d --- /dev/null +++ b/db/migrations/versions/20251224_160000_add_proactive_checkins.py @@ -0,0 +1,45 @@ +"""Add proactive check-in fields to Discord entities + +Revision ID: e1f2a3b4c5d6 +Revises: d0e1f2a3b4c5 +Create Date: 2025-12-24 16:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "e1f2a3b4c5d6" +down_revision: Union[str, None] = "d0e1f2a3b4c5" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add proactive fields to all MessageProcessor tables + for table in ["discord_servers", "discord_channels", "discord_users"]: + op.add_column( + table, + sa.Column("proactive_cron", sa.Text(), nullable=True), + ) + op.add_column( + table, + sa.Column("proactive_prompt", sa.Text(), nullable=True), + ) + op.add_column( + table, + sa.Column( + "last_proactive_at", sa.DateTime(timezone=True), nullable=True + ), + ) + + +def downgrade() -> None: + for table in ["discord_servers", "discord_channels", "discord_users"]: + op.drop_column(table, "last_proactive_at") + op.drop_column(table, "proactive_prompt") + op.drop_column(table, "proactive_cron") diff --git a/requirements/requirements-common.txt b/requirements/requirements-common.txt index 9269445..2cbcd4e 100644 --- a/requirements/requirements-common.txt +++ b/requirements/requirements-common.txt @@ -10,5 +10,6 @@ openai==2.3.0 # Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1) httpx>=0.28.1 celery[redis,sqs]==5.3.6 +croniter==2.0.1 cryptography==43.0.0 bcrypt==4.1.2 \ No newline at end of file diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 445b8ba..2a89bfb 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,7 +1,9 @@ pytest==7.4.4 pytest-cov==4.1.0 +pytest-asyncio==0.23.0 black==23.12.1 mypy==1.8.0 -isort==5.13.2 +isort==5.13.2 testcontainers[qdrant]==4.10.0 -click==8.1.7 \ No newline at end of file +click==8.1.7 +croniter==2.0.1 \ No newline at end of file diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index 4ec849e..46d9622 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -17,6 +17,7 @@ DISCORD_ROOT = "memory.workers.tasks.discord" BACKUP_ROOT = "memory.workers.tasks.backup" GITHUB_ROOT = "memory.workers.tasks.github" PEOPLE_ROOT = "memory.workers.tasks.people" +PROACTIVE_ROOT = "memory.workers.tasks.proactive" ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message" EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message" PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message" @@ -73,6 +74,10 @@ SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person" UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person" SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file" +# Proactive check-in tasks +EVALUATE_PROACTIVE_CHECKINS = f"{PROACTIVE_ROOT}.evaluate_proactive_checkins" +EXECUTE_PROACTIVE_CHECKIN = f"{PROACTIVE_ROOT}.execute_proactive_checkin" + def get_broker_url() -> str: protocol = settings.CELERY_BROKER_TYPE @@ -130,12 +135,17 @@ app.conf.update( f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"}, f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"}, f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"}, + f"{PROACTIVE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"}, }, beat_schedule={ "sync-github-repos-hourly": { "task": SYNC_ALL_GITHUB_REPOS, "schedule": crontab(minute=0), # Every hour at :00 }, + "evaluate-proactive-checkins": { + "task": EVALUATE_PROACTIVE_CHECKINS, + "schedule": crontab(), # Every minute + }, }, ) diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py index 0b5c568..da8d70a 100644 --- a/src/memory/common/db/models/discord.py +++ b/src/memory/common/db/models/discord.py @@ -63,13 +63,29 @@ class MessageProcessor: doc=textwrap.dedent( """ A summary of this processor, made by and for AI systems. - + The idea here is that AI systems can use this summary to keep notes on the given processor. - These should automatically be injected into the context of the messages that are processed by this processor. + These should automatically be injected into the context of the messages that are processed by this processor. """ ), ) + proactive_cron = Column( + Text, + nullable=True, + doc="Cron schedule for proactive check-ins (e.g., '0 9 * * *' for 9am daily). None = disabled.", + ) + proactive_prompt = Column( + Text, + nullable=True, + doc="Custom instructions for proactive check-ins.", + ) + last_proactive_at = Column( + DateTime(timezone=True), + nullable=True, + doc="When the last proactive check-in was sent.", + ) + @property def entity_type(self) -> str: return self.__class__.__tablename__[8:-1] # type: ignore diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 07a4306..d157cf0 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -132,6 +132,7 @@ CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60)) NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60)) LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24)) SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60)) +PROACTIVE_CHECKIN_INTERVAL = int(os.getenv("PROACTIVE_CHECKIN_INTERVAL", 60)) CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24)) diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py index 0c432bf..668ad81 100644 --- a/src/memory/discord/commands.py +++ b/src/memory/discord/commands.py @@ -167,6 +167,25 @@ def _create_scope_group( url=url and url.strip(), ) + # Proactive command + @group.command(name="proactive", description=f"Configure {name}'s proactive check-ins") + @discord.app_commands.describe( + cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable", + prompt="Optional custom instructions for check-ins", + ) + async def proactive_cmd( + interaction: discord.Interaction, + cron: str | None = None, + prompt: str | None = None, + ): + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_proactive, + cron=cron and cron.strip(), + prompt=prompt, + ) + return group @@ -265,6 +284,28 @@ def _create_user_scope_group( url=url and url.strip(), ) + # Proactive command + @group.command(name="proactive", description=f"Configure {name}'s proactive check-ins") + @discord.app_commands.describe( + user="Target user", + cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable", + prompt="Optional custom instructions for check-ins", + ) + async def proactive_cmd( + interaction: discord.Interaction, + user: discord.User, + cron: str | None = None, + prompt: str | None = None, + ): + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_proactive, + target_user=user, + cron=cron and cron.strip(), + prompt=prompt, + ) + return group @@ -663,3 +704,68 @@ async def handle_mcp_servers( except Exception as exc: logger.error(f"Error running MCP server command: {exc}", exc_info=True) raise CommandError(f"Error: {exc}") from exc + + +async def handle_proactive( + context: CommandContext, + *, + cron: str | None = None, + prompt: str | None = None, +) -> CommandResponse: + """Handle proactive check-in configuration.""" + from croniter import croniter + + model = context.target + + # If no arguments, show current settings + if cron is None and prompt is None: + current_cron = getattr(model, "proactive_cron", None) + current_prompt = getattr(model, "proactive_prompt", None) + + if not current_cron: + return CommandResponse( + content=f"Proactive check-ins are disabled for {context.display_name}." + ) + + lines = [f"Proactive check-ins for {context.display_name}:"] + lines.append(f" Schedule: `{current_cron}`") + if current_prompt: + lines.append(f" Prompt: {current_prompt}") + return CommandResponse(content="\n".join(lines)) + + # Handle cron setting + if cron is not None: + if cron.lower() == "off": + setattr(model, "proactive_cron", None) + return CommandResponse( + content=f"Proactive check-ins disabled for {context.display_name}." + ) + + # Validate cron expression + try: + croniter(cron) + except (ValueError, KeyError) as e: + raise CommandError( + f"Invalid cron expression: {cron}\n" + "Examples:\n" + " `0 9 * * *` - 9am daily\n" + " `0 9,17 * * 1-5` - 9am and 5pm weekdays\n" + " `0 */4 * * *` - every 4 hours" + ) from e + + setattr(model, "proactive_cron", cron) + + # Handle prompt setting + if prompt is not None: + setattr(model, "proactive_prompt", prompt or None) + + # Build response + current_cron = getattr(model, "proactive_cron", None) + current_prompt = getattr(model, "proactive_prompt", None) + + lines = [f"Updated proactive settings for {context.display_name}:"] + lines.append(f" Schedule: `{current_cron}`") + if current_prompt: + lines.append(f" Prompt: {current_prompt}") + + return CommandResponse(content="\n".join(lines)) diff --git a/src/memory/workers/ingest.py b/src/memory/workers/ingest.py index 6b0b4e8..b18810c 100644 --- a/src/memory/workers/ingest.py +++ b/src/memory/workers/ingest.py @@ -11,6 +11,7 @@ from memory.common.celery_app import ( SYNC_LESSWRONG, RUN_SCHEDULED_CALLS, BACKUP_ALL, + EVALUATE_PROACTIVE_CHECKINS, ) logger = logging.getLogger(__name__) @@ -53,4 +54,8 @@ app.conf.beat_schedule = { "task": BACKUP_ALL, "schedule": settings.S3_BACKUP_INTERVAL, }, + "evaluate-proactive-checkins": { + "task": EVALUATE_PROACTIVE_CHECKINS, + "schedule": settings.PROACTIVE_CHECKIN_INTERVAL, + }, } diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py index a6ae73c..f7c0a97 100644 --- a/src/memory/workers/tasks/__init__.py +++ b/src/memory/workers/tasks/__init__.py @@ -15,6 +15,7 @@ from memory.workers.tasks import ( notes, observations, people, + proactive, scheduled_calls, ) # noqa @@ -31,5 +32,6 @@ __all__ = [ "notes", "observations", "people", + "proactive", "scheduled_calls", ] diff --git a/src/memory/workers/tasks/proactive.py b/src/memory/workers/tasks/proactive.py new file mode 100644 index 0000000..689cfc9 --- /dev/null +++ b/src/memory/workers/tasks/proactive.py @@ -0,0 +1,341 @@ +""" +Celery tasks for proactive Discord check-ins. +""" + +import logging +import re +import textwrap +from datetime import datetime, timezone +from typing import Any, Literal, cast + +from croniter import croniter +from sqlalchemy import or_ + +from memory.common import settings +from memory.common.celery_app import app +from memory.common.db.connection import make_session +from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser +from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response +from memory.workers.tasks.content_processing import safe_task_execution + +logger = logging.getLogger(__name__) + +EVALUATE_PROACTIVE_CHECKINS = "memory.workers.tasks.proactive.evaluate_proactive_checkins" +EXECUTE_PROACTIVE_CHECKIN = "memory.workers.tasks.proactive.execute_proactive_checkin" + +EntityType = Literal["user", "channel", "server"] + + +def is_cron_due(cron_expr: str, last_run: datetime | None, now: datetime) -> bool: + """Check if a cron expression is due to run now. + + Uses croniter to determine if the current time falls within the cron's schedule + and enough time has passed since the last run. + """ + try: + cron = croniter(cron_expr, now) + # Get the previous scheduled time from now + prev_run = cron.get_prev(datetime) + # Get the one before that to determine the interval + cron.get_prev(datetime) + prev_prev_run = cron.get_current(datetime) + + # If we haven't run since the last scheduled time, we should run + if last_run is None: + # Never run before - check if current time is within a minute of prev_run + time_since_scheduled = (now - prev_run).total_seconds() + return time_since_scheduled < 120 # Within 2 minutes of scheduled time + + # Make sure last_run is timezone aware + if last_run.tzinfo is None: + last_run = last_run.replace(tzinfo=timezone.utc) + + # We should run if last_run is before the previous scheduled time + return last_run < prev_run + except Exception as e: + logger.warning(f"Invalid cron expression '{cron_expr}': {e}") + return False + + +def get_bot_for_entity( + session, entity_type: EntityType, entity_id: int +) -> DiscordUser | None: + """Get the bot user associated with an entity.""" + from memory.common.db.models import DiscordBotUser, DiscordMessage + + from sqlalchemy.orm import joinedload + + # For servers, find a bot that has sent messages in that server + if entity_type == "server": + # Find bots that have interacted with this server + bot_users = ( + session.query(DiscordUser) + .options(joinedload(DiscordUser.system_user)) + .join(DiscordMessage, DiscordMessage.from_id == DiscordUser.id) + .filter( + DiscordMessage.server_id == entity_id, + DiscordUser.system_user_id.isnot(None), + ) + .distinct() + .all() + ) + # Find one that's actually a bot + for user in bot_users: + if user.system_user and user.system_user.user_type == "discord_bot": + return user + + # For channels, check the server the channel belongs to + if entity_type == "channel": + channel = session.get(DiscordChannel, entity_id) + if channel and channel.server_id: + return get_bot_for_entity(session, "server", channel.server_id) + + # Fallback: use first available bot + bot = ( + session.query(DiscordBotUser) + .options(joinedload(DiscordBotUser.discord_users).joinedload(DiscordUser.system_user)) + .first() + ) + if bot and bot.discord_users: + return bot.discord_users[0] + return None + + +def get_target_user_for_entity( + session, entity_type: EntityType, entity_id: int +) -> DiscordUser | None: + """Get the target user for sending a proactive message.""" + if entity_type == "user": + return session.get(DiscordUser, entity_id) + # For channels and servers, we don't have a specific target user + return None + + +def get_channel_for_entity( + session, entity_type: EntityType, entity_id: int +) -> DiscordChannel | None: + """Get the channel for sending a proactive message.""" + if entity_type == "channel": + return session.get(DiscordChannel, entity_id) + if entity_type == "server": + # For servers, find the first text channel (prefer "general") + channels = ( + session.query(DiscordChannel) + .filter( + DiscordChannel.server_id == entity_id, + DiscordChannel.channel_type == "text", + ) + .all() + ) + if not channels: + return None + # Prefer a channel named "general" if it exists + for channel in channels: + if channel.name and "general" in channel.name.lower(): + return channel + return channels[0] + # For users, we use DMs (no channel) + return None + + +@app.task(name=EVALUATE_PROACTIVE_CHECKINS) +@safe_task_execution +def evaluate_proactive_checkins() -> dict[str, Any]: + """ + Evaluate which entities need proactive check-ins. + + This task runs every minute and checks all entities with proactive_cron set + to see if they're due for a check-in. + """ + now = datetime.now(timezone.utc) + dispatched = [] + + with make_session() as session: + # Query all entities with proactive_cron set + for model, entity_type in [ + (DiscordUser, "user"), + (DiscordChannel, "channel"), + (DiscordServer, "server"), + ]: + entities = ( + session.query(model) + .filter(model.proactive_cron.isnot(None)) + .all() + ) + + for entity in entities: + cron_expr = cast(str, entity.proactive_cron) + last_run = entity.last_proactive_at + + if is_cron_due(cron_expr, last_run, now): + logger.info( + f"Proactive check-in due for {entity_type} {entity.id}" + ) + execute_proactive_checkin.delay(entity_type, entity.id) + dispatched.append({"type": entity_type, "id": entity.id}) + + return { + "evaluated_at": now.isoformat(), + "dispatched": dispatched, + "count": len(dispatched), + } + + +@app.task(name=EXECUTE_PROACTIVE_CHECKIN) +@safe_task_execution +def execute_proactive_checkin(entity_type: EntityType, entity_id: int) -> dict[str, Any]: + """ + Execute a proactive check-in for a specific entity. + + This evaluates whether the bot should reach out and, if so, generates + and sends a check-in message. + """ + logger.info(f"Executing proactive check-in for {entity_type} {entity_id}") + + with make_session() as session: + # Get the entity + model_class = { + "user": DiscordUser, + "channel": DiscordChannel, + "server": DiscordServer, + }[entity_type] + + entity = session.get(model_class, entity_id) + if not entity: + return {"error": f"{entity_type} {entity_id} not found"} + + # Get the bot user + bot_user = get_bot_for_entity(session, entity_type, entity_id) + if not bot_user: + return {"error": "No bot user found"} + + # Get target user and channel + target_user = get_target_user_for_entity(session, entity_type, entity_id) + channel = get_channel_for_entity(session, entity_type, entity_id) + + if not target_user and not channel: + return {"error": "No target user or channel for proactive check-in"} + + # Get chattiness threshold + chattiness = entity.chattiness_threshold or 90 + + # Build the evaluation prompt + proactive_prompt = entity.proactive_prompt or "" + eval_prompt = textwrap.dedent(""" + You are considering whether to proactively reach out to check in. + + {proactive_prompt} + + Based on your notes and the context of previous conversations: + 1. Is there anything worth checking in about? + 2. Has enough happened or enough time passed to warrant a check-in? + 3. Would reaching out now be welcome or intrusive? + + Please return a number between 0 and 100 indicating how strongly you want to check in + (0 = definitely not, 100 = definitely yes). + + + 50 + Your reasoning here + + """).format(proactive_prompt=proactive_prompt) + + # Build context + system_prompt = comm_channel_prompt( + session, bot_user, target_user, channel + ) + + # First, evaluate whether we should check in + eval_response = call_llm( + session, + bot_user=bot_user, + from_user=target_user, + channel=channel, + model=settings.SUMMARIZER_MODEL, + system_prompt=system_prompt, + messages=[eval_prompt], + allowed_tools=[ + "update_channel_summary", + "update_user_summary", + "update_server_summary", + ], + ) + + if not eval_response: + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + return {"status": "no_eval_response", "entity_type": entity_type, "entity_id": entity_id} + + # Parse the interest score + match = re.search(r"(\d+)", eval_response) + if not match: + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + return {"status": "no_score_in_response", "entity_type": entity_type, "entity_id": entity_id} + + interest_score = int(match.group(1)) + threshold = 100 - chattiness + + logger.info( + f"Proactive check-in eval: interest={interest_score}, threshold={threshold}, chattiness={chattiness}" + ) + + if interest_score < threshold: + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + return { + "status": "below_threshold", + "interest": interest_score, + "threshold": threshold, + "entity_type": entity_type, + "entity_id": entity_id, + } + + # Generate the actual check-in message + checkin_prompt = textwrap.dedent(""" + You've decided to proactively check in. Generate a natural, friendly check-in message. + + {proactive_prompt} + + Keep it brief and genuine. Don't be overly formal or robotic. + Reference specific things from your notes if relevant. + """).format(proactive_prompt=proactive_prompt) + + response = call_llm( + session, + bot_user=bot_user, + from_user=target_user, + channel=channel, + model=settings.DISCORD_MODEL, + system_prompt=system_prompt, + messages=[checkin_prompt], + ) + + if not response: + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + return {"status": "no_message_generated", "entity_type": entity_type, "entity_id": entity_id} + + # Send the message + bot_id = bot_user.system_user.id if bot_user.system_user else None + if not bot_id: + return {"error": "No system user for bot"} + + success = send_discord_response( + bot_id=bot_id, + response=response, + channel_id=channel.id if channel else None, + user_identifier=target_user.username if target_user else None, + ) + + # Update last_proactive_at + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + + return { + "status": "sent" if success else "send_failed", + "interest": interest_score, + "entity_type": entity_type, + "entity_id": entity_id, + "response_preview": response[:100] + "..." if len(response) > 100 else response, + } diff --git a/tests/memory/discord_tests/test_commands.py b/tests/memory/discord_tests/test_commands.py index fd2bf9d..9d05195 100644 --- a/tests/memory/discord_tests/test_commands.py +++ b/tests/memory/discord_tests/test_commands.py @@ -15,6 +15,7 @@ from memory.discord.commands import ( handle_chattiness, handle_ignore, handle_summary, + handle_proactive, respond, with_object_context, handle_mcp_servers, @@ -377,3 +378,207 @@ async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction): await handle_mcp_servers(context, action="list", url=None) assert "Error: boom" in str(exc.value) + + +# ============================================================================ +# Tests for handle_proactive +# ============================================================================ + + +@pytest.mark.asyncio +async def test_handle_proactive_show_disabled(db_session, interaction, guild): + """Test showing proactive settings when disabled.""" + server = DiscordServer(id=guild.id, name="Guild", proactive_cron=None) + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context) + + assert "disabled" in response.content.lower() + + +@pytest.mark.asyncio +async def test_handle_proactive_show_enabled(db_session, interaction, guild): + """Test showing proactive settings when enabled.""" + server = DiscordServer( + id=guild.id, + name="Guild", + proactive_cron="0 9 * * *", + proactive_prompt="Check on projects", + ) + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context) + + assert "0 9 * * *" in response.content + assert "Check on projects" in response.content + + +@pytest.mark.asyncio +async def test_handle_proactive_set_cron(db_session, interaction, guild): + """Test setting proactive cron schedule.""" + server = DiscordServer(id=guild.id, name="Guild") + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context, cron="0 9 * * *") + + assert "Updated" in response.content + assert "0 9 * * *" in response.content + assert server.proactive_cron == "0 9 * * *" + + +@pytest.mark.asyncio +async def test_handle_proactive_set_prompt(db_session, interaction, guild): + """Test setting proactive prompt.""" + server = DiscordServer(id=guild.id, name="Guild", proactive_cron="0 9 * * *") + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context, prompt="Focus on daily standups") + + assert "Updated" in response.content + assert server.proactive_prompt == "Focus on daily standups" + + +@pytest.mark.asyncio +async def test_handle_proactive_disable(db_session, interaction, guild): + """Test disabling proactive check-ins.""" + server = DiscordServer( + id=guild.id, + name="Guild", + proactive_cron="0 9 * * *", + proactive_prompt="Some prompt", + ) + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context, cron="off") + + assert "disabled" in response.content.lower() + assert server.proactive_cron is None + + +@pytest.mark.asyncio +async def test_handle_proactive_invalid_cron(db_session, interaction, guild): + """Test error on invalid cron expression.""" + server = DiscordServer(id=guild.id, name="Guild") + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + with pytest.raises(CommandError) as exc: + await handle_proactive(context, cron="not a valid cron") + + assert "Invalid cron expression" in str(exc.value) + + +@pytest.mark.asyncio +async def test_handle_proactive_user_scope(db_session, interaction, discord_user): + """Test proactive settings for user scope.""" + user_model = DiscordUser( + id=discord_user.id, username="testuser", proactive_cron=None + ) + db_session.add(user_model) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="me", + target=user_model, + display_name="you (**testuser**)", + ) + + response = await handle_proactive(context, cron="0 9,17 * * 1-5") + + assert "Updated" in response.content + assert user_model.proactive_cron == "0 9,17 * * 1-5" + + +@pytest.mark.asyncio +async def test_handle_proactive_channel_scope( + db_session, interaction, guild, text_channel +): + """Test proactive settings for channel scope.""" + server = DiscordServer(id=guild.id, name="Guild") + db_session.add(server) + db_session.flush() + + channel_model = DiscordChannel( + id=text_channel.id, + name="general", + channel_type="text", + server_id=guild.id, + ) + db_session.add(channel_model) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="channel", + target=channel_model, + display_name="channel **#general**", + ) + + response = await handle_proactive(context, cron="0 12 * * *") + + assert "Updated" in response.content + assert channel_model.proactive_cron == "0 12 * * *" diff --git a/tests/memory/workers/tasks/test_proactive.py b/tests/memory/workers/tasks/test_proactive.py new file mode 100644 index 0000000..3edcd38 --- /dev/null +++ b/tests/memory/workers/tasks/test_proactive.py @@ -0,0 +1,536 @@ +"""Tests for proactive check-in tasks.""" + +import pytest +from datetime import datetime, timezone, timedelta +from unittest.mock import Mock, patch, MagicMock + +from memory.common.db.models import ( + DiscordBotUser, + DiscordUser, + DiscordChannel, + DiscordServer, +) +from memory.workers.tasks import proactive +from memory.workers.tasks.proactive import is_cron_due + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def bot_user(db_session): + """Create a bot user for testing.""" + bot_discord_user = DiscordUser( + id=999999999, + username="testbot", + ) + db_session.add(bot_discord_user) + db_session.flush() + + user = DiscordBotUser.create_with_api_key( + discord_users=[bot_discord_user], + name="testbot", + email="bot@example.com", + ) + db_session.add(user) + db_session.commit() + return user + + +@pytest.fixture +def target_user(db_session): + """Create a target Discord user for testing.""" + discord_user = DiscordUser( + id=123456789, + username="targetuser", + proactive_cron="0 9 * * *", # 9am daily + chattiness_threshold=50, + ) + db_session.add(discord_user) + db_session.commit() + return discord_user + + +@pytest.fixture +def target_user_no_cron(db_session): + """Create a target Discord user without proactive cron.""" + discord_user = DiscordUser( + id=123456790, + username="nocronuser", + proactive_cron=None, + ) + db_session.add(discord_user) + db_session.commit() + return discord_user + + +@pytest.fixture +def target_server(db_session): + """Create a target Discord server for testing.""" + server = DiscordServer( + id=987654321, + name="Test Server", + proactive_cron="0 */4 * * *", # Every 4 hours + chattiness_threshold=30, + ) + db_session.add(server) + db_session.commit() + return server + + +@pytest.fixture +def target_channel(db_session, target_server): + """Create a target Discord channel for testing.""" + channel = DiscordChannel( + id=111222333, + name="test-channel", + channel_type="text", + server_id=target_server.id, + proactive_cron="0 12 * * 1-5", # Noon on weekdays + chattiness_threshold=70, + ) + db_session.add(channel) + db_session.commit() + return channel + + +# ============================================================================ +# Tests for is_cron_due helper +# ============================================================================ + + +@pytest.mark.parametrize( + "cron_expr,now,last_run,expected", + [ + # Cron is due when never run before and time matches + ( + "0 9 * * *", + datetime(2025, 12, 24, 9, 0, 30, tzinfo=timezone.utc), + None, + True, + ), + # Cron is due when last run was before the scheduled time + ( + "0 9 * * *", + datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc), + datetime(2025, 12, 23, 9, 0, 0, tzinfo=timezone.utc), + True, + ), + # Cron is NOT due when already run this period + ( + "0 9 * * *", + datetime(2025, 12, 24, 9, 30, 0, tzinfo=timezone.utc), + datetime(2025, 12, 24, 9, 5, 0, tzinfo=timezone.utc), + False, + ), + # Cron is NOT due when current time is before scheduled time + ( + "0 9 * * *", + datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc), + None, + False, + ), + # Hourly cron schedule + ( + "0 * * * *", + datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc), + datetime(2025, 12, 24, 11, 0, 0, tzinfo=timezone.utc), + True, + ), + # Every 4 hours cron schedule + ( + "0 */4 * * *", + datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc), + datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc), + True, + ), + ], + ids=[ + "due_never_run", + "due_last_run_before_schedule", + "not_due_already_run", + "not_due_too_early", + "due_hourly", + "due_every_4_hours", + ], +) +def test_is_cron_due(cron_expr, now, last_run, expected): + """Test is_cron_due with various scenarios.""" + assert is_cron_due(cron_expr, last_run, now) is expected + + +def test_is_cron_due_invalid_expression(): + """Test invalid cron expression returns False.""" + now = datetime(2025, 12, 24, 9, 0, 0, tzinfo=timezone.utc) + assert is_cron_due("invalid cron", None, now) is False + + +def test_is_cron_due_with_naive_last_run(): + """Test cron handles naive datetime for last_run.""" + now = datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc) + cron_expr = "0 9 * * *" + last_run = datetime(2025, 12, 23, 9, 0, 0) # Naive datetime + assert is_cron_due(cron_expr, last_run, now) is True + + +# ============================================================================ +# Tests for evaluate_proactive_checkins task +# ============================================================================ + + +@patch("memory.workers.tasks.proactive.execute_proactive_checkin") +@patch("memory.workers.tasks.proactive.is_cron_due") +@patch("memory.workers.tasks.proactive.make_session") +def test_evaluate_proactive_checkins_dispatches_due( + mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user +): + """Test that due check-ins are dispatched.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + mock_is_cron_due.return_value = True + + result = proactive.evaluate_proactive_checkins() + + assert result["count"] >= 1 + mock_execute.delay.assert_called() + + +@patch("memory.workers.tasks.proactive.execute_proactive_checkin") +@patch("memory.workers.tasks.proactive.is_cron_due") +@patch("memory.workers.tasks.proactive.make_session") +def test_evaluate_proactive_checkins_skips_not_due( + mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user +): + """Test that not-due check-ins are not dispatched.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + mock_is_cron_due.return_value = False + + result = proactive.evaluate_proactive_checkins() + + assert result["count"] == 0 + mock_execute.delay.assert_not_called() + + +@patch("memory.workers.tasks.proactive.execute_proactive_checkin") +@patch("memory.workers.tasks.proactive.make_session") +def test_evaluate_proactive_checkins_skips_no_cron( + mock_make_session, mock_execute, db_session, target_user_no_cron +): + """Test that entities without proactive_cron are skipped.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + result = proactive.evaluate_proactive_checkins() + + for call in mock_execute.delay.call_args_list: + entity_type, entity_id = call[0] + assert entity_id != target_user_no_cron.id + + +@patch("memory.workers.tasks.proactive.execute_proactive_checkin") +@patch("memory.workers.tasks.proactive.is_cron_due") +@patch("memory.workers.tasks.proactive.make_session") +def test_evaluate_proactive_checkins_multiple_entity_types( + mock_make_session, + mock_is_cron_due, + mock_execute, + db_session, + target_user, + target_server, + target_channel, +): + """Test that check-ins are dispatched for users, channels, and servers.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + mock_is_cron_due.return_value = True + + result = proactive.evaluate_proactive_checkins() + + assert result["count"] == 3 + dispatched_types = {d["type"] for d in result["dispatched"]} + assert "user" in dispatched_types + assert "channel" in dispatched_types + assert "server" in dispatched_types + + +# ============================================================================ +# Tests for execute_proactive_checkin task +# ============================================================================ + + +@patch("memory.workers.tasks.proactive.send_discord_response") +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_sends_when_above_threshold( + mock_make_session, + mock_get_bot, + mock_call_llm, + mock_send, + db_session, + target_user, + bot_user, +): + """Test check-in is sent when interest exceeds threshold.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.side_effect = [ + "80Should check in", + "Hey! Just checking in - how are things going?", + ] + mock_send.return_value = True + + result = proactive.execute_proactive_checkin("user", target_user.id) + + assert result["status"] == "sent" + assert result["interest"] == 80 + mock_send.assert_called_once() + + db_session.refresh(target_user) + assert target_user.last_proactive_at is not None + + +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_skips_below_threshold( + mock_make_session, + mock_get_bot, + mock_call_llm, + db_session, + target_user, + bot_user, +): + """Test check-in is skipped when interest is below threshold.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.return_value = ( + "30Not much to say" + ) + + result = proactive.execute_proactive_checkin("user", target_user.id) + + assert result["status"] == "below_threshold" + assert result["interest"] == 30 + assert result["threshold"] == 50 + + +@pytest.mark.parametrize( + "llm_response,expected_status", + [ + (None, "no_eval_response"), + ("I'm not sure what to say.", "no_score_in_response"), + ], + ids=["no_response", "malformed_response"], +) +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_handles_bad_llm_response( + mock_make_session, + mock_get_bot, + mock_call_llm, + llm_response, + expected_status, + db_session, + target_user, + bot_user, +): + """Test handling of missing or malformed LLM responses.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + mock_call_llm.return_value = llm_response + + result = proactive.execute_proactive_checkin("user", target_user.id) + + assert result["status"] == expected_status + + +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_nonexistent_entity(mock_make_session, db_session): + """Test handling when entity doesn't exist.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + result = proactive.execute_proactive_checkin("user", 999999) + + assert "error" in result + assert "not found" in result["error"] + + +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_no_bot_user( + mock_make_session, mock_get_bot, db_session, target_user +): + """Test handling when no bot user is found.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + mock_get_bot.return_value = None + + result = proactive.execute_proactive_checkin("user", target_user.id) + + assert "error" in result + assert "No bot user" in result["error"] + + +@patch("memory.workers.tasks.proactive.send_discord_response") +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_uses_proactive_prompt( + mock_make_session, + mock_get_bot, + mock_call_llm, + mock_send, + db_session, + bot_user, +): + """Test that proactive_prompt is included in the evaluation.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + user_with_prompt = DiscordUser( + id=555666777, + username="promptuser", + proactive_cron="0 9 * * *", + proactive_prompt="Focus on their coding projects", + chattiness_threshold=50, + ) + db_session.add(user_with_prompt) + db_session.commit() + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.side_effect = [ + "80Check on projects", + "How are your coding projects coming along?", + ] + mock_send.return_value = True + + result = proactive.execute_proactive_checkin("user", user_with_prompt.id) + + assert result["status"] == "sent" + call_args = mock_call_llm.call_args_list[0] + messages_arg = call_args.kwargs.get("messages") or call_args[1].get("messages") + assert any("Focus on their coding projects" in str(m) for m in messages_arg) + + +@patch("memory.workers.tasks.proactive.send_discord_response") +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_channel( + mock_make_session, + mock_get_bot, + mock_call_llm, + mock_send, + db_session, + target_channel, + bot_user, +): + """Test check-in to a channel.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.side_effect = [ + "50Check channel", + "Good morning everyone!", + ] + mock_send.return_value = True + + result = proactive.execute_proactive_checkin("channel", target_channel.id) + + assert result["status"] == "sent" + assert result["entity_type"] == "channel" + + send_call = mock_send.call_args + assert send_call.kwargs.get("channel_id") == target_channel.id + + +@patch("memory.workers.tasks.proactive.send_discord_response") +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_updates_last_proactive_at( + mock_make_session, + mock_get_bot, + mock_call_llm, + mock_send, + db_session, + target_user, + bot_user, +): + """Test that last_proactive_at is updated after successful check-in.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.side_effect = [ + "80Check in", + "Hey there!", + ] + mock_send.return_value = True + + before_time = datetime.now(timezone.utc) + proactive.execute_proactive_checkin("user", target_user.id) + after_time = datetime.now(timezone.utc) + + db_session.refresh(target_user) + assert target_user.last_proactive_at is not None + assert before_time <= target_user.last_proactive_at <= after_time + + +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_updates_last_proactive_at_on_skip( + mock_make_session, + mock_get_bot, + mock_call_llm, + db_session, + target_user, + bot_user, +): + """Test that last_proactive_at is updated even when check-in is skipped.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.return_value = ( + "10Nothing to say" + ) + + proactive.execute_proactive_checkin("user", target_user.id) + + db_session.refresh(target_user) + assert target_user.last_proactive_at is not None