diff --git a/db/migrations/versions/20251224_160000_add_proactive_checkins.py b/db/migrations/versions/20251224_160000_add_proactive_checkins.py
new file mode 100644
index 0000000..e8a8a5d
--- /dev/null
+++ b/db/migrations/versions/20251224_160000_add_proactive_checkins.py
@@ -0,0 +1,45 @@
+"""Add proactive check-in fields to Discord entities
+
+Revision ID: e1f2a3b4c5d6
+Revises: d0e1f2a3b4c5
+Create Date: 2025-12-24 16:00:00.000000
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "e1f2a3b4c5d6"
+down_revision: Union[str, None] = "d0e1f2a3b4c5"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # Add proactive fields to all MessageProcessor tables
+ for table in ["discord_servers", "discord_channels", "discord_users"]:
+ op.add_column(
+ table,
+ sa.Column("proactive_cron", sa.Text(), nullable=True),
+ )
+ op.add_column(
+ table,
+ sa.Column("proactive_prompt", sa.Text(), nullable=True),
+ )
+ op.add_column(
+ table,
+ sa.Column(
+ "last_proactive_at", sa.DateTime(timezone=True), nullable=True
+ ),
+ )
+
+
+def downgrade() -> None:
+ for table in ["discord_servers", "discord_channels", "discord_users"]:
+ op.drop_column(table, "last_proactive_at")
+ op.drop_column(table, "proactive_prompt")
+ op.drop_column(table, "proactive_cron")
diff --git a/requirements/requirements-common.txt b/requirements/requirements-common.txt
index 9269445..2cbcd4e 100644
--- a/requirements/requirements-common.txt
+++ b/requirements/requirements-common.txt
@@ -10,5 +10,6 @@ openai==2.3.0
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
httpx>=0.28.1
celery[redis,sqs]==5.3.6
+croniter==2.0.1
cryptography==43.0.0
bcrypt==4.1.2
\ No newline at end of file
diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt
index 445b8ba..2a89bfb 100644
--- a/requirements/requirements-dev.txt
+++ b/requirements/requirements-dev.txt
@@ -1,7 +1,9 @@
pytest==7.4.4
pytest-cov==4.1.0
+pytest-asyncio==0.23.0
black==23.12.1
mypy==1.8.0
-isort==5.13.2
+isort==5.13.2
testcontainers[qdrant]==4.10.0
-click==8.1.7
\ No newline at end of file
+click==8.1.7
+croniter==2.0.1
\ No newline at end of file
diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py
index 4ec849e..46d9622 100644
--- a/src/memory/common/celery_app.py
+++ b/src/memory/common/celery_app.py
@@ -17,6 +17,7 @@ DISCORD_ROOT = "memory.workers.tasks.discord"
BACKUP_ROOT = "memory.workers.tasks.backup"
GITHUB_ROOT = "memory.workers.tasks.github"
PEOPLE_ROOT = "memory.workers.tasks.people"
+PROACTIVE_ROOT = "memory.workers.tasks.proactive"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
@@ -73,6 +74,10 @@ SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person"
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file"
+# Proactive check-in tasks
+EVALUATE_PROACTIVE_CHECKINS = f"{PROACTIVE_ROOT}.evaluate_proactive_checkins"
+EXECUTE_PROACTIVE_CHECKIN = f"{PROACTIVE_ROOT}.execute_proactive_checkin"
+
def get_broker_url() -> str:
protocol = settings.CELERY_BROKER_TYPE
@@ -130,12 +135,17 @@ app.conf.update(
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
+ f"{PROACTIVE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
},
beat_schedule={
"sync-github-repos-hourly": {
"task": SYNC_ALL_GITHUB_REPOS,
"schedule": crontab(minute=0), # Every hour at :00
},
+ "evaluate-proactive-checkins": {
+ "task": EVALUATE_PROACTIVE_CHECKINS,
+ "schedule": crontab(), # Every minute
+ },
},
)
diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py
index 0b5c568..da8d70a 100644
--- a/src/memory/common/db/models/discord.py
+++ b/src/memory/common/db/models/discord.py
@@ -63,13 +63,29 @@ class MessageProcessor:
doc=textwrap.dedent(
"""
A summary of this processor, made by and for AI systems.
-
+
The idea here is that AI systems can use this summary to keep notes on the given processor.
- These should automatically be injected into the context of the messages that are processed by this processor.
+ These should automatically be injected into the context of the messages that are processed by this processor.
"""
),
)
+ proactive_cron = Column(
+ Text,
+ nullable=True,
+ doc="Cron schedule for proactive check-ins (e.g., '0 9 * * *' for 9am daily). None = disabled.",
+ )
+ proactive_prompt = Column(
+ Text,
+ nullable=True,
+ doc="Custom instructions for proactive check-ins.",
+ )
+ last_proactive_at = Column(
+ DateTime(timezone=True),
+ nullable=True,
+ doc="When the last proactive check-in was sent.",
+ )
+
@property
def entity_type(self) -> str:
return self.__class__.__tablename__[8:-1] # type: ignore
diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py
index 07a4306..d157cf0 100644
--- a/src/memory/common/settings.py
+++ b/src/memory/common/settings.py
@@ -132,6 +132,7 @@ CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
+PROACTIVE_CHECKIN_INTERVAL = int(os.getenv("PROACTIVE_CHECKIN_INTERVAL", 60))
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py
index 0c432bf..668ad81 100644
--- a/src/memory/discord/commands.py
+++ b/src/memory/discord/commands.py
@@ -167,6 +167,25 @@ def _create_scope_group(
url=url and url.strip(),
)
+ # Proactive command
+ @group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
+ @discord.app_commands.describe(
+ cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
+ prompt="Optional custom instructions for check-ins",
+ )
+ async def proactive_cmd(
+ interaction: discord.Interaction,
+ cron: str | None = None,
+ prompt: str | None = None,
+ ):
+ await _run_interaction_command(
+ interaction,
+ scope=scope,
+ handler=handle_proactive,
+ cron=cron and cron.strip(),
+ prompt=prompt,
+ )
+
return group
@@ -265,6 +284,28 @@ def _create_user_scope_group(
url=url and url.strip(),
)
+ # Proactive command
+ @group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
+ @discord.app_commands.describe(
+ user="Target user",
+ cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
+ prompt="Optional custom instructions for check-ins",
+ )
+ async def proactive_cmd(
+ interaction: discord.Interaction,
+ user: discord.User,
+ cron: str | None = None,
+ prompt: str | None = None,
+ ):
+ await _run_interaction_command(
+ interaction,
+ scope=scope,
+ handler=handle_proactive,
+ target_user=user,
+ cron=cron and cron.strip(),
+ prompt=prompt,
+ )
+
return group
@@ -663,3 +704,68 @@ async def handle_mcp_servers(
except Exception as exc:
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
raise CommandError(f"Error: {exc}") from exc
+
+
+async def handle_proactive(
+ context: CommandContext,
+ *,
+ cron: str | None = None,
+ prompt: str | None = None,
+) -> CommandResponse:
+ """Handle proactive check-in configuration."""
+ from croniter import croniter
+
+ model = context.target
+
+ # If no arguments, show current settings
+ if cron is None and prompt is None:
+ current_cron = getattr(model, "proactive_cron", None)
+ current_prompt = getattr(model, "proactive_prompt", None)
+
+ if not current_cron:
+ return CommandResponse(
+ content=f"Proactive check-ins are disabled for {context.display_name}."
+ )
+
+ lines = [f"Proactive check-ins for {context.display_name}:"]
+ lines.append(f" Schedule: `{current_cron}`")
+ if current_prompt:
+ lines.append(f" Prompt: {current_prompt}")
+ return CommandResponse(content="\n".join(lines))
+
+ # Handle cron setting
+ if cron is not None:
+ if cron.lower() == "off":
+ setattr(model, "proactive_cron", None)
+ return CommandResponse(
+ content=f"Proactive check-ins disabled for {context.display_name}."
+ )
+
+ # Validate cron expression
+ try:
+ croniter(cron)
+ except (ValueError, KeyError) as e:
+ raise CommandError(
+ f"Invalid cron expression: {cron}\n"
+ "Examples:\n"
+ " `0 9 * * *` - 9am daily\n"
+ " `0 9,17 * * 1-5` - 9am and 5pm weekdays\n"
+ " `0 */4 * * *` - every 4 hours"
+ ) from e
+
+ setattr(model, "proactive_cron", cron)
+
+ # Handle prompt setting
+ if prompt is not None:
+ setattr(model, "proactive_prompt", prompt or None)
+
+ # Build response
+ current_cron = getattr(model, "proactive_cron", None)
+ current_prompt = getattr(model, "proactive_prompt", None)
+
+ lines = [f"Updated proactive settings for {context.display_name}:"]
+ lines.append(f" Schedule: `{current_cron}`")
+ if current_prompt:
+ lines.append(f" Prompt: {current_prompt}")
+
+ return CommandResponse(content="\n".join(lines))
diff --git a/src/memory/workers/ingest.py b/src/memory/workers/ingest.py
index 6b0b4e8..b18810c 100644
--- a/src/memory/workers/ingest.py
+++ b/src/memory/workers/ingest.py
@@ -11,6 +11,7 @@ from memory.common.celery_app import (
SYNC_LESSWRONG,
RUN_SCHEDULED_CALLS,
BACKUP_ALL,
+ EVALUATE_PROACTIVE_CHECKINS,
)
logger = logging.getLogger(__name__)
@@ -53,4 +54,8 @@ app.conf.beat_schedule = {
"task": BACKUP_ALL,
"schedule": settings.S3_BACKUP_INTERVAL,
},
+ "evaluate-proactive-checkins": {
+ "task": EVALUATE_PROACTIVE_CHECKINS,
+ "schedule": settings.PROACTIVE_CHECKIN_INTERVAL,
+ },
}
diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py
index a6ae73c..f7c0a97 100644
--- a/src/memory/workers/tasks/__init__.py
+++ b/src/memory/workers/tasks/__init__.py
@@ -15,6 +15,7 @@ from memory.workers.tasks import (
notes,
observations,
people,
+ proactive,
scheduled_calls,
) # noqa
@@ -31,5 +32,6 @@ __all__ = [
"notes",
"observations",
"people",
+ "proactive",
"scheduled_calls",
]
diff --git a/src/memory/workers/tasks/proactive.py b/src/memory/workers/tasks/proactive.py
new file mode 100644
index 0000000..689cfc9
--- /dev/null
+++ b/src/memory/workers/tasks/proactive.py
@@ -0,0 +1,341 @@
+"""
+Celery tasks for proactive Discord check-ins.
+"""
+
+import logging
+import re
+import textwrap
+from datetime import datetime, timezone
+from typing import Any, Literal, cast
+
+from croniter import croniter
+from sqlalchemy import or_
+
+from memory.common import settings
+from memory.common.celery_app import app
+from memory.common.db.connection import make_session
+from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
+from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
+from memory.workers.tasks.content_processing import safe_task_execution
+
+logger = logging.getLogger(__name__)
+
+EVALUATE_PROACTIVE_CHECKINS = "memory.workers.tasks.proactive.evaluate_proactive_checkins"
+EXECUTE_PROACTIVE_CHECKIN = "memory.workers.tasks.proactive.execute_proactive_checkin"
+
+EntityType = Literal["user", "channel", "server"]
+
+
+def is_cron_due(cron_expr: str, last_run: datetime | None, now: datetime) -> bool:
+ """Check if a cron expression is due to run now.
+
+ Uses croniter to determine if the current time falls within the cron's schedule
+ and enough time has passed since the last run.
+ """
+ try:
+ cron = croniter(cron_expr, now)
+ # Get the previous scheduled time from now
+ prev_run = cron.get_prev(datetime)
+ # Get the one before that to determine the interval
+ cron.get_prev(datetime)
+ prev_prev_run = cron.get_current(datetime)
+
+ # If we haven't run since the last scheduled time, we should run
+ if last_run is None:
+ # Never run before - check if current time is within a minute of prev_run
+ time_since_scheduled = (now - prev_run).total_seconds()
+ return time_since_scheduled < 120 # Within 2 minutes of scheduled time
+
+ # Make sure last_run is timezone aware
+ if last_run.tzinfo is None:
+ last_run = last_run.replace(tzinfo=timezone.utc)
+
+ # We should run if last_run is before the previous scheduled time
+ return last_run < prev_run
+ except Exception as e:
+ logger.warning(f"Invalid cron expression '{cron_expr}': {e}")
+ return False
+
+
+def get_bot_for_entity(
+ session, entity_type: EntityType, entity_id: int
+) -> DiscordUser | None:
+ """Get the bot user associated with an entity."""
+ from memory.common.db.models import DiscordBotUser, DiscordMessage
+
+ from sqlalchemy.orm import joinedload
+
+ # For servers, find a bot that has sent messages in that server
+ if entity_type == "server":
+ # Find bots that have interacted with this server
+ bot_users = (
+ session.query(DiscordUser)
+ .options(joinedload(DiscordUser.system_user))
+ .join(DiscordMessage, DiscordMessage.from_id == DiscordUser.id)
+ .filter(
+ DiscordMessage.server_id == entity_id,
+ DiscordUser.system_user_id.isnot(None),
+ )
+ .distinct()
+ .all()
+ )
+ # Find one that's actually a bot
+ for user in bot_users:
+ if user.system_user and user.system_user.user_type == "discord_bot":
+ return user
+
+ # For channels, check the server the channel belongs to
+ if entity_type == "channel":
+ channel = session.get(DiscordChannel, entity_id)
+ if channel and channel.server_id:
+ return get_bot_for_entity(session, "server", channel.server_id)
+
+ # Fallback: use first available bot
+ bot = (
+ session.query(DiscordBotUser)
+ .options(joinedload(DiscordBotUser.discord_users).joinedload(DiscordUser.system_user))
+ .first()
+ )
+ if bot and bot.discord_users:
+ return bot.discord_users[0]
+ return None
+
+
+def get_target_user_for_entity(
+ session, entity_type: EntityType, entity_id: int
+) -> DiscordUser | None:
+ """Get the target user for sending a proactive message."""
+ if entity_type == "user":
+ return session.get(DiscordUser, entity_id)
+ # For channels and servers, we don't have a specific target user
+ return None
+
+
+def get_channel_for_entity(
+ session, entity_type: EntityType, entity_id: int
+) -> DiscordChannel | None:
+ """Get the channel for sending a proactive message."""
+ if entity_type == "channel":
+ return session.get(DiscordChannel, entity_id)
+ if entity_type == "server":
+ # For servers, find the first text channel (prefer "general")
+ channels = (
+ session.query(DiscordChannel)
+ .filter(
+ DiscordChannel.server_id == entity_id,
+ DiscordChannel.channel_type == "text",
+ )
+ .all()
+ )
+ if not channels:
+ return None
+ # Prefer a channel named "general" if it exists
+ for channel in channels:
+ if channel.name and "general" in channel.name.lower():
+ return channel
+ return channels[0]
+ # For users, we use DMs (no channel)
+ return None
+
+
+@app.task(name=EVALUATE_PROACTIVE_CHECKINS)
+@safe_task_execution
+def evaluate_proactive_checkins() -> dict[str, Any]:
+ """
+ Evaluate which entities need proactive check-ins.
+
+ This task runs every minute and checks all entities with proactive_cron set
+ to see if they're due for a check-in.
+ """
+ now = datetime.now(timezone.utc)
+ dispatched = []
+
+ with make_session() as session:
+ # Query all entities with proactive_cron set
+ for model, entity_type in [
+ (DiscordUser, "user"),
+ (DiscordChannel, "channel"),
+ (DiscordServer, "server"),
+ ]:
+ entities = (
+ session.query(model)
+ .filter(model.proactive_cron.isnot(None))
+ .all()
+ )
+
+ for entity in entities:
+ cron_expr = cast(str, entity.proactive_cron)
+ last_run = entity.last_proactive_at
+
+ if is_cron_due(cron_expr, last_run, now):
+ logger.info(
+ f"Proactive check-in due for {entity_type} {entity.id}"
+ )
+ execute_proactive_checkin.delay(entity_type, entity.id)
+ dispatched.append({"type": entity_type, "id": entity.id})
+
+ return {
+ "evaluated_at": now.isoformat(),
+ "dispatched": dispatched,
+ "count": len(dispatched),
+ }
+
+
+@app.task(name=EXECUTE_PROACTIVE_CHECKIN)
+@safe_task_execution
+def execute_proactive_checkin(entity_type: EntityType, entity_id: int) -> dict[str, Any]:
+ """
+ Execute a proactive check-in for a specific entity.
+
+ This evaluates whether the bot should reach out and, if so, generates
+ and sends a check-in message.
+ """
+ logger.info(f"Executing proactive check-in for {entity_type} {entity_id}")
+
+ with make_session() as session:
+ # Get the entity
+ model_class = {
+ "user": DiscordUser,
+ "channel": DiscordChannel,
+ "server": DiscordServer,
+ }[entity_type]
+
+ entity = session.get(model_class, entity_id)
+ if not entity:
+ return {"error": f"{entity_type} {entity_id} not found"}
+
+ # Get the bot user
+ bot_user = get_bot_for_entity(session, entity_type, entity_id)
+ if not bot_user:
+ return {"error": "No bot user found"}
+
+ # Get target user and channel
+ target_user = get_target_user_for_entity(session, entity_type, entity_id)
+ channel = get_channel_for_entity(session, entity_type, entity_id)
+
+ if not target_user and not channel:
+ return {"error": "No target user or channel for proactive check-in"}
+
+ # Get chattiness threshold
+ chattiness = entity.chattiness_threshold or 90
+
+ # Build the evaluation prompt
+ proactive_prompt = entity.proactive_prompt or ""
+ eval_prompt = textwrap.dedent("""
+ You are considering whether to proactively reach out to check in.
+
+ {proactive_prompt}
+
+ Based on your notes and the context of previous conversations:
+ 1. Is there anything worth checking in about?
+ 2. Has enough happened or enough time passed to warrant a check-in?
+ 3. Would reaching out now be welcome or intrusive?
+
+ Please return a number between 0 and 100 indicating how strongly you want to check in
+ (0 = definitely not, 100 = definitely yes).
+
+
+ 50
+ Your reasoning here
+
+ """).format(proactive_prompt=proactive_prompt)
+
+ # Build context
+ system_prompt = comm_channel_prompt(
+ session, bot_user, target_user, channel
+ )
+
+ # First, evaluate whether we should check in
+ eval_response = call_llm(
+ session,
+ bot_user=bot_user,
+ from_user=target_user,
+ channel=channel,
+ model=settings.SUMMARIZER_MODEL,
+ system_prompt=system_prompt,
+ messages=[eval_prompt],
+ allowed_tools=[
+ "update_channel_summary",
+ "update_user_summary",
+ "update_server_summary",
+ ],
+ )
+
+ if not eval_response:
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+ return {"status": "no_eval_response", "entity_type": entity_type, "entity_id": entity_id}
+
+ # Parse the interest score
+ match = re.search(r"(\d+)", eval_response)
+ if not match:
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+ return {"status": "no_score_in_response", "entity_type": entity_type, "entity_id": entity_id}
+
+ interest_score = int(match.group(1))
+ threshold = 100 - chattiness
+
+ logger.info(
+ f"Proactive check-in eval: interest={interest_score}, threshold={threshold}, chattiness={chattiness}"
+ )
+
+ if interest_score < threshold:
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+ return {
+ "status": "below_threshold",
+ "interest": interest_score,
+ "threshold": threshold,
+ "entity_type": entity_type,
+ "entity_id": entity_id,
+ }
+
+ # Generate the actual check-in message
+ checkin_prompt = textwrap.dedent("""
+ You've decided to proactively check in. Generate a natural, friendly check-in message.
+
+ {proactive_prompt}
+
+ Keep it brief and genuine. Don't be overly formal or robotic.
+ Reference specific things from your notes if relevant.
+ """).format(proactive_prompt=proactive_prompt)
+
+ response = call_llm(
+ session,
+ bot_user=bot_user,
+ from_user=target_user,
+ channel=channel,
+ model=settings.DISCORD_MODEL,
+ system_prompt=system_prompt,
+ messages=[checkin_prompt],
+ )
+
+ if not response:
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+ return {"status": "no_message_generated", "entity_type": entity_type, "entity_id": entity_id}
+
+ # Send the message
+ bot_id = bot_user.system_user.id if bot_user.system_user else None
+ if not bot_id:
+ return {"error": "No system user for bot"}
+
+ success = send_discord_response(
+ bot_id=bot_id,
+ response=response,
+ channel_id=channel.id if channel else None,
+ user_identifier=target_user.username if target_user else None,
+ )
+
+ # Update last_proactive_at
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+
+ return {
+ "status": "sent" if success else "send_failed",
+ "interest": interest_score,
+ "entity_type": entity_type,
+ "entity_id": entity_id,
+ "response_preview": response[:100] + "..." if len(response) > 100 else response,
+ }
diff --git a/tests/memory/discord_tests/test_commands.py b/tests/memory/discord_tests/test_commands.py
index fd2bf9d..9d05195 100644
--- a/tests/memory/discord_tests/test_commands.py
+++ b/tests/memory/discord_tests/test_commands.py
@@ -15,6 +15,7 @@ from memory.discord.commands import (
handle_chattiness,
handle_ignore,
handle_summary,
+ handle_proactive,
respond,
with_object_context,
handle_mcp_servers,
@@ -377,3 +378,207 @@ async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
await handle_mcp_servers(context, action="list", url=None)
assert "Error: boom" in str(exc.value)
+
+
+# ============================================================================
+# Tests for handle_proactive
+# ============================================================================
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_show_disabled(db_session, interaction, guild):
+ """Test showing proactive settings when disabled."""
+ server = DiscordServer(id=guild.id, name="Guild", proactive_cron=None)
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context)
+
+ assert "disabled" in response.content.lower()
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_show_enabled(db_session, interaction, guild):
+ """Test showing proactive settings when enabled."""
+ server = DiscordServer(
+ id=guild.id,
+ name="Guild",
+ proactive_cron="0 9 * * *",
+ proactive_prompt="Check on projects",
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context)
+
+ assert "0 9 * * *" in response.content
+ assert "Check on projects" in response.content
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_set_cron(db_session, interaction, guild):
+ """Test setting proactive cron schedule."""
+ server = DiscordServer(id=guild.id, name="Guild")
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context, cron="0 9 * * *")
+
+ assert "Updated" in response.content
+ assert "0 9 * * *" in response.content
+ assert server.proactive_cron == "0 9 * * *"
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_set_prompt(db_session, interaction, guild):
+ """Test setting proactive prompt."""
+ server = DiscordServer(id=guild.id, name="Guild", proactive_cron="0 9 * * *")
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context, prompt="Focus on daily standups")
+
+ assert "Updated" in response.content
+ assert server.proactive_prompt == "Focus on daily standups"
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_disable(db_session, interaction, guild):
+ """Test disabling proactive check-ins."""
+ server = DiscordServer(
+ id=guild.id,
+ name="Guild",
+ proactive_cron="0 9 * * *",
+ proactive_prompt="Some prompt",
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context, cron="off")
+
+ assert "disabled" in response.content.lower()
+ assert server.proactive_cron is None
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_invalid_cron(db_session, interaction, guild):
+ """Test error on invalid cron expression."""
+ server = DiscordServer(id=guild.id, name="Guild")
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ with pytest.raises(CommandError) as exc:
+ await handle_proactive(context, cron="not a valid cron")
+
+ assert "Invalid cron expression" in str(exc.value)
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_user_scope(db_session, interaction, discord_user):
+ """Test proactive settings for user scope."""
+ user_model = DiscordUser(
+ id=discord_user.id, username="testuser", proactive_cron=None
+ )
+ db_session.add(user_model)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="me",
+ target=user_model,
+ display_name="you (**testuser**)",
+ )
+
+ response = await handle_proactive(context, cron="0 9,17 * * 1-5")
+
+ assert "Updated" in response.content
+ assert user_model.proactive_cron == "0 9,17 * * 1-5"
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_channel_scope(
+ db_session, interaction, guild, text_channel
+):
+ """Test proactive settings for channel scope."""
+ server = DiscordServer(id=guild.id, name="Guild")
+ db_session.add(server)
+ db_session.flush()
+
+ channel_model = DiscordChannel(
+ id=text_channel.id,
+ name="general",
+ channel_type="text",
+ server_id=guild.id,
+ )
+ db_session.add(channel_model)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="channel",
+ target=channel_model,
+ display_name="channel **#general**",
+ )
+
+ response = await handle_proactive(context, cron="0 12 * * *")
+
+ assert "Updated" in response.content
+ assert channel_model.proactive_cron == "0 12 * * *"
diff --git a/tests/memory/workers/tasks/test_proactive.py b/tests/memory/workers/tasks/test_proactive.py
new file mode 100644
index 0000000..3edcd38
--- /dev/null
+++ b/tests/memory/workers/tasks/test_proactive.py
@@ -0,0 +1,536 @@
+"""Tests for proactive check-in tasks."""
+
+import pytest
+from datetime import datetime, timezone, timedelta
+from unittest.mock import Mock, patch, MagicMock
+
+from memory.common.db.models import (
+ DiscordBotUser,
+ DiscordUser,
+ DiscordChannel,
+ DiscordServer,
+)
+from memory.workers.tasks import proactive
+from memory.workers.tasks.proactive import is_cron_due
+
+
+# ============================================================================
+# Fixtures
+# ============================================================================
+
+
+@pytest.fixture
+def bot_user(db_session):
+ """Create a bot user for testing."""
+ bot_discord_user = DiscordUser(
+ id=999999999,
+ username="testbot",
+ )
+ db_session.add(bot_discord_user)
+ db_session.flush()
+
+ user = DiscordBotUser.create_with_api_key(
+ discord_users=[bot_discord_user],
+ name="testbot",
+ email="bot@example.com",
+ )
+ db_session.add(user)
+ db_session.commit()
+ return user
+
+
+@pytest.fixture
+def target_user(db_session):
+ """Create a target Discord user for testing."""
+ discord_user = DiscordUser(
+ id=123456789,
+ username="targetuser",
+ proactive_cron="0 9 * * *", # 9am daily
+ chattiness_threshold=50,
+ )
+ db_session.add(discord_user)
+ db_session.commit()
+ return discord_user
+
+
+@pytest.fixture
+def target_user_no_cron(db_session):
+ """Create a target Discord user without proactive cron."""
+ discord_user = DiscordUser(
+ id=123456790,
+ username="nocronuser",
+ proactive_cron=None,
+ )
+ db_session.add(discord_user)
+ db_session.commit()
+ return discord_user
+
+
+@pytest.fixture
+def target_server(db_session):
+ """Create a target Discord server for testing."""
+ server = DiscordServer(
+ id=987654321,
+ name="Test Server",
+ proactive_cron="0 */4 * * *", # Every 4 hours
+ chattiness_threshold=30,
+ )
+ db_session.add(server)
+ db_session.commit()
+ return server
+
+
+@pytest.fixture
+def target_channel(db_session, target_server):
+ """Create a target Discord channel for testing."""
+ channel = DiscordChannel(
+ id=111222333,
+ name="test-channel",
+ channel_type="text",
+ server_id=target_server.id,
+ proactive_cron="0 12 * * 1-5", # Noon on weekdays
+ chattiness_threshold=70,
+ )
+ db_session.add(channel)
+ db_session.commit()
+ return channel
+
+
+# ============================================================================
+# Tests for is_cron_due helper
+# ============================================================================
+
+
+@pytest.mark.parametrize(
+ "cron_expr,now,last_run,expected",
+ [
+ # Cron is due when never run before and time matches
+ (
+ "0 9 * * *",
+ datetime(2025, 12, 24, 9, 0, 30, tzinfo=timezone.utc),
+ None,
+ True,
+ ),
+ # Cron is due when last run was before the scheduled time
+ (
+ "0 9 * * *",
+ datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc),
+ datetime(2025, 12, 23, 9, 0, 0, tzinfo=timezone.utc),
+ True,
+ ),
+ # Cron is NOT due when already run this period
+ (
+ "0 9 * * *",
+ datetime(2025, 12, 24, 9, 30, 0, tzinfo=timezone.utc),
+ datetime(2025, 12, 24, 9, 5, 0, tzinfo=timezone.utc),
+ False,
+ ),
+ # Cron is NOT due when current time is before scheduled time
+ (
+ "0 9 * * *",
+ datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
+ None,
+ False,
+ ),
+ # Hourly cron schedule
+ (
+ "0 * * * *",
+ datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
+ datetime(2025, 12, 24, 11, 0, 0, tzinfo=timezone.utc),
+ True,
+ ),
+ # Every 4 hours cron schedule
+ (
+ "0 */4 * * *",
+ datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
+ datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
+ True,
+ ),
+ ],
+ ids=[
+ "due_never_run",
+ "due_last_run_before_schedule",
+ "not_due_already_run",
+ "not_due_too_early",
+ "due_hourly",
+ "due_every_4_hours",
+ ],
+)
+def test_is_cron_due(cron_expr, now, last_run, expected):
+ """Test is_cron_due with various scenarios."""
+ assert is_cron_due(cron_expr, last_run, now) is expected
+
+
+def test_is_cron_due_invalid_expression():
+ """Test invalid cron expression returns False."""
+ now = datetime(2025, 12, 24, 9, 0, 0, tzinfo=timezone.utc)
+ assert is_cron_due("invalid cron", None, now) is False
+
+
+def test_is_cron_due_with_naive_last_run():
+ """Test cron handles naive datetime for last_run."""
+ now = datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc)
+ cron_expr = "0 9 * * *"
+ last_run = datetime(2025, 12, 23, 9, 0, 0) # Naive datetime
+ assert is_cron_due(cron_expr, last_run, now) is True
+
+
+# ============================================================================
+# Tests for evaluate_proactive_checkins task
+# ============================================================================
+
+
+@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
+@patch("memory.workers.tasks.proactive.is_cron_due")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_evaluate_proactive_checkins_dispatches_due(
+ mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
+):
+ """Test that due check-ins are dispatched."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+ mock_is_cron_due.return_value = True
+
+ result = proactive.evaluate_proactive_checkins()
+
+ assert result["count"] >= 1
+ mock_execute.delay.assert_called()
+
+
+@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
+@patch("memory.workers.tasks.proactive.is_cron_due")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_evaluate_proactive_checkins_skips_not_due(
+ mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
+):
+ """Test that not-due check-ins are not dispatched."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+ mock_is_cron_due.return_value = False
+
+ result = proactive.evaluate_proactive_checkins()
+
+ assert result["count"] == 0
+ mock_execute.delay.assert_not_called()
+
+
+@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_evaluate_proactive_checkins_skips_no_cron(
+ mock_make_session, mock_execute, db_session, target_user_no_cron
+):
+ """Test that entities without proactive_cron are skipped."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ result = proactive.evaluate_proactive_checkins()
+
+ for call in mock_execute.delay.call_args_list:
+ entity_type, entity_id = call[0]
+ assert entity_id != target_user_no_cron.id
+
+
+@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
+@patch("memory.workers.tasks.proactive.is_cron_due")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_evaluate_proactive_checkins_multiple_entity_types(
+ mock_make_session,
+ mock_is_cron_due,
+ mock_execute,
+ db_session,
+ target_user,
+ target_server,
+ target_channel,
+):
+ """Test that check-ins are dispatched for users, channels, and servers."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+ mock_is_cron_due.return_value = True
+
+ result = proactive.evaluate_proactive_checkins()
+
+ assert result["count"] == 3
+ dispatched_types = {d["type"] for d in result["dispatched"]}
+ assert "user" in dispatched_types
+ assert "channel" in dispatched_types
+ assert "server" in dispatched_types
+
+
+# ============================================================================
+# Tests for execute_proactive_checkin task
+# ============================================================================
+
+
+@patch("memory.workers.tasks.proactive.send_discord_response")
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_sends_when_above_threshold(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ mock_send,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test check-in is sent when interest exceeds threshold."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.side_effect = [
+ "80Should check in",
+ "Hey! Just checking in - how are things going?",
+ ]
+ mock_send.return_value = True
+
+ result = proactive.execute_proactive_checkin("user", target_user.id)
+
+ assert result["status"] == "sent"
+ assert result["interest"] == 80
+ mock_send.assert_called_once()
+
+ db_session.refresh(target_user)
+ assert target_user.last_proactive_at is not None
+
+
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_skips_below_threshold(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test check-in is skipped when interest is below threshold."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.return_value = (
+ "30Not much to say"
+ )
+
+ result = proactive.execute_proactive_checkin("user", target_user.id)
+
+ assert result["status"] == "below_threshold"
+ assert result["interest"] == 30
+ assert result["threshold"] == 50
+
+
+@pytest.mark.parametrize(
+ "llm_response,expected_status",
+ [
+ (None, "no_eval_response"),
+ ("I'm not sure what to say.", "no_score_in_response"),
+ ],
+ ids=["no_response", "malformed_response"],
+)
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_handles_bad_llm_response(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ llm_response,
+ expected_status,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test handling of missing or malformed LLM responses."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+ mock_call_llm.return_value = llm_response
+
+ result = proactive.execute_proactive_checkin("user", target_user.id)
+
+ assert result["status"] == expected_status
+
+
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_nonexistent_entity(mock_make_session, db_session):
+ """Test handling when entity doesn't exist."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ result = proactive.execute_proactive_checkin("user", 999999)
+
+ assert "error" in result
+ assert "not found" in result["error"]
+
+
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_no_bot_user(
+ mock_make_session, mock_get_bot, db_session, target_user
+):
+ """Test handling when no bot user is found."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+ mock_get_bot.return_value = None
+
+ result = proactive.execute_proactive_checkin("user", target_user.id)
+
+ assert "error" in result
+ assert "No bot user" in result["error"]
+
+
+@patch("memory.workers.tasks.proactive.send_discord_response")
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_uses_proactive_prompt(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ mock_send,
+ db_session,
+ bot_user,
+):
+ """Test that proactive_prompt is included in the evaluation."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ user_with_prompt = DiscordUser(
+ id=555666777,
+ username="promptuser",
+ proactive_cron="0 9 * * *",
+ proactive_prompt="Focus on their coding projects",
+ chattiness_threshold=50,
+ )
+ db_session.add(user_with_prompt)
+ db_session.commit()
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.side_effect = [
+ "80Check on projects",
+ "How are your coding projects coming along?",
+ ]
+ mock_send.return_value = True
+
+ result = proactive.execute_proactive_checkin("user", user_with_prompt.id)
+
+ assert result["status"] == "sent"
+ call_args = mock_call_llm.call_args_list[0]
+ messages_arg = call_args.kwargs.get("messages") or call_args[1].get("messages")
+ assert any("Focus on their coding projects" in str(m) for m in messages_arg)
+
+
+@patch("memory.workers.tasks.proactive.send_discord_response")
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_channel(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ mock_send,
+ db_session,
+ target_channel,
+ bot_user,
+):
+ """Test check-in to a channel."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.side_effect = [
+ "50Check channel",
+ "Good morning everyone!",
+ ]
+ mock_send.return_value = True
+
+ result = proactive.execute_proactive_checkin("channel", target_channel.id)
+
+ assert result["status"] == "sent"
+ assert result["entity_type"] == "channel"
+
+ send_call = mock_send.call_args
+ assert send_call.kwargs.get("channel_id") == target_channel.id
+
+
+@patch("memory.workers.tasks.proactive.send_discord_response")
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_updates_last_proactive_at(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ mock_send,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test that last_proactive_at is updated after successful check-in."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.side_effect = [
+ "80Check in",
+ "Hey there!",
+ ]
+ mock_send.return_value = True
+
+ before_time = datetime.now(timezone.utc)
+ proactive.execute_proactive_checkin("user", target_user.id)
+ after_time = datetime.now(timezone.utc)
+
+ db_session.refresh(target_user)
+ assert target_user.last_proactive_at is not None
+ assert before_time <= target_user.last_proactive_at <= after_time
+
+
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_updates_last_proactive_at_on_skip(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test that last_proactive_at is updated even when check-in is skipped."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.return_value = (
+ "10Nothing to say"
+ )
+
+ proactive.execute_proactive_checkin("user", target_user.id)
+
+ db_session.refresh(target_user)
+ assert target_user.last_proactive_at is not None