fix scheduler

This commit is contained in:
mruwnik 2025-11-04 12:46:38 +00:00
parent 470061bd43
commit 56ed7b7d8f
16 changed files with 263 additions and 128 deletions

View File

@ -4,11 +4,11 @@ MCP tools for the epistemic sparring partner system.
import logging
from datetime import datetime, timezone
from typing import Any
from typing import Any, cast
from memory.api.MCP.base import get_current_user, mcp
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
from memory.discord.messages import schedule_discord_message
logger = logging.getLogger(__name__)
@ -71,11 +71,15 @@ async def schedule_message(
raise ValueError("Invalid datetime format")
with make_session() as session:
bot = session.query(DiscordBotUser).first()
if not bot:
return {"error": "No bot found"}
scheduled_call = schedule_discord_message(
session=session,
scheduled_time=scheduled_dt,
message=message,
user_id=current_user.get("user", {}).get("user_id"),
user_id=cast(int, bot.id),
model=model,
topic=topic,
discord_channel=discord_channel,

View File

@ -1,5 +1,6 @@
import logging
from datetime import datetime, timedelta, timezone
from typing import cast
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sqlalchemy.orm import Session as DBSession
@ -15,6 +16,7 @@ from memory.common.db.models import (
User,
UserSession,
)
from memory.common.mcp import mcp_tools_list
from memory.common.oauth import complete_oauth_flow
logger = logging.getLogger(__name__)
@ -171,9 +173,24 @@ async def oauth_callback_discord(request: Request):
mcp_server = (
session.query(MCPServer).filter(MCPServer.state == state).first()
)
if not mcp_server:
return Response(
content="MCP server not found",
status_code=404,
)
status_code, message = await complete_oauth_flow(mcp_server, code, state)
session.commit()
tools = await mcp_tools_list(
cast(str, mcp_server.mcp_server_url), cast(str, mcp_server.access_token)
)
mcp_server.available_tools = [
name for tool in tools if (name := tool.get("name"))
]
session.commit()
logger.info(f"MCP server tools: {tools}")
if 200 <= status_code < 300:
title = "✅ Authorization Successful!"
close = "You can close this window and return to the MCP server."

View File

@ -28,6 +28,7 @@ class MCPServer(Base):
mcp_server_url = Column(Text, nullable=False)
client_id = Column(Text, nullable=False)
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disabled_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
# OAuth flow state (temporary, cleared after token exchange)
state = Column(Text, nullable=True, unique=True)

View File

@ -401,10 +401,10 @@ class DiscordMessage(SourceItem):
filter(
None,
[
self.recipient_user.id,
self.from_user.id,
self.channel.id,
self.server.id,
self.recipient_id,
self.from_id,
self.channel_id,
self.server_id,
],
)
)

View File

@ -145,6 +145,12 @@ class DiscordBotUser(BotUser):
bot.discord_users = discord_users
return bot
@property
def discord_id(self) -> int | None:
if not self.discord_users:
return None
return self.discord_users[0].id
class UserSession(Base):
__tablename__ = "user_sessions"

View File

@ -2,6 +2,7 @@
import json
import logging
from urllib.parse import urlparse
from typing import Any, AsyncIterator, Iterator
import anthropic
@ -283,9 +284,10 @@ class AnthropicProvider(BaseLLMProvider):
# Include server info if present
if current_tool_use.get("server_name"):
tool_data["server_name"] = current_tool_use["server_name"]
if current_tool_use.get("is_server_call"):
tool_data["is_server_call"] = current_tool_use["is_server_call"]
# Emit different event type for MCP server tools
if current_tool_use.get("is_server_call"):
return StreamEvent(type="server_tool_use", data=tool_data), None
return StreamEvent(type="tool_use", data=tool_data), None
elif event_type == "message_delta":

View File

@ -189,7 +189,15 @@ class Message:
class StreamEvent:
"""An event from the streaming response."""
type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"]
type: Literal[
"text",
"tool_use",
"server_tool_use",
"tool_result",
"thinking",
"error",
"done",
]
data: Any = None
signature: str | None = None
@ -565,9 +573,31 @@ class BaseLLMProvider(ABC):
elif event.type == "thinking":
thinking.thinking += event.data
yield event
elif event.type == "tool_use":
elif event.type == "server_tool_use":
# MCP server tools are executed by Anthropic's backend
# Results will come as separate tool_result events
yield event
# Track the MCP tool call but don't execute locally
messages.append(
Message.assistant(
response,
thinking,
ToolUseContent(
id=event.data["id"],
name=event.data["name"],
input=event.data["input"],
),
)
)
# Reset response for next turn
response = TextContent(text="")
thinking = ThinkingContent(thinking="")
# Continue streaming to get the tool result from Anthropic
elif event.type == "tool_use":
# Execute local tools
yield event
# Execute the tool and yield the result
tool_result = self.execute_tool(event.data, tools)
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
@ -598,6 +628,27 @@ class BaseLLMProvider(ABC):
)
return # Exit after recursive call completes
elif event.type == "tool_result":
yield event
# Add user message with the result
tool_result_content = ToolResultContent(
tool_use_id=event.data["id"],
content=str(event.data.get("result", "")),
is_error=event.data.get("is_error", False),
)
messages.append(Message.user(tool_result=tool_result_content))
# Continue conversation with reduced iterations
yield from self.stream_with_tools(
messages,
tools,
mcp_servers,
settings,
system_prompt,
max_iterations - 1,
)
return # Exit after recursive call completes
elif event.type == "error":
logger.error(f"LLM error: {event.data}")
raise RuntimeError(f"LLM error: {event.data}")
@ -625,7 +676,7 @@ class BaseLLMProvider(ABC):
):
if event.type == "thinking":
thinking += event.data
elif event.type == "tool_use":
elif event.type == "tool_use" or event.type == "server_tool_use":
tool_calls[event.data["id"]] = {
"name": event.data["name"],
"input": event.data["input"],
@ -634,11 +685,15 @@ class BaseLLMProvider(ABC):
elif event.type == "text":
response += event.data
elif event.type == "tool_result":
current = tool_calls.get(event.data["tool_use_id"]) or {}
tool_calls[event.data["tool_use_id"]] = {
tool_id = event.data.get("id") or event.data.get("tool_use_id")
if not tool_id:
logger.warning(f"tool_result event missing id: {event.data}")
continue
current = tool_calls.get(tool_id) or {}
tool_calls[tool_id] = {
"name": event.data.get("name") or current.get("name"),
"input": event.data.get("input") or current.get("input"),
"output": event.data.get("content"),
"output": event.data.get("content") or event.data.get("result"),
}
return Turn(
thinking=thinking or None,

View File

@ -32,6 +32,18 @@ class MCPServer:
token: str
allowed_tools: list[str] | None = None
@classmethod
def from_model(cls, model) -> "MCPServer":
allowed_tools = None
if model.disabled_tools:
allowed_tools = list(set(model.available_tools) - set(model.disabled_tools))
return cls(
name=model.name,
url=model.mcp_server_url,
token=model.access_token,
allowed_tools=allowed_tools,
)
@dataclass
class ToolDefinition:

View File

@ -14,7 +14,7 @@ from memory.common.db.models import (
DiscordServer,
DiscordChannel,
DiscordUser,
BotUser,
DiscordBotUser,
)
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
from memory.common.discord import add_reaction
@ -83,9 +83,9 @@ def make_summary_tool(type: UpdateSummaryType, item_id: BigInteger) -> ToolDefin
def schedule_message(
user_id: int,
user: int | None,
channel: int | None,
bot_id: int,
recipient_id: int | None,
channel_id: int | None,
model: str,
message: str,
date_time: datetime,
@ -95,43 +95,62 @@ def schedule_message(
session,
scheduled_time=date_time,
message=message,
user_id=user_id,
user_id=bot_id,
model=model,
discord_user=user,
discord_channel=channel,
system_prompt=comm_channel_prompt(session, user, channel),
discord_user=recipient_id,
discord_channel=channel_id,
)
import logging
logger = logging.getLogger(__name__)
logger.error(f"Scheduled message: {call}")
logger.error(f"Scheduled message: {call.id}")
logger.error(f"Scheduled message time: {call.scheduled_time}")
logger.error(f"Scheduled message message: {call.message}")
logger.error(f"Scheduled message model: {call.model}")
logger.error(f"Scheduled message user id: {call.user_id}")
logger.error(f"Scheduled message discord user id: {call.discord_user_id}")
logger.error(f"Scheduled message discord channel id: {call.discord_channel_id}")
session.commit()
return cast(str, call.id)
def make_message_scheduler(
bot: BotUser, user: int | None, channel: int | None, model: str
bot: DiscordBotUser, user_id: int | None, channel_id: int | None, model: str
) -> ToolDefinition:
bot_id = cast(int, bot.id)
if user:
if user_id:
channel_type = "from your chat with this user"
elif channel:
elif channel_id:
channel_type = "in this channel"
else:
raise ValueError("Either user or channel must be provided")
def handler(input: ToolInput) -> str:
if not isinstance(input, dict):
raise ValueError("Input must be a dictionary")
try:
time = datetime.fromisoformat(input["date_time"])
except ValueError:
raise ValueError("Invalid date time format")
except KeyError:
raise ValueError("Date time is required")
if not isinstance(input, dict):
raise ValueError("Input must be a dictionary")
return schedule_message(bot_id, user, channel, model, input["message"], time)
try:
time = datetime.fromisoformat(input["date_time"])
except ValueError:
raise ValueError("Invalid date time format")
except KeyError:
raise ValueError("Date time is required")
return schedule_message(
bot_id, user_id, channel_id, model, input["message"], time
)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Error scheduling message: {e}")
raise e
return ToolDefinition(
name="schedule_message",
name="schedule_discord_message",
description=textwrap.dedent("""
Use this to schedule a message to be sent to yourself.
@ -162,10 +181,13 @@ def make_message_scheduler(
)
def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefinition:
if user:
def make_prev_messages_tool(
bot: DiscordBotUser, user_id: int | None, channel_id: int | None
) -> ToolDefinition:
bot_id = bot.discord_id
if user_id:
channel_type = "from your chat with this user"
elif channel:
elif channel_id:
channel_type = "in this channel"
else:
raise ValueError("Either user or channel must be provided")
@ -185,7 +207,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
raise ValueError("Offset must be greater than or equal to 0")
with make_session() as session:
messages = previous_messages(session, user, channel, max_messages, offset)
messages = previous_messages(
session, bot_id, user_id, channel_id, max_messages, offset
)
return "\n\n".join([msg.title for msg in messages])
return ToolDefinition(
@ -210,7 +234,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
)
def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition:
def make_add_reaction_tool(
bot: DiscordBotUser, channel: DiscordChannel
) -> ToolDefinition:
bot_id = cast(int, bot.id)
channel_id = channel and channel.id
@ -255,7 +281,7 @@ def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinit
def make_discord_tools(
bot: BotUser,
bot: DiscordBotUser,
author: DiscordUser | None,
channel: DiscordChannel | None,
model: str,
@ -264,7 +290,7 @@ def make_discord_tools(
channel_id = channel and channel.id
tools = [
make_message_scheduler(bot, author_id, channel_id, model),
make_prev_messages_tool(author_id, channel_id),
make_prev_messages_tool(bot, author_id, channel_id),
make_summary_tool("channel", channel_id),
]
if author:

58
src/memory/common/mcp.py Normal file
View File

@ -0,0 +1,58 @@
import json
import logging
import time
from typing import Any, AsyncGenerator
import aiohttp
logger = logging.getLogger(__name__)
async def mcp_call(
url: str, access_token: str, method: str, params: dict = {}
) -> AsyncGenerator[Any, None]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Authorization": f"Bearer {access_token}",
}
payload = {
"jsonrpc": "2.0",
"id": int(time.time() * 1000),
"method": method,
"params": params,
}
async with aiohttp.ClientSession() as http_session:
async with http_session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Tools list failed: {resp.status} - {error_text}")
raise ValueError(
f"Failed to call MCP server: {resp.status} - {error_text}"
)
# Parse SSE stream
async for line in resp.content:
line_str = line.decode("utf-8").strip()
# SSE format: "data: {json}"
if line_str.startswith("data: "):
json_str = line_str[6:] # Remove "data: " prefix
try:
yield json.loads(json_str)
except json.JSONDecodeError:
continue # Skip invalid JSON lines
async def mcp_tools_list(url: str, access_token: str) -> list[dict]:
async for data in mcp_call(url, access_token, "tools/list"):
if "result" in data and "tools" in data["result"]:
return data["result"]["tools"]
return []

View File

@ -108,6 +108,7 @@ async def send_dm_endpoint(request: SendDMRequest):
"""Send a DM via the collector's Discord client"""
collector = app.bots.get(request.bot_id)
if not collector:
logger.error(f"Bot not found: {request.bot_id}")
raise HTTPException(status_code=404, detail="Bot not found")
try:

View File

@ -1,9 +1,7 @@
"""Lightweight slash-command helpers for the Discord collector."""
import json
import logging
import time
from typing import Any, AsyncGenerator, Literal, cast
from typing import Literal, cast
import aiohttp
import discord
@ -12,6 +10,7 @@ from sqlalchemy.orm import Session, scoped_session
from memory.common.db.connection import make_session
from memory.common.db.models import MCPServer, MCPServerAssignment
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
from memory.common.mcp import mcp_tools_list
logger = logging.getLogger(__name__)
@ -33,49 +32,6 @@ def find_mcp_server(
return assignment and assignment.mcp_server
async def call_mcp_server(
url: str, access_token: str, method: str, params: dict = {}
) -> AsyncGenerator[Any, None]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Authorization": f"Bearer {access_token}",
}
payload = {
"jsonrpc": "2.0",
"id": int(time.time() * 1000),
"method": method,
"params": params,
}
async with aiohttp.ClientSession() as http_session:
async with http_session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Tools list failed: {resp.status} - {error_text}")
raise ValueError(
f"Failed to call MCP server: {resp.status} - {error_text}"
)
# Parse SSE stream
async for line in resp.content:
line_str = line.decode("utf-8").strip()
# SSE format: "data: {json}"
if line_str.startswith("data: "):
json_str = line_str[6:] # Remove "data: " prefix
try:
yield json.loads(json_str)
except json.JSONDecodeError:
continue # Skip invalid JSON lines
async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
"""List all MCP servers for the user."""
with make_session() as session:
@ -258,10 +214,7 @@ async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
# Make JSON-RPC request to MCP server
tools = None
try:
async for data in call_mcp_server(url, access_token, "tools/list"):
if "result" in data and "tools" in data["result"]:
tools = data["result"]["tools"]
break
tools = await mcp_tools_list(url, access_token)
except aiohttp.ClientError as exc:
logger.exception(f"Failed to connect to MCP server: {exc}")
raise ValueError(

View File

@ -132,14 +132,17 @@ def upsert_scheduled_message(
def previous_messages(
session: Session | scoped_session,
bot_id: int,
user_id: int | None,
channel_id: int | None,
max_messages: int = 10,
offset: int = 0,
) -> list[DiscordMessage]:
messages = session.query(DiscordMessage)
messages = session.query(DiscordMessage).filter(
DiscordMessage.recipient_id == bot_id
)
if user_id:
messages = messages.filter(DiscordMessage.recipient_id == user_id)
messages = messages.filter(DiscordMessage.from_id == user_id)
if channel_id:
messages = messages.filter(DiscordMessage.channel_id == channel_id)
return list(
@ -154,15 +157,23 @@ def previous_messages(
def comm_channel_prompt(
session: Session | scoped_session,
bot: DiscordEntity,
user: DiscordEntity,
channel: DiscordEntity,
max_messages: int = 10,
) -> str:
user = resolve_discord_user(session, user)
channel = resolve_discord_channel(session, channel)
bot = resolve_discord_user(session, bot)
if not bot:
raise ValueError("Bot not found")
messages = previous_messages(
session, user and user.id, channel and channel.id, max_messages
session,
cast(int, bot.id),
user and user.id,
channel and channel.id,
max_messages,
)
server_context = ""
@ -244,6 +255,7 @@ def call_llm(
user_id = cast(int, from_user.id)
prev_messages = previous_messages(
session,
bot_user.system_user.discord_id,
user_id,
channel and channel.id,
max_messages=num_previous_messages,
@ -263,18 +275,12 @@ def call_llm(
if bot_user.system_prompt:
system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "")
message_content = [m.as_content() for m in prev_messages] + messages
return provider.run_with_tools(
messages=provider.as_messages(message_content),
tools=tools,
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
mcp_servers=[
MCPServer(
name=str(server.name),
url=str(server.mcp_server_url),
token=str(server.access_token),
)
for server in mcp_servers
]
mcp_servers=[MCPServer.from_model(server) for server in mcp_servers]
if mcp_servers
else None,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,

View File

@ -113,7 +113,7 @@ def should_process(message: DiscordMessage) -> bool:
system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel
session, message.recipient_user, message.from_user, message.channel
)
allowed_tools = [
"update_channel_summary",
@ -199,7 +199,10 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
mcp_servers = discord_message.get_mcp_servers(session)
system_prompt = discord_message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, discord_message.recipient_user, discord_message.channel
session,
discord_message.recipient_user,
discord_message.from_user,
discord_message.channel,
)
try:

View File

@ -9,14 +9,14 @@ from memory.common.celery_app import (
app,
)
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
from memory.discord.messages import call_llm, send_discord_response
from memory.workers.tasks.content_processing import safe_task_execution
logger = logging.getLogger(__name__)
def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
def call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
"""Call LLM with tools support for scheduled calls."""
if not scheduled_call.discord_user:
logger.warning("No discord_user for scheduled call - cannot execute")
@ -30,6 +30,7 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str |
bot_user = (
scheduled_call.user.discord_users and scheduled_call.user.discord_users[0]
)
return call_llm(
session=session,
bot_user=bot_user,
@ -42,18 +43,8 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str |
)
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
def send_to_discord(bot_id: int, scheduled_call: ScheduledLLMCall, response: str):
"""Send the LLM response to Discord user or channel."""
bot_id_value = scheduled_call.user_id
if bot_id_value is None:
logger.warning(
"Scheduled call %s has no associated bot user; skipping Discord send",
scheduled_call.id,
)
return
bot_id = cast(int, bot_id_value)
send_discord_response(
bot_id=bot_id,
response=response,
@ -101,7 +92,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
# Make the LLM call with tools support
try:
response = _call_llm_for_scheduled(session, scheduled_call)
response = call_llm_for_scheduled(session, scheduled_call)
except Exception:
logger.exception("Failed to generate LLM response")
scheduled_call.status = "failed"
@ -132,7 +123,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
# Send to Discord
try:
_send_to_discord(scheduled_call, response)
send_to_discord(cast(int, scheduled_call.user_id), scheduled_call, response)
logger.info(f"Response sent to Discord for {scheduled_call_id}")
except Exception as discord_error:
logger.error(f"Failed to send to Discord: {discord_error}")

View File

@ -127,7 +127,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
"""Test sending to Discord user."""
response = "This is a test response."
scheduled_calls._send_to_discord(pending_scheduled_call, response)
scheduled_calls.send_to_discord(pending_scheduled_call, response)
mock_send_dm.assert_called_once_with(
pending_scheduled_call.user_id,
@ -141,7 +141,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
"""Test sending to Discord channel."""
response = "This is a channel response."
scheduled_calls._send_to_discord(completed_scheduled_call, response)
scheduled_calls.send_to_discord(completed_scheduled_call, response)
mock_broadcast.assert_called_once_with(
completed_scheduled_call.user_id,
@ -155,7 +155,7 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
"""Test message truncation for long responses."""
long_response = "A" * 2500 # Very long response
scheduled_calls._send_to_discord(pending_scheduled_call, long_response)
scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
# Verify the message was truncated
args, kwargs = mock_send_dm.call_args
@ -170,7 +170,7 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
"""Test that normal length messages are not truncated."""
normal_response = "This is a normal length response."
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args
assert args[0] == pending_scheduled_call.user_id
@ -535,7 +535,7 @@ def test_discord_destination_priority(
db_session.commit()
response = "Test response"
scheduled_calls._send_to_discord(call, response)
scheduled_calls.send_to_discord(call, response)
if expected_method == "send_dm":
mock_send_dm.assert_called_once()
@ -590,7 +590,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
mock_call.discord_user = mock_discord_user
mock_call.discord_channel = None
scheduled_calls._send_to_discord(mock_call, response)
scheduled_calls.send_to_discord(mock_call, response)
# Get the actual message that was sent
args, kwargs = mock_send_dm.call_args