mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
fix scheduler
This commit is contained in:
parent
470061bd43
commit
56ed7b7d8f
@ -4,11 +4,11 @@ MCP tools for the epistemic sparring partner system.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from memory.api.MCP.base import get_current_user, mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import ScheduledLLMCall
|
||||
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
|
||||
from memory.discord.messages import schedule_discord_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -71,11 +71,15 @@ async def schedule_message(
|
||||
raise ValueError("Invalid datetime format")
|
||||
|
||||
with make_session() as session:
|
||||
bot = session.query(DiscordBotUser).first()
|
||||
if not bot:
|
||||
return {"error": "No bot found"}
|
||||
|
||||
scheduled_call = schedule_discord_message(
|
||||
session=session,
|
||||
scheduled_time=scheduled_dt,
|
||||
message=message,
|
||||
user_id=current_user.get("user", {}).get("user_id"),
|
||||
user_id=cast(int, bot.id),
|
||||
model=model,
|
||||
topic=topic,
|
||||
discord_channel=discord_channel,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from sqlalchemy.orm import Session as DBSession
|
||||
@ -15,6 +16,7 @@ from memory.common.db.models import (
|
||||
User,
|
||||
UserSession,
|
||||
)
|
||||
from memory.common.mcp import mcp_tools_list
|
||||
from memory.common.oauth import complete_oauth_flow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -171,9 +173,24 @@ async def oauth_callback_discord(request: Request):
|
||||
mcp_server = (
|
||||
session.query(MCPServer).filter(MCPServer.state == state).first()
|
||||
)
|
||||
if not mcp_server:
|
||||
return Response(
|
||||
content="MCP server not found",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
||||
session.commit()
|
||||
|
||||
tools = await mcp_tools_list(
|
||||
cast(str, mcp_server.mcp_server_url), cast(str, mcp_server.access_token)
|
||||
)
|
||||
mcp_server.available_tools = [
|
||||
name for tool in tools if (name := tool.get("name"))
|
||||
]
|
||||
session.commit()
|
||||
logger.info(f"MCP server tools: {tools}")
|
||||
|
||||
if 200 <= status_code < 300:
|
||||
title = "✅ Authorization Successful!"
|
||||
close = "You can close this window and return to the MCP server."
|
||||
|
||||
@ -28,6 +28,7 @@ class MCPServer(Base):
|
||||
mcp_server_url = Column(Text, nullable=False)
|
||||
client_id = Column(Text, nullable=False)
|
||||
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
disabled_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
# OAuth flow state (temporary, cleared after token exchange)
|
||||
state = Column(Text, nullable=True, unique=True)
|
||||
|
||||
@ -401,10 +401,10 @@ class DiscordMessage(SourceItem):
|
||||
filter(
|
||||
None,
|
||||
[
|
||||
self.recipient_user.id,
|
||||
self.from_user.id,
|
||||
self.channel.id,
|
||||
self.server.id,
|
||||
self.recipient_id,
|
||||
self.from_id,
|
||||
self.channel_id,
|
||||
self.server_id,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@ -145,6 +145,12 @@ class DiscordBotUser(BotUser):
|
||||
bot.discord_users = discord_users
|
||||
return bot
|
||||
|
||||
@property
|
||||
def discord_id(self) -> int | None:
|
||||
if not self.discord_users:
|
||||
return None
|
||||
return self.discord_users[0].id
|
||||
|
||||
|
||||
class UserSession(Base):
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, AsyncIterator, Iterator
|
||||
|
||||
import anthropic
|
||||
@ -283,9 +284,10 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Include server info if present
|
||||
if current_tool_use.get("server_name"):
|
||||
tool_data["server_name"] = current_tool_use["server_name"]
|
||||
if current_tool_use.get("is_server_call"):
|
||||
tool_data["is_server_call"] = current_tool_use["is_server_call"]
|
||||
|
||||
# Emit different event type for MCP server tools
|
||||
if current_tool_use.get("is_server_call"):
|
||||
return StreamEvent(type="server_tool_use", data=tool_data), None
|
||||
return StreamEvent(type="tool_use", data=tool_data), None
|
||||
|
||||
elif event_type == "message_delta":
|
||||
|
||||
@ -189,7 +189,15 @@ class Message:
|
||||
class StreamEvent:
|
||||
"""An event from the streaming response."""
|
||||
|
||||
type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"]
|
||||
type: Literal[
|
||||
"text",
|
||||
"tool_use",
|
||||
"server_tool_use",
|
||||
"tool_result",
|
||||
"thinking",
|
||||
"error",
|
||||
"done",
|
||||
]
|
||||
data: Any = None
|
||||
signature: str | None = None
|
||||
|
||||
@ -565,9 +573,31 @@ class BaseLLMProvider(ABC):
|
||||
elif event.type == "thinking":
|
||||
thinking.thinking += event.data
|
||||
yield event
|
||||
elif event.type == "tool_use":
|
||||
elif event.type == "server_tool_use":
|
||||
# MCP server tools are executed by Anthropic's backend
|
||||
# Results will come as separate tool_result events
|
||||
yield event
|
||||
|
||||
# Track the MCP tool call but don't execute locally
|
||||
messages.append(
|
||||
Message.assistant(
|
||||
response,
|
||||
thinking,
|
||||
ToolUseContent(
|
||||
id=event.data["id"],
|
||||
name=event.data["name"],
|
||||
input=event.data["input"],
|
||||
),
|
||||
)
|
||||
)
|
||||
# Reset response for next turn
|
||||
response = TextContent(text="")
|
||||
thinking = ThinkingContent(thinking="")
|
||||
# Continue streaming to get the tool result from Anthropic
|
||||
|
||||
elif event.type == "tool_use":
|
||||
# Execute local tools
|
||||
yield event
|
||||
# Execute the tool and yield the result
|
||||
tool_result = self.execute_tool(event.data, tools)
|
||||
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
||||
|
||||
@ -598,6 +628,27 @@ class BaseLLMProvider(ABC):
|
||||
)
|
||||
return # Exit after recursive call completes
|
||||
|
||||
elif event.type == "tool_result":
|
||||
yield event
|
||||
# Add user message with the result
|
||||
tool_result_content = ToolResultContent(
|
||||
tool_use_id=event.data["id"],
|
||||
content=str(event.data.get("result", "")),
|
||||
is_error=event.data.get("is_error", False),
|
||||
)
|
||||
messages.append(Message.user(tool_result=tool_result_content))
|
||||
|
||||
# Continue conversation with reduced iterations
|
||||
yield from self.stream_with_tools(
|
||||
messages,
|
||||
tools,
|
||||
mcp_servers,
|
||||
settings,
|
||||
system_prompt,
|
||||
max_iterations - 1,
|
||||
)
|
||||
return # Exit after recursive call completes
|
||||
|
||||
elif event.type == "error":
|
||||
logger.error(f"LLM error: {event.data}")
|
||||
raise RuntimeError(f"LLM error: {event.data}")
|
||||
@ -625,7 +676,7 @@ class BaseLLMProvider(ABC):
|
||||
):
|
||||
if event.type == "thinking":
|
||||
thinking += event.data
|
||||
elif event.type == "tool_use":
|
||||
elif event.type == "tool_use" or event.type == "server_tool_use":
|
||||
tool_calls[event.data["id"]] = {
|
||||
"name": event.data["name"],
|
||||
"input": event.data["input"],
|
||||
@ -634,11 +685,15 @@ class BaseLLMProvider(ABC):
|
||||
elif event.type == "text":
|
||||
response += event.data
|
||||
elif event.type == "tool_result":
|
||||
current = tool_calls.get(event.data["tool_use_id"]) or {}
|
||||
tool_calls[event.data["tool_use_id"]] = {
|
||||
tool_id = event.data.get("id") or event.data.get("tool_use_id")
|
||||
if not tool_id:
|
||||
logger.warning(f"tool_result event missing id: {event.data}")
|
||||
continue
|
||||
current = tool_calls.get(tool_id) or {}
|
||||
tool_calls[tool_id] = {
|
||||
"name": event.data.get("name") or current.get("name"),
|
||||
"input": event.data.get("input") or current.get("input"),
|
||||
"output": event.data.get("content"),
|
||||
"output": event.data.get("content") or event.data.get("result"),
|
||||
}
|
||||
return Turn(
|
||||
thinking=thinking or None,
|
||||
|
||||
@ -32,6 +32,18 @@ class MCPServer:
|
||||
token: str
|
||||
allowed_tools: list[str] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model) -> "MCPServer":
|
||||
allowed_tools = None
|
||||
if model.disabled_tools:
|
||||
allowed_tools = list(set(model.available_tools) - set(model.disabled_tools))
|
||||
return cls(
|
||||
name=model.name,
|
||||
url=model.mcp_server_url,
|
||||
token=model.access_token,
|
||||
allowed_tools=allowed_tools,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
|
||||
@ -14,7 +14,7 @@ from memory.common.db.models import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
BotUser,
|
||||
DiscordBotUser,
|
||||
)
|
||||
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
|
||||
from memory.common.discord import add_reaction
|
||||
@ -83,9 +83,9 @@ def make_summary_tool(type: UpdateSummaryType, item_id: BigInteger) -> ToolDefin
|
||||
|
||||
|
||||
def schedule_message(
|
||||
user_id: int,
|
||||
user: int | None,
|
||||
channel: int | None,
|
||||
bot_id: int,
|
||||
recipient_id: int | None,
|
||||
channel_id: int | None,
|
||||
model: str,
|
||||
message: str,
|
||||
date_time: datetime,
|
||||
@ -95,29 +95,40 @@ def schedule_message(
|
||||
session,
|
||||
scheduled_time=date_time,
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
user_id=bot_id,
|
||||
model=model,
|
||||
discord_user=user,
|
||||
discord_channel=channel,
|
||||
system_prompt=comm_channel_prompt(session, user, channel),
|
||||
discord_user=recipient_id,
|
||||
discord_channel=channel_id,
|
||||
)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Scheduled message: {call}")
|
||||
logger.error(f"Scheduled message: {call.id}")
|
||||
logger.error(f"Scheduled message time: {call.scheduled_time}")
|
||||
logger.error(f"Scheduled message message: {call.message}")
|
||||
logger.error(f"Scheduled message model: {call.model}")
|
||||
logger.error(f"Scheduled message user id: {call.user_id}")
|
||||
logger.error(f"Scheduled message discord user id: {call.discord_user_id}")
|
||||
logger.error(f"Scheduled message discord channel id: {call.discord_channel_id}")
|
||||
|
||||
session.commit()
|
||||
return cast(str, call.id)
|
||||
|
||||
|
||||
def make_message_scheduler(
|
||||
bot: BotUser, user: int | None, channel: int | None, model: str
|
||||
bot: DiscordBotUser, user_id: int | None, channel_id: int | None, model: str
|
||||
) -> ToolDefinition:
|
||||
bot_id = cast(int, bot.id)
|
||||
if user:
|
||||
if user_id:
|
||||
channel_type = "from your chat with this user"
|
||||
elif channel:
|
||||
elif channel_id:
|
||||
channel_type = "in this channel"
|
||||
else:
|
||||
raise ValueError("Either user or channel must be provided")
|
||||
|
||||
def handler(input: ToolInput) -> str:
|
||||
try:
|
||||
if not isinstance(input, dict):
|
||||
raise ValueError("Input must be a dictionary")
|
||||
|
||||
@ -128,10 +139,18 @@ def make_message_scheduler(
|
||||
except KeyError:
|
||||
raise ValueError("Date time is required")
|
||||
|
||||
return schedule_message(bot_id, user, channel, model, input["message"], time)
|
||||
return schedule_message(
|
||||
bot_id, user_id, channel_id, model, input["message"], time
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Error scheduling message: {e}")
|
||||
raise e
|
||||
|
||||
return ToolDefinition(
|
||||
name="schedule_message",
|
||||
name="schedule_discord_message",
|
||||
description=textwrap.dedent("""
|
||||
Use this to schedule a message to be sent to yourself.
|
||||
|
||||
@ -162,10 +181,13 @@ def make_message_scheduler(
|
||||
)
|
||||
|
||||
|
||||
def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefinition:
|
||||
if user:
|
||||
def make_prev_messages_tool(
|
||||
bot: DiscordBotUser, user_id: int | None, channel_id: int | None
|
||||
) -> ToolDefinition:
|
||||
bot_id = bot.discord_id
|
||||
if user_id:
|
||||
channel_type = "from your chat with this user"
|
||||
elif channel:
|
||||
elif channel_id:
|
||||
channel_type = "in this channel"
|
||||
else:
|
||||
raise ValueError("Either user or channel must be provided")
|
||||
@ -185,7 +207,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
|
||||
raise ValueError("Offset must be greater than or equal to 0")
|
||||
|
||||
with make_session() as session:
|
||||
messages = previous_messages(session, user, channel, max_messages, offset)
|
||||
messages = previous_messages(
|
||||
session, bot_id, user_id, channel_id, max_messages, offset
|
||||
)
|
||||
return "\n\n".join([msg.title for msg in messages])
|
||||
|
||||
return ToolDefinition(
|
||||
@ -210,7 +234,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
|
||||
)
|
||||
|
||||
|
||||
def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition:
|
||||
def make_add_reaction_tool(
|
||||
bot: DiscordBotUser, channel: DiscordChannel
|
||||
) -> ToolDefinition:
|
||||
bot_id = cast(int, bot.id)
|
||||
channel_id = channel and channel.id
|
||||
|
||||
@ -255,7 +281,7 @@ def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinit
|
||||
|
||||
|
||||
def make_discord_tools(
|
||||
bot: BotUser,
|
||||
bot: DiscordBotUser,
|
||||
author: DiscordUser | None,
|
||||
channel: DiscordChannel | None,
|
||||
model: str,
|
||||
@ -264,7 +290,7 @@ def make_discord_tools(
|
||||
channel_id = channel and channel.id
|
||||
tools = [
|
||||
make_message_scheduler(bot, author_id, channel_id, model),
|
||||
make_prev_messages_tool(author_id, channel_id),
|
||||
make_prev_messages_tool(bot, author_id, channel_id),
|
||||
make_summary_tool("channel", channel_id),
|
||||
]
|
||||
if author:
|
||||
|
||||
58
src/memory/common/mcp.py
Normal file
58
src/memory/common/mcp.py
Normal file
@ -0,0 +1,58 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def mcp_call(
|
||||
url: str, access_token: str, method: str, params: dict = {}
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": int(time.time() * 1000),
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"Tools list failed: {resp.status} - {error_text}")
|
||||
raise ValueError(
|
||||
f"Failed to call MCP server: {resp.status} - {error_text}"
|
||||
)
|
||||
|
||||
# Parse SSE stream
|
||||
async for line in resp.content:
|
||||
line_str = line.decode("utf-8").strip()
|
||||
|
||||
# SSE format: "data: {json}"
|
||||
if line_str.startswith("data: "):
|
||||
json_str = line_str[6:] # Remove "data: " prefix
|
||||
try:
|
||||
yield json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
continue # Skip invalid JSON lines
|
||||
|
||||
|
||||
async def mcp_tools_list(url: str, access_token: str) -> list[dict]:
|
||||
async for data in mcp_call(url, access_token, "tools/list"):
|
||||
if "result" in data and "tools" in data["result"]:
|
||||
return data["result"]["tools"]
|
||||
return []
|
||||
@ -108,6 +108,7 @@ async def send_dm_endpoint(request: SendDMRequest):
|
||||
"""Send a DM via the collector's Discord client"""
|
||||
collector = app.bots.get(request.bot_id)
|
||||
if not collector:
|
||||
logger.error(f"Bot not found: {request.bot_id}")
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
|
||||
try:
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
"""Lightweight slash-command helpers for the Discord collector."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Literal, cast
|
||||
from typing import Literal, cast
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
@ -12,6 +10,7 @@ from sqlalchemy.orm import Session, scoped_session
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||
from memory.common.mcp import mcp_tools_list
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -33,49 +32,6 @@ def find_mcp_server(
|
||||
return assignment and assignment.mcp_server
|
||||
|
||||
|
||||
async def call_mcp_server(
|
||||
url: str, access_token: str, method: str, params: dict = {}
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": int(time.time() * 1000),
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"Tools list failed: {resp.status} - {error_text}")
|
||||
raise ValueError(
|
||||
f"Failed to call MCP server: {resp.status} - {error_text}"
|
||||
)
|
||||
|
||||
# Parse SSE stream
|
||||
async for line in resp.content:
|
||||
line_str = line.decode("utf-8").strip()
|
||||
|
||||
# SSE format: "data: {json}"
|
||||
if line_str.startswith("data: "):
|
||||
json_str = line_str[6:] # Remove "data: " prefix
|
||||
try:
|
||||
yield json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
continue # Skip invalid JSON lines
|
||||
|
||||
|
||||
async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
|
||||
"""List all MCP servers for the user."""
|
||||
with make_session() as session:
|
||||
@ -258,10 +214,7 @@ async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
|
||||
# Make JSON-RPC request to MCP server
|
||||
tools = None
|
||||
try:
|
||||
async for data in call_mcp_server(url, access_token, "tools/list"):
|
||||
if "result" in data and "tools" in data["result"]:
|
||||
tools = data["result"]["tools"]
|
||||
break
|
||||
tools = await mcp_tools_list(url, access_token)
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.exception(f"Failed to connect to MCP server: {exc}")
|
||||
raise ValueError(
|
||||
|
||||
@ -132,14 +132,17 @@ def upsert_scheduled_message(
|
||||
|
||||
def previous_messages(
|
||||
session: Session | scoped_session,
|
||||
bot_id: int,
|
||||
user_id: int | None,
|
||||
channel_id: int | None,
|
||||
max_messages: int = 10,
|
||||
offset: int = 0,
|
||||
) -> list[DiscordMessage]:
|
||||
messages = session.query(DiscordMessage)
|
||||
messages = session.query(DiscordMessage).filter(
|
||||
DiscordMessage.recipient_id == bot_id
|
||||
)
|
||||
if user_id:
|
||||
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
||||
messages = messages.filter(DiscordMessage.from_id == user_id)
|
||||
if channel_id:
|
||||
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
||||
return list(
|
||||
@ -154,15 +157,23 @@ def previous_messages(
|
||||
|
||||
def comm_channel_prompt(
|
||||
session: Session | scoped_session,
|
||||
bot: DiscordEntity,
|
||||
user: DiscordEntity,
|
||||
channel: DiscordEntity,
|
||||
max_messages: int = 10,
|
||||
) -> str:
|
||||
user = resolve_discord_user(session, user)
|
||||
channel = resolve_discord_channel(session, channel)
|
||||
bot = resolve_discord_user(session, bot)
|
||||
if not bot:
|
||||
raise ValueError("Bot not found")
|
||||
|
||||
messages = previous_messages(
|
||||
session, user and user.id, channel and channel.id, max_messages
|
||||
session,
|
||||
cast(int, bot.id),
|
||||
user and user.id,
|
||||
channel and channel.id,
|
||||
max_messages,
|
||||
)
|
||||
|
||||
server_context = ""
|
||||
@ -244,6 +255,7 @@ def call_llm(
|
||||
user_id = cast(int, from_user.id)
|
||||
prev_messages = previous_messages(
|
||||
session,
|
||||
bot_user.system_user.discord_id,
|
||||
user_id,
|
||||
channel and channel.id,
|
||||
max_messages=num_previous_messages,
|
||||
@ -263,18 +275,12 @@ def call_llm(
|
||||
if bot_user.system_prompt:
|
||||
system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "")
|
||||
message_content = [m.as_content() for m in prev_messages] + messages
|
||||
|
||||
return provider.run_with_tools(
|
||||
messages=provider.as_messages(message_content),
|
||||
tools=tools,
|
||||
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
|
||||
mcp_servers=[
|
||||
MCPServer(
|
||||
name=str(server.name),
|
||||
url=str(server.mcp_server_url),
|
||||
token=str(server.access_token),
|
||||
)
|
||||
for server in mcp_servers
|
||||
]
|
||||
mcp_servers=[MCPServer.from_model(server) for server in mcp_servers]
|
||||
if mcp_servers
|
||||
else None,
|
||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||
|
||||
@ -113,7 +113,7 @@ def should_process(message: DiscordMessage) -> bool:
|
||||
|
||||
system_prompt = message.system_prompt or ""
|
||||
system_prompt += comm_channel_prompt(
|
||||
session, message.recipient_user, message.channel
|
||||
session, message.recipient_user, message.from_user, message.channel
|
||||
)
|
||||
allowed_tools = [
|
||||
"update_channel_summary",
|
||||
@ -199,7 +199,10 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
mcp_servers = discord_message.get_mcp_servers(session)
|
||||
system_prompt = discord_message.system_prompt or ""
|
||||
system_prompt += comm_channel_prompt(
|
||||
session, discord_message.recipient_user, discord_message.channel
|
||||
session,
|
||||
discord_message.recipient_user,
|
||||
discord_message.from_user,
|
||||
discord_message.channel,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -9,14 +9,14 @@ from memory.common.celery_app import (
|
||||
app,
|
||||
)
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import ScheduledLLMCall
|
||||
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
|
||||
from memory.discord.messages import call_llm, send_discord_response
|
||||
from memory.workers.tasks.content_processing import safe_task_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
|
||||
def call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
|
||||
"""Call LLM with tools support for scheduled calls."""
|
||||
if not scheduled_call.discord_user:
|
||||
logger.warning("No discord_user for scheduled call - cannot execute")
|
||||
@ -30,6 +30,7 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str |
|
||||
bot_user = (
|
||||
scheduled_call.user.discord_users and scheduled_call.user.discord_users[0]
|
||||
)
|
||||
|
||||
return call_llm(
|
||||
session=session,
|
||||
bot_user=bot_user,
|
||||
@ -42,18 +43,8 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str |
|
||||
)
|
||||
|
||||
|
||||
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
||||
def send_to_discord(bot_id: int, scheduled_call: ScheduledLLMCall, response: str):
|
||||
"""Send the LLM response to Discord user or channel."""
|
||||
bot_id_value = scheduled_call.user_id
|
||||
if bot_id_value is None:
|
||||
logger.warning(
|
||||
"Scheduled call %s has no associated bot user; skipping Discord send",
|
||||
scheduled_call.id,
|
||||
)
|
||||
return
|
||||
|
||||
bot_id = cast(int, bot_id_value)
|
||||
|
||||
send_discord_response(
|
||||
bot_id=bot_id,
|
||||
response=response,
|
||||
@ -101,7 +92,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
||||
|
||||
# Make the LLM call with tools support
|
||||
try:
|
||||
response = _call_llm_for_scheduled(session, scheduled_call)
|
||||
response = call_llm_for_scheduled(session, scheduled_call)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate LLM response")
|
||||
scheduled_call.status = "failed"
|
||||
@ -132,7 +123,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
||||
|
||||
# Send to Discord
|
||||
try:
|
||||
_send_to_discord(scheduled_call, response)
|
||||
send_to_discord(cast(int, scheduled_call.user_id), scheduled_call, response)
|
||||
logger.info(f"Response sent to Discord for {scheduled_call_id}")
|
||||
except Exception as discord_error:
|
||||
logger.error(f"Failed to send to Discord: {discord_error}")
|
||||
|
||||
@ -127,7 +127,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
||||
"""Test sending to Discord user."""
|
||||
response = "This is a test response."
|
||||
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
||||
scheduled_calls.send_to_discord(pending_scheduled_call, response)
|
||||
|
||||
mock_send_dm.assert_called_once_with(
|
||||
pending_scheduled_call.user_id,
|
||||
@ -141,7 +141,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
||||
"""Test sending to Discord channel."""
|
||||
response = "This is a channel response."
|
||||
|
||||
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
||||
scheduled_calls.send_to_discord(completed_scheduled_call, response)
|
||||
|
||||
mock_broadcast.assert_called_once_with(
|
||||
completed_scheduled_call.user_id,
|
||||
@ -155,7 +155,7 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
|
||||
"""Test message truncation for long responses."""
|
||||
long_response = "A" * 2500 # Very long response
|
||||
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, long_response)
|
||||
scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
|
||||
|
||||
# Verify the message was truncated
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
@ -170,7 +170,7 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
|
||||
"""Test that normal length messages are not truncated."""
|
||||
normal_response = "This is a normal length response."
|
||||
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
||||
scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
|
||||
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
assert args[0] == pending_scheduled_call.user_id
|
||||
@ -535,7 +535,7 @@ def test_discord_destination_priority(
|
||||
db_session.commit()
|
||||
|
||||
response = "Test response"
|
||||
scheduled_calls._send_to_discord(call, response)
|
||||
scheduled_calls.send_to_discord(call, response)
|
||||
|
||||
if expected_method == "send_dm":
|
||||
mock_send_dm.assert_called_once()
|
||||
@ -590,7 +590,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
mock_call.discord_user = mock_discord_user
|
||||
mock_call.discord_channel = None
|
||||
|
||||
scheduled_calls._send_to_discord(mock_call, response)
|
||||
scheduled_calls.send_to_discord(mock_call, response)
|
||||
|
||||
# Get the actual message that was sent
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user