mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Add proactive check-in functionality for Discord
- Add proactive_cron, proactive_prompt, last_proactive_at fields to Discord models - Add /proactive command handler for configuring check-in schedules - Add evaluate_proactive_checkins task (runs every minute via celery beat) - Add execute_proactive_checkin task that evaluates interest and sends messages - Smart bot selection finds the correct bot for each server - Channel selection defaults to "general" text channel for servers - Add database migration for new fields - Add comprehensive tests for commands and tasks 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
9088997295
commit
a238ca6329
@ -0,0 +1,45 @@
|
||||
"""Add proactive check-in fields to Discord entities
|
||||
|
||||
Revision ID: e1f2a3b4c5d6
|
||||
Revises: d0e1f2a3b4c5
|
||||
Create Date: 2025-12-24 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e1f2a3b4c5d6"
|
||||
down_revision: Union[str, None] = "d0e1f2a3b4c5"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add proactive fields to all MessageProcessor tables
|
||||
for table in ["discord_servers", "discord_channels", "discord_users"]:
|
||||
op.add_column(
|
||||
table,
|
||||
sa.Column("proactive_cron", sa.Text(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
table,
|
||||
sa.Column("proactive_prompt", sa.Text(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
table,
|
||||
sa.Column(
|
||||
"last_proactive_at", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ["discord_servers", "discord_channels", "discord_users"]:
|
||||
op.drop_column(table, "last_proactive_at")
|
||||
op.drop_column(table, "proactive_prompt")
|
||||
op.drop_column(table, "proactive_cron")
|
||||
@ -10,5 +10,6 @@ openai==2.3.0
|
||||
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
|
||||
httpx>=0.28.1
|
||||
celery[redis,sqs]==5.3.6
|
||||
croniter==2.0.1
|
||||
cryptography==43.0.0
|
||||
bcrypt==4.1.2
|
||||
@ -1,7 +1,9 @@
|
||||
pytest==7.4.4
|
||||
pytest-cov==4.1.0
|
||||
pytest-asyncio==0.23.0
|
||||
black==23.12.1
|
||||
mypy==1.8.0
|
||||
isort==5.13.2
|
||||
testcontainers[qdrant]==4.10.0
|
||||
click==8.1.7
|
||||
croniter==2.0.1
|
||||
@ -17,6 +17,7 @@ DISCORD_ROOT = "memory.workers.tasks.discord"
|
||||
BACKUP_ROOT = "memory.workers.tasks.backup"
|
||||
GITHUB_ROOT = "memory.workers.tasks.github"
|
||||
PEOPLE_ROOT = "memory.workers.tasks.people"
|
||||
PROACTIVE_ROOT = "memory.workers.tasks.proactive"
|
||||
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
||||
@ -73,6 +74,10 @@ SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person"
|
||||
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
|
||||
SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file"
|
||||
|
||||
# Proactive check-in tasks
|
||||
EVALUATE_PROACTIVE_CHECKINS = f"{PROACTIVE_ROOT}.evaluate_proactive_checkins"
|
||||
EXECUTE_PROACTIVE_CHECKIN = f"{PROACTIVE_ROOT}.execute_proactive_checkin"
|
||||
|
||||
|
||||
def get_broker_url() -> str:
|
||||
protocol = settings.CELERY_BROKER_TYPE
|
||||
@ -130,12 +135,17 @@ app.conf.update(
|
||||
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
|
||||
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
|
||||
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
|
||||
f"{PROACTIVE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
|
||||
},
|
||||
beat_schedule={
|
||||
"sync-github-repos-hourly": {
|
||||
"task": SYNC_ALL_GITHUB_REPOS,
|
||||
"schedule": crontab(minute=0), # Every hour at :00
|
||||
},
|
||||
"evaluate-proactive-checkins": {
|
||||
"task": EVALUATE_PROACTIVE_CHECKINS,
|
||||
"schedule": crontab(), # Every minute
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -70,6 +70,22 @@ class MessageProcessor:
|
||||
),
|
||||
)
|
||||
|
||||
proactive_cron = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
doc="Cron schedule for proactive check-ins (e.g., '0 9 * * *' for 9am daily). None = disabled.",
|
||||
)
|
||||
proactive_prompt = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
doc="Custom instructions for proactive check-ins.",
|
||||
)
|
||||
last_proactive_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
doc="When the last proactive check-in was sent.",
|
||||
)
|
||||
|
||||
@property
|
||||
def entity_type(self) -> str:
|
||||
return self.__class__.__tablename__[8:-1] # type: ignore
|
||||
|
||||
@ -132,6 +132,7 @@ CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
|
||||
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
|
||||
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
|
||||
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
|
||||
PROACTIVE_CHECKIN_INTERVAL = int(os.getenv("PROACTIVE_CHECKIN_INTERVAL", 60))
|
||||
|
||||
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
||||
|
||||
|
||||
@ -167,6 +167,25 @@ def _create_scope_group(
|
||||
url=url and url.strip(),
|
||||
)
|
||||
|
||||
# Proactive command
|
||||
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
|
||||
@discord.app_commands.describe(
|
||||
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
|
||||
prompt="Optional custom instructions for check-ins",
|
||||
)
|
||||
async def proactive_cmd(
|
||||
interaction: discord.Interaction,
|
||||
cron: str | None = None,
|
||||
prompt: str | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_proactive,
|
||||
cron=cron and cron.strip(),
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
@ -265,6 +284,28 @@ def _create_user_scope_group(
|
||||
url=url and url.strip(),
|
||||
)
|
||||
|
||||
# Proactive command
|
||||
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user",
|
||||
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
|
||||
prompt="Optional custom instructions for check-ins",
|
||||
)
|
||||
async def proactive_cmd(
|
||||
interaction: discord.Interaction,
|
||||
user: discord.User,
|
||||
cron: str | None = None,
|
||||
prompt: str | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_proactive,
|
||||
target_user=user,
|
||||
cron=cron and cron.strip(),
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
@ -663,3 +704,68 @@ async def handle_mcp_servers(
|
||||
except Exception as exc:
|
||||
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
|
||||
raise CommandError(f"Error: {exc}") from exc
|
||||
|
||||
|
||||
async def handle_proactive(
|
||||
context: CommandContext,
|
||||
*,
|
||||
cron: str | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> CommandResponse:
|
||||
"""Handle proactive check-in configuration."""
|
||||
from croniter import croniter
|
||||
|
||||
model = context.target
|
||||
|
||||
# If no arguments, show current settings
|
||||
if cron is None and prompt is None:
|
||||
current_cron = getattr(model, "proactive_cron", None)
|
||||
current_prompt = getattr(model, "proactive_prompt", None)
|
||||
|
||||
if not current_cron:
|
||||
return CommandResponse(
|
||||
content=f"Proactive check-ins are disabled for {context.display_name}."
|
||||
)
|
||||
|
||||
lines = [f"Proactive check-ins for {context.display_name}:"]
|
||||
lines.append(f" Schedule: `{current_cron}`")
|
||||
if current_prompt:
|
||||
lines.append(f" Prompt: {current_prompt}")
|
||||
return CommandResponse(content="\n".join(lines))
|
||||
|
||||
# Handle cron setting
|
||||
if cron is not None:
|
||||
if cron.lower() == "off":
|
||||
setattr(model, "proactive_cron", None)
|
||||
return CommandResponse(
|
||||
content=f"Proactive check-ins disabled for {context.display_name}."
|
||||
)
|
||||
|
||||
# Validate cron expression
|
||||
try:
|
||||
croniter(cron)
|
||||
except (ValueError, KeyError) as e:
|
||||
raise CommandError(
|
||||
f"Invalid cron expression: {cron}\n"
|
||||
"Examples:\n"
|
||||
" `0 9 * * *` - 9am daily\n"
|
||||
" `0 9,17 * * 1-5` - 9am and 5pm weekdays\n"
|
||||
" `0 */4 * * *` - every 4 hours"
|
||||
) from e
|
||||
|
||||
setattr(model, "proactive_cron", cron)
|
||||
|
||||
# Handle prompt setting
|
||||
if prompt is not None:
|
||||
setattr(model, "proactive_prompt", prompt or None)
|
||||
|
||||
# Build response
|
||||
current_cron = getattr(model, "proactive_cron", None)
|
||||
current_prompt = getattr(model, "proactive_prompt", None)
|
||||
|
||||
lines = [f"Updated proactive settings for {context.display_name}:"]
|
||||
lines.append(f" Schedule: `{current_cron}`")
|
||||
if current_prompt:
|
||||
lines.append(f" Prompt: {current_prompt}")
|
||||
|
||||
return CommandResponse(content="\n".join(lines))
|
||||
|
||||
@ -11,6 +11,7 @@ from memory.common.celery_app import (
|
||||
SYNC_LESSWRONG,
|
||||
RUN_SCHEDULED_CALLS,
|
||||
BACKUP_ALL,
|
||||
EVALUATE_PROACTIVE_CHECKINS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -53,4 +54,8 @@ app.conf.beat_schedule = {
|
||||
"task": BACKUP_ALL,
|
||||
"schedule": settings.S3_BACKUP_INTERVAL,
|
||||
},
|
||||
"evaluate-proactive-checkins": {
|
||||
"task": EVALUATE_PROACTIVE_CHECKINS,
|
||||
"schedule": settings.PROACTIVE_CHECKIN_INTERVAL,
|
||||
},
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ from memory.workers.tasks import (
|
||||
notes,
|
||||
observations,
|
||||
people,
|
||||
proactive,
|
||||
scheduled_calls,
|
||||
) # noqa
|
||||
|
||||
@ -31,5 +32,6 @@ __all__ = [
|
||||
"notes",
|
||||
"observations",
|
||||
"people",
|
||||
"proactive",
|
||||
"scheduled_calls",
|
||||
]
|
||||
|
||||
341
src/memory/workers/tasks/proactive.py
Normal file
341
src/memory/workers/tasks/proactive.py
Normal file
@ -0,0 +1,341 @@
|
||||
"""
|
||||
Celery tasks for proactive Discord check-ins.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from croniter import croniter
|
||||
from sqlalchemy import or_
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.celery_app import app
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
|
||||
from memory.workers.tasks.content_processing import safe_task_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVALUATE_PROACTIVE_CHECKINS = "memory.workers.tasks.proactive.evaluate_proactive_checkins"
|
||||
EXECUTE_PROACTIVE_CHECKIN = "memory.workers.tasks.proactive.execute_proactive_checkin"
|
||||
|
||||
EntityType = Literal["user", "channel", "server"]
|
||||
|
||||
|
||||
def is_cron_due(cron_expr: str, last_run: datetime | None, now: datetime) -> bool:
|
||||
"""Check if a cron expression is due to run now.
|
||||
|
||||
Uses croniter to determine if the current time falls within the cron's schedule
|
||||
and enough time has passed since the last run.
|
||||
"""
|
||||
try:
|
||||
cron = croniter(cron_expr, now)
|
||||
# Get the previous scheduled time from now
|
||||
prev_run = cron.get_prev(datetime)
|
||||
# Get the one before that to determine the interval
|
||||
cron.get_prev(datetime)
|
||||
prev_prev_run = cron.get_current(datetime)
|
||||
|
||||
# If we haven't run since the last scheduled time, we should run
|
||||
if last_run is None:
|
||||
# Never run before - check if current time is within a minute of prev_run
|
||||
time_since_scheduled = (now - prev_run).total_seconds()
|
||||
return time_since_scheduled < 120 # Within 2 minutes of scheduled time
|
||||
|
||||
# Make sure last_run is timezone aware
|
||||
if last_run.tzinfo is None:
|
||||
last_run = last_run.replace(tzinfo=timezone.utc)
|
||||
|
||||
# We should run if last_run is before the previous scheduled time
|
||||
return last_run < prev_run
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid cron expression '{cron_expr}': {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_bot_for_entity(
|
||||
session, entity_type: EntityType, entity_id: int
|
||||
) -> DiscordUser | None:
|
||||
"""Get the bot user associated with an entity."""
|
||||
from memory.common.db.models import DiscordBotUser, DiscordMessage
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
# For servers, find a bot that has sent messages in that server
|
||||
if entity_type == "server":
|
||||
# Find bots that have interacted with this server
|
||||
bot_users = (
|
||||
session.query(DiscordUser)
|
||||
.options(joinedload(DiscordUser.system_user))
|
||||
.join(DiscordMessage, DiscordMessage.from_id == DiscordUser.id)
|
||||
.filter(
|
||||
DiscordMessage.server_id == entity_id,
|
||||
DiscordUser.system_user_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
# Find one that's actually a bot
|
||||
for user in bot_users:
|
||||
if user.system_user and user.system_user.user_type == "discord_bot":
|
||||
return user
|
||||
|
||||
# For channels, check the server the channel belongs to
|
||||
if entity_type == "channel":
|
||||
channel = session.get(DiscordChannel, entity_id)
|
||||
if channel and channel.server_id:
|
||||
return get_bot_for_entity(session, "server", channel.server_id)
|
||||
|
||||
# Fallback: use first available bot
|
||||
bot = (
|
||||
session.query(DiscordBotUser)
|
||||
.options(joinedload(DiscordBotUser.discord_users).joinedload(DiscordUser.system_user))
|
||||
.first()
|
||||
)
|
||||
if bot and bot.discord_users:
|
||||
return bot.discord_users[0]
|
||||
return None
|
||||
|
||||
|
||||
def get_target_user_for_entity(
|
||||
session, entity_type: EntityType, entity_id: int
|
||||
) -> DiscordUser | None:
|
||||
"""Get the target user for sending a proactive message."""
|
||||
if entity_type == "user":
|
||||
return session.get(DiscordUser, entity_id)
|
||||
# For channels and servers, we don't have a specific target user
|
||||
return None
|
||||
|
||||
|
||||
def get_channel_for_entity(
|
||||
session, entity_type: EntityType, entity_id: int
|
||||
) -> DiscordChannel | None:
|
||||
"""Get the channel for sending a proactive message."""
|
||||
if entity_type == "channel":
|
||||
return session.get(DiscordChannel, entity_id)
|
||||
if entity_type == "server":
|
||||
# For servers, find the first text channel (prefer "general")
|
||||
channels = (
|
||||
session.query(DiscordChannel)
|
||||
.filter(
|
||||
DiscordChannel.server_id == entity_id,
|
||||
DiscordChannel.channel_type == "text",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not channels:
|
||||
return None
|
||||
# Prefer a channel named "general" if it exists
|
||||
for channel in channels:
|
||||
if channel.name and "general" in channel.name.lower():
|
||||
return channel
|
||||
return channels[0]
|
||||
# For users, we use DMs (no channel)
|
||||
return None
|
||||
|
||||
|
||||
@app.task(name=EVALUATE_PROACTIVE_CHECKINS)
|
||||
@safe_task_execution
|
||||
def evaluate_proactive_checkins() -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate which entities need proactive check-ins.
|
||||
|
||||
This task runs every minute and checks all entities with proactive_cron set
|
||||
to see if they're due for a check-in.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
dispatched = []
|
||||
|
||||
with make_session() as session:
|
||||
# Query all entities with proactive_cron set
|
||||
for model, entity_type in [
|
||||
(DiscordUser, "user"),
|
||||
(DiscordChannel, "channel"),
|
||||
(DiscordServer, "server"),
|
||||
]:
|
||||
entities = (
|
||||
session.query(model)
|
||||
.filter(model.proactive_cron.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
for entity in entities:
|
||||
cron_expr = cast(str, entity.proactive_cron)
|
||||
last_run = entity.last_proactive_at
|
||||
|
||||
if is_cron_due(cron_expr, last_run, now):
|
||||
logger.info(
|
||||
f"Proactive check-in due for {entity_type} {entity.id}"
|
||||
)
|
||||
execute_proactive_checkin.delay(entity_type, entity.id)
|
||||
dispatched.append({"type": entity_type, "id": entity.id})
|
||||
|
||||
return {
|
||||
"evaluated_at": now.isoformat(),
|
||||
"dispatched": dispatched,
|
||||
"count": len(dispatched),
|
||||
}
|
||||
|
||||
|
||||
@app.task(name=EXECUTE_PROACTIVE_CHECKIN)
|
||||
@safe_task_execution
|
||||
def execute_proactive_checkin(entity_type: EntityType, entity_id: int) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a proactive check-in for a specific entity.
|
||||
|
||||
This evaluates whether the bot should reach out and, if so, generates
|
||||
and sends a check-in message.
|
||||
"""
|
||||
logger.info(f"Executing proactive check-in for {entity_type} {entity_id}")
|
||||
|
||||
with make_session() as session:
|
||||
# Get the entity
|
||||
model_class = {
|
||||
"user": DiscordUser,
|
||||
"channel": DiscordChannel,
|
||||
"server": DiscordServer,
|
||||
}[entity_type]
|
||||
|
||||
entity = session.get(model_class, entity_id)
|
||||
if not entity:
|
||||
return {"error": f"{entity_type} {entity_id} not found"}
|
||||
|
||||
# Get the bot user
|
||||
bot_user = get_bot_for_entity(session, entity_type, entity_id)
|
||||
if not bot_user:
|
||||
return {"error": "No bot user found"}
|
||||
|
||||
# Get target user and channel
|
||||
target_user = get_target_user_for_entity(session, entity_type, entity_id)
|
||||
channel = get_channel_for_entity(session, entity_type, entity_id)
|
||||
|
||||
if not target_user and not channel:
|
||||
return {"error": "No target user or channel for proactive check-in"}
|
||||
|
||||
# Get chattiness threshold
|
||||
chattiness = entity.chattiness_threshold or 90
|
||||
|
||||
# Build the evaluation prompt
|
||||
proactive_prompt = entity.proactive_prompt or ""
|
||||
eval_prompt = textwrap.dedent("""
|
||||
You are considering whether to proactively reach out to check in.
|
||||
|
||||
{proactive_prompt}
|
||||
|
||||
Based on your notes and the context of previous conversations:
|
||||
1. Is there anything worth checking in about?
|
||||
2. Has enough happened or enough time passed to warrant a check-in?
|
||||
3. Would reaching out now be welcome or intrusive?
|
||||
|
||||
Please return a number between 0 and 100 indicating how strongly you want to check in
|
||||
(0 = definitely not, 100 = definitely yes).
|
||||
|
||||
<response>
|
||||
<number>50</number>
|
||||
<reason>Your reasoning here</reason>
|
||||
</response>
|
||||
""").format(proactive_prompt=proactive_prompt)
|
||||
|
||||
# Build context
|
||||
system_prompt = comm_channel_prompt(
|
||||
session, bot_user, target_user, channel
|
||||
)
|
||||
|
||||
# First, evaluate whether we should check in
|
||||
eval_response = call_llm(
|
||||
session,
|
||||
bot_user=bot_user,
|
||||
from_user=target_user,
|
||||
channel=channel,
|
||||
model=settings.SUMMARIZER_MODEL,
|
||||
system_prompt=system_prompt,
|
||||
messages=[eval_prompt],
|
||||
allowed_tools=[
|
||||
"update_channel_summary",
|
||||
"update_user_summary",
|
||||
"update_server_summary",
|
||||
],
|
||||
)
|
||||
|
||||
if not eval_response:
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
return {"status": "no_eval_response", "entity_type": entity_type, "entity_id": entity_id}
|
||||
|
||||
# Parse the interest score
|
||||
match = re.search(r"<number>(\d+)</number>", eval_response)
|
||||
if not match:
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
return {"status": "no_score_in_response", "entity_type": entity_type, "entity_id": entity_id}
|
||||
|
||||
interest_score = int(match.group(1))
|
||||
threshold = 100 - chattiness
|
||||
|
||||
logger.info(
|
||||
f"Proactive check-in eval: interest={interest_score}, threshold={threshold}, chattiness={chattiness}"
|
||||
)
|
||||
|
||||
if interest_score < threshold:
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
return {
|
||||
"status": "below_threshold",
|
||||
"interest": interest_score,
|
||||
"threshold": threshold,
|
||||
"entity_type": entity_type,
|
||||
"entity_id": entity_id,
|
||||
}
|
||||
|
||||
# Generate the actual check-in message
|
||||
checkin_prompt = textwrap.dedent("""
|
||||
You've decided to proactively check in. Generate a natural, friendly check-in message.
|
||||
|
||||
{proactive_prompt}
|
||||
|
||||
Keep it brief and genuine. Don't be overly formal or robotic.
|
||||
Reference specific things from your notes if relevant.
|
||||
""").format(proactive_prompt=proactive_prompt)
|
||||
|
||||
response = call_llm(
|
||||
session,
|
||||
bot_user=bot_user,
|
||||
from_user=target_user,
|
||||
channel=channel,
|
||||
model=settings.DISCORD_MODEL,
|
||||
system_prompt=system_prompt,
|
||||
messages=[checkin_prompt],
|
||||
)
|
||||
|
||||
if not response:
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
return {"status": "no_message_generated", "entity_type": entity_type, "entity_id": entity_id}
|
||||
|
||||
# Send the message
|
||||
bot_id = bot_user.system_user.id if bot_user.system_user else None
|
||||
if not bot_id:
|
||||
return {"error": "No system user for bot"}
|
||||
|
||||
success = send_discord_response(
|
||||
bot_id=bot_id,
|
||||
response=response,
|
||||
channel_id=channel.id if channel else None,
|
||||
user_identifier=target_user.username if target_user else None,
|
||||
)
|
||||
|
||||
# Update last_proactive_at
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"status": "sent" if success else "send_failed",
|
||||
"interest": interest_score,
|
||||
"entity_type": entity_type,
|
||||
"entity_id": entity_id,
|
||||
"response_preview": response[:100] + "..." if len(response) > 100 else response,
|
||||
}
|
||||
@ -15,6 +15,7 @@ from memory.discord.commands import (
|
||||
handle_chattiness,
|
||||
handle_ignore,
|
||||
handle_summary,
|
||||
handle_proactive,
|
||||
respond,
|
||||
with_object_context,
|
||||
handle_mcp_servers,
|
||||
@ -377,3 +378,207 @@ async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
|
||||
await handle_mcp_servers(context, action="list", url=None)
|
||||
|
||||
assert "Error: boom" in str(exc.value)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for handle_proactive
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_show_disabled(db_session, interaction, guild):
|
||||
"""Test showing proactive settings when disabled."""
|
||||
server = DiscordServer(id=guild.id, name="Guild", proactive_cron=None)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context)
|
||||
|
||||
assert "disabled" in response.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_show_enabled(db_session, interaction, guild):
|
||||
"""Test showing proactive settings when enabled."""
|
||||
server = DiscordServer(
|
||||
id=guild.id,
|
||||
name="Guild",
|
||||
proactive_cron="0 9 * * *",
|
||||
proactive_prompt="Check on projects",
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context)
|
||||
|
||||
assert "0 9 * * *" in response.content
|
||||
assert "Check on projects" in response.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_set_cron(db_session, interaction, guild):
|
||||
"""Test setting proactive cron schedule."""
|
||||
server = DiscordServer(id=guild.id, name="Guild")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, cron="0 9 * * *")
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert "0 9 * * *" in response.content
|
||||
assert server.proactive_cron == "0 9 * * *"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_set_prompt(db_session, interaction, guild):
|
||||
"""Test setting proactive prompt."""
|
||||
server = DiscordServer(id=guild.id, name="Guild", proactive_cron="0 9 * * *")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, prompt="Focus on daily standups")
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert server.proactive_prompt == "Focus on daily standups"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_disable(db_session, interaction, guild):
|
||||
"""Test disabling proactive check-ins."""
|
||||
server = DiscordServer(
|
||||
id=guild.id,
|
||||
name="Guild",
|
||||
proactive_cron="0 9 * * *",
|
||||
proactive_prompt="Some prompt",
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, cron="off")
|
||||
|
||||
assert "disabled" in response.content.lower()
|
||||
assert server.proactive_cron is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_invalid_cron(db_session, interaction, guild):
|
||||
"""Test error on invalid cron expression."""
|
||||
server = DiscordServer(id=guild.id, name="Guild")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
with pytest.raises(CommandError) as exc:
|
||||
await handle_proactive(context, cron="not a valid cron")
|
||||
|
||||
assert "Invalid cron expression" in str(exc.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_user_scope(db_session, interaction, discord_user):
|
||||
"""Test proactive settings for user scope."""
|
||||
user_model = DiscordUser(
|
||||
id=discord_user.id, username="testuser", proactive_cron=None
|
||||
)
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="me",
|
||||
target=user_model,
|
||||
display_name="you (**testuser**)",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, cron="0 9,17 * * 1-5")
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert user_model.proactive_cron == "0 9,17 * * 1-5"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_channel_scope(
|
||||
db_session, interaction, guild, text_channel
|
||||
):
|
||||
"""Test proactive settings for channel scope."""
|
||||
server = DiscordServer(id=guild.id, name="Guild")
|
||||
db_session.add(server)
|
||||
db_session.flush()
|
||||
|
||||
channel_model = DiscordChannel(
|
||||
id=text_channel.id,
|
||||
name="general",
|
||||
channel_type="text",
|
||||
server_id=guild.id,
|
||||
)
|
||||
db_session.add(channel_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="channel",
|
||||
target=channel_model,
|
||||
display_name="channel **#general**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, cron="0 12 * * *")
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert channel_model.proactive_cron == "0 12 * * *"
|
||||
|
||||
536
tests/memory/workers/tasks/test_proactive.py
Normal file
536
tests/memory/workers/tasks/test_proactive.py
Normal file
@ -0,0 +1,536 @@
|
||||
"""Tests for proactive check-in tasks."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from memory.common.db.models import (
|
||||
DiscordBotUser,
|
||||
DiscordUser,
|
||||
DiscordChannel,
|
||||
DiscordServer,
|
||||
)
|
||||
from memory.workers.tasks import proactive
|
||||
from memory.workers.tasks.proactive import is_cron_due
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bot_user(db_session):
|
||||
"""Create a bot user for testing."""
|
||||
bot_discord_user = DiscordUser(
|
||||
id=999999999,
|
||||
username="testbot",
|
||||
)
|
||||
db_session.add(bot_discord_user)
|
||||
db_session.flush()
|
||||
|
||||
user = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[bot_discord_user],
|
||||
name="testbot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user(db_session):
|
||||
"""Create a target Discord user for testing."""
|
||||
discord_user = DiscordUser(
|
||||
id=123456789,
|
||||
username="targetuser",
|
||||
proactive_cron="0 9 * * *", # 9am daily
|
||||
chattiness_threshold=50,
|
||||
)
|
||||
db_session.add(discord_user)
|
||||
db_session.commit()
|
||||
return discord_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_no_cron(db_session):
|
||||
"""Create a target Discord user without proactive cron."""
|
||||
discord_user = DiscordUser(
|
||||
id=123456790,
|
||||
username="nocronuser",
|
||||
proactive_cron=None,
|
||||
)
|
||||
db_session.add(discord_user)
|
||||
db_session.commit()
|
||||
return discord_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_server(db_session):
|
||||
"""Create a target Discord server for testing."""
|
||||
server = DiscordServer(
|
||||
id=987654321,
|
||||
name="Test Server",
|
||||
proactive_cron="0 */4 * * *", # Every 4 hours
|
||||
chattiness_threshold=30,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_channel(db_session, target_server):
|
||||
"""Create a target Discord channel for testing."""
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
name="test-channel",
|
||||
channel_type="text",
|
||||
server_id=target_server.id,
|
||||
proactive_cron="0 12 * * 1-5", # Noon on weekdays
|
||||
chattiness_threshold=70,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
return channel
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for is_cron_due helper
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cron_expr,now,last_run,expected",
|
||||
[
|
||||
# Cron is due when never run before and time matches
|
||||
(
|
||||
"0 9 * * *",
|
||||
datetime(2025, 12, 24, 9, 0, 30, tzinfo=timezone.utc),
|
||||
None,
|
||||
True,
|
||||
),
|
||||
# Cron is due when last run was before the scheduled time
|
||||
(
|
||||
"0 9 * * *",
|
||||
datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc),
|
||||
datetime(2025, 12, 23, 9, 0, 0, tzinfo=timezone.utc),
|
||||
True,
|
||||
),
|
||||
# Cron is NOT due when already run this period
|
||||
(
|
||||
"0 9 * * *",
|
||||
datetime(2025, 12, 24, 9, 30, 0, tzinfo=timezone.utc),
|
||||
datetime(2025, 12, 24, 9, 5, 0, tzinfo=timezone.utc),
|
||||
False,
|
||||
),
|
||||
# Cron is NOT due when current time is before scheduled time
|
||||
(
|
||||
"0 9 * * *",
|
||||
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
|
||||
None,
|
||||
False,
|
||||
),
|
||||
# Hourly cron schedule
|
||||
(
|
||||
"0 * * * *",
|
||||
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
|
||||
datetime(2025, 12, 24, 11, 0, 0, tzinfo=timezone.utc),
|
||||
True,
|
||||
),
|
||||
# Every 4 hours cron schedule
|
||||
(
|
||||
"0 */4 * * *",
|
||||
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
|
||||
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
|
||||
True,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"due_never_run",
|
||||
"due_last_run_before_schedule",
|
||||
"not_due_already_run",
|
||||
"not_due_too_early",
|
||||
"due_hourly",
|
||||
"due_every_4_hours",
|
||||
],
|
||||
)
|
||||
def test_is_cron_due(cron_expr, now, last_run, expected):
|
||||
"""Test is_cron_due with various scenarios."""
|
||||
assert is_cron_due(cron_expr, last_run, now) is expected
|
||||
|
||||
|
||||
def test_is_cron_due_invalid_expression():
|
||||
"""Test invalid cron expression returns False."""
|
||||
now = datetime(2025, 12, 24, 9, 0, 0, tzinfo=timezone.utc)
|
||||
assert is_cron_due("invalid cron", None, now) is False
|
||||
|
||||
|
||||
def test_is_cron_due_with_naive_last_run():
|
||||
"""Test cron handles naive datetime for last_run."""
|
||||
now = datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc)
|
||||
cron_expr = "0 9 * * *"
|
||||
last_run = datetime(2025, 12, 23, 9, 0, 0) # Naive datetime
|
||||
assert is_cron_due(cron_expr, last_run, now) is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for evaluate_proactive_checkins task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||
@patch("memory.workers.tasks.proactive.is_cron_due")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_evaluate_proactive_checkins_dispatches_due(
|
||||
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
|
||||
):
|
||||
"""Test that due check-ins are dispatched."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
mock_is_cron_due.return_value = True
|
||||
|
||||
result = proactive.evaluate_proactive_checkins()
|
||||
|
||||
assert result["count"] >= 1
|
||||
mock_execute.delay.assert_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||
@patch("memory.workers.tasks.proactive.is_cron_due")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_evaluate_proactive_checkins_skips_not_due(
|
||||
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
|
||||
):
|
||||
"""Test that not-due check-ins are not dispatched."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
mock_is_cron_due.return_value = False
|
||||
|
||||
result = proactive.evaluate_proactive_checkins()
|
||||
|
||||
assert result["count"] == 0
|
||||
mock_execute.delay.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_evaluate_proactive_checkins_skips_no_cron(
|
||||
mock_make_session, mock_execute, db_session, target_user_no_cron
|
||||
):
|
||||
"""Test that entities without proactive_cron are skipped."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
result = proactive.evaluate_proactive_checkins()
|
||||
|
||||
for call in mock_execute.delay.call_args_list:
|
||||
entity_type, entity_id = call[0]
|
||||
assert entity_id != target_user_no_cron.id
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||
@patch("memory.workers.tasks.proactive.is_cron_due")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_evaluate_proactive_checkins_multiple_entity_types(
|
||||
mock_make_session,
|
||||
mock_is_cron_due,
|
||||
mock_execute,
|
||||
db_session,
|
||||
target_user,
|
||||
target_server,
|
||||
target_channel,
|
||||
):
|
||||
"""Test that check-ins are dispatched for users, channels, and servers."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
mock_is_cron_due.return_value = True
|
||||
|
||||
result = proactive.evaluate_proactive_checkins()
|
||||
|
||||
assert result["count"] == 3
|
||||
dispatched_types = {d["type"] for d in result["dispatched"]}
|
||||
assert "user" in dispatched_types
|
||||
assert "channel" in dispatched_types
|
||||
assert "server" in dispatched_types
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for execute_proactive_checkin task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_sends_when_above_threshold(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
mock_send,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test check-in is sent when interest exceeds threshold."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.side_effect = [
|
||||
"<response><number>80</number><reason>Should check in</reason></response>",
|
||||
"Hey! Just checking in - how are things going?",
|
||||
]
|
||||
mock_send.return_value = True
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
assert result["status"] == "sent"
|
||||
assert result["interest"] == 80
|
||||
mock_send.assert_called_once()
|
||||
|
||||
db_session.refresh(target_user)
|
||||
assert target_user.last_proactive_at is not None
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_skips_below_threshold(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test check-in is skipped when interest is below threshold."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.return_value = (
|
||||
"<response><number>30</number><reason>Not much to say</reason></response>"
|
||||
)
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
assert result["status"] == "below_threshold"
|
||||
assert result["interest"] == 30
|
||||
assert result["threshold"] == 50
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_response,expected_status",
|
||||
[
|
||||
(None, "no_eval_response"),
|
||||
("I'm not sure what to say.", "no_score_in_response"),
|
||||
],
|
||||
ids=["no_response", "malformed_response"],
|
||||
)
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_handles_bad_llm_response(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
llm_response,
|
||||
expected_status,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test handling of missing or malformed LLM responses."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
mock_call_llm.return_value = llm_response
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
assert result["status"] == expected_status
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_nonexistent_entity(mock_make_session, db_session):
|
||||
"""Test handling when entity doesn't exist."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", 999999)
|
||||
|
||||
assert "error" in result
|
||||
assert "not found" in result["error"]
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_no_bot_user(
|
||||
mock_make_session, mock_get_bot, db_session, target_user
|
||||
):
|
||||
"""Test handling when no bot user is found."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
mock_get_bot.return_value = None
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
assert "error" in result
|
||||
assert "No bot user" in result["error"]
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_uses_proactive_prompt(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
mock_send,
|
||||
db_session,
|
||||
bot_user,
|
||||
):
|
||||
"""Test that proactive_prompt is included in the evaluation."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
user_with_prompt = DiscordUser(
|
||||
id=555666777,
|
||||
username="promptuser",
|
||||
proactive_cron="0 9 * * *",
|
||||
proactive_prompt="Focus on their coding projects",
|
||||
chattiness_threshold=50,
|
||||
)
|
||||
db_session.add(user_with_prompt)
|
||||
db_session.commit()
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.side_effect = [
|
||||
"<response><number>80</number><reason>Check on projects</reason></response>",
|
||||
"How are your coding projects coming along?",
|
||||
]
|
||||
mock_send.return_value = True
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", user_with_prompt.id)
|
||||
|
||||
assert result["status"] == "sent"
|
||||
call_args = mock_call_llm.call_args_list[0]
|
||||
messages_arg = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
assert any("Focus on their coding projects" in str(m) for m in messages_arg)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_channel(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
mock_send,
|
||||
db_session,
|
||||
target_channel,
|
||||
bot_user,
|
||||
):
|
||||
"""Test check-in to a channel."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.side_effect = [
|
||||
"<response><number>50</number><reason>Check channel</reason></response>",
|
||||
"Good morning everyone!",
|
||||
]
|
||||
mock_send.return_value = True
|
||||
|
||||
result = proactive.execute_proactive_checkin("channel", target_channel.id)
|
||||
|
||||
assert result["status"] == "sent"
|
||||
assert result["entity_type"] == "channel"
|
||||
|
||||
send_call = mock_send.call_args
|
||||
assert send_call.kwargs.get("channel_id") == target_channel.id
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_updates_last_proactive_at(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
mock_send,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test that last_proactive_at is updated after successful check-in."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.side_effect = [
|
||||
"<response><number>80</number><reason>Check in</reason></response>",
|
||||
"Hey there!",
|
||||
]
|
||||
mock_send.return_value = True
|
||||
|
||||
before_time = datetime.now(timezone.utc)
|
||||
proactive.execute_proactive_checkin("user", target_user.id)
|
||||
after_time = datetime.now(timezone.utc)
|
||||
|
||||
db_session.refresh(target_user)
|
||||
assert target_user.last_proactive_at is not None
|
||||
assert before_time <= target_user.last_proactive_at <= after_time
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_updates_last_proactive_at_on_skip(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test that last_proactive_at is updated even when check-in is skipped."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.return_value = (
|
||||
"<response><number>10</number><reason>Nothing to say</reason></response>"
|
||||
)
|
||||
|
||||
proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
db_session.refresh(target_user)
|
||||
assert target_user.last_proactive_at is not None
|
||||
Loading…
x
Reference in New Issue
Block a user