From f454aa9afaf463dfebf02d03d0579d5df2b5a62b Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 12 Oct 2025 10:14:01 +0200 Subject: [PATCH] change schedule call signature --- ..._rename_prompt_to_message_in_scheduled_.py | 29 +++++++++++++++++++ src/memory/api/MCP/schedules.py | 26 +++++++++++------ .../common/db/models/scheduled_calls.py | 8 ++--- src/memory/workers/tasks/scheduled_calls.py | 15 ++++++---- .../workers/tasks/test_scheduled_calls.py | 26 ++++++++--------- 5 files changed, 73 insertions(+), 31 deletions(-) create mode 100644 db/migrations/versions/20251012_101257_rename_prompt_to_message_in_scheduled_.py diff --git a/db/migrations/versions/20251012_101257_rename_prompt_to_message_in_scheduled_.py b/db/migrations/versions/20251012_101257_rename_prompt_to_message_in_scheduled_.py new file mode 100644 index 0000000..a11ae23 --- /dev/null +++ b/db/migrations/versions/20251012_101257_rename_prompt_to_message_in_scheduled_.py @@ -0,0 +1,29 @@ +"""rename_prompt_to_message_in_scheduled_calls + +Revision ID: c86079073c1d +Revises: 2fb3223dc71b +Create Date: 2025-10-12 10:12:57.421009 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "c86079073c1d" +down_revision: Union[str, None] = "2fb3223dc71b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Rename prompt column to message in scheduled_llm_calls table + op.alter_column("scheduled_llm_calls", "prompt", new_column_name="message") + + +def downgrade() -> None: + # Rename message column back to prompt in scheduled_llm_calls table + op.alter_column("scheduled_llm_calls", "message", new_column_name="prompt") diff --git a/src/memory/api/MCP/schedules.py b/src/memory/api/MCP/schedules.py index 485f5f7..78740b6 100644 --- a/src/memory/api/MCP/schedules.py +++ b/src/memory/api/MCP/schedules.py @@ -15,23 +15,29 @@ logger = logging.getLogger(__name__) @mcp.tool() -async def schedule_llm_call( +async def schedule_message( scheduled_time: str, - model: str, - prompt: str, + message: str | None = None, + model: str | None = None, topic: str | None = None, discord_channel: str | None = None, system_prompt: str | None = None, metadata: dict[str, Any] | None = None, ) -> dict[str, Any]: """ - Schedule an LLM call to be executed at a specific time with response sent to Discord. + Schedule an message to be sent to the user's Discord at specific time. + + This can either be a string to be sent, or a prompt that should be + be first sent to an LLM to generate the final message to be sent. + + If `model` is empty, the message will be sent as is. If a model is provided, the message will first be sent to that AI system, and the user + will be sent whatever the AI system generates. Args: scheduled_time: ISO format datetime string (e.g., "2024-12-20T15:30:00Z") - model: Model to use (e.g., "anthropic/claude-3-5-sonnet-20241022"). If not provided, the message will be sent to the user directly. - prompt: The prompt to send to the LLM - topic: The topic of the scheduled call. If not provided, the topic will be inferred from the prompt. + message: A raw message to be sent to the user, or prompt to the LLM if `model` is set + model: Model to use (e.g., "anthropic/claude-3-5-sonnet-20241022"). If not provided, the message will be sent to the user directly. Currently only OpenAI and Anthropic models are supported + topic: The topic of the scheduled call. If not provided, the topic will be inferred from the prompt (if provided). discord_channel: Discord channel name where the response should be sent. If not provided, the message will be sent to the user directly. system_prompt: Optional system prompt metadata: Optional metadata dict for tracking @@ -39,7 +45,9 @@ async def schedule_llm_call( Returns: Dict with scheduled call ID and status """ - logger.info("schedule_llm_call tool called") + logger.info("schedule_message tool called") + if not message: + raise ValueError("You must provide `message`") current_user = get_current_user() if not current_user["authenticated"]: @@ -72,9 +80,9 @@ async def schedule_llm_call( scheduled_call = ScheduledLLMCall( user_id=user_id, scheduled_time=scheduled_dt, + message=message, topic=topic, model=model, - prompt=prompt, system_prompt=system_prompt, discord_channel=discord_channel, discord_user=discord_user, diff --git a/src/memory/common/db/models/scheduled_calls.py b/src/memory/common/db/models/scheduled_calls.py index b4238ab..097fdd9 100644 --- a/src/memory/common/db/models/scheduled_calls.py +++ b/src/memory/common/db/models/scheduled_calls.py @@ -30,9 +30,9 @@ class ScheduledLLMCall(Base): # LLM call configuration model = Column( - String, nullable=True - ) # e.g., "anthropic/claude-3-5-sonnet-20241022" - prompt = Column(Text, nullable=False) + String, nullable=True, doc='e.g., "anthropic/claude-3-5-sonnet-20241022"' + ) + message = Column(Text, nullable=False) system_prompt = Column(Text, nullable=True) allowed_tools = Column(JSON, nullable=True) # List of allowed tool names @@ -70,7 +70,7 @@ class ScheduledLLMCall(Base): "created_at": print_datetime(cast(datetime, self.created_at)), "executed_at": print_datetime(cast(datetime, self.executed_at)), "model": self.model, - "prompt": self.prompt, + "prompt": self.message, "system_prompt": self.system_prompt, "allowed_tools": self.allowed_tools, "discord_channel": self.discord_channel, diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py index e2df55a..200cd2a 100644 --- a/src/memory/workers/tasks/scheduled_calls.py +++ b/src/memory/workers/tasks/scheduled_calls.py @@ -1,6 +1,5 @@ import logging from datetime import datetime, timezone -import textwrap from typing import cast from memory.common.db.connection import make_session @@ -24,9 +23,15 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str): scheduled_call: The scheduled call object response: The LLM response to send """ - message = response + # Format the message with topic, model, and response + message_parts = [] if cast(str, scheduled_call.topic): - message = f"**{scheduled_call.topic}**\n\n{message}" + message_parts.append(f"**Topic:** {scheduled_call.topic}") + if cast(str, scheduled_call.model): + message_parts.append(f"**Model:** {scheduled_call.model}") + message_parts.append(f"**Response:** {response}") + + message = "\n".join(message_parts) # Discord has a 2000 character limit, so we may need to split the message if len(message) > 1900: # Leave some buffer @@ -84,13 +89,13 @@ def execute_scheduled_call(self, scheduled_call_id: str): # Make the LLM call if scheduled_call.model: response = llms.call( - prompt=cast(str, scheduled_call.prompt), + prompt=cast(str, scheduled_call.message), model=cast(str, scheduled_call.model), system_prompt=cast(str, scheduled_call.system_prompt) or llms.SYSTEM_PROMPT, ) else: - response = cast(str, scheduled_call.prompt) + response = cast(str, scheduled_call.message) # Store the response scheduled_call.response = response diff --git a/tests/memory/workers/tasks/test_scheduled_calls.py b/tests/memory/workers/tasks/test_scheduled_calls.py index 9dc417e..8b9be2a 100644 --- a/tests/memory/workers/tasks/test_scheduled_calls.py +++ b/tests/memory/workers/tasks/test_scheduled_calls.py @@ -30,7 +30,7 @@ def pending_scheduled_call(db_session, sample_user): topic="Test Topic", scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), model="anthropic/claude-3-5-sonnet-20241022", - prompt="What is the weather like today?", + message="What is the weather like today?", system_prompt="You are a helpful assistant.", discord_user="123456789", status="pending", @@ -50,7 +50,7 @@ def completed_scheduled_call(db_session, sample_user): scheduled_time=datetime.now(timezone.utc) - timedelta(hours=1), executed_at=datetime.now(timezone.utc) - timedelta(minutes=30), model="anthropic/claude-3-5-sonnet-20241022", - prompt="Tell me a joke.", + message="Tell me a joke.", system_prompt="You are a funny assistant.", discord_channel="987654321", status="completed", @@ -70,7 +70,7 @@ def future_scheduled_call(db_session, sample_user): topic="Future Topic", scheduled_time=datetime.now(timezone.utc) + timedelta(hours=1), model="anthropic/claude-3-5-sonnet-20241022", - prompt="What will happen tomorrow?", + message="What will happen tomorrow?", discord_user="123456789", status="pending", ) @@ -195,7 +195,7 @@ def test_execute_scheduled_call_with_default_system_prompt( topic="No System Prompt", scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), model="anthropic/claude-3-5-sonnet-20241022", - prompt="Test prompt", + message="Test prompt", system_prompt=None, discord_user="123456789", status="pending", @@ -288,7 +288,7 @@ def test_run_scheduled_calls_with_due_calls( user_id=sample_user.id, scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=10), model="test-model", - prompt="Test 1", + message="Test 1", discord_user="123", status="pending", ) @@ -297,7 +297,7 @@ def test_run_scheduled_calls_with_due_calls( user_id=sample_user.id, scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), model="test-model", - prompt="Test 2", + message="Test 2", discord_user="123", status="pending", ) @@ -346,7 +346,7 @@ def test_run_scheduled_calls_mixed_statuses( user_id=sample_user.id, scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), model="test-model", - prompt="Pending", + message="Pending", discord_user="123", status="pending", ) @@ -355,7 +355,7 @@ def test_run_scheduled_calls_mixed_statuses( user_id=sample_user.id, scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), model="test-model", - prompt="Executing", + message="Executing", discord_user="123", status="executing", ) @@ -364,7 +364,7 @@ def test_run_scheduled_calls_mixed_statuses( user_id=sample_user.id, scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), model="test-model", - prompt="Completed", + message="Completed", discord_user="123", status="completed", ) @@ -397,7 +397,7 @@ def test_run_scheduled_calls_timezone_handling( user_id=sample_user.id, scheduled_time=past_time.replace(tzinfo=None), # Store as naive datetime model="test-model", - prompt="Due call", + message="Due call", discord_user="123", status="pending", ) @@ -409,7 +409,7 @@ def test_run_scheduled_calls_timezone_handling( user_id=sample_user.id, scheduled_time=future_time.replace(tzinfo=None), # Store as naive datetime model="test-model", - prompt="Future call", + message="Future call", discord_user="123", status="pending", ) @@ -477,7 +477,7 @@ def test_discord_destination_priority( topic="Priority Test", scheduled_time=datetime.now(timezone.utc), model="test-model", - prompt="Test", + message="Test", discord_user=discord_user, discord_channel=discord_channel, status="pending", @@ -568,7 +568,7 @@ def test_execute_scheduled_call_status_check( topic="Status Test", scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5), model="test-model", - prompt="Test", + message="Test", discord_user="123", status=status, )