diff --git a/src/memory/api/MCP/schedules.py b/src/memory/api/MCP/schedules.py index 22157dc..26b50ad 100644 --- a/src/memory/api/MCP/schedules.py +++ b/src/memory/api/MCP/schedules.py @@ -4,11 +4,11 @@ MCP tools for the epistemic sparring partner system. import logging from datetime import datetime, timezone -from typing import Any +from typing import Any, cast from memory.api.MCP.base import get_current_user, mcp from memory.common.db.connection import make_session -from memory.common.db.models import ScheduledLLMCall +from memory.common.db.models import ScheduledLLMCall, DiscordBotUser from memory.discord.messages import schedule_discord_message logger = logging.getLogger(__name__) @@ -71,11 +71,15 @@ async def schedule_message( raise ValueError("Invalid datetime format") with make_session() as session: + bot = session.query(DiscordBotUser).first() + if not bot: + return {"error": "No bot found"} + scheduled_call = schedule_discord_message( session=session, scheduled_time=scheduled_dt, message=message, - user_id=current_user.get("user", {}).get("user_id"), + user_id=cast(int, bot.id), model=model, topic=topic, discord_channel=discord_channel, diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py index 0567762..151e6d6 100644 --- a/src/memory/api/auth.py +++ b/src/memory/api/auth.py @@ -1,5 +1,6 @@ import logging from datetime import datetime, timedelta, timezone +from typing import cast from fastapi import APIRouter, Depends, HTTPException, Request, Response from sqlalchemy.orm import Session as DBSession @@ -15,6 +16,7 @@ from memory.common.db.models import ( User, UserSession, ) +from memory.common.mcp import mcp_tools_list from memory.common.oauth import complete_oauth_flow logger = logging.getLogger(__name__) @@ -171,9 +173,24 @@ async def oauth_callback_discord(request: Request): mcp_server = ( session.query(MCPServer).filter(MCPServer.state == state).first() ) + if not mcp_server: + return Response( + content="MCP server not found", + status_code=404, + ) + status_code, message = await complete_oauth_flow(mcp_server, code, state) session.commit() + tools = await mcp_tools_list( + cast(str, mcp_server.mcp_server_url), cast(str, mcp_server.access_token) + ) + mcp_server.available_tools = [ + name for tool in tools if (name := tool.get("name")) + ] + session.commit() + logger.info(f"MCP server tools: {tools}") + if 200 <= status_code < 300: title = "✅ Authorization Successful!" close = "You can close this window and return to the MCP server." diff --git a/src/memory/common/db/models/mcp.py b/src/memory/common/db/models/mcp.py index 1bf6a8c..ff0f6a9 100644 --- a/src/memory/common/db/models/mcp.py +++ b/src/memory/common/db/models/mcp.py @@ -28,6 +28,7 @@ class MCPServer(Base): mcp_server_url = Column(Text, nullable=False) client_id = Column(Text, nullable=False) available_tools = Column(ARRAY(Text), nullable=False, server_default="{}") + disabled_tools = Column(ARRAY(Text), nullable=False, server_default="{}") # OAuth flow state (temporary, cleared after token exchange) state = Column(Text, nullable=True, unique=True) diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index fa35f17..a6d1ba9 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -401,10 +401,10 @@ class DiscordMessage(SourceItem): filter( None, [ - self.recipient_user.id, - self.from_user.id, - self.channel.id, - self.server.id, + self.recipient_id, + self.from_id, + self.channel_id, + self.server_id, ], ) ) diff --git a/src/memory/common/db/models/users.py b/src/memory/common/db/models/users.py index af75f8c..a79682d 100644 --- a/src/memory/common/db/models/users.py +++ b/src/memory/common/db/models/users.py @@ -145,6 +145,12 @@ class DiscordBotUser(BotUser): bot.discord_users = discord_users return bot + @property + def discord_id(self) -> int | None: + if not self.discord_users: + return None + return self.discord_users[0].id + class UserSession(Base): __tablename__ = "user_sessions" diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py index f7bac60..172baf1 100644 --- a/src/memory/common/llms/anthropic_provider.py +++ b/src/memory/common/llms/anthropic_provider.py @@ -2,6 +2,7 @@ import json import logging +from urllib.parse import urlparse from typing import Any, AsyncIterator, Iterator import anthropic @@ -283,9 +284,10 @@ class AnthropicProvider(BaseLLMProvider): # Include server info if present if current_tool_use.get("server_name"): tool_data["server_name"] = current_tool_use["server_name"] - if current_tool_use.get("is_server_call"): - tool_data["is_server_call"] = current_tool_use["is_server_call"] + # Emit different event type for MCP server tools + if current_tool_use.get("is_server_call"): + return StreamEvent(type="server_tool_use", data=tool_data), None return StreamEvent(type="tool_use", data=tool_data), None elif event_type == "message_delta": diff --git a/src/memory/common/llms/base.py b/src/memory/common/llms/base.py index ac31185..cfbbd9e 100644 --- a/src/memory/common/llms/base.py +++ b/src/memory/common/llms/base.py @@ -189,7 +189,15 @@ class Message: class StreamEvent: """An event from the streaming response.""" - type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"] + type: Literal[ + "text", + "tool_use", + "server_tool_use", + "tool_result", + "thinking", + "error", + "done", + ] data: Any = None signature: str | None = None @@ -565,9 +573,31 @@ class BaseLLMProvider(ABC): elif event.type == "thinking": thinking.thinking += event.data yield event - elif event.type == "tool_use": + elif event.type == "server_tool_use": + # MCP server tools are executed by Anthropic's backend + # Results will come as separate tool_result events + yield event + + # Track the MCP tool call but don't execute locally + messages.append( + Message.assistant( + response, + thinking, + ToolUseContent( + id=event.data["id"], + name=event.data["name"], + input=event.data["input"], + ), + ) + ) + # Reset response for next turn + response = TextContent(text="") + thinking = ThinkingContent(thinking="") + # Continue streaming to get the tool result from Anthropic + + elif event.type == "tool_use": + # Execute local tools yield event - # Execute the tool and yield the result tool_result = self.execute_tool(event.data, tools) yield StreamEvent(type="tool_result", data=tool_result.to_dict()) @@ -598,6 +628,27 @@ class BaseLLMProvider(ABC): ) return # Exit after recursive call completes + elif event.type == "tool_result": + yield event + # Add user message with the result + tool_result_content = ToolResultContent( + tool_use_id=event.data["id"], + content=str(event.data.get("result", "")), + is_error=event.data.get("is_error", False), + ) + messages.append(Message.user(tool_result=tool_result_content)) + + # Continue conversation with reduced iterations + yield from self.stream_with_tools( + messages, + tools, + mcp_servers, + settings, + system_prompt, + max_iterations - 1, + ) + return # Exit after recursive call completes + elif event.type == "error": logger.error(f"LLM error: {event.data}") raise RuntimeError(f"LLM error: {event.data}") @@ -625,7 +676,7 @@ class BaseLLMProvider(ABC): ): if event.type == "thinking": thinking += event.data - elif event.type == "tool_use": + elif event.type == "tool_use" or event.type == "server_tool_use": tool_calls[event.data["id"]] = { "name": event.data["name"], "input": event.data["input"], @@ -634,11 +685,15 @@ class BaseLLMProvider(ABC): elif event.type == "text": response += event.data elif event.type == "tool_result": - current = tool_calls.get(event.data["tool_use_id"]) or {} - tool_calls[event.data["tool_use_id"]] = { + tool_id = event.data.get("id") or event.data.get("tool_use_id") + if not tool_id: + logger.warning(f"tool_result event missing id: {event.data}") + continue + current = tool_calls.get(tool_id) or {} + tool_calls[tool_id] = { "name": event.data.get("name") or current.get("name"), "input": event.data.get("input") or current.get("input"), - "output": event.data.get("content"), + "output": event.data.get("content") or event.data.get("result"), } return Turn( thinking=thinking or None, diff --git a/src/memory/common/llms/tools/__init__.py b/src/memory/common/llms/tools/__init__.py index a04bd8a..1f8957d 100644 --- a/src/memory/common/llms/tools/__init__.py +++ b/src/memory/common/llms/tools/__init__.py @@ -32,6 +32,18 @@ class MCPServer: token: str allowed_tools: list[str] | None = None + @classmethod + def from_model(cls, model) -> "MCPServer": + allowed_tools = None + if model.disabled_tools: + allowed_tools = list(set(model.available_tools) - set(model.disabled_tools)) + return cls( + name=model.name, + url=model.mcp_server_url, + token=model.access_token, + allowed_tools=allowed_tools, + ) + @dataclass class ToolDefinition: diff --git a/src/memory/common/llms/tools/discord.py b/src/memory/common/llms/tools/discord.py index d334bb5..978f75a 100644 --- a/src/memory/common/llms/tools/discord.py +++ b/src/memory/common/llms/tools/discord.py @@ -14,7 +14,7 @@ from memory.common.db.models import ( DiscordServer, DiscordChannel, DiscordUser, - BotUser, + DiscordBotUser, ) from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler from memory.common.discord import add_reaction @@ -83,9 +83,9 @@ def make_summary_tool(type: UpdateSummaryType, item_id: BigInteger) -> ToolDefin def schedule_message( - user_id: int, - user: int | None, - channel: int | None, + bot_id: int, + recipient_id: int | None, + channel_id: int | None, model: str, message: str, date_time: datetime, @@ -95,43 +95,62 @@ def schedule_message( session, scheduled_time=date_time, message=message, - user_id=user_id, + user_id=bot_id, model=model, - discord_user=user, - discord_channel=channel, - system_prompt=comm_channel_prompt(session, user, channel), + discord_user=recipient_id, + discord_channel=channel_id, ) + import logging + + logger = logging.getLogger(__name__) + logger.error(f"Scheduled message: {call}") + logger.error(f"Scheduled message: {call.id}") + logger.error(f"Scheduled message time: {call.scheduled_time}") + logger.error(f"Scheduled message message: {call.message}") + logger.error(f"Scheduled message model: {call.model}") + logger.error(f"Scheduled message user id: {call.user_id}") + logger.error(f"Scheduled message discord user id: {call.discord_user_id}") + logger.error(f"Scheduled message discord channel id: {call.discord_channel_id}") session.commit() return cast(str, call.id) def make_message_scheduler( - bot: BotUser, user: int | None, channel: int | None, model: str + bot: DiscordBotUser, user_id: int | None, channel_id: int | None, model: str ) -> ToolDefinition: bot_id = cast(int, bot.id) - if user: + if user_id: channel_type = "from your chat with this user" - elif channel: + elif channel_id: channel_type = "in this channel" else: raise ValueError("Either user or channel must be provided") def handler(input: ToolInput) -> str: - if not isinstance(input, dict): - raise ValueError("Input must be a dictionary") - try: - time = datetime.fromisoformat(input["date_time"]) - except ValueError: - raise ValueError("Invalid date time format") - except KeyError: - raise ValueError("Date time is required") + if not isinstance(input, dict): + raise ValueError("Input must be a dictionary") - return schedule_message(bot_id, user, channel, model, input["message"], time) + try: + time = datetime.fromisoformat(input["date_time"]) + except ValueError: + raise ValueError("Invalid date time format") + except KeyError: + raise ValueError("Date time is required") + + return schedule_message( + bot_id, user_id, channel_id, model, input["message"], time + ) + except Exception as e: + import logging + + logger = logging.getLogger(__name__) + logger.error(f"Error scheduling message: {e}") + raise e return ToolDefinition( - name="schedule_message", + name="schedule_discord_message", description=textwrap.dedent(""" Use this to schedule a message to be sent to yourself. @@ -162,10 +181,13 @@ def make_message_scheduler( ) -def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefinition: - if user: +def make_prev_messages_tool( + bot: DiscordBotUser, user_id: int | None, channel_id: int | None +) -> ToolDefinition: + bot_id = bot.discord_id + if user_id: channel_type = "from your chat with this user" - elif channel: + elif channel_id: channel_type = "in this channel" else: raise ValueError("Either user or channel must be provided") @@ -185,7 +207,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini raise ValueError("Offset must be greater than or equal to 0") with make_session() as session: - messages = previous_messages(session, user, channel, max_messages, offset) + messages = previous_messages( + session, bot_id, user_id, channel_id, max_messages, offset + ) return "\n\n".join([msg.title for msg in messages]) return ToolDefinition( @@ -210,7 +234,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini ) -def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition: +def make_add_reaction_tool( + bot: DiscordBotUser, channel: DiscordChannel +) -> ToolDefinition: bot_id = cast(int, bot.id) channel_id = channel and channel.id @@ -255,7 +281,7 @@ def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinit def make_discord_tools( - bot: BotUser, + bot: DiscordBotUser, author: DiscordUser | None, channel: DiscordChannel | None, model: str, @@ -264,7 +290,7 @@ def make_discord_tools( channel_id = channel and channel.id tools = [ make_message_scheduler(bot, author_id, channel_id, model), - make_prev_messages_tool(author_id, channel_id), + make_prev_messages_tool(bot, author_id, channel_id), make_summary_tool("channel", channel_id), ] if author: diff --git a/src/memory/common/mcp.py b/src/memory/common/mcp.py new file mode 100644 index 0000000..6fbe46f --- /dev/null +++ b/src/memory/common/mcp.py @@ -0,0 +1,58 @@ +import json +import logging +import time +from typing import Any, AsyncGenerator + +import aiohttp + +logger = logging.getLogger(__name__) + + +async def mcp_call( + url: str, access_token: str, method: str, params: dict = {} +) -> AsyncGenerator[Any, None]: + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "Authorization": f"Bearer {access_token}", + } + + payload = { + "jsonrpc": "2.0", + "id": int(time.time() * 1000), + "method": method, + "params": params, + } + + async with aiohttp.ClientSession() as http_session: + async with http_session.post( + url, + json=payload, + headers=headers, + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + if resp.status != 200: + error_text = await resp.text() + logger.error(f"Tools list failed: {resp.status} - {error_text}") + raise ValueError( + f"Failed to call MCP server: {resp.status} - {error_text}" + ) + + # Parse SSE stream + async for line in resp.content: + line_str = line.decode("utf-8").strip() + + # SSE format: "data: {json}" + if line_str.startswith("data: "): + json_str = line_str[6:] # Remove "data: " prefix + try: + yield json.loads(json_str) + except json.JSONDecodeError: + continue # Skip invalid JSON lines + + +async def mcp_tools_list(url: str, access_token: str) -> list[dict]: + async for data in mcp_call(url, access_token, "tools/list"): + if "result" in data and "tools" in data["result"]: + return data["result"]["tools"] + return [] diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py index 93f8d27..532018f 100644 --- a/src/memory/discord/api.py +++ b/src/memory/discord/api.py @@ -108,6 +108,7 @@ async def send_dm_endpoint(request: SendDMRequest): """Send a DM via the collector's Discord client""" collector = app.bots.get(request.bot_id) if not collector: + logger.error(f"Bot not found: {request.bot_id}") raise HTTPException(status_code=404, detail="Bot not found") try: diff --git a/src/memory/discord/mcp.py b/src/memory/discord/mcp.py index bb72988..37432e5 100644 --- a/src/memory/discord/mcp.py +++ b/src/memory/discord/mcp.py @@ -1,9 +1,7 @@ """Lightweight slash-command helpers for the Discord collector.""" -import json import logging -import time -from typing import Any, AsyncGenerator, Literal, cast +from typing import Literal, cast import aiohttp import discord @@ -12,6 +10,7 @@ from sqlalchemy.orm import Session, scoped_session from memory.common.db.connection import make_session from memory.common.db.models import MCPServer, MCPServerAssignment from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client +from memory.common.mcp import mcp_tools_list logger = logging.getLogger(__name__) @@ -33,49 +32,6 @@ def find_mcp_server( return assignment and assignment.mcp_server -async def call_mcp_server( - url: str, access_token: str, method: str, params: dict = {} -) -> AsyncGenerator[Any, None]: - headers = { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - "Authorization": f"Bearer {access_token}", - } - - payload = { - "jsonrpc": "2.0", - "id": int(time.time() * 1000), - "method": method, - "params": params, - } - - async with aiohttp.ClientSession() as http_session: - async with http_session.post( - url, - json=payload, - headers=headers, - timeout=aiohttp.ClientTimeout(total=10), - ) as resp: - if resp.status != 200: - error_text = await resp.text() - logger.error(f"Tools list failed: {resp.status} - {error_text}") - raise ValueError( - f"Failed to call MCP server: {resp.status} - {error_text}" - ) - - # Parse SSE stream - async for line in resp.content: - line_str = line.decode("utf-8").strip() - - # SSE format: "data: {json}" - if line_str.startswith("data: "): - json_str = line_str[6:] # Remove "data: " prefix - try: - yield json.loads(json_str) - except json.JSONDecodeError: - continue # Skip invalid JSON lines - - async def handle_mcp_list(entity_type: str, entity_id: int) -> str: """List all MCP servers for the user.""" with make_session() as session: @@ -258,10 +214,7 @@ async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str: # Make JSON-RPC request to MCP server tools = None try: - async for data in call_mcp_server(url, access_token, "tools/list"): - if "result" in data and "tools" in data["result"]: - tools = data["result"]["tools"] - break + tools = await mcp_tools_list(url, access_token) except aiohttp.ClientError as exc: logger.exception(f"Failed to connect to MCP server: {exc}") raise ValueError( diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py index 01c5b28..5579bfd 100644 --- a/src/memory/discord/messages.py +++ b/src/memory/discord/messages.py @@ -132,14 +132,17 @@ def upsert_scheduled_message( def previous_messages( session: Session | scoped_session, + bot_id: int, user_id: int | None, channel_id: int | None, max_messages: int = 10, offset: int = 0, ) -> list[DiscordMessage]: - messages = session.query(DiscordMessage) + messages = session.query(DiscordMessage).filter( + DiscordMessage.recipient_id == bot_id + ) if user_id: - messages = messages.filter(DiscordMessage.recipient_id == user_id) + messages = messages.filter(DiscordMessage.from_id == user_id) if channel_id: messages = messages.filter(DiscordMessage.channel_id == channel_id) return list( @@ -154,15 +157,23 @@ def previous_messages( def comm_channel_prompt( session: Session | scoped_session, + bot: DiscordEntity, user: DiscordEntity, channel: DiscordEntity, max_messages: int = 10, ) -> str: user = resolve_discord_user(session, user) channel = resolve_discord_channel(session, channel) + bot = resolve_discord_user(session, bot) + if not bot: + raise ValueError("Bot not found") messages = previous_messages( - session, user and user.id, channel and channel.id, max_messages + session, + cast(int, bot.id), + user and user.id, + channel and channel.id, + max_messages, ) server_context = "" @@ -244,6 +255,7 @@ def call_llm( user_id = cast(int, from_user.id) prev_messages = previous_messages( session, + bot_user.system_user.discord_id, user_id, channel and channel.id, max_messages=num_previous_messages, @@ -263,18 +275,12 @@ def call_llm( if bot_user.system_prompt: system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "") message_content = [m.as_content() for m in prev_messages] + messages + return provider.run_with_tools( messages=provider.as_messages(message_content), tools=tools, system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""), - mcp_servers=[ - MCPServer( - name=str(server.name), - url=str(server.mcp_server_url), - token=str(server.access_token), - ) - for server in mcp_servers - ] + mcp_servers=[MCPServer.from_model(server) for server in mcp_servers] if mcp_servers else None, max_iterations=settings.DISCORD_MAX_TOOL_CALLS, diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index 05e85b3..1fc0348 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -113,7 +113,7 @@ def should_process(message: DiscordMessage) -> bool: system_prompt = message.system_prompt or "" system_prompt += comm_channel_prompt( - session, message.recipient_user, message.channel + session, message.recipient_user, message.from_user, message.channel ) allowed_tools = [ "update_channel_summary", @@ -199,7 +199,10 @@ def process_discord_message(message_id: int) -> dict[str, Any]: mcp_servers = discord_message.get_mcp_servers(session) system_prompt = discord_message.system_prompt or "" system_prompt += comm_channel_prompt( - session, discord_message.recipient_user, discord_message.channel + session, + discord_message.recipient_user, + discord_message.from_user, + discord_message.channel, ) try: diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py index b6c0ed0..8d83867 100644 --- a/src/memory/workers/tasks/scheduled_calls.py +++ b/src/memory/workers/tasks/scheduled_calls.py @@ -9,14 +9,14 @@ from memory.common.celery_app import ( app, ) from memory.common.db.connection import make_session -from memory.common.db.models import ScheduledLLMCall +from memory.common.db.models import ScheduledLLMCall, DiscordBotUser from memory.discord.messages import call_llm, send_discord_response from memory.workers.tasks.content_processing import safe_task_execution logger = logging.getLogger(__name__) -def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None: +def call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None: """Call LLM with tools support for scheduled calls.""" if not scheduled_call.discord_user: logger.warning("No discord_user for scheduled call - cannot execute") @@ -30,6 +30,7 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | bot_user = ( scheduled_call.user.discord_users and scheduled_call.user.discord_users[0] ) + return call_llm( session=session, bot_user=bot_user, @@ -42,18 +43,8 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | ) -def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str): +def send_to_discord(bot_id: int, scheduled_call: ScheduledLLMCall, response: str): """Send the LLM response to Discord user or channel.""" - bot_id_value = scheduled_call.user_id - if bot_id_value is None: - logger.warning( - "Scheduled call %s has no associated bot user; skipping Discord send", - scheduled_call.id, - ) - return - - bot_id = cast(int, bot_id_value) - send_discord_response( bot_id=bot_id, response=response, @@ -101,7 +92,7 @@ def execute_scheduled_call(self, scheduled_call_id: str): # Make the LLM call with tools support try: - response = _call_llm_for_scheduled(session, scheduled_call) + response = call_llm_for_scheduled(session, scheduled_call) except Exception: logger.exception("Failed to generate LLM response") scheduled_call.status = "failed" @@ -132,7 +123,7 @@ def execute_scheduled_call(self, scheduled_call_id: str): # Send to Discord try: - _send_to_discord(scheduled_call, response) + send_to_discord(cast(int, scheduled_call.user_id), scheduled_call, response) logger.info(f"Response sent to Discord for {scheduled_call_id}") except Exception as discord_error: logger.error(f"Failed to send to Discord: {discord_error}") diff --git a/tests/memory/workers/tasks/test_scheduled_calls.py b/tests/memory/workers/tasks/test_scheduled_calls.py index 5b49a8a..44d567d 100644 --- a/tests/memory/workers/tasks/test_scheduled_calls.py +++ b/tests/memory/workers/tasks/test_scheduled_calls.py @@ -127,7 +127,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call): """Test sending to Discord user.""" response = "This is a test response." - scheduled_calls._send_to_discord(pending_scheduled_call, response) + scheduled_calls.send_to_discord(pending_scheduled_call, response) mock_send_dm.assert_called_once_with( pending_scheduled_call.user_id, @@ -141,7 +141,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call): """Test sending to Discord channel.""" response = "This is a channel response." - scheduled_calls._send_to_discord(completed_scheduled_call, response) + scheduled_calls.send_to_discord(completed_scheduled_call, response) mock_broadcast.assert_called_once_with( completed_scheduled_call.user_id, @@ -155,7 +155,7 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled """Test message truncation for long responses.""" long_response = "A" * 2500 # Very long response - scheduled_calls._send_to_discord(pending_scheduled_call, long_response) + scheduled_calls.send_to_discord(pending_scheduled_call, long_response) # Verify the message was truncated args, kwargs = mock_send_dm.call_args @@ -170,7 +170,7 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c """Test that normal length messages are not truncated.""" normal_response = "This is a normal length response." - scheduled_calls._send_to_discord(pending_scheduled_call, normal_response) + scheduled_calls.send_to_discord(pending_scheduled_call, normal_response) args, kwargs = mock_send_dm.call_args assert args[0] == pending_scheduled_call.user_id @@ -535,7 +535,7 @@ def test_discord_destination_priority( db_session.commit() response = "Test response" - scheduled_calls._send_to_discord(call, response) + scheduled_calls.send_to_discord(call, response) if expected_method == "send_dm": mock_send_dm.assert_called_once() @@ -590,7 +590,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me mock_call.discord_user = mock_discord_user mock_call.discord_channel = None - scheduled_calls._send_to_discord(mock_call, response) + scheduled_calls.send_to_discord(mock_call, response) # Get the actual message that was sent args, kwargs = mock_send_dm.call_args