change schedule call signature

This commit is contained in:
Daniel O'Connell 2025-10-12 10:14:01 +02:00
parent a3544222e7
commit f454aa9afa
5 changed files with 73 additions and 31 deletions

View File

@ -0,0 +1,29 @@
"""rename_prompt_to_message_in_scheduled_calls
Revision ID: c86079073c1d
Revises: 2fb3223dc71b
Create Date: 2025-10-12 10:12:57.421009
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c86079073c1d"
down_revision: Union[str, None] = "2fb3223dc71b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Rename prompt column to message in scheduled_llm_calls table
op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message")
def downgrade() -> None:
# Rename message column back to prompt in scheduled_llm_calls table
op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt")

View File

@ -15,23 +15,29 @@ logger = logging.getLogger(__name__)
@mcp.tool() @mcp.tool()
async def schedule_llm_call( async def schedule_message(
scheduled_time: str, scheduled_time: str,
model: str, message: str | None = None,
prompt: str, model: str | None = None,
topic: str | None = None, topic: str | None = None,
discord_channel: str | None = None, discord_channel: str | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Schedule an LLM call to be executed at a specific time with response sent to Discord. Schedule an message to be sent to the user's Discord at specific time.
This can either be a string to be sent, or a prompt that should be
be first sent to an LLM to generate the final message to be sent.
If `model` is empty, the message will be sent as is. If a model is provided, the message will first be sent to that AI system, and the user
will be sent whatever the AI system generates.
Args: Args:
scheduled_time: ISO format datetime string (e.g., "2024-12-20T15:30:00Z") scheduled_time: ISO format datetime string (e.g., "2024-12-20T15:30:00Z")
model: Model to use (e.g., "anthropic/claude-3-5-sonnet-20241022"). If not provided, the message will be sent to the user directly. message: A raw message to be sent to the user, or prompt to the LLM if `model` is set
prompt: The prompt to send to the LLM model: Model to use (e.g., "anthropic/claude-3-5-sonnet-20241022"). If not provided, the message will be sent to the user directly. Currently only OpenAI and Anthropic models are supported
topic: The topic of the scheduled call. If not provided, the topic will be inferred from the prompt. topic: The topic of the scheduled call. If not provided, the topic will be inferred from the prompt (if provided).
discord_channel: Discord channel name where the response should be sent. If not provided, the message will be sent to the user directly. discord_channel: Discord channel name where the response should be sent. If not provided, the message will be sent to the user directly.
system_prompt: Optional system prompt system_prompt: Optional system prompt
metadata: Optional metadata dict for tracking metadata: Optional metadata dict for tracking
@ -39,7 +45,9 @@ async def schedule_llm_call(
Returns: Returns:
Dict with scheduled call ID and status Dict with scheduled call ID and status
""" """
logger.info("schedule_llm_call tool called") logger.info("schedule_message tool called")
if not message:
raise ValueError("You must provide `message`")
current_user = get_current_user() current_user = get_current_user()
if not current_user["authenticated"]: if not current_user["authenticated"]:
@ -72,9 +80,9 @@ async def schedule_llm_call(
scheduled_call = ScheduledLLMCall( scheduled_call = ScheduledLLMCall(
user_id=user_id, user_id=user_id,
scheduled_time=scheduled_dt, scheduled_time=scheduled_dt,
message=message,
topic=topic, topic=topic,
model=model, model=model,
prompt=prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
discord_channel=discord_channel, discord_channel=discord_channel,
discord_user=discord_user, discord_user=discord_user,

View File

@ -30,9 +30,9 @@ class ScheduledLLMCall(Base):
# LLM call configuration # LLM call configuration
model = Column( model = Column(
String, nullable=True String, nullable=True, doc='e.g., "anthropic/claude-3-5-sonnet-20241022"'
) # e.g., "anthropic/claude-3-5-sonnet-20241022" )
prompt = Column(Text, nullable=False) message = Column(Text, nullable=False)
system_prompt = Column(Text, nullable=True) system_prompt = Column(Text, nullable=True)
allowed_tools = Column(JSON, nullable=True) # List of allowed tool names allowed_tools = Column(JSON, nullable=True) # List of allowed tool names
@ -70,7 +70,7 @@ class ScheduledLLMCall(Base):
"created_at": print_datetime(cast(datetime, self.created_at)), "created_at": print_datetime(cast(datetime, self.created_at)),
"executed_at": print_datetime(cast(datetime, self.executed_at)), "executed_at": print_datetime(cast(datetime, self.executed_at)),
"model": self.model, "model": self.model,
"prompt": self.prompt, "prompt": self.message,
"system_prompt": self.system_prompt, "system_prompt": self.system_prompt,
"allowed_tools": self.allowed_tools, "allowed_tools": self.allowed_tools,
"discord_channel": self.discord_channel, "discord_channel": self.discord_channel,

View File

@ -1,6 +1,5 @@
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
import textwrap
from typing import cast from typing import cast
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
@ -24,9 +23,15 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
scheduled_call: The scheduled call object scheduled_call: The scheduled call object
response: The LLM response to send response: The LLM response to send
""" """
message = response # Format the message with topic, model, and response
message_parts = []
if cast(str, scheduled_call.topic): if cast(str, scheduled_call.topic):
message = f"**{scheduled_call.topic}**\n\n{message}" message_parts.append(f"**Topic:** {scheduled_call.topic}")
if cast(str, scheduled_call.model):
message_parts.append(f"**Model:** {scheduled_call.model}")
message_parts.append(f"**Response:** {response}")
message = "\n".join(message_parts)
# Discord has a 2000 character limit, so we may need to split the message # Discord has a 2000 character limit, so we may need to split the message
if len(message) > 1900: # Leave some buffer if len(message) > 1900: # Leave some buffer
@ -84,13 +89,13 @@ def execute_scheduled_call(self, scheduled_call_id: str):
# Make the LLM call # Make the LLM call
if scheduled_call.model: if scheduled_call.model:
response = llms.call( response = llms.call(
prompt=cast(str, scheduled_call.prompt), prompt=cast(str, scheduled_call.message),
model=cast(str, scheduled_call.model), model=cast(str, scheduled_call.model),
system_prompt=cast(str, scheduled_call.system_prompt) system_prompt=cast(str, scheduled_call.system_prompt)
or llms.SYSTEM_PROMPT, or llms.SYSTEM_PROMPT,
) )
else: else:
response = cast(str, scheduled_call.prompt) response = cast(str, scheduled_call.message)
# Store the response # Store the response
scheduled_call.response = response scheduled_call.response = response

View File

@ -30,7 +30,7 @@ def pending_scheduled_call(db_session, sample_user):
topic="Test Topic", topic="Test Topic",
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="anthropic/claude-3-5-sonnet-20241022", model="anthropic/claude-3-5-sonnet-20241022",
prompt="What is the weather like today?", message="What is the weather like today?",
system_prompt="You are a helpful assistant.", system_prompt="You are a helpful assistant.",
discord_user="123456789", discord_user="123456789",
status="pending", status="pending",
@ -50,7 +50,7 @@ def completed_scheduled_call(db_session, sample_user):
scheduled_time=datetime.now(timezone.utc) - timedelta(hours=1), scheduled_time=datetime.now(timezone.utc) - timedelta(hours=1),
executed_at=datetime.now(timezone.utc) - timedelta(minutes=30), executed_at=datetime.now(timezone.utc) - timedelta(minutes=30),
model="anthropic/claude-3-5-sonnet-20241022", model="anthropic/claude-3-5-sonnet-20241022",
prompt="Tell me a joke.", message="Tell me a joke.",
system_prompt="You are a funny assistant.", system_prompt="You are a funny assistant.",
discord_channel="987654321", discord_channel="987654321",
status="completed", status="completed",
@ -70,7 +70,7 @@ def future_scheduled_call(db_session, sample_user):
topic="Future Topic", topic="Future Topic",
scheduled_time=datetime.now(timezone.utc) + timedelta(hours=1), scheduled_time=datetime.now(timezone.utc) + timedelta(hours=1),
model="anthropic/claude-3-5-sonnet-20241022", model="anthropic/claude-3-5-sonnet-20241022",
prompt="What will happen tomorrow?", message="What will happen tomorrow?",
discord_user="123456789", discord_user="123456789",
status="pending", status="pending",
) )
@ -195,7 +195,7 @@ def test_execute_scheduled_call_with_default_system_prompt(
topic="No System Prompt", topic="No System Prompt",
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="anthropic/claude-3-5-sonnet-20241022", model="anthropic/claude-3-5-sonnet-20241022",
prompt="Test prompt", message="Test prompt",
system_prompt=None, system_prompt=None,
discord_user="123456789", discord_user="123456789",
status="pending", status="pending",
@ -288,7 +288,7 @@ def test_run_scheduled_calls_with_due_calls(
user_id=sample_user.id, user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=10), scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=10),
model="test-model", model="test-model",
prompt="Test 1", message="Test 1",
discord_user="123", discord_user="123",
status="pending", status="pending",
) )
@ -297,7 +297,7 @@ def test_run_scheduled_calls_with_due_calls(
user_id=sample_user.id, user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model", model="test-model",
prompt="Test 2", message="Test 2",
discord_user="123", discord_user="123",
status="pending", status="pending",
) )
@ -346,7 +346,7 @@ def test_run_scheduled_calls_mixed_statuses(
user_id=sample_user.id, user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model", model="test-model",
prompt="Pending", message="Pending",
discord_user="123", discord_user="123",
status="pending", status="pending",
) )
@ -355,7 +355,7 @@ def test_run_scheduled_calls_mixed_statuses(
user_id=sample_user.id, user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model", model="test-model",
prompt="Executing", message="Executing",
discord_user="123", discord_user="123",
status="executing", status="executing",
) )
@ -364,7 +364,7 @@ def test_run_scheduled_calls_mixed_statuses(
user_id=sample_user.id, user_id=sample_user.id,
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model", model="test-model",
prompt="Completed", message="Completed",
discord_user="123", discord_user="123",
status="completed", status="completed",
) )
@ -397,7 +397,7 @@ def test_run_scheduled_calls_timezone_handling(
user_id=sample_user.id, user_id=sample_user.id,
scheduled_time=past_time.replace(tzinfo=None), # Store as naive datetime scheduled_time=past_time.replace(tzinfo=None), # Store as naive datetime
model="test-model", model="test-model",
prompt="Due call", message="Due call",
discord_user="123", discord_user="123",
status="pending", status="pending",
) )
@ -409,7 +409,7 @@ def test_run_scheduled_calls_timezone_handling(
user_id=sample_user.id, user_id=sample_user.id,
scheduled_time=future_time.replace(tzinfo=None), # Store as naive datetime scheduled_time=future_time.replace(tzinfo=None), # Store as naive datetime
model="test-model", model="test-model",
prompt="Future call", message="Future call",
discord_user="123", discord_user="123",
status="pending", status="pending",
) )
@ -477,7 +477,7 @@ def test_discord_destination_priority(
topic="Priority Test", topic="Priority Test",
scheduled_time=datetime.now(timezone.utc), scheduled_time=datetime.now(timezone.utc),
model="test-model", model="test-model",
prompt="Test", message="Test",
discord_user=discord_user, discord_user=discord_user,
discord_channel=discord_channel, discord_channel=discord_channel,
status="pending", status="pending",
@ -568,7 +568,7 @@ def test_execute_scheduled_call_status_check(
topic="Status Test", topic="Status Test",
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
model="test-model", model="test-model",
prompt="Test", message="Test",
discord_user="123", discord_user="123",
status=status, status=status,
) )