fix scheduler

This commit is contained in:
mruwnik 2025-11-04 12:46:38 +00:00
parent 470061bd43
commit 56ed7b7d8f
16 changed files with 263 additions and 128 deletions

View File

@ -4,11 +4,11 @@ MCP tools for the epistemic sparring partner system.
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any, cast
from memory.api.MCP.base import get_current_user, mcp from memory.api.MCP.base import get_current_user, mcp
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
from memory.discord.messages import schedule_discord_message from memory.discord.messages import schedule_discord_message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -71,11 +71,15 @@ async def schedule_message(
raise ValueError("Invalid datetime format") raise ValueError("Invalid datetime format")
with make_session() as session: with make_session() as session:
bot = session.query(DiscordBotUser).first()
if not bot:
return {"error": "No bot found"}
scheduled_call = schedule_discord_message( scheduled_call = schedule_discord_message(
session=session, session=session,
scheduled_time=scheduled_dt, scheduled_time=scheduled_dt,
message=message, message=message,
user_id=current_user.get("user", {}).get("user_id"), user_id=cast(int, bot.id),
model=model, model=model,
topic=topic, topic=topic,
discord_channel=discord_channel, discord_channel=discord_channel,

View File

@ -1,5 +1,6 @@
import logging import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import cast
from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sqlalchemy.orm import Session as DBSession from sqlalchemy.orm import Session as DBSession
@ -15,6 +16,7 @@ from memory.common.db.models import (
User, User,
UserSession, UserSession,
) )
from memory.common.mcp import mcp_tools_list
from memory.common.oauth import complete_oauth_flow from memory.common.oauth import complete_oauth_flow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -171,9 +173,24 @@ async def oauth_callback_discord(request: Request):
mcp_server = ( mcp_server = (
session.query(MCPServer).filter(MCPServer.state == state).first() session.query(MCPServer).filter(MCPServer.state == state).first()
) )
if not mcp_server:
return Response(
content="MCP server not found",
status_code=404,
)
status_code, message = await complete_oauth_flow(mcp_server, code, state) status_code, message = await complete_oauth_flow(mcp_server, code, state)
session.commit() session.commit()
tools = await mcp_tools_list(
cast(str, mcp_server.mcp_server_url), cast(str, mcp_server.access_token)
)
mcp_server.available_tools = [
name for tool in tools if (name := tool.get("name"))
]
session.commit()
logger.info(f"MCP server tools: {tools}")
if 200 <= status_code < 300: if 200 <= status_code < 300:
title = "✅ Authorization Successful!" title = "✅ Authorization Successful!"
close = "You can close this window and return to the MCP server." close = "You can close this window and return to the MCP server."

View File

@ -28,6 +28,7 @@ class MCPServer(Base):
mcp_server_url = Column(Text, nullable=False) mcp_server_url = Column(Text, nullable=False)
client_id = Column(Text, nullable=False) client_id = Column(Text, nullable=False)
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}") available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disabled_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
# OAuth flow state (temporary, cleared after token exchange) # OAuth flow state (temporary, cleared after token exchange)
state = Column(Text, nullable=True, unique=True) state = Column(Text, nullable=True, unique=True)

View File

@ -401,10 +401,10 @@ class DiscordMessage(SourceItem):
filter( filter(
None, None,
[ [
self.recipient_user.id, self.recipient_id,
self.from_user.id, self.from_id,
self.channel.id, self.channel_id,
self.server.id, self.server_id,
], ],
) )
) )

View File

@ -145,6 +145,12 @@ class DiscordBotUser(BotUser):
bot.discord_users = discord_users bot.discord_users = discord_users
return bot return bot
@property
def discord_id(self) -> int | None:
if not self.discord_users:
return None
return self.discord_users[0].id
class UserSession(Base): class UserSession(Base):
__tablename__ = "user_sessions" __tablename__ = "user_sessions"

View File

@ -2,6 +2,7 @@
import json import json
import logging import logging
from urllib.parse import urlparse
from typing import Any, AsyncIterator, Iterator from typing import Any, AsyncIterator, Iterator
import anthropic import anthropic
@ -283,9 +284,10 @@ class AnthropicProvider(BaseLLMProvider):
# Include server info if present # Include server info if present
if current_tool_use.get("server_name"): if current_tool_use.get("server_name"):
tool_data["server_name"] = current_tool_use["server_name"] tool_data["server_name"] = current_tool_use["server_name"]
if current_tool_use.get("is_server_call"):
tool_data["is_server_call"] = current_tool_use["is_server_call"]
# Emit different event type for MCP server tools
if current_tool_use.get("is_server_call"):
return StreamEvent(type="server_tool_use", data=tool_data), None
return StreamEvent(type="tool_use", data=tool_data), None return StreamEvent(type="tool_use", data=tool_data), None
elif event_type == "message_delta": elif event_type == "message_delta":

View File

@ -189,7 +189,15 @@ class Message:
class StreamEvent: class StreamEvent:
"""An event from the streaming response.""" """An event from the streaming response."""
type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"] type: Literal[
"text",
"tool_use",
"server_tool_use",
"tool_result",
"thinking",
"error",
"done",
]
data: Any = None data: Any = None
signature: str | None = None signature: str | None = None
@ -565,9 +573,31 @@ class BaseLLMProvider(ABC):
elif event.type == "thinking": elif event.type == "thinking":
thinking.thinking += event.data thinking.thinking += event.data
yield event yield event
elif event.type == "tool_use": elif event.type == "server_tool_use":
# MCP server tools are executed by Anthropic's backend
# Results will come as separate tool_result events
yield event
# Track the MCP tool call but don't execute locally
messages.append(
Message.assistant(
response,
thinking,
ToolUseContent(
id=event.data["id"],
name=event.data["name"],
input=event.data["input"],
),
)
)
# Reset response for next turn
response = TextContent(text="")
thinking = ThinkingContent(thinking="")
# Continue streaming to get the tool result from Anthropic
elif event.type == "tool_use":
# Execute local tools
yield event yield event
# Execute the tool and yield the result
tool_result = self.execute_tool(event.data, tools) tool_result = self.execute_tool(event.data, tools)
yield StreamEvent(type="tool_result", data=tool_result.to_dict()) yield StreamEvent(type="tool_result", data=tool_result.to_dict())
@ -598,6 +628,27 @@ class BaseLLMProvider(ABC):
) )
return # Exit after recursive call completes return # Exit after recursive call completes
elif event.type == "tool_result":
yield event
# Add user message with the result
tool_result_content = ToolResultContent(
tool_use_id=event.data["id"],
content=str(event.data.get("result", "")),
is_error=event.data.get("is_error", False),
)
messages.append(Message.user(tool_result=tool_result_content))
# Continue conversation with reduced iterations
yield from self.stream_with_tools(
messages,
tools,
mcp_servers,
settings,
system_prompt,
max_iterations - 1,
)
return # Exit after recursive call completes
elif event.type == "error": elif event.type == "error":
logger.error(f"LLM error: {event.data}") logger.error(f"LLM error: {event.data}")
raise RuntimeError(f"LLM error: {event.data}") raise RuntimeError(f"LLM error: {event.data}")
@ -625,7 +676,7 @@ class BaseLLMProvider(ABC):
): ):
if event.type == "thinking": if event.type == "thinking":
thinking += event.data thinking += event.data
elif event.type == "tool_use": elif event.type == "tool_use" or event.type == "server_tool_use":
tool_calls[event.data["id"]] = { tool_calls[event.data["id"]] = {
"name": event.data["name"], "name": event.data["name"],
"input": event.data["input"], "input": event.data["input"],
@ -634,11 +685,15 @@ class BaseLLMProvider(ABC):
elif event.type == "text": elif event.type == "text":
response += event.data response += event.data
elif event.type == "tool_result": elif event.type == "tool_result":
current = tool_calls.get(event.data["tool_use_id"]) or {} tool_id = event.data.get("id") or event.data.get("tool_use_id")
tool_calls[event.data["tool_use_id"]] = { if not tool_id:
logger.warning(f"tool_result event missing id: {event.data}")
continue
current = tool_calls.get(tool_id) or {}
tool_calls[tool_id] = {
"name": event.data.get("name") or current.get("name"), "name": event.data.get("name") or current.get("name"),
"input": event.data.get("input") or current.get("input"), "input": event.data.get("input") or current.get("input"),
"output": event.data.get("content"), "output": event.data.get("content") or event.data.get("result"),
} }
return Turn( return Turn(
thinking=thinking or None, thinking=thinking or None,

View File

@ -32,6 +32,18 @@ class MCPServer:
token: str token: str
allowed_tools: list[str] | None = None allowed_tools: list[str] | None = None
@classmethod
def from_model(cls, model) -> "MCPServer":
allowed_tools = None
if model.disabled_tools:
allowed_tools = list(set(model.available_tools) - set(model.disabled_tools))
return cls(
name=model.name,
url=model.mcp_server_url,
token=model.access_token,
allowed_tools=allowed_tools,
)
@dataclass @dataclass
class ToolDefinition: class ToolDefinition:

View File

@ -14,7 +14,7 @@ from memory.common.db.models import (
DiscordServer, DiscordServer,
DiscordChannel, DiscordChannel,
DiscordUser, DiscordUser,
BotUser, DiscordBotUser,
) )
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
from memory.common.discord import add_reaction from memory.common.discord import add_reaction
@ -83,9 +83,9 @@ def make_summary_tool(type: UpdateSummaryType, item_id: BigInteger) -> ToolDefin
def schedule_message( def schedule_message(
user_id: int, bot_id: int,
user: int | None, recipient_id: int | None,
channel: int | None, channel_id: int | None,
model: str, model: str,
message: str, message: str,
date_time: datetime, date_time: datetime,
@ -95,29 +95,40 @@ def schedule_message(
session, session,
scheduled_time=date_time, scheduled_time=date_time,
message=message, message=message,
user_id=user_id, user_id=bot_id,
model=model, model=model,
discord_user=user, discord_user=recipient_id,
discord_channel=channel, discord_channel=channel_id,
system_prompt=comm_channel_prompt(session, user, channel),
) )
import logging
logger = logging.getLogger(__name__)
logger.error(f"Scheduled message: {call}")
logger.error(f"Scheduled message: {call.id}")
logger.error(f"Scheduled message time: {call.scheduled_time}")
logger.error(f"Scheduled message message: {call.message}")
logger.error(f"Scheduled message model: {call.model}")
logger.error(f"Scheduled message user id: {call.user_id}")
logger.error(f"Scheduled message discord user id: {call.discord_user_id}")
logger.error(f"Scheduled message discord channel id: {call.discord_channel_id}")
session.commit() session.commit()
return cast(str, call.id) return cast(str, call.id)
def make_message_scheduler( def make_message_scheduler(
bot: BotUser, user: int | None, channel: int | None, model: str bot: DiscordBotUser, user_id: int | None, channel_id: int | None, model: str
) -> ToolDefinition: ) -> ToolDefinition:
bot_id = cast(int, bot.id) bot_id = cast(int, bot.id)
if user: if user_id:
channel_type = "from your chat with this user" channel_type = "from your chat with this user"
elif channel: elif channel_id:
channel_type = "in this channel" channel_type = "in this channel"
else: else:
raise ValueError("Either user or channel must be provided") raise ValueError("Either user or channel must be provided")
def handler(input: ToolInput) -> str: def handler(input: ToolInput) -> str:
try:
if not isinstance(input, dict): if not isinstance(input, dict):
raise ValueError("Input must be a dictionary") raise ValueError("Input must be a dictionary")
@ -128,10 +139,18 @@ def make_message_scheduler(
except KeyError: except KeyError:
raise ValueError("Date time is required") raise ValueError("Date time is required")
return schedule_message(bot_id, user, channel, model, input["message"], time) return schedule_message(
bot_id, user_id, channel_id, model, input["message"], time
)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Error scheduling message: {e}")
raise e
return ToolDefinition( return ToolDefinition(
name="schedule_message", name="schedule_discord_message",
description=textwrap.dedent(""" description=textwrap.dedent("""
Use this to schedule a message to be sent to yourself. Use this to schedule a message to be sent to yourself.
@ -162,10 +181,13 @@ def make_message_scheduler(
) )
def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefinition: def make_prev_messages_tool(
if user: bot: DiscordBotUser, user_id: int | None, channel_id: int | None
) -> ToolDefinition:
bot_id = bot.discord_id
if user_id:
channel_type = "from your chat with this user" channel_type = "from your chat with this user"
elif channel: elif channel_id:
channel_type = "in this channel" channel_type = "in this channel"
else: else:
raise ValueError("Either user or channel must be provided") raise ValueError("Either user or channel must be provided")
@ -185,7 +207,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
raise ValueError("Offset must be greater than or equal to 0") raise ValueError("Offset must be greater than or equal to 0")
with make_session() as session: with make_session() as session:
messages = previous_messages(session, user, channel, max_messages, offset) messages = previous_messages(
session, bot_id, user_id, channel_id, max_messages, offset
)
return "\n\n".join([msg.title for msg in messages]) return "\n\n".join([msg.title for msg in messages])
return ToolDefinition( return ToolDefinition(
@ -210,7 +234,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
) )
def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition: def make_add_reaction_tool(
bot: DiscordBotUser, channel: DiscordChannel
) -> ToolDefinition:
bot_id = cast(int, bot.id) bot_id = cast(int, bot.id)
channel_id = channel and channel.id channel_id = channel and channel.id
@ -255,7 +281,7 @@ def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinit
def make_discord_tools( def make_discord_tools(
bot: BotUser, bot: DiscordBotUser,
author: DiscordUser | None, author: DiscordUser | None,
channel: DiscordChannel | None, channel: DiscordChannel | None,
model: str, model: str,
@ -264,7 +290,7 @@ def make_discord_tools(
channel_id = channel and channel.id channel_id = channel and channel.id
tools = [ tools = [
make_message_scheduler(bot, author_id, channel_id, model), make_message_scheduler(bot, author_id, channel_id, model),
make_prev_messages_tool(author_id, channel_id), make_prev_messages_tool(bot, author_id, channel_id),
make_summary_tool("channel", channel_id), make_summary_tool("channel", channel_id),
] ]
if author: if author:

58
src/memory/common/mcp.py Normal file
View File

@ -0,0 +1,58 @@
import json
import logging
import time
from typing import Any, AsyncGenerator
import aiohttp
logger = logging.getLogger(__name__)
async def mcp_call(
url: str, access_token: str, method: str, params: dict = {}
) -> AsyncGenerator[Any, None]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Authorization": f"Bearer {access_token}",
}
payload = {
"jsonrpc": "2.0",
"id": int(time.time() * 1000),
"method": method,
"params": params,
}
async with aiohttp.ClientSession() as http_session:
async with http_session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Tools list failed: {resp.status} - {error_text}")
raise ValueError(
f"Failed to call MCP server: {resp.status} - {error_text}"
)
# Parse SSE stream
async for line in resp.content:
line_str = line.decode("utf-8").strip()
# SSE format: "data: {json}"
if line_str.startswith("data: "):
json_str = line_str[6:] # Remove "data: " prefix
try:
yield json.loads(json_str)
except json.JSONDecodeError:
continue # Skip invalid JSON lines
async def mcp_tools_list(url: str, access_token: str) -> list[dict]:
async for data in mcp_call(url, access_token, "tools/list"):
if "result" in data and "tools" in data["result"]:
return data["result"]["tools"]
return []

View File

@ -108,6 +108,7 @@ async def send_dm_endpoint(request: SendDMRequest):
"""Send a DM via the collector's Discord client""" """Send a DM via the collector's Discord client"""
collector = app.bots.get(request.bot_id) collector = app.bots.get(request.bot_id)
if not collector: if not collector:
logger.error(f"Bot not found: {request.bot_id}")
raise HTTPException(status_code=404, detail="Bot not found") raise HTTPException(status_code=404, detail="Bot not found")
try: try:

View File

@ -1,9 +1,7 @@
"""Lightweight slash-command helpers for the Discord collector.""" """Lightweight slash-command helpers for the Discord collector."""
import json
import logging import logging
import time from typing import Literal, cast
from typing import Any, AsyncGenerator, Literal, cast
import aiohttp import aiohttp
import discord import discord
@ -12,6 +10,7 @@ from sqlalchemy.orm import Session, scoped_session
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import MCPServer, MCPServerAssignment from memory.common.db.models import MCPServer, MCPServerAssignment
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
from memory.common.mcp import mcp_tools_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,49 +32,6 @@ def find_mcp_server(
return assignment and assignment.mcp_server return assignment and assignment.mcp_server
async def call_mcp_server(
url: str, access_token: str, method: str, params: dict = {}
) -> AsyncGenerator[Any, None]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Authorization": f"Bearer {access_token}",
}
payload = {
"jsonrpc": "2.0",
"id": int(time.time() * 1000),
"method": method,
"params": params,
}
async with aiohttp.ClientSession() as http_session:
async with http_session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Tools list failed: {resp.status} - {error_text}")
raise ValueError(
f"Failed to call MCP server: {resp.status} - {error_text}"
)
# Parse SSE stream
async for line in resp.content:
line_str = line.decode("utf-8").strip()
# SSE format: "data: {json}"
if line_str.startswith("data: "):
json_str = line_str[6:] # Remove "data: " prefix
try:
yield json.loads(json_str)
except json.JSONDecodeError:
continue # Skip invalid JSON lines
async def handle_mcp_list(entity_type: str, entity_id: int) -> str: async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
"""List all MCP servers for the user.""" """List all MCP servers for the user."""
with make_session() as session: with make_session() as session:
@ -258,10 +214,7 @@ async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
# Make JSON-RPC request to MCP server # Make JSON-RPC request to MCP server
tools = None tools = None
try: try:
async for data in call_mcp_server(url, access_token, "tools/list"): tools = await mcp_tools_list(url, access_token)
if "result" in data and "tools" in data["result"]:
tools = data["result"]["tools"]
break
except aiohttp.ClientError as exc: except aiohttp.ClientError as exc:
logger.exception(f"Failed to connect to MCP server: {exc}") logger.exception(f"Failed to connect to MCP server: {exc}")
raise ValueError( raise ValueError(

View File

@ -132,14 +132,17 @@ def upsert_scheduled_message(
def previous_messages( def previous_messages(
session: Session | scoped_session, session: Session | scoped_session,
bot_id: int,
user_id: int | None, user_id: int | None,
channel_id: int | None, channel_id: int | None,
max_messages: int = 10, max_messages: int = 10,
offset: int = 0, offset: int = 0,
) -> list[DiscordMessage]: ) -> list[DiscordMessage]:
messages = session.query(DiscordMessage) messages = session.query(DiscordMessage).filter(
DiscordMessage.recipient_id == bot_id
)
if user_id: if user_id:
messages = messages.filter(DiscordMessage.recipient_id == user_id) messages = messages.filter(DiscordMessage.from_id == user_id)
if channel_id: if channel_id:
messages = messages.filter(DiscordMessage.channel_id == channel_id) messages = messages.filter(DiscordMessage.channel_id == channel_id)
return list( return list(
@ -154,15 +157,23 @@ def previous_messages(
def comm_channel_prompt( def comm_channel_prompt(
session: Session | scoped_session, session: Session | scoped_session,
bot: DiscordEntity,
user: DiscordEntity, user: DiscordEntity,
channel: DiscordEntity, channel: DiscordEntity,
max_messages: int = 10, max_messages: int = 10,
) -> str: ) -> str:
user = resolve_discord_user(session, user) user = resolve_discord_user(session, user)
channel = resolve_discord_channel(session, channel) channel = resolve_discord_channel(session, channel)
bot = resolve_discord_user(session, bot)
if not bot:
raise ValueError("Bot not found")
messages = previous_messages( messages = previous_messages(
session, user and user.id, channel and channel.id, max_messages session,
cast(int, bot.id),
user and user.id,
channel and channel.id,
max_messages,
) )
server_context = "" server_context = ""
@ -244,6 +255,7 @@ def call_llm(
user_id = cast(int, from_user.id) user_id = cast(int, from_user.id)
prev_messages = previous_messages( prev_messages = previous_messages(
session, session,
bot_user.system_user.discord_id,
user_id, user_id,
channel and channel.id, channel and channel.id,
max_messages=num_previous_messages, max_messages=num_previous_messages,
@ -263,18 +275,12 @@ def call_llm(
if bot_user.system_prompt: if bot_user.system_prompt:
system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "") system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "")
message_content = [m.as_content() for m in prev_messages] + messages message_content = [m.as_content() for m in prev_messages] + messages
return provider.run_with_tools( return provider.run_with_tools(
messages=provider.as_messages(message_content), messages=provider.as_messages(message_content),
tools=tools, tools=tools,
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""), system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
mcp_servers=[ mcp_servers=[MCPServer.from_model(server) for server in mcp_servers]
MCPServer(
name=str(server.name),
url=str(server.mcp_server_url),
token=str(server.access_token),
)
for server in mcp_servers
]
if mcp_servers if mcp_servers
else None, else None,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS, max_iterations=settings.DISCORD_MAX_TOOL_CALLS,

View File

@ -113,7 +113,7 @@ def should_process(message: DiscordMessage) -> bool:
system_prompt = message.system_prompt or "" system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt( system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel session, message.recipient_user, message.from_user, message.channel
) )
allowed_tools = [ allowed_tools = [
"update_channel_summary", "update_channel_summary",
@ -199,7 +199,10 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
mcp_servers = discord_message.get_mcp_servers(session) mcp_servers = discord_message.get_mcp_servers(session)
system_prompt = discord_message.system_prompt or "" system_prompt = discord_message.system_prompt or ""
system_prompt += comm_channel_prompt( system_prompt += comm_channel_prompt(
session, discord_message.recipient_user, discord_message.channel session,
discord_message.recipient_user,
discord_message.from_user,
discord_message.channel,
) )
try: try:

View File

@ -9,14 +9,14 @@ from memory.common.celery_app import (
app, app,
) )
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
from memory.discord.messages import call_llm, send_discord_response from memory.discord.messages import call_llm, send_discord_response
from memory.workers.tasks.content_processing import safe_task_execution from memory.workers.tasks.content_processing import safe_task_execution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None: def call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
"""Call LLM with tools support for scheduled calls.""" """Call LLM with tools support for scheduled calls."""
if not scheduled_call.discord_user: if not scheduled_call.discord_user:
logger.warning("No discord_user for scheduled call - cannot execute") logger.warning("No discord_user for scheduled call - cannot execute")
@ -30,6 +30,7 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str |
bot_user = ( bot_user = (
scheduled_call.user.discord_users and scheduled_call.user.discord_users[0] scheduled_call.user.discord_users and scheduled_call.user.discord_users[0]
) )
return call_llm( return call_llm(
session=session, session=session,
bot_user=bot_user, bot_user=bot_user,
@ -42,18 +43,8 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str |
) )
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str): def send_to_discord(bot_id: int, scheduled_call: ScheduledLLMCall, response: str):
"""Send the LLM response to Discord user or channel.""" """Send the LLM response to Discord user or channel."""
bot_id_value = scheduled_call.user_id
if bot_id_value is None:
logger.warning(
"Scheduled call %s has no associated bot user; skipping Discord send",
scheduled_call.id,
)
return
bot_id = cast(int, bot_id_value)
send_discord_response( send_discord_response(
bot_id=bot_id, bot_id=bot_id,
response=response, response=response,
@ -101,7 +92,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
# Make the LLM call with tools support # Make the LLM call with tools support
try: try:
response = _call_llm_for_scheduled(session, scheduled_call) response = call_llm_for_scheduled(session, scheduled_call)
except Exception: except Exception:
logger.exception("Failed to generate LLM response") logger.exception("Failed to generate LLM response")
scheduled_call.status = "failed" scheduled_call.status = "failed"
@ -132,7 +123,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
# Send to Discord # Send to Discord
try: try:
_send_to_discord(scheduled_call, response) send_to_discord(cast(int, scheduled_call.user_id), scheduled_call, response)
logger.info(f"Response sent to Discord for {scheduled_call_id}") logger.info(f"Response sent to Discord for {scheduled_call_id}")
except Exception as discord_error: except Exception as discord_error:
logger.error(f"Failed to send to Discord: {discord_error}") logger.error(f"Failed to send to Discord: {discord_error}")

View File

@ -127,7 +127,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
"""Test sending to Discord user.""" """Test sending to Discord user."""
response = "This is a test response." response = "This is a test response."
scheduled_calls._send_to_discord(pending_scheduled_call, response) scheduled_calls.send_to_discord(pending_scheduled_call, response)
mock_send_dm.assert_called_once_with( mock_send_dm.assert_called_once_with(
pending_scheduled_call.user_id, pending_scheduled_call.user_id,
@ -141,7 +141,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
"""Test sending to Discord channel.""" """Test sending to Discord channel."""
response = "This is a channel response." response = "This is a channel response."
scheduled_calls._send_to_discord(completed_scheduled_call, response) scheduled_calls.send_to_discord(completed_scheduled_call, response)
mock_broadcast.assert_called_once_with( mock_broadcast.assert_called_once_with(
completed_scheduled_call.user_id, completed_scheduled_call.user_id,
@ -155,7 +155,7 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
"""Test message truncation for long responses.""" """Test message truncation for long responses."""
long_response = "A" * 2500 # Very long response long_response = "A" * 2500 # Very long response
scheduled_calls._send_to_discord(pending_scheduled_call, long_response) scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
# Verify the message was truncated # Verify the message was truncated
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
@ -170,7 +170,7 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
"""Test that normal length messages are not truncated.""" """Test that normal length messages are not truncated."""
normal_response = "This is a normal length response." normal_response = "This is a normal length response."
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response) scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
assert args[0] == pending_scheduled_call.user_id assert args[0] == pending_scheduled_call.user_id
@ -535,7 +535,7 @@ def test_discord_destination_priority(
db_session.commit() db_session.commit()
response = "Test response" response = "Test response"
scheduled_calls._send_to_discord(call, response) scheduled_calls.send_to_discord(call, response)
if expected_method == "send_dm": if expected_method == "send_dm":
mock_send_dm.assert_called_once() mock_send_dm.assert_called_once()
@ -590,7 +590,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
mock_call.discord_user = mock_discord_user mock_call.discord_user = mock_discord_user
mock_call.discord_channel = None mock_call.discord_channel = None
scheduled_calls._send_to_discord(mock_call, response) scheduled_calls.send_to_discord(mock_call, response)
# Get the actual message that was sent # Get the actual message that was sent
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args