add scheduled calls

This commit is contained in:
Daniel O'Connell 2025-08-12 23:48:13 +02:00 committed by EC2 Default User
parent a2d107fad7
commit a3544222e7
19 changed files with 1218 additions and 53 deletions

View File

@ -0,0 +1,55 @@
"""discord schedules
Revision ID: 2fb3223dc71b
Revises: 1d6bc8015ea9
Create Date: 2025-08-12 23:43:27.671182
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "2fb3223dc71b"
down_revision: Union[str, None] = "1d6bc8015ea9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"scheduled_llm_calls",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("topic", sa.Text(), nullable=True),
sa.Column("scheduled_time", sa.DateTime(), nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column("executed_at", sa.DateTime(), nullable=True),
sa.Column("model", sa.String(), nullable=True),
sa.Column("prompt", sa.Text(), nullable=False),
sa.Column("system_prompt", sa.Text(), nullable=True),
sa.Column("allowed_tools", sa.JSON(), nullable=True),
sa.Column("discord_channel", sa.String(), nullable=True),
sa.Column("discord_user", sa.String(), nullable=True),
sa.Column("status", sa.String(), nullable=False),
sa.Column("response", sa.Text(), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("celery_task_id", sa.String(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("users", sa.Column("discord_user_id", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("users", "discord_user_id")
op.drop_table("scheduled_llm_calls")

View File

@ -174,7 +174,7 @@ services:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes"
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes,scheduler"
ingest-hub:
<<: *worker-base

View File

@ -1,3 +1,6 @@
import memory.api.MCP.manifest
import memory.api.MCP.tools
import memory.api.MCP.memory
import memory.api.MCP.metadata
import memory.api.MCP.schedules
import memory.api.MCP.books
import memory.api.MCP.manifest

View File

@ -8,23 +8,24 @@ from mcp.server.auth.handlers.token import (
RefreshTokenRequest,
TokenRequest,
)
from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
from mcp.server.fastmcp import FastMCP
from mcp.shared.auth import OAuthClientMetadata
from memory.common.db.models.users import User
from pydantic import AnyHttpUrl
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.templating import Jinja2Templates
from memory.api.MCP.oauth_provider import (
SimpleOAuthProvider,
ALLOWED_SCOPES,
BASE_SCOPES,
SimpleOAuthProvider,
)
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import OAuthState
from memory.common.db.models import OAuthState, UserSession
from memory.common.db.models.users import User
logger = logging.getLogger(__name__)
@ -134,3 +135,30 @@ async def handle_login(request: Request):
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
redirect_url = redirect_url.replace("http://", "cursor://")
return RedirectResponse(url=redirect_url, status_code=302)
def get_current_user() -> dict:
access_token = get_access_token()
if not access_token:
return {"authenticated": False}
with make_session() as session:
user_session = (
session.query(UserSession)
.filter(UserSession.id == access_token.token)
.first()
)
if user_session and user_session.user:
user_info = user_session.user.serialize()
else:
user_info = {"error": "User not found"}
return {
"authenticated": True,
"token_type": "Bearer",
"scopes": access_token.scopes,
"client_id": access_token.client_id,
"user": user_info,
}

View File

@ -13,7 +13,7 @@ from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.tools import mcp
from memory.api.MCP.base import mcp
from memory.api.search.search import search
from memory.api.search.types import SearchFilters, SearchConfig
from memory.common import extract, settings

View File

@ -0,0 +1,186 @@
"""
MCP tools for the epistemic sparring partner system.
"""
import logging
from datetime import datetime, timezone
from typing import Any
from memory.api.MCP.base import get_current_user
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
from memory.api.MCP.base import mcp
logger = logging.getLogger(__name__)
@mcp.tool()
async def schedule_llm_call(
scheduled_time: str,
model: str,
prompt: str,
topic: str | None = None,
discord_channel: str | None = None,
system_prompt: str | None = None,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Schedule an LLM call to be executed at a specific time with response sent to Discord.
Args:
scheduled_time: ISO format datetime string (e.g., "2024-12-20T15:30:00Z")
model: Model to use (e.g., "anthropic/claude-3-5-sonnet-20241022"). If not provided, the message will be sent to the user directly.
prompt: The prompt to send to the LLM
topic: The topic of the scheduled call. If not provided, the topic will be inferred from the prompt.
discord_channel: Discord channel name where the response should be sent. If not provided, the message will be sent to the user directly.
system_prompt: Optional system prompt
metadata: Optional metadata dict for tracking
Returns:
Dict with scheduled call ID and status
"""
logger.info("schedule_llm_call tool called")
current_user = get_current_user()
if not current_user["authenticated"]:
raise ValueError("Not authenticated")
user_id = current_user.get("user", {}).get("user_id")
if not user_id:
raise ValueError("User not found")
discord_user = current_user.get("user", {}).get("discord_user_id")
if not discord_user and not discord_channel:
raise ValueError("Either discord_user or discord_channel must be provided")
# Parse scheduled time
try:
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
# Ensure we store as naive datetime (remove timezone info for database storage)
if scheduled_dt.tzinfo is not None:
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
except ValueError:
raise ValueError("Invalid datetime format")
# Validate that the scheduled time is in the future
# Compare with naive datetime since we store naive in the database
current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None)
if scheduled_dt <= current_time_naive:
raise ValueError("Scheduled time must be in the future")
with make_session() as session:
# Create the scheduled call
scheduled_call = ScheduledLLMCall(
user_id=user_id,
scheduled_time=scheduled_dt,
topic=topic,
model=model,
prompt=prompt,
system_prompt=system_prompt,
discord_channel=discord_channel,
discord_user=discord_user,
data=metadata or {},
)
session.add(scheduled_call)
session.commit()
return {
"success": True,
"scheduled_call_id": scheduled_call.id,
"scheduled_time": scheduled_dt.isoformat(),
"message": f"LLM call scheduled for {scheduled_dt.isoformat()}",
}
@mcp.tool()
async def list_scheduled_llm_calls(
status: str | None = None, limit: int | None = 50
) -> dict[str, Any]:
"""
List scheduled LLM calls for the authenticated user.
Args:
status: Optional status filter ("pending", "executing", "completed", "failed", "cancelled")
limit: Maximum number of calls to return (default: 50)
Returns:
Dict with list of scheduled calls
"""
logger.info("list_scheduled_llm_calls tool called")
current_user = get_current_user()
if not current_user["authenticated"]:
return {"error": "Not authenticated", "user": current_user}
user_id = current_user.get("user", {}).get("user_id")
if not user_id:
return {"error": "User not found", "user": current_user}
with make_session() as session:
query = (
session.query(ScheduledLLMCall)
.filter(ScheduledLLMCall.user_id == user_id)
.order_by(ScheduledLLMCall.scheduled_time.desc())
)
if status:
query = query.filter(ScheduledLLMCall.status == status)
if limit:
query = query.limit(limit)
calls = query.all()
return {
"success": True,
"scheduled_calls": [call.serialize() for call in calls],
"count": len(calls),
}
@mcp.tool()
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
"""
Cancel a scheduled LLM call.
Args:
scheduled_call_id: ID of the scheduled call to cancel
Returns:
Dict with cancellation status
"""
logger.info(f"cancel_scheduled_llm_call tool called for ID: {scheduled_call_id}")
current_user = get_current_user()
if not current_user["authenticated"]:
return {"error": "Not authenticated", "user": current_user}
user_id = current_user.get("user", {}).get("user_id")
if not user_id:
return {"error": "User not found", "user": current_user}
with make_session() as session:
# Find the scheduled call
scheduled_call = (
session.query(ScheduledLLMCall)
.filter(
ScheduledLLMCall.id == scheduled_call_id,
ScheduledLLMCall.user_id == user_id,
)
.first()
)
if not scheduled_call:
return {"error": "Scheduled call not found"}
if not scheduled_call.can_be_cancelled():
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
# Update the status
scheduled_call.status = "cancelled"
session.commit()
logger.info(f"Scheduled LLM call {scheduled_call_id} cancelled")
return {
"success": True,
"message": f"Scheduled call {scheduled_call_id} has been cancelled",
}

View File

@ -5,18 +5,13 @@ MCP tools for the epistemic sparring partner system.
import logging
from datetime import datetime, timezone
from mcp.server.auth.middleware.auth_context import get_access_token
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.common.db.connection import make_session
from memory.common.db.models import (
AgentObservation,
SourceItem,
UserSession,
)
from memory.api.MCP.base import mcp
from memory.common.db.models import AgentObservation, SourceItem
from memory.api.MCP.base import mcp, get_current_user
logger = logging.getLogger(__name__)
@ -76,35 +71,4 @@ async def get_current_time() -> dict:
@mcp.tool()
async def get_authenticated_user() -> dict:
"""Get information about the authenticated user."""
logger.info("🔧 get_authenticated_user tool called")
access_token = get_access_token()
logger.info(f"🔧 Access token from MCP context: {access_token}")
if not access_token:
logger.warning("❌ No access token found in MCP context!")
return {"error": "Not authenticated"}
logger.info(
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
)
# Look up the actual user from the session token
with make_session() as session:
user_session = (
session.query(UserSession)
.filter(UserSession.id == access_token.token)
.first()
)
if user_session and user_session.user:
user_info = user_session.user.serialize()
else:
user_info = {"error": "User not found"}
return {
"authenticated": True,
"token_type": "Bearer",
"scopes": access_token.scopes,
"client_id": access_token.client_id,
"user": user_info,
}
return get_current_user()

View File

@ -11,6 +11,7 @@ EBOOK_ROOT = "memory.workers.tasks.ebook"
MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
NOTES_ROOT = "memory.workers.tasks.notes"
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
@ -44,6 +45,10 @@ SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
ADD_ARTICLE_FEED = f"{BLOGS_ROOT}.add_article_feed"
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
# Scheduled calls tasks
EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call"
RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
def get_broker_url() -> str:
protocol = settings.CELERY_BROKER_TYPE
@ -78,6 +83,9 @@ app.conf.update(
},
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
f"{SCHEDULED_CALLS_ROOT}.*": {
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
},
},
)

View File

@ -48,6 +48,9 @@ from memory.common.db.models.users import (
OAuthState,
OAuthRefreshToken,
)
from memory.common.db.models.scheduled_calls import (
ScheduledLLMCall,
)
Payload = (
SourceItemPayload
@ -96,6 +99,8 @@ __all__ = [
"OAuthClientInformation",
"OAuthState",
"OAuthRefreshToken",
# Scheduled Calls
"ScheduledLLMCall",
# Payloads
"Payload",
]

View File

@ -0,0 +1,92 @@
from datetime import datetime
import uuid
from typing import Any, Dict, cast
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
ForeignKey,
JSON,
Text,
)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
class ScheduledLLMCall(Base):
__tablename__ = "scheduled_llm_calls"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
topic = Column(Text, nullable=True)
# Scheduling info
scheduled_time = Column(DateTime, nullable=False)
created_at = Column(DateTime, server_default=func.now())
executed_at = Column(DateTime, nullable=True)
# LLM call configuration
model = Column(
String, nullable=True
) # e.g., "anthropic/claude-3-5-sonnet-20241022"
prompt = Column(Text, nullable=False)
system_prompt = Column(Text, nullable=True)
allowed_tools = Column(JSON, nullable=True) # List of allowed tool names
# Discord configuration
discord_channel = Column(String, nullable=True)
discord_user = Column(String, nullable=True)
# Execution status and results
status = Column(
String, nullable=False, default="pending"
) # pending, executing, completed, failed, cancelled
response = Column(Text, nullable=True) # LLM response content
error_message = Column(Text, nullable=True)
# Additional metadata
data = Column(JSON, nullable=True) # For extensibility
# Celery task tracking
celery_task_id = Column(String, nullable=True) # Track the Celery Beat task
# Relationships
user = relationship("User")
def serialize(self) -> Dict[str, Any]:
def print_datetime(dt: datetime | None) -> str | None:
if dt:
return dt.isoformat()
return None
return {
"id": self.id,
"user_id": self.user_id,
"topic": self.topic,
"scheduled_time": print_datetime(cast(datetime, self.scheduled_time)),
"created_at": print_datetime(cast(datetime, self.created_at)),
"executed_at": print_datetime(cast(datetime, self.executed_at)),
"model": self.model,
"prompt": self.prompt,
"system_prompt": self.system_prompt,
"allowed_tools": self.allowed_tools,
"discord_channel": self.discord_channel,
"discord_user": self.discord_user,
"status": self.status,
"response": self.response,
"error_message": self.error_message,
"metadata": self.data,
"celery_task_id": self.celery_task_id,
}
def is_pending(self) -> bool:
return cast(str, self.status) == "pending"
def is_completed(self) -> bool:
return cast(str, self.status) in ("completed", "failed", "cancelled")
def can_be_cancelled(self) -> bool:
return cast(str, self.status) in ("pending",)

View File

@ -2,6 +2,8 @@ import hashlib
import secrets
from typing import cast
import uuid
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from memory.common.db.models.base import Base
from sqlalchemy import (
Column,
@ -39,6 +41,7 @@ class User(Base):
name = Column(String, nullable=False)
email = Column(String, nullable=False, unique=True)
password_hash = Column(String, nullable=False)
discord_user_id = Column(String, nullable=True)
# Relationship to sessions
sessions = relationship(
@ -53,6 +56,7 @@ class User(Base):
"user_id": self.id,
"name": self.name,
"email": self.email,
"discord_user_id": self.discord_user_id,
}
def is_valid_password(self, password: str) -> bool:
@ -193,3 +197,15 @@ class OAuthRefreshToken(Base, OAuthToken):
"expires_at": self.expires_at.timestamp(),
"revoked": self.revoked,
} | super().serialize()
def purge_oauth(session: Session):
for token in session.query(OAuthRefreshToken).all():
session.delete(token)
for user_session in session.query(UserSession).all():
session.delete(user_session)
for oauth_state in session.query(OAuthState).all():
session.delete(oauth_state)
for oauth_client in session.query(OAuthClientInformation).all():
session.delete(oauth_client)

View File

@ -1,6 +1,7 @@
import logging
import requests
from typing import Any, Dict, List
import re
from typing import Any
from memory.common import settings
@ -19,6 +20,7 @@ class DiscordServer(requests.Session):
self.channels = {}
super().__init__(*args, **kwargs)
self.setup_channels()
self.members = self.fetch_all_members()
def setup_channels(self):
resp = self.get(self.channels_url)
@ -63,9 +65,20 @@ class DiscordServer(requests.Session):
return channel_id
def send_message(self, channel_id: str, content: str):
self.post(
payload: dict[str, Any] = {"content": content}
mentions = re.findall(r"@(\S*)", content)
users = {u: i for u, i in self.members.items() if u in mentions}
if users:
for u, i in users.items():
payload["content"] = payload["content"].replace(f"@{u}", f"<@{i}>")
payload["allowed_mentions"] = {
"parse": [],
"users": list(users.values()),
}
return self.post(
f"https://discord.com/api/v10/channels/{channel_id}/messages",
json={"content": content},
json=payload,
)
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
@ -91,8 +104,64 @@ class DiscordServer(requests.Session):
def channels_url(self) -> str:
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
@property
def members_url(self) -> str:
return f"https://discord.com/api/v10/guilds/{self.server_id}/members"
def get_bot_servers() -> List[Dict]:
@property
def dm_create_url(self) -> str:
return "https://discord.com/api/v10/users/@me/channels"
def list_members(
self, limit: int = 1000, after: str | None = None
) -> list[dict[str, Any]]:
"""List up to `limit` members in this guild, starting after a user ID.
Requires the bot to have the Server Members Intent enabled in the Discord developer portal.
"""
params: dict[str, Any] = {"limit": limit}
if after:
params["after"] = after
resp = self.get(self.members_url, params=params)
resp.raise_for_status()
return resp.json()
def fetch_all_members(self, page_size: int = 1000) -> dict[str, str]:
"""Retrieve all members in the guild by paginating the members list.
Note: Large guilds may take multiple requests. Rate limits are respected by requests.Session automatically.
"""
members: dict[str, str] = {}
after: str | None = None
while batch := self.list_members(limit=page_size, after=after):
for member in batch:
user = member.get("user", {})
members[user.get("global_name") or user.get("username", "")] = user.get(
"id", ""
)
after = user.get("id", "")
return members
def create_dm_channel(self, user_id: str) -> str:
"""Create (or retrieve) a DM channel with the given user and return the channel ID.
The bot must share a guild with the user, and the user's privacy settings must allow DMs from server members.
"""
resp = self.post(self.dm_create_url, json={"recipient_id": user_id})
resp.raise_for_status()
data = resp.json()
return data["id"]
def send_dm(self, user_id: str, content: str):
"""Send a direct message to a specific user by ID."""
channel_id = self.create_dm_channel(self.members.get(user_id) or user_id)
return self.post(
f"https://discord.com/api/v10/channels/{channel_id}/messages",
json={"content": content},
)
def get_bot_servers() -> list[dict[str, Any]]:
"""Get list of servers the bot is in."""
if not settings.DISCORD_BOT_TOKEN:
return []
@ -141,6 +210,14 @@ def send_chat_message(message: str):
broadcast_message(CHAT_CHANNEL, message)
def send_dm(user_id: str, message: str):
for server in servers.values():
if not server.members.get(user_id) and user_id not in server.members.values():
continue
server.send_dm(user_id, message)
def notify_task_failure(
task_name: str,
error_message: str,

View File

@ -107,6 +107,7 @@ CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 24 * 60 *
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
@ -168,6 +169,6 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
# Enable Discord notifications if bot token is set
DISCORD_NOTIFICATIONS_ENABLED = (
DISCORD_NOTIFICATIONS_ENABLED = bool(
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
)

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass, field
import logging
import time
from datetime import datetime, timedelta

View File

@ -9,6 +9,7 @@ from memory.common.celery_app import (
SYNC_ALL_ARTICLE_FEEDS,
TRACK_GIT_CHANGES,
SYNC_LESSWRONG,
RUN_SCHEDULED_CALLS,
)
logger = logging.getLogger(__name__)
@ -43,4 +44,8 @@ app.conf.beat_schedule = {
"task": SYNC_LESSWRONG,
"schedule": settings.LESSWRONG_SYNC_INTERVAL,
},
"run-scheduled-calls": {
"task": RUN_SCHEDULED_CALLS,
"schedule": settings.SCHEDULED_CALL_RUN_INTERVAL,
},
}

View File

@ -11,6 +11,7 @@ from memory.workers.tasks import (
maintenance,
notes,
observations,
scheduled_calls,
) # noqa
@ -23,4 +24,5 @@ __all__ = [
"maintenance",
"notes",
"observations",
"scheduled_calls",
]

View File

@ -0,0 +1,141 @@
import logging
from datetime import datetime, timezone
import textwrap
from typing import cast
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
from memory.common.celery_app import (
app,
EXECUTE_SCHEDULED_CALL,
RUN_SCHEDULED_CALLS,
)
from memory.common import llms, discord
from memory.workers.tasks.content_processing import safe_task_execution
logger = logging.getLogger(__name__)
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
"""
Send the LLM response to the specified Discord user.
Args:
scheduled_call: The scheduled call object
response: The LLM response to send
"""
message = response
if cast(str, scheduled_call.topic):
message = f"**{scheduled_call.topic}**\n\n{message}"
# Discord has a 2000 character limit, so we may need to split the message
if len(message) > 1900: # Leave some buffer
message = message[:1900] + "\n\n... (response truncated)"
if discord_user := cast(str, scheduled_call.discord_user):
logger.info(f"Sending DM to {discord_user}: {message}")
discord.send_dm(discord_user, message)
elif discord_channel := cast(str, scheduled_call.discord_channel):
logger.info(f"Broadcasting message to {discord_channel}: {message}")
discord.broadcast_message(discord_channel, message)
else:
logger.warning(
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
)
@app.task(bind=True, name=EXECUTE_SCHEDULED_CALL)
@safe_task_execution
def execute_scheduled_call(self, scheduled_call_id: str):
"""
Execute a scheduled LLM call and send the response to Discord.
Args:
scheduled_call_id: The ID of the scheduled call to execute
"""
logger.info(f"Executing scheduled LLM call: {scheduled_call_id}")
with make_session() as session:
# Fetch the scheduled call
scheduled_call = (
session.query(ScheduledLLMCall)
.filter(ScheduledLLMCall.id == scheduled_call_id)
.first()
)
if not scheduled_call:
logger.error(f"Scheduled call {scheduled_call_id} not found")
return {"error": "Scheduled call not found"}
# Check if the call is still pending
if not scheduled_call.is_pending():
logger.warning(
f"Scheduled call {scheduled_call_id} is not pending (status: {scheduled_call.status})"
)
return {"error": f"Call is not pending (status: {scheduled_call.status})"}
# Update status to executing
scheduled_call.status = "executing"
scheduled_call.executed_at = datetime.now(timezone.utc)
session.commit()
logger.info(f"Calling LLM with model {scheduled_call.model}")
# Make the LLM call
if scheduled_call.model:
response = llms.call(
prompt=cast(str, scheduled_call.prompt),
model=cast(str, scheduled_call.model),
system_prompt=cast(str, scheduled_call.system_prompt)
or llms.SYSTEM_PROMPT,
)
else:
response = cast(str, scheduled_call.prompt)
# Store the response
scheduled_call.response = response
scheduled_call.status = "completed"
session.commit()
logger.info(f"LLM call completed for {scheduled_call_id}")
# Send to Discord
try:
_send_to_discord(scheduled_call, response)
logger.info(f"Response sent to Discord for {scheduled_call_id}")
except Exception as discord_error:
logger.error(f"Failed to send to Discord: {discord_error}")
# Don't mark as failed since the LLM call succeeded
scheduled_call.data = scheduled_call.data or {}
scheduled_call.data["discord_error"] = str(discord_error)
session.commit()
return {
"success": True,
"scheduled_call_id": scheduled_call_id,
"response": response[:100] + "..." if len(response) > 100 else response,
"discord_sent": True,
}
@app.task(name=RUN_SCHEDULED_CALLS)
@safe_task_execution
def run_scheduled_calls():
"""Run scheduled calls that are due."""
with make_session() as session:
calls = (
session.query(ScheduledLLMCall)
.filter(
ScheduledLLMCall.status.in_(["pending"]),
ScheduledLLMCall.scheduled_time
< datetime.now(timezone.utc).replace(tzinfo=None),
)
.all()
)
for call in calls:
execute_scheduled_call.delay(call.id)
return {
"calls": [call.id for call in calls],
"count": len(calls),
}

View File

@ -1,8 +1,6 @@
import logging
import pytest
from unittest.mock import Mock, patch
import requests
import json
from memory.common import discord, settings

View File

@ -0,0 +1,585 @@
import pytest
from datetime import datetime, timezone, timedelta
from unittest.mock import Mock, patch
import uuid
from memory.common.db.models import ScheduledLLMCall, User
from memory.workers.tasks import scheduled_calls
@pytest.fixture
def sample_user(db_session):
"""Create a sample user for testing."""
user = User(
name="testuser",
email="test@example.com",
discord_user_id="123456789",
password_hash="password",
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def pending_scheduled_call(db_session, sample_user):
"""Create a pending scheduled call for testing."""
call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
topic="Test Topic",
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="anthropic/claude-3-5-sonnet-20241022",
prompt="What is the weather like today?",
system_prompt="You are a helpful assistant.",
discord_user="123456789",
status="pending",
)
db_session.add(call)
db_session.commit()
return call
@pytest.fixture
def completed_scheduled_call(db_session, sample_user):
"""Create a completed scheduled call for testing."""
call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
topic="Completed Topic",
scheduled_time=datetime.now(timezone.utc) - timedelta(hours=1),
executed_at=datetime.now(timezone.utc) - timedelta(minutes=30),
model="anthropic/claude-3-5-sonnet-20241022",
prompt="Tell me a joke.",
system_prompt="You are a funny assistant.",
discord_channel="987654321",
status="completed",
response="Why did the chicken cross the road? To get to the other side!",
)
db_session.add(call)
db_session.commit()
return call
@pytest.fixture
def future_scheduled_call(db_session, sample_user):
"""Create a future scheduled call for testing."""
call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
topic="Future Topic",
scheduled_time=datetime.now(timezone.utc) + timedelta(hours=1),
model="anthropic/claude-3-5-sonnet-20241022",
prompt="What will happen tomorrow?",
discord_user="123456789",
status="pending",
)
db_session.add(call)
db_session.commit()
return call
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
"""Test sending to Discord user."""
response = "This is a test response."
scheduled_calls._send_to_discord(pending_scheduled_call, response)
mock_send_dm.assert_called_once_with(
"123456789",
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
)
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
"""Test sending to Discord channel."""
response = "This is a channel response."
scheduled_calls._send_to_discord(completed_scheduled_call, response)
mock_broadcast.assert_called_once_with(
"987654321",
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
)
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
"""Test message truncation for long responses."""
long_response = "A" * 2500 # Very long response
scheduled_calls._send_to_discord(pending_scheduled_call, long_response)
# Verify the message was truncated
args, kwargs = mock_send_dm.call_args
message = args[1]
assert len(message) <= 1950 # Should be truncated
assert message.endswith("... (response truncated)")
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
"""Test that normal length messages are not truncated."""
normal_response = "This is a normal length response."
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args
message = args[1]
assert not message.endswith("... (response truncated)")
assert "This is a normal length response." in message
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.call")
def test_execute_scheduled_call_success(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
"""Test successful execution of a scheduled LLM call."""
mock_llm_call.return_value = "The weather is sunny today!"
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
# Verify LLM was called with correct parameters
mock_llm_call.assert_called_once_with(
prompt="What is the weather like today?",
model="anthropic/claude-3-5-sonnet-20241022",
system_prompt="You are a helpful assistant.",
)
# Verify result
assert result["success"] is True
assert result["scheduled_call_id"] == pending_scheduled_call.id
assert result["response"] == "The weather is sunny today!"
assert result["discord_sent"] is True
# Verify database was updated
db_session.refresh(pending_scheduled_call)
assert pending_scheduled_call.status == "completed"
assert pending_scheduled_call.response == "The weather is sunny today!"
assert pending_scheduled_call.executed_at is not None
def test_execute_scheduled_call_not_found(db_session):
"""Test execution with non-existent call ID."""
fake_id = str(uuid.uuid4())
result = scheduled_calls.execute_scheduled_call(fake_id)
assert result == {"error": "Scheduled call not found"}
@patch("memory.workers.tasks.scheduled_calls.llms.call")
def test_execute_scheduled_call_not_pending(
mock_llm_call, completed_scheduled_call, db_session
):
"""Test execution of a call that is not pending."""
result = scheduled_calls.execute_scheduled_call(completed_scheduled_call.id)
assert result == {"error": "Call is not pending (status: completed)"}
mock_llm_call.assert_not_called()
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.call")
def test_execute_scheduled_call_with_default_system_prompt(
mock_llm_call, mock_send_discord, db_session, sample_user
):
"""Test execution when system_prompt is None, should use default."""
# Create call without system prompt
call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
topic="No System Prompt",
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="anthropic/claude-3-5-sonnet-20241022",
prompt="Test prompt",
system_prompt=None,
discord_user="123456789",
status="pending",
)
db_session.add(call)
db_session.commit()
mock_llm_call.return_value = "Response"
scheduled_calls.execute_scheduled_call(call.id)
# Verify default system prompt was used
mock_llm_call.assert_called_once_with(
prompt="Test prompt",
model="anthropic/claude-3-5-sonnet-20241022",
system_prompt=scheduled_calls.llms.SYSTEM_PROMPT,
)
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.call")
def test_execute_scheduled_call_discord_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
"""Test execution when Discord sending fails."""
mock_llm_call.return_value = "LLM response"
mock_send_discord.side_effect = Exception("Discord API error")
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
# Should still return success since LLM call succeeded
assert result["success"] is True
assert (
result["discord_sent"] is True
) # This is always True in current implementation
# Verify the call was marked as completed despite Discord error
db_session.refresh(pending_scheduled_call)
assert pending_scheduled_call.status == "completed"
assert pending_scheduled_call.response == "LLM response"
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.call")
def test_execute_scheduled_call_llm_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
"""Test execution when LLM call fails."""
mock_llm_call.side_effect = Exception("LLM API error")
# The safe_task_execution decorator should catch this
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
assert result["status"] == "error"
assert "LLM API error" in result["error"]
# Discord should not be called
mock_send_discord.assert_not_called()
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.call")
def test_execute_scheduled_call_long_response_truncation(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
"""Test that long responses are truncated in the result."""
long_response = "A" * 500 # Long response
mock_llm_call.return_value = long_response
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
# Response in result should be truncated
assert len(result["response"]) <= 103 # 100 chars + "..."
assert result["response"].endswith("...")
# But full response should be stored in database
db_session.refresh(pending_scheduled_call)
assert pending_scheduled_call.response == long_response
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
def test_run_scheduled_calls_with_due_calls(
mock_execute_delay, db_session, sample_user
):
"""Test running scheduled calls with due calls."""
# Create multiple due calls
due_call1 = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=10),
model="test-model",
prompt="Test 1",
discord_user="123",
status="pending",
)
due_call2 = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
prompt="Test 2",
discord_user="123",
status="pending",
)
db_session.add_all([due_call1, due_call2])
db_session.commit()
mock_task = Mock()
mock_task.id = "task-123"
mock_execute_delay.delay.return_value = mock_task
result = scheduled_calls.run_scheduled_calls()
assert result["count"] == 2
assert due_call1.id in result["calls"]
assert due_call2.id in result["calls"]
# Verify execute_scheduled_call.delay was called for both
assert mock_execute_delay.delay.call_count == 2
mock_execute_delay.delay.assert_any_call(due_call1.id)
mock_execute_delay.delay.assert_any_call(due_call2.id)
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
def test_run_scheduled_calls_no_due_calls(
mock_execute_delay, future_scheduled_call, db_session
):
"""Test running scheduled calls when no calls are due."""
result = scheduled_calls.run_scheduled_calls()
assert result["count"] == 0
assert result["calls"] == []
# No tasks should be scheduled
mock_execute_delay.delay.assert_not_called()
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
def test_run_scheduled_calls_mixed_statuses(
mock_execute_delay, db_session, sample_user
):
"""Test that only pending calls are processed."""
# Create calls with different statuses
pending_call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
prompt="Pending",
discord_user="123",
status="pending",
)
executing_call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
prompt="Executing",
discord_user="123",
status="executing",
)
completed_call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
prompt="Completed",
discord_user="123",
status="completed",
)
db_session.add_all([pending_call, executing_call, completed_call])
db_session.commit()
mock_task = Mock()
mock_task.id = "task-123"
mock_execute_delay.delay.return_value = mock_task
result = scheduled_calls.run_scheduled_calls()
# Only the pending call should be processed
assert result["count"] == 1
assert result["calls"] == [pending_call.id]
mock_execute_delay.delay.assert_called_once_with(pending_call.id)
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
def test_run_scheduled_calls_timezone_handling(
mock_execute_delay, db_session, sample_user
):
"""Test that timezone handling works correctly."""
# Create a call that's due (scheduled time in the past)
past_time = datetime.now(timezone.utc) - timedelta(minutes=5)
due_call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
scheduled_time=past_time.replace(tzinfo=None), # Store as naive datetime
model="test-model",
prompt="Due call",
discord_user="123",
status="pending",
)
# Create a call that's not due yet
future_time = datetime.now(timezone.utc) + timedelta(minutes=5)
future_call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
scheduled_time=future_time.replace(tzinfo=None), # Store as naive datetime
model="test-model",
prompt="Future call",
discord_user="123",
status="pending",
)
db_session.add_all([due_call, future_call])
db_session.commit()
mock_task = Mock()
mock_task.id = "task-123"
mock_execute_delay.delay.return_value = mock_task
result = scheduled_calls.run_scheduled_calls()
# Only the due call should be processed
assert result["count"] == 1
assert result["calls"] == [due_call.id]
mock_execute_delay.delay.assert_called_once_with(due_call.id)
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.call")
def test_status_transition_pending_to_executing_to_completed(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
"""Test that status transitions correctly during execution."""
mock_llm_call.return_value = "Response"
# Initial status should be pending
assert pending_scheduled_call.status == "pending"
assert pending_scheduled_call.executed_at is None
scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
# Final status should be completed
db_session.refresh(pending_scheduled_call)
assert pending_scheduled_call.status == "completed"
assert pending_scheduled_call.executed_at is not None
assert pending_scheduled_call.response == "Response"
@pytest.mark.parametrize(
"discord_user,discord_channel,expected_method",
[
("123456789", None, "send_dm"),
(None, "987654321", "broadcast_message"),
("123456789", "987654321", "send_dm"), # User takes precedence
],
)
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
def test_discord_destination_priority(
mock_broadcast,
mock_send_dm,
discord_user,
discord_channel,
expected_method,
db_session,
sample_user,
):
"""Test that Discord user takes precedence over channel."""
call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
topic="Priority Test",
scheduled_time=datetime.now(timezone.utc),
model="test-model",
prompt="Test",
discord_user=discord_user,
discord_channel=discord_channel,
status="pending",
)
db_session.add(call)
db_session.commit()
response = "Test response"
scheduled_calls._send_to_discord(call, response)
if expected_method == "send_dm":
mock_send_dm.assert_called_once()
mock_broadcast.assert_not_called()
else:
mock_broadcast.assert_called_once()
mock_send_dm.assert_not_called()
@pytest.mark.parametrize(
"topic,model,response,expected_in_message",
[
(
"Weather Check",
"anthropic/claude-3-5-sonnet-20241022",
"It's sunny!",
[
"**Topic:** Weather Check",
"**Model:** anthropic/claude-3-5-sonnet-20241022",
"**Response:** It's sunny!",
],
),
(
"Test Topic",
"gpt-4",
"Hello world",
["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
),
(
"Long Topic Name Here",
"claude-2",
"Short",
[
"**Topic:** Long Topic Name Here",
"**Model:** claude-2",
"**Response:** Short",
],
),
],
)
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
"""Test the Discord message formatting with different inputs."""
# Create a mock scheduled call
mock_call = Mock()
mock_call.topic = topic
mock_call.model = model
mock_call.discord_user = "123456789"
scheduled_calls._send_to_discord(mock_call, response)
# Get the actual message that was sent
args, kwargs = mock_send_dm.call_args
actual_message = args[1]
# Verify all expected parts are in the message
for expected_part in expected_in_message:
assert expected_part in actual_message
@pytest.mark.parametrize(
"status,should_execute",
[
("pending", True),
("executing", False),
("completed", False),
("failed", False),
("cancelled", False),
],
)
@patch("memory.workers.tasks.scheduled_calls.llms.call")
def test_execute_scheduled_call_status_check(
mock_llm_call, status, should_execute, db_session, sample_user
):
"""Test that only pending calls are executed."""
call = ScheduledLLMCall(
id=str(uuid.uuid4()),
user_id=sample_user.id,
topic="Status Test",
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model",
prompt="Test",
discord_user="123",
status=status,
)
db_session.add(call)
db_session.commit()
result = scheduled_calls.execute_scheduled_call(call.id)
if should_execute:
mock_llm_call.assert_called_once()
# We don't check the full result here since it depends on mocking more functions
else:
assert result == {"error": f"Call is not pending (status: {status})"}
mock_llm_call.assert_not_called()