mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Add proactive check-in functionality for Discord
- Add proactive_cron, proactive_prompt, last_proactive_at fields to Discord models - Add /proactive command handler for configuring check-in schedules - Add evaluate_proactive_checkins task (runs every minute via celery beat) - Add execute_proactive_checkin task that evaluates interest and sends messages - Smart bot selection finds the correct bot for each server - Channel selection defaults to "general" text channel for servers - Add database migration for new fields - Add comprehensive tests for commands and tasks 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
9088997295
commit
a238ca6329
@ -0,0 +1,45 @@
|
|||||||
|
"""Add proactive check-in fields to Discord entities
|
||||||
|
|
||||||
|
Revision ID: e1f2a3b4c5d6
|
||||||
|
Revises: d0e1f2a3b4c5
|
||||||
|
Create Date: 2025-12-24 16:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "e1f2a3b4c5d6"
|
||||||
|
down_revision: Union[str, None] = "d0e1f2a3b4c5"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add proactive fields to all MessageProcessor tables
|
||||||
|
for table in ["discord_servers", "discord_channels", "discord_users"]:
|
||||||
|
op.add_column(
|
||||||
|
table,
|
||||||
|
sa.Column("proactive_cron", sa.Text(), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
table,
|
||||||
|
sa.Column("proactive_prompt", sa.Text(), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
table,
|
||||||
|
sa.Column(
|
||||||
|
"last_proactive_at", sa.DateTime(timezone=True), nullable=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
for table in ["discord_servers", "discord_channels", "discord_users"]:
|
||||||
|
op.drop_column(table, "last_proactive_at")
|
||||||
|
op.drop_column(table, "proactive_prompt")
|
||||||
|
op.drop_column(table, "proactive_cron")
|
||||||
@ -10,5 +10,6 @@ openai==2.3.0
|
|||||||
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
|
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
|
||||||
httpx>=0.28.1
|
httpx>=0.28.1
|
||||||
celery[redis,sqs]==5.3.6
|
celery[redis,sqs]==5.3.6
|
||||||
|
croniter==2.0.1
|
||||||
cryptography==43.0.0
|
cryptography==43.0.0
|
||||||
bcrypt==4.1.2
|
bcrypt==4.1.2
|
||||||
@ -1,7 +1,9 @@
|
|||||||
pytest==7.4.4
|
pytest==7.4.4
|
||||||
pytest-cov==4.1.0
|
pytest-cov==4.1.0
|
||||||
|
pytest-asyncio==0.23.0
|
||||||
black==23.12.1
|
black==23.12.1
|
||||||
mypy==1.8.0
|
mypy==1.8.0
|
||||||
isort==5.13.2
|
isort==5.13.2
|
||||||
testcontainers[qdrant]==4.10.0
|
testcontainers[qdrant]==4.10.0
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
|
croniter==2.0.1
|
||||||
@ -17,6 +17,7 @@ DISCORD_ROOT = "memory.workers.tasks.discord"
|
|||||||
BACKUP_ROOT = "memory.workers.tasks.backup"
|
BACKUP_ROOT = "memory.workers.tasks.backup"
|
||||||
GITHUB_ROOT = "memory.workers.tasks.github"
|
GITHUB_ROOT = "memory.workers.tasks.github"
|
||||||
PEOPLE_ROOT = "memory.workers.tasks.people"
|
PEOPLE_ROOT = "memory.workers.tasks.people"
|
||||||
|
PROACTIVE_ROOT = "memory.workers.tasks.proactive"
|
||||||
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||||
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||||
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
||||||
@ -73,6 +74,10 @@ SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person"
|
|||||||
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
|
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
|
||||||
SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file"
|
SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file"
|
||||||
|
|
||||||
|
# Proactive check-in tasks
|
||||||
|
EVALUATE_PROACTIVE_CHECKINS = f"{PROACTIVE_ROOT}.evaluate_proactive_checkins"
|
||||||
|
EXECUTE_PROACTIVE_CHECKIN = f"{PROACTIVE_ROOT}.execute_proactive_checkin"
|
||||||
|
|
||||||
|
|
||||||
def get_broker_url() -> str:
|
def get_broker_url() -> str:
|
||||||
protocol = settings.CELERY_BROKER_TYPE
|
protocol = settings.CELERY_BROKER_TYPE
|
||||||
@ -130,12 +135,17 @@ app.conf.update(
|
|||||||
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
|
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
|
||||||
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
|
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
|
||||||
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
|
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
|
||||||
|
f"{PROACTIVE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
|
||||||
},
|
},
|
||||||
beat_schedule={
|
beat_schedule={
|
||||||
"sync-github-repos-hourly": {
|
"sync-github-repos-hourly": {
|
||||||
"task": SYNC_ALL_GITHUB_REPOS,
|
"task": SYNC_ALL_GITHUB_REPOS,
|
||||||
"schedule": crontab(minute=0), # Every hour at :00
|
"schedule": crontab(minute=0), # Every hour at :00
|
||||||
},
|
},
|
||||||
|
"evaluate-proactive-checkins": {
|
||||||
|
"task": EVALUATE_PROACTIVE_CHECKINS,
|
||||||
|
"schedule": crontab(), # Every minute
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -63,13 +63,29 @@ class MessageProcessor:
|
|||||||
doc=textwrap.dedent(
|
doc=textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
A summary of this processor, made by and for AI systems.
|
A summary of this processor, made by and for AI systems.
|
||||||
|
|
||||||
The idea here is that AI systems can use this summary to keep notes on the given processor.
|
The idea here is that AI systems can use this summary to keep notes on the given processor.
|
||||||
These should automatically be injected into the context of the messages that are processed by this processor.
|
These should automatically be injected into the context of the messages that are processed by this processor.
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
proactive_cron = Column(
|
||||||
|
Text,
|
||||||
|
nullable=True,
|
||||||
|
doc="Cron schedule for proactive check-ins (e.g., '0 9 * * *' for 9am daily). None = disabled.",
|
||||||
|
)
|
||||||
|
proactive_prompt = Column(
|
||||||
|
Text,
|
||||||
|
nullable=True,
|
||||||
|
doc="Custom instructions for proactive check-ins.",
|
||||||
|
)
|
||||||
|
last_proactive_at = Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
nullable=True,
|
||||||
|
doc="When the last proactive check-in was sent.",
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entity_type(self) -> str:
|
def entity_type(self) -> str:
|
||||||
return self.__class__.__tablename__[8:-1] # type: ignore
|
return self.__class__.__tablename__[8:-1] # type: ignore
|
||||||
|
|||||||
@ -132,6 +132,7 @@ CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
|
|||||||
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
|
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
|
||||||
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
|
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
|
||||||
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
|
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
|
||||||
|
PROACTIVE_CHECKIN_INTERVAL = int(os.getenv("PROACTIVE_CHECKIN_INTERVAL", 60))
|
||||||
|
|
||||||
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
||||||
|
|
||||||
|
|||||||
@ -167,6 +167,25 @@ def _create_scope_group(
|
|||||||
url=url and url.strip(),
|
url=url and url.strip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Proactive command
|
||||||
|
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
|
||||||
|
prompt="Optional custom instructions for check-ins",
|
||||||
|
)
|
||||||
|
async def proactive_cmd(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
cron: str | None = None,
|
||||||
|
prompt: str | None = None,
|
||||||
|
):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_proactive,
|
||||||
|
cron=cron and cron.strip(),
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
|
||||||
@ -265,6 +284,28 @@ def _create_user_scope_group(
|
|||||||
url=url and url.strip(),
|
url=url and url.strip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Proactive command
|
||||||
|
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
user="Target user",
|
||||||
|
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
|
||||||
|
prompt="Optional custom instructions for check-ins",
|
||||||
|
)
|
||||||
|
async def proactive_cmd(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
user: discord.User,
|
||||||
|
cron: str | None = None,
|
||||||
|
prompt: str | None = None,
|
||||||
|
):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_proactive,
|
||||||
|
target_user=user,
|
||||||
|
cron=cron and cron.strip(),
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
|
||||||
@ -663,3 +704,68 @@ async def handle_mcp_servers(
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
|
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
|
||||||
raise CommandError(f"Error: {exc}") from exc
|
raise CommandError(f"Error: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_proactive(
|
||||||
|
context: CommandContext,
|
||||||
|
*,
|
||||||
|
cron: str | None = None,
|
||||||
|
prompt: str | None = None,
|
||||||
|
) -> CommandResponse:
|
||||||
|
"""Handle proactive check-in configuration."""
|
||||||
|
from croniter import croniter
|
||||||
|
|
||||||
|
model = context.target
|
||||||
|
|
||||||
|
# If no arguments, show current settings
|
||||||
|
if cron is None and prompt is None:
|
||||||
|
current_cron = getattr(model, "proactive_cron", None)
|
||||||
|
current_prompt = getattr(model, "proactive_prompt", None)
|
||||||
|
|
||||||
|
if not current_cron:
|
||||||
|
return CommandResponse(
|
||||||
|
content=f"Proactive check-ins are disabled for {context.display_name}."
|
||||||
|
)
|
||||||
|
|
||||||
|
lines = [f"Proactive check-ins for {context.display_name}:"]
|
||||||
|
lines.append(f" Schedule: `{current_cron}`")
|
||||||
|
if current_prompt:
|
||||||
|
lines.append(f" Prompt: {current_prompt}")
|
||||||
|
return CommandResponse(content="\n".join(lines))
|
||||||
|
|
||||||
|
# Handle cron setting
|
||||||
|
if cron is not None:
|
||||||
|
if cron.lower() == "off":
|
||||||
|
setattr(model, "proactive_cron", None)
|
||||||
|
return CommandResponse(
|
||||||
|
content=f"Proactive check-ins disabled for {context.display_name}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate cron expression
|
||||||
|
try:
|
||||||
|
croniter(cron)
|
||||||
|
except (ValueError, KeyError) as e:
|
||||||
|
raise CommandError(
|
||||||
|
f"Invalid cron expression: {cron}\n"
|
||||||
|
"Examples:\n"
|
||||||
|
" `0 9 * * *` - 9am daily\n"
|
||||||
|
" `0 9,17 * * 1-5` - 9am and 5pm weekdays\n"
|
||||||
|
" `0 */4 * * *` - every 4 hours"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
setattr(model, "proactive_cron", cron)
|
||||||
|
|
||||||
|
# Handle prompt setting
|
||||||
|
if prompt is not None:
|
||||||
|
setattr(model, "proactive_prompt", prompt or None)
|
||||||
|
|
||||||
|
# Build response
|
||||||
|
current_cron = getattr(model, "proactive_cron", None)
|
||||||
|
current_prompt = getattr(model, "proactive_prompt", None)
|
||||||
|
|
||||||
|
lines = [f"Updated proactive settings for {context.display_name}:"]
|
||||||
|
lines.append(f" Schedule: `{current_cron}`")
|
||||||
|
if current_prompt:
|
||||||
|
lines.append(f" Prompt: {current_prompt}")
|
||||||
|
|
||||||
|
return CommandResponse(content="\n".join(lines))
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from memory.common.celery_app import (
|
|||||||
SYNC_LESSWRONG,
|
SYNC_LESSWRONG,
|
||||||
RUN_SCHEDULED_CALLS,
|
RUN_SCHEDULED_CALLS,
|
||||||
BACKUP_ALL,
|
BACKUP_ALL,
|
||||||
|
EVALUATE_PROACTIVE_CHECKINS,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -53,4 +54,8 @@ app.conf.beat_schedule = {
|
|||||||
"task": BACKUP_ALL,
|
"task": BACKUP_ALL,
|
||||||
"schedule": settings.S3_BACKUP_INTERVAL,
|
"schedule": settings.S3_BACKUP_INTERVAL,
|
||||||
},
|
},
|
||||||
|
"evaluate-proactive-checkins": {
|
||||||
|
"task": EVALUATE_PROACTIVE_CHECKINS,
|
||||||
|
"schedule": settings.PROACTIVE_CHECKIN_INTERVAL,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from memory.workers.tasks import (
|
|||||||
notes,
|
notes,
|
||||||
observations,
|
observations,
|
||||||
people,
|
people,
|
||||||
|
proactive,
|
||||||
scheduled_calls,
|
scheduled_calls,
|
||||||
) # noqa
|
) # noqa
|
||||||
|
|
||||||
@ -31,5 +32,6 @@ __all__ = [
|
|||||||
"notes",
|
"notes",
|
||||||
"observations",
|
"observations",
|
||||||
"people",
|
"people",
|
||||||
|
"proactive",
|
||||||
"scheduled_calls",
|
"scheduled_calls",
|
||||||
]
|
]
|
||||||
|
|||||||
341
src/memory/workers/tasks/proactive.py
Normal file
341
src/memory/workers/tasks/proactive.py
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
"""
|
||||||
|
Celery tasks for proactive Discord check-ins.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import textwrap
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
|
from croniter import croniter
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.celery_app import app
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||||
|
from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
|
||||||
|
from memory.workers.tasks.content_processing import safe_task_execution
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EVALUATE_PROACTIVE_CHECKINS = "memory.workers.tasks.proactive.evaluate_proactive_checkins"
|
||||||
|
EXECUTE_PROACTIVE_CHECKIN = "memory.workers.tasks.proactive.execute_proactive_checkin"
|
||||||
|
|
||||||
|
EntityType = Literal["user", "channel", "server"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_cron_due(cron_expr: str, last_run: datetime | None, now: datetime) -> bool:
|
||||||
|
"""Check if a cron expression is due to run now.
|
||||||
|
|
||||||
|
Uses croniter to determine if the current time falls within the cron's schedule
|
||||||
|
and enough time has passed since the last run.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cron = croniter(cron_expr, now)
|
||||||
|
# Get the previous scheduled time from now
|
||||||
|
prev_run = cron.get_prev(datetime)
|
||||||
|
# Get the one before that to determine the interval
|
||||||
|
cron.get_prev(datetime)
|
||||||
|
prev_prev_run = cron.get_current(datetime)
|
||||||
|
|
||||||
|
# If we haven't run since the last scheduled time, we should run
|
||||||
|
if last_run is None:
|
||||||
|
# Never run before - check if current time is within a minute of prev_run
|
||||||
|
time_since_scheduled = (now - prev_run).total_seconds()
|
||||||
|
return time_since_scheduled < 120 # Within 2 minutes of scheduled time
|
||||||
|
|
||||||
|
# Make sure last_run is timezone aware
|
||||||
|
if last_run.tzinfo is None:
|
||||||
|
last_run = last_run.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
# We should run if last_run is before the previous scheduled time
|
||||||
|
return last_run < prev_run
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Invalid cron expression '{cron_expr}': {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_bot_for_entity(
|
||||||
|
session, entity_type: EntityType, entity_id: int
|
||||||
|
) -> DiscordUser | None:
|
||||||
|
"""Get the bot user associated with an entity."""
|
||||||
|
from memory.common.db.models import DiscordBotUser, DiscordMessage
|
||||||
|
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
# For servers, find a bot that has sent messages in that server
|
||||||
|
if entity_type == "server":
|
||||||
|
# Find bots that have interacted with this server
|
||||||
|
bot_users = (
|
||||||
|
session.query(DiscordUser)
|
||||||
|
.options(joinedload(DiscordUser.system_user))
|
||||||
|
.join(DiscordMessage, DiscordMessage.from_id == DiscordUser.id)
|
||||||
|
.filter(
|
||||||
|
DiscordMessage.server_id == entity_id,
|
||||||
|
DiscordUser.system_user_id.isnot(None),
|
||||||
|
)
|
||||||
|
.distinct()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
# Find one that's actually a bot
|
||||||
|
for user in bot_users:
|
||||||
|
if user.system_user and user.system_user.user_type == "discord_bot":
|
||||||
|
return user
|
||||||
|
|
||||||
|
# For channels, check the server the channel belongs to
|
||||||
|
if entity_type == "channel":
|
||||||
|
channel = session.get(DiscordChannel, entity_id)
|
||||||
|
if channel and channel.server_id:
|
||||||
|
return get_bot_for_entity(session, "server", channel.server_id)
|
||||||
|
|
||||||
|
# Fallback: use first available bot
|
||||||
|
bot = (
|
||||||
|
session.query(DiscordBotUser)
|
||||||
|
.options(joinedload(DiscordBotUser.discord_users).joinedload(DiscordUser.system_user))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if bot and bot.discord_users:
|
||||||
|
return bot.discord_users[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_target_user_for_entity(
|
||||||
|
session, entity_type: EntityType, entity_id: int
|
||||||
|
) -> DiscordUser | None:
|
||||||
|
"""Get the target user for sending a proactive message."""
|
||||||
|
if entity_type == "user":
|
||||||
|
return session.get(DiscordUser, entity_id)
|
||||||
|
# For channels and servers, we don't have a specific target user
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_channel_for_entity(
|
||||||
|
session, entity_type: EntityType, entity_id: int
|
||||||
|
) -> DiscordChannel | None:
|
||||||
|
"""Get the channel for sending a proactive message."""
|
||||||
|
if entity_type == "channel":
|
||||||
|
return session.get(DiscordChannel, entity_id)
|
||||||
|
if entity_type == "server":
|
||||||
|
# For servers, find the first text channel (prefer "general")
|
||||||
|
channels = (
|
||||||
|
session.query(DiscordChannel)
|
||||||
|
.filter(
|
||||||
|
DiscordChannel.server_id == entity_id,
|
||||||
|
DiscordChannel.channel_type == "text",
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
if not channels:
|
||||||
|
return None
|
||||||
|
# Prefer a channel named "general" if it exists
|
||||||
|
for channel in channels:
|
||||||
|
if channel.name and "general" in channel.name.lower():
|
||||||
|
return channel
|
||||||
|
return channels[0]
|
||||||
|
# For users, we use DMs (no channel)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=EVALUATE_PROACTIVE_CHECKINS)
|
||||||
|
@safe_task_execution
|
||||||
|
def evaluate_proactive_checkins() -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Evaluate which entities need proactive check-ins.
|
||||||
|
|
||||||
|
This task runs every minute and checks all entities with proactive_cron set
|
||||||
|
to see if they're due for a check-in.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dispatched = []
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Query all entities with proactive_cron set
|
||||||
|
for model, entity_type in [
|
||||||
|
(DiscordUser, "user"),
|
||||||
|
(DiscordChannel, "channel"),
|
||||||
|
(DiscordServer, "server"),
|
||||||
|
]:
|
||||||
|
entities = (
|
||||||
|
session.query(model)
|
||||||
|
.filter(model.proactive_cron.isnot(None))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
cron_expr = cast(str, entity.proactive_cron)
|
||||||
|
last_run = entity.last_proactive_at
|
||||||
|
|
||||||
|
if is_cron_due(cron_expr, last_run, now):
|
||||||
|
logger.info(
|
||||||
|
f"Proactive check-in due for {entity_type} {entity.id}"
|
||||||
|
)
|
||||||
|
execute_proactive_checkin.delay(entity_type, entity.id)
|
||||||
|
dispatched.append({"type": entity_type, "id": entity.id})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"evaluated_at": now.isoformat(),
|
||||||
|
"dispatched": dispatched,
|
||||||
|
"count": len(dispatched),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=EXECUTE_PROACTIVE_CHECKIN)
|
||||||
|
@safe_task_execution
|
||||||
|
def execute_proactive_checkin(entity_type: EntityType, entity_id: int) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute a proactive check-in for a specific entity.
|
||||||
|
|
||||||
|
This evaluates whether the bot should reach out and, if so, generates
|
||||||
|
and sends a check-in message.
|
||||||
|
"""
|
||||||
|
logger.info(f"Executing proactive check-in for {entity_type} {entity_id}")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Get the entity
|
||||||
|
model_class = {
|
||||||
|
"user": DiscordUser,
|
||||||
|
"channel": DiscordChannel,
|
||||||
|
"server": DiscordServer,
|
||||||
|
}[entity_type]
|
||||||
|
|
||||||
|
entity = session.get(model_class, entity_id)
|
||||||
|
if not entity:
|
||||||
|
return {"error": f"{entity_type} {entity_id} not found"}
|
||||||
|
|
||||||
|
# Get the bot user
|
||||||
|
bot_user = get_bot_for_entity(session, entity_type, entity_id)
|
||||||
|
if not bot_user:
|
||||||
|
return {"error": "No bot user found"}
|
||||||
|
|
||||||
|
# Get target user and channel
|
||||||
|
target_user = get_target_user_for_entity(session, entity_type, entity_id)
|
||||||
|
channel = get_channel_for_entity(session, entity_type, entity_id)
|
||||||
|
|
||||||
|
if not target_user and not channel:
|
||||||
|
return {"error": "No target user or channel for proactive check-in"}
|
||||||
|
|
||||||
|
# Get chattiness threshold
|
||||||
|
chattiness = entity.chattiness_threshold or 90
|
||||||
|
|
||||||
|
# Build the evaluation prompt
|
||||||
|
proactive_prompt = entity.proactive_prompt or ""
|
||||||
|
eval_prompt = textwrap.dedent("""
|
||||||
|
You are considering whether to proactively reach out to check in.
|
||||||
|
|
||||||
|
{proactive_prompt}
|
||||||
|
|
||||||
|
Based on your notes and the context of previous conversations:
|
||||||
|
1. Is there anything worth checking in about?
|
||||||
|
2. Has enough happened or enough time passed to warrant a check-in?
|
||||||
|
3. Would reaching out now be welcome or intrusive?
|
||||||
|
|
||||||
|
Please return a number between 0 and 100 indicating how strongly you want to check in
|
||||||
|
(0 = definitely not, 100 = definitely yes).
|
||||||
|
|
||||||
|
<response>
|
||||||
|
<number>50</number>
|
||||||
|
<reason>Your reasoning here</reason>
|
||||||
|
</response>
|
||||||
|
""").format(proactive_prompt=proactive_prompt)
|
||||||
|
|
||||||
|
# Build context
|
||||||
|
system_prompt = comm_channel_prompt(
|
||||||
|
session, bot_user, target_user, channel
|
||||||
|
)
|
||||||
|
|
||||||
|
# First, evaluate whether we should check in
|
||||||
|
eval_response = call_llm(
|
||||||
|
session,
|
||||||
|
bot_user=bot_user,
|
||||||
|
from_user=target_user,
|
||||||
|
channel=channel,
|
||||||
|
model=settings.SUMMARIZER_MODEL,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
messages=[eval_prompt],
|
||||||
|
allowed_tools=[
|
||||||
|
"update_channel_summary",
|
||||||
|
"update_user_summary",
|
||||||
|
"update_server_summary",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not eval_response:
|
||||||
|
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||||
|
session.commit()
|
||||||
|
return {"status": "no_eval_response", "entity_type": entity_type, "entity_id": entity_id}
|
||||||
|
|
||||||
|
# Parse the interest score
|
||||||
|
match = re.search(r"<number>(\d+)</number>", eval_response)
|
||||||
|
if not match:
|
||||||
|
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||||
|
session.commit()
|
||||||
|
return {"status": "no_score_in_response", "entity_type": entity_type, "entity_id": entity_id}
|
||||||
|
|
||||||
|
interest_score = int(match.group(1))
|
||||||
|
threshold = 100 - chattiness
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Proactive check-in eval: interest={interest_score}, threshold={threshold}, chattiness={chattiness}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if interest_score < threshold:
|
||||||
|
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||||
|
session.commit()
|
||||||
|
return {
|
||||||
|
"status": "below_threshold",
|
||||||
|
"interest": interest_score,
|
||||||
|
"threshold": threshold,
|
||||||
|
"entity_type": entity_type,
|
||||||
|
"entity_id": entity_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate the actual check-in message
|
||||||
|
checkin_prompt = textwrap.dedent("""
|
||||||
|
You've decided to proactively check in. Generate a natural, friendly check-in message.
|
||||||
|
|
||||||
|
{proactive_prompt}
|
||||||
|
|
||||||
|
Keep it brief and genuine. Don't be overly formal or robotic.
|
||||||
|
Reference specific things from your notes if relevant.
|
||||||
|
""").format(proactive_prompt=proactive_prompt)
|
||||||
|
|
||||||
|
response = call_llm(
|
||||||
|
session,
|
||||||
|
bot_user=bot_user,
|
||||||
|
from_user=target_user,
|
||||||
|
channel=channel,
|
||||||
|
model=settings.DISCORD_MODEL,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
messages=[checkin_prompt],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||||
|
session.commit()
|
||||||
|
return {"status": "no_message_generated", "entity_type": entity_type, "entity_id": entity_id}
|
||||||
|
|
||||||
|
# Send the message
|
||||||
|
bot_id = bot_user.system_user.id if bot_user.system_user else None
|
||||||
|
if not bot_id:
|
||||||
|
return {"error": "No system user for bot"}
|
||||||
|
|
||||||
|
success = send_discord_response(
|
||||||
|
bot_id=bot_id,
|
||||||
|
response=response,
|
||||||
|
channel_id=channel.id if channel else None,
|
||||||
|
user_identifier=target_user.username if target_user else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update last_proactive_at
|
||||||
|
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "sent" if success else "send_failed",
|
||||||
|
"interest": interest_score,
|
||||||
|
"entity_type": entity_type,
|
||||||
|
"entity_id": entity_id,
|
||||||
|
"response_preview": response[:100] + "..." if len(response) > 100 else response,
|
||||||
|
}
|
||||||
@ -15,6 +15,7 @@ from memory.discord.commands import (
|
|||||||
handle_chattiness,
|
handle_chattiness,
|
||||||
handle_ignore,
|
handle_ignore,
|
||||||
handle_summary,
|
handle_summary,
|
||||||
|
handle_proactive,
|
||||||
respond,
|
respond,
|
||||||
with_object_context,
|
with_object_context,
|
||||||
handle_mcp_servers,
|
handle_mcp_servers,
|
||||||
@ -377,3 +378,207 @@ async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
|
|||||||
await handle_mcp_servers(context, action="list", url=None)
|
await handle_mcp_servers(context, action="list", url=None)
|
||||||
|
|
||||||
assert "Error: boom" in str(exc.value)
|
assert "Error: boom" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for handle_proactive
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_proactive_show_disabled(db_session, interaction, guild):
|
||||||
|
"""Test showing proactive settings when disabled."""
|
||||||
|
server = DiscordServer(id=guild.id, name="Guild", proactive_cron=None)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="server",
|
||||||
|
target=server,
|
||||||
|
display_name="server **Guild**",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_proactive(context)
|
||||||
|
|
||||||
|
assert "disabled" in response.content.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_proactive_show_enabled(db_session, interaction, guild):
|
||||||
|
"""Test showing proactive settings when enabled."""
|
||||||
|
server = DiscordServer(
|
||||||
|
id=guild.id,
|
||||||
|
name="Guild",
|
||||||
|
proactive_cron="0 9 * * *",
|
||||||
|
proactive_prompt="Check on projects",
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="server",
|
||||||
|
target=server,
|
||||||
|
display_name="server **Guild**",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_proactive(context)
|
||||||
|
|
||||||
|
assert "0 9 * * *" in response.content
|
||||||
|
assert "Check on projects" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_proactive_set_cron(db_session, interaction, guild):
|
||||||
|
"""Test setting proactive cron schedule."""
|
||||||
|
server = DiscordServer(id=guild.id, name="Guild")
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="server",
|
||||||
|
target=server,
|
||||||
|
display_name="server **Guild**",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_proactive(context, cron="0 9 * * *")
|
||||||
|
|
||||||
|
assert "Updated" in response.content
|
||||||
|
assert "0 9 * * *" in response.content
|
||||||
|
assert server.proactive_cron == "0 9 * * *"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_proactive_set_prompt(db_session, interaction, guild):
|
||||||
|
"""Test setting proactive prompt."""
|
||||||
|
server = DiscordServer(id=guild.id, name="Guild", proactive_cron="0 9 * * *")
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="server",
|
||||||
|
target=server,
|
||||||
|
display_name="server **Guild**",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_proactive(context, prompt="Focus on daily standups")
|
||||||
|
|
||||||
|
assert "Updated" in response.content
|
||||||
|
assert server.proactive_prompt == "Focus on daily standups"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_proactive_disable(db_session, interaction, guild):
|
||||||
|
"""Test disabling proactive check-ins."""
|
||||||
|
server = DiscordServer(
|
||||||
|
id=guild.id,
|
||||||
|
name="Guild",
|
||||||
|
proactive_cron="0 9 * * *",
|
||||||
|
proactive_prompt="Some prompt",
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="server",
|
||||||
|
target=server,
|
||||||
|
display_name="server **Guild**",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_proactive(context, cron="off")
|
||||||
|
|
||||||
|
assert "disabled" in response.content.lower()
|
||||||
|
assert server.proactive_cron is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_proactive_invalid_cron(db_session, interaction, guild):
|
||||||
|
"""Test error on invalid cron expression."""
|
||||||
|
server = DiscordServer(id=guild.id, name="Guild")
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="server",
|
||||||
|
target=server,
|
||||||
|
display_name="server **Guild**",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(CommandError) as exc:
|
||||||
|
await handle_proactive(context, cron="not a valid cron")
|
||||||
|
|
||||||
|
assert "Invalid cron expression" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_proactive_user_scope(db_session, interaction, discord_user):
|
||||||
|
"""Test proactive settings for user scope."""
|
||||||
|
user_model = DiscordUser(
|
||||||
|
id=discord_user.id, username="testuser", proactive_cron=None
|
||||||
|
)
|
||||||
|
db_session.add(user_model)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="me",
|
||||||
|
target=user_model,
|
||||||
|
display_name="you (**testuser**)",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_proactive(context, cron="0 9,17 * * 1-5")
|
||||||
|
|
||||||
|
assert "Updated" in response.content
|
||||||
|
assert user_model.proactive_cron == "0 9,17 * * 1-5"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_proactive_channel_scope(
|
||||||
|
db_session, interaction, guild, text_channel
|
||||||
|
):
|
||||||
|
"""Test proactive settings for channel scope."""
|
||||||
|
server = DiscordServer(id=guild.id, name="Guild")
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
channel_model = DiscordChannel(
|
||||||
|
id=text_channel.id,
|
||||||
|
name="general",
|
||||||
|
channel_type="text",
|
||||||
|
server_id=guild.id,
|
||||||
|
)
|
||||||
|
db_session.add(channel_model)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="channel",
|
||||||
|
target=channel_model,
|
||||||
|
display_name="channel **#general**",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_proactive(context, cron="0 12 * * *")
|
||||||
|
|
||||||
|
assert "Updated" in response.content
|
||||||
|
assert channel_model.proactive_cron == "0 12 * * *"
|
||||||
|
|||||||
536
tests/memory/workers/tasks/test_proactive.py
Normal file
536
tests/memory/workers/tasks/test_proactive.py
Normal file
@ -0,0 +1,536 @@
|
|||||||
|
"""Tests for proactive check-in tasks."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
|
||||||
|
from memory.common.db.models import (
|
||||||
|
DiscordBotUser,
|
||||||
|
DiscordUser,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordServer,
|
||||||
|
)
|
||||||
|
from memory.workers.tasks import proactive
|
||||||
|
from memory.workers.tasks.proactive import is_cron_due
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bot_user(db_session):
|
||||||
|
"""Create a bot user for testing."""
|
||||||
|
bot_discord_user = DiscordUser(
|
||||||
|
id=999999999,
|
||||||
|
username="testbot",
|
||||||
|
)
|
||||||
|
db_session.add(bot_discord_user)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
user = DiscordBotUser.create_with_api_key(
|
||||||
|
discord_users=[bot_discord_user],
|
||||||
|
name="testbot",
|
||||||
|
email="bot@example.com",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def target_user(db_session):
|
||||||
|
"""Create a target Discord user for testing."""
|
||||||
|
discord_user = DiscordUser(
|
||||||
|
id=123456789,
|
||||||
|
username="targetuser",
|
||||||
|
proactive_cron="0 9 * * *", # 9am daily
|
||||||
|
chattiness_threshold=50,
|
||||||
|
)
|
||||||
|
db_session.add(discord_user)
|
||||||
|
db_session.commit()
|
||||||
|
return discord_user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def target_user_no_cron(db_session):
|
||||||
|
"""Create a target Discord user without proactive cron."""
|
||||||
|
discord_user = DiscordUser(
|
||||||
|
id=123456790,
|
||||||
|
username="nocronuser",
|
||||||
|
proactive_cron=None,
|
||||||
|
)
|
||||||
|
db_session.add(discord_user)
|
||||||
|
db_session.commit()
|
||||||
|
return discord_user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def target_server(db_session):
|
||||||
|
"""Create a target Discord server for testing."""
|
||||||
|
server = DiscordServer(
|
||||||
|
id=987654321,
|
||||||
|
name="Test Server",
|
||||||
|
proactive_cron="0 */4 * * *", # Every 4 hours
|
||||||
|
chattiness_threshold=30,
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def target_channel(db_session, target_server):
|
||||||
|
"""Create a target Discord channel for testing."""
|
||||||
|
channel = DiscordChannel(
|
||||||
|
id=111222333,
|
||||||
|
name="test-channel",
|
||||||
|
channel_type="text",
|
||||||
|
server_id=target_server.id,
|
||||||
|
proactive_cron="0 12 * * 1-5", # Noon on weekdays
|
||||||
|
chattiness_threshold=70,
|
||||||
|
)
|
||||||
|
db_session.add(channel)
|
||||||
|
db_session.commit()
|
||||||
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for is_cron_due helper
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"cron_expr,now,last_run,expected",
|
||||||
|
[
|
||||||
|
# Cron is due when never run before and time matches
|
||||||
|
(
|
||||||
|
"0 9 * * *",
|
||||||
|
datetime(2025, 12, 24, 9, 0, 30, tzinfo=timezone.utc),
|
||||||
|
None,
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
# Cron is due when last run was before the scheduled time
|
||||||
|
(
|
||||||
|
"0 9 * * *",
|
||||||
|
datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc),
|
||||||
|
datetime(2025, 12, 23, 9, 0, 0, tzinfo=timezone.utc),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
# Cron is NOT due when already run this period
|
||||||
|
(
|
||||||
|
"0 9 * * *",
|
||||||
|
datetime(2025, 12, 24, 9, 30, 0, tzinfo=timezone.utc),
|
||||||
|
datetime(2025, 12, 24, 9, 5, 0, tzinfo=timezone.utc),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
# Cron is NOT due when current time is before scheduled time
|
||||||
|
(
|
||||||
|
"0 9 * * *",
|
||||||
|
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
# Hourly cron schedule
|
||||||
|
(
|
||||||
|
"0 * * * *",
|
||||||
|
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
|
||||||
|
datetime(2025, 12, 24, 11, 0, 0, tzinfo=timezone.utc),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
# Every 4 hours cron schedule
|
||||||
|
(
|
||||||
|
"0 */4 * * *",
|
||||||
|
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
|
||||||
|
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"due_never_run",
|
||||||
|
"due_last_run_before_schedule",
|
||||||
|
"not_due_already_run",
|
||||||
|
"not_due_too_early",
|
||||||
|
"due_hourly",
|
||||||
|
"due_every_4_hours",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_cron_due(cron_expr, now, last_run, expected):
|
||||||
|
"""Test is_cron_due with various scenarios."""
|
||||||
|
assert is_cron_due(cron_expr, last_run, now) is expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_cron_due_invalid_expression():
|
||||||
|
"""Test invalid cron expression returns False."""
|
||||||
|
now = datetime(2025, 12, 24, 9, 0, 0, tzinfo=timezone.utc)
|
||||||
|
assert is_cron_due("invalid cron", None, now) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_cron_due_with_naive_last_run():
|
||||||
|
"""Test cron handles naive datetime for last_run."""
|
||||||
|
now = datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc)
|
||||||
|
cron_expr = "0 9 * * *"
|
||||||
|
last_run = datetime(2025, 12, 23, 9, 0, 0) # Naive datetime
|
||||||
|
assert is_cron_due(cron_expr, last_run, now) is True
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for evaluate_proactive_checkins task
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||||
|
@patch("memory.workers.tasks.proactive.is_cron_due")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_evaluate_proactive_checkins_dispatches_due(
|
||||||
|
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
|
||||||
|
):
|
||||||
|
"""Test that due check-ins are dispatched."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
mock_is_cron_due.return_value = True
|
||||||
|
|
||||||
|
result = proactive.evaluate_proactive_checkins()
|
||||||
|
|
||||||
|
assert result["count"] >= 1
|
||||||
|
mock_execute.delay.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||||
|
@patch("memory.workers.tasks.proactive.is_cron_due")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_evaluate_proactive_checkins_skips_not_due(
|
||||||
|
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
|
||||||
|
):
|
||||||
|
"""Test that not-due check-ins are not dispatched."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
mock_is_cron_due.return_value = False
|
||||||
|
|
||||||
|
result = proactive.evaluate_proactive_checkins()
|
||||||
|
|
||||||
|
assert result["count"] == 0
|
||||||
|
mock_execute.delay.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_evaluate_proactive_checkins_skips_no_cron(
|
||||||
|
mock_make_session, mock_execute, db_session, target_user_no_cron
|
||||||
|
):
|
||||||
|
"""Test that entities without proactive_cron are skipped."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
|
||||||
|
result = proactive.evaluate_proactive_checkins()
|
||||||
|
|
||||||
|
for call in mock_execute.delay.call_args_list:
|
||||||
|
entity_type, entity_id = call[0]
|
||||||
|
assert entity_id != target_user_no_cron.id
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||||
|
@patch("memory.workers.tasks.proactive.is_cron_due")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_evaluate_proactive_checkins_multiple_entity_types(
|
||||||
|
mock_make_session,
|
||||||
|
mock_is_cron_due,
|
||||||
|
mock_execute,
|
||||||
|
db_session,
|
||||||
|
target_user,
|
||||||
|
target_server,
|
||||||
|
target_channel,
|
||||||
|
):
|
||||||
|
"""Test that check-ins are dispatched for users, channels, and servers."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
mock_is_cron_due.return_value = True
|
||||||
|
|
||||||
|
result = proactive.evaluate_proactive_checkins()
|
||||||
|
|
||||||
|
assert result["count"] == 3
|
||||||
|
dispatched_types = {d["type"] for d in result["dispatched"]}
|
||||||
|
assert "user" in dispatched_types
|
||||||
|
assert "channel" in dispatched_types
|
||||||
|
assert "server" in dispatched_types
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for execute_proactive_checkin task
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||||
|
@patch("memory.workers.tasks.proactive.call_llm")
|
||||||
|
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_execute_proactive_checkin_sends_when_above_threshold(
|
||||||
|
mock_make_session,
|
||||||
|
mock_get_bot,
|
||||||
|
mock_call_llm,
|
||||||
|
mock_send,
|
||||||
|
db_session,
|
||||||
|
target_user,
|
||||||
|
bot_user,
|
||||||
|
):
|
||||||
|
"""Test check-in is sent when interest exceeds threshold."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
|
||||||
|
bot_discord_user = bot_user.discord_users[0]
|
||||||
|
bot_discord_user.system_user = bot_user
|
||||||
|
mock_get_bot.return_value = bot_discord_user
|
||||||
|
|
||||||
|
mock_call_llm.side_effect = [
|
||||||
|
"<response><number>80</number><reason>Should check in</reason></response>",
|
||||||
|
"Hey! Just checking in - how are things going?",
|
||||||
|
]
|
||||||
|
mock_send.return_value = True
|
||||||
|
|
||||||
|
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||||
|
|
||||||
|
assert result["status"] == "sent"
|
||||||
|
assert result["interest"] == 80
|
||||||
|
mock_send.assert_called_once()
|
||||||
|
|
||||||
|
db_session.refresh(target_user)
|
||||||
|
assert target_user.last_proactive_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.call_llm")
|
||||||
|
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_execute_proactive_checkin_skips_below_threshold(
|
||||||
|
mock_make_session,
|
||||||
|
mock_get_bot,
|
||||||
|
mock_call_llm,
|
||||||
|
db_session,
|
||||||
|
target_user,
|
||||||
|
bot_user,
|
||||||
|
):
|
||||||
|
"""Test check-in is skipped when interest is below threshold."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
|
||||||
|
bot_discord_user = bot_user.discord_users[0]
|
||||||
|
bot_discord_user.system_user = bot_user
|
||||||
|
mock_get_bot.return_value = bot_discord_user
|
||||||
|
|
||||||
|
mock_call_llm.return_value = (
|
||||||
|
"<response><number>30</number><reason>Not much to say</reason></response>"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||||
|
|
||||||
|
assert result["status"] == "below_threshold"
|
||||||
|
assert result["interest"] == 30
|
||||||
|
assert result["threshold"] == 50
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"llm_response,expected_status",
|
||||||
|
[
|
||||||
|
(None, "no_eval_response"),
|
||||||
|
("I'm not sure what to say.", "no_score_in_response"),
|
||||||
|
],
|
||||||
|
ids=["no_response", "malformed_response"],
|
||||||
|
)
|
||||||
|
@patch("memory.workers.tasks.proactive.call_llm")
|
||||||
|
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_execute_proactive_checkin_handles_bad_llm_response(
|
||||||
|
mock_make_session,
|
||||||
|
mock_get_bot,
|
||||||
|
mock_call_llm,
|
||||||
|
llm_response,
|
||||||
|
expected_status,
|
||||||
|
db_session,
|
||||||
|
target_user,
|
||||||
|
bot_user,
|
||||||
|
):
|
||||||
|
"""Test handling of missing or malformed LLM responses."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
|
||||||
|
bot_discord_user = bot_user.discord_users[0]
|
||||||
|
bot_discord_user.system_user = bot_user
|
||||||
|
mock_get_bot.return_value = bot_discord_user
|
||||||
|
mock_call_llm.return_value = llm_response
|
||||||
|
|
||||||
|
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||||
|
|
||||||
|
assert result["status"] == expected_status
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_execute_proactive_checkin_nonexistent_entity(mock_make_session, db_session):
|
||||||
|
"""Test handling when entity doesn't exist."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
|
||||||
|
result = proactive.execute_proactive_checkin("user", 999999)
|
||||||
|
|
||||||
|
assert "error" in result
|
||||||
|
assert "not found" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_execute_proactive_checkin_no_bot_user(
|
||||||
|
mock_make_session, mock_get_bot, db_session, target_user
|
||||||
|
):
|
||||||
|
"""Test handling when no bot user is found."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
mock_get_bot.return_value = None
|
||||||
|
|
||||||
|
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||||
|
|
||||||
|
assert "error" in result
|
||||||
|
assert "No bot user" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||||
|
@patch("memory.workers.tasks.proactive.call_llm")
|
||||||
|
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_execute_proactive_checkin_uses_proactive_prompt(
|
||||||
|
mock_make_session,
|
||||||
|
mock_get_bot,
|
||||||
|
mock_call_llm,
|
||||||
|
mock_send,
|
||||||
|
db_session,
|
||||||
|
bot_user,
|
||||||
|
):
|
||||||
|
"""Test that proactive_prompt is included in the evaluation."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
|
||||||
|
user_with_prompt = DiscordUser(
|
||||||
|
id=555666777,
|
||||||
|
username="promptuser",
|
||||||
|
proactive_cron="0 9 * * *",
|
||||||
|
proactive_prompt="Focus on their coding projects",
|
||||||
|
chattiness_threshold=50,
|
||||||
|
)
|
||||||
|
db_session.add(user_with_prompt)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
bot_discord_user = bot_user.discord_users[0]
|
||||||
|
bot_discord_user.system_user = bot_user
|
||||||
|
mock_get_bot.return_value = bot_discord_user
|
||||||
|
|
||||||
|
mock_call_llm.side_effect = [
|
||||||
|
"<response><number>80</number><reason>Check on projects</reason></response>",
|
||||||
|
"How are your coding projects coming along?",
|
||||||
|
]
|
||||||
|
mock_send.return_value = True
|
||||||
|
|
||||||
|
result = proactive.execute_proactive_checkin("user", user_with_prompt.id)
|
||||||
|
|
||||||
|
assert result["status"] == "sent"
|
||||||
|
call_args = mock_call_llm.call_args_list[0]
|
||||||
|
messages_arg = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||||
|
assert any("Focus on their coding projects" in str(m) for m in messages_arg)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||||
|
@patch("memory.workers.tasks.proactive.call_llm")
|
||||||
|
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_execute_proactive_checkin_channel(
|
||||||
|
mock_make_session,
|
||||||
|
mock_get_bot,
|
||||||
|
mock_call_llm,
|
||||||
|
mock_send,
|
||||||
|
db_session,
|
||||||
|
target_channel,
|
||||||
|
bot_user,
|
||||||
|
):
|
||||||
|
"""Test check-in to a channel."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
|
||||||
|
bot_discord_user = bot_user.discord_users[0]
|
||||||
|
bot_discord_user.system_user = bot_user
|
||||||
|
mock_get_bot.return_value = bot_discord_user
|
||||||
|
|
||||||
|
mock_call_llm.side_effect = [
|
||||||
|
"<response><number>50</number><reason>Check channel</reason></response>",
|
||||||
|
"Good morning everyone!",
|
||||||
|
]
|
||||||
|
mock_send.return_value = True
|
||||||
|
|
||||||
|
result = proactive.execute_proactive_checkin("channel", target_channel.id)
|
||||||
|
|
||||||
|
assert result["status"] == "sent"
|
||||||
|
assert result["entity_type"] == "channel"
|
||||||
|
|
||||||
|
send_call = mock_send.call_args
|
||||||
|
assert send_call.kwargs.get("channel_id") == target_channel.id
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||||
|
@patch("memory.workers.tasks.proactive.call_llm")
|
||||||
|
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_execute_proactive_checkin_updates_last_proactive_at(
|
||||||
|
mock_make_session,
|
||||||
|
mock_get_bot,
|
||||||
|
mock_call_llm,
|
||||||
|
mock_send,
|
||||||
|
db_session,
|
||||||
|
target_user,
|
||||||
|
bot_user,
|
||||||
|
):
|
||||||
|
"""Test that last_proactive_at is updated after successful check-in."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
|
||||||
|
bot_discord_user = bot_user.discord_users[0]
|
||||||
|
bot_discord_user.system_user = bot_user
|
||||||
|
mock_get_bot.return_value = bot_discord_user
|
||||||
|
|
||||||
|
mock_call_llm.side_effect = [
|
||||||
|
"<response><number>80</number><reason>Check in</reason></response>",
|
||||||
|
"Hey there!",
|
||||||
|
]
|
||||||
|
mock_send.return_value = True
|
||||||
|
|
||||||
|
before_time = datetime.now(timezone.utc)
|
||||||
|
proactive.execute_proactive_checkin("user", target_user.id)
|
||||||
|
after_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
db_session.refresh(target_user)
|
||||||
|
assert target_user.last_proactive_at is not None
|
||||||
|
assert before_time <= target_user.last_proactive_at <= after_time
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.proactive.call_llm")
|
||||||
|
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||||
|
@patch("memory.workers.tasks.proactive.make_session")
|
||||||
|
def test_execute_proactive_checkin_updates_last_proactive_at_on_skip(
|
||||||
|
mock_make_session,
|
||||||
|
mock_get_bot,
|
||||||
|
mock_call_llm,
|
||||||
|
db_session,
|
||||||
|
target_user,
|
||||||
|
bot_user,
|
||||||
|
):
|
||||||
|
"""Test that last_proactive_at is updated even when check-in is skipped."""
|
||||||
|
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||||
|
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||||
|
|
||||||
|
bot_discord_user = bot_user.discord_users[0]
|
||||||
|
bot_discord_user.system_user = bot_user
|
||||||
|
mock_get_bot.return_value = bot_discord_user
|
||||||
|
|
||||||
|
mock_call_llm.return_value = (
|
||||||
|
"<response><number>10</number><reason>Nothing to say</reason></response>"
|
||||||
|
)
|
||||||
|
|
||||||
|
proactive.execute_proactive_checkin("user", target_user.id)
|
||||||
|
|
||||||
|
db_session.refresh(target_user)
|
||||||
|
assert target_user.last_proactive_at is not None
|
||||||
Loading…
x
Reference in New Issue
Block a user