mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
fix scheduler
This commit is contained in:
parent
470061bd43
commit
56ed7b7d8f
@ -4,11 +4,11 @@ MCP tools for the epistemic sparring partner system.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from memory.api.MCP.base import get_current_user, mcp
|
from memory.api.MCP.base import get_current_user, mcp
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import ScheduledLLMCall
|
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
|
||||||
from memory.discord.messages import schedule_discord_message
|
from memory.discord.messages import schedule_discord_message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -71,11 +71,15 @@ async def schedule_message(
|
|||||||
raise ValueError("Invalid datetime format")
|
raise ValueError("Invalid datetime format")
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
|
bot = session.query(DiscordBotUser).first()
|
||||||
|
if not bot:
|
||||||
|
return {"error": "No bot found"}
|
||||||
|
|
||||||
scheduled_call = schedule_discord_message(
|
scheduled_call = schedule_discord_message(
|
||||||
session=session,
|
session=session,
|
||||||
scheduled_time=scheduled_dt,
|
scheduled_time=scheduled_dt,
|
||||||
message=message,
|
message=message,
|
||||||
user_id=current_user.get("user", {}).get("user_id"),
|
user_id=cast(int, bot.id),
|
||||||
model=model,
|
model=model,
|
||||||
topic=topic,
|
topic=topic,
|
||||||
discord_channel=discord_channel,
|
discord_channel=discord_channel,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from sqlalchemy.orm import Session as DBSession
|
from sqlalchemy.orm import Session as DBSession
|
||||||
@ -15,6 +16,7 @@ from memory.common.db.models import (
|
|||||||
User,
|
User,
|
||||||
UserSession,
|
UserSession,
|
||||||
)
|
)
|
||||||
|
from memory.common.mcp import mcp_tools_list
|
||||||
from memory.common.oauth import complete_oauth_flow
|
from memory.common.oauth import complete_oauth_flow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -171,9 +173,24 @@ async def oauth_callback_discord(request: Request):
|
|||||||
mcp_server = (
|
mcp_server = (
|
||||||
session.query(MCPServer).filter(MCPServer.state == state).first()
|
session.query(MCPServer).filter(MCPServer.state == state).first()
|
||||||
)
|
)
|
||||||
|
if not mcp_server:
|
||||||
|
return Response(
|
||||||
|
content="MCP server not found",
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
tools = await mcp_tools_list(
|
||||||
|
cast(str, mcp_server.mcp_server_url), cast(str, mcp_server.access_token)
|
||||||
|
)
|
||||||
|
mcp_server.available_tools = [
|
||||||
|
name for tool in tools if (name := tool.get("name"))
|
||||||
|
]
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"MCP server tools: {tools}")
|
||||||
|
|
||||||
if 200 <= status_code < 300:
|
if 200 <= status_code < 300:
|
||||||
title = "✅ Authorization Successful!"
|
title = "✅ Authorization Successful!"
|
||||||
close = "You can close this window and return to the MCP server."
|
close = "You can close this window and return to the MCP server."
|
||||||
|
|||||||
@ -28,6 +28,7 @@ class MCPServer(Base):
|
|||||||
mcp_server_url = Column(Text, nullable=False)
|
mcp_server_url = Column(Text, nullable=False)
|
||||||
client_id = Column(Text, nullable=False)
|
client_id = Column(Text, nullable=False)
|
||||||
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
disabled_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
# OAuth flow state (temporary, cleared after token exchange)
|
# OAuth flow state (temporary, cleared after token exchange)
|
||||||
state = Column(Text, nullable=True, unique=True)
|
state = Column(Text, nullable=True, unique=True)
|
||||||
|
|||||||
@ -401,10 +401,10 @@ class DiscordMessage(SourceItem):
|
|||||||
filter(
|
filter(
|
||||||
None,
|
None,
|
||||||
[
|
[
|
||||||
self.recipient_user.id,
|
self.recipient_id,
|
||||||
self.from_user.id,
|
self.from_id,
|
||||||
self.channel.id,
|
self.channel_id,
|
||||||
self.server.id,
|
self.server_id,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -145,6 +145,12 @@ class DiscordBotUser(BotUser):
|
|||||||
bot.discord_users = discord_users
|
bot.discord_users = discord_users
|
||||||
return bot
|
return bot
|
||||||
|
|
||||||
|
@property
|
||||||
|
def discord_id(self) -> int | None:
|
||||||
|
if not self.discord_users:
|
||||||
|
return None
|
||||||
|
return self.discord_users[0].id
|
||||||
|
|
||||||
|
|
||||||
class UserSession(Base):
|
class UserSession(Base):
|
||||||
__tablename__ = "user_sessions"
|
__tablename__ = "user_sessions"
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from urllib.parse import urlparse
|
||||||
from typing import Any, AsyncIterator, Iterator
|
from typing import Any, AsyncIterator, Iterator
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
@ -283,9 +284,10 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
# Include server info if present
|
# Include server info if present
|
||||||
if current_tool_use.get("server_name"):
|
if current_tool_use.get("server_name"):
|
||||||
tool_data["server_name"] = current_tool_use["server_name"]
|
tool_data["server_name"] = current_tool_use["server_name"]
|
||||||
if current_tool_use.get("is_server_call"):
|
|
||||||
tool_data["is_server_call"] = current_tool_use["is_server_call"]
|
|
||||||
|
|
||||||
|
# Emit different event type for MCP server tools
|
||||||
|
if current_tool_use.get("is_server_call"):
|
||||||
|
return StreamEvent(type="server_tool_use", data=tool_data), None
|
||||||
return StreamEvent(type="tool_use", data=tool_data), None
|
return StreamEvent(type="tool_use", data=tool_data), None
|
||||||
|
|
||||||
elif event_type == "message_delta":
|
elif event_type == "message_delta":
|
||||||
|
|||||||
@ -189,7 +189,15 @@ class Message:
|
|||||||
class StreamEvent:
|
class StreamEvent:
|
||||||
"""An event from the streaming response."""
|
"""An event from the streaming response."""
|
||||||
|
|
||||||
type: Literal["text", "tool_use", "tool_result", "thinking", "error", "done"]
|
type: Literal[
|
||||||
|
"text",
|
||||||
|
"tool_use",
|
||||||
|
"server_tool_use",
|
||||||
|
"tool_result",
|
||||||
|
"thinking",
|
||||||
|
"error",
|
||||||
|
"done",
|
||||||
|
]
|
||||||
data: Any = None
|
data: Any = None
|
||||||
signature: str | None = None
|
signature: str | None = None
|
||||||
|
|
||||||
@ -565,9 +573,31 @@ class BaseLLMProvider(ABC):
|
|||||||
elif event.type == "thinking":
|
elif event.type == "thinking":
|
||||||
thinking.thinking += event.data
|
thinking.thinking += event.data
|
||||||
yield event
|
yield event
|
||||||
elif event.type == "tool_use":
|
elif event.type == "server_tool_use":
|
||||||
|
# MCP server tools are executed by Anthropic's backend
|
||||||
|
# Results will come as separate tool_result events
|
||||||
|
yield event
|
||||||
|
|
||||||
|
# Track the MCP tool call but don't execute locally
|
||||||
|
messages.append(
|
||||||
|
Message.assistant(
|
||||||
|
response,
|
||||||
|
thinking,
|
||||||
|
ToolUseContent(
|
||||||
|
id=event.data["id"],
|
||||||
|
name=event.data["name"],
|
||||||
|
input=event.data["input"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Reset response for next turn
|
||||||
|
response = TextContent(text="")
|
||||||
|
thinking = ThinkingContent(thinking="")
|
||||||
|
# Continue streaming to get the tool result from Anthropic
|
||||||
|
|
||||||
|
elif event.type == "tool_use":
|
||||||
|
# Execute local tools
|
||||||
yield event
|
yield event
|
||||||
# Execute the tool and yield the result
|
|
||||||
tool_result = self.execute_tool(event.data, tools)
|
tool_result = self.execute_tool(event.data, tools)
|
||||||
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
yield StreamEvent(type="tool_result", data=tool_result.to_dict())
|
||||||
|
|
||||||
@ -598,6 +628,27 @@ class BaseLLMProvider(ABC):
|
|||||||
)
|
)
|
||||||
return # Exit after recursive call completes
|
return # Exit after recursive call completes
|
||||||
|
|
||||||
|
elif event.type == "tool_result":
|
||||||
|
yield event
|
||||||
|
# Add user message with the result
|
||||||
|
tool_result_content = ToolResultContent(
|
||||||
|
tool_use_id=event.data["id"],
|
||||||
|
content=str(event.data.get("result", "")),
|
||||||
|
is_error=event.data.get("is_error", False),
|
||||||
|
)
|
||||||
|
messages.append(Message.user(tool_result=tool_result_content))
|
||||||
|
|
||||||
|
# Continue conversation with reduced iterations
|
||||||
|
yield from self.stream_with_tools(
|
||||||
|
messages,
|
||||||
|
tools,
|
||||||
|
mcp_servers,
|
||||||
|
settings,
|
||||||
|
system_prompt,
|
||||||
|
max_iterations - 1,
|
||||||
|
)
|
||||||
|
return # Exit after recursive call completes
|
||||||
|
|
||||||
elif event.type == "error":
|
elif event.type == "error":
|
||||||
logger.error(f"LLM error: {event.data}")
|
logger.error(f"LLM error: {event.data}")
|
||||||
raise RuntimeError(f"LLM error: {event.data}")
|
raise RuntimeError(f"LLM error: {event.data}")
|
||||||
@ -625,7 +676,7 @@ class BaseLLMProvider(ABC):
|
|||||||
):
|
):
|
||||||
if event.type == "thinking":
|
if event.type == "thinking":
|
||||||
thinking += event.data
|
thinking += event.data
|
||||||
elif event.type == "tool_use":
|
elif event.type == "tool_use" or event.type == "server_tool_use":
|
||||||
tool_calls[event.data["id"]] = {
|
tool_calls[event.data["id"]] = {
|
||||||
"name": event.data["name"],
|
"name": event.data["name"],
|
||||||
"input": event.data["input"],
|
"input": event.data["input"],
|
||||||
@ -634,11 +685,15 @@ class BaseLLMProvider(ABC):
|
|||||||
elif event.type == "text":
|
elif event.type == "text":
|
||||||
response += event.data
|
response += event.data
|
||||||
elif event.type == "tool_result":
|
elif event.type == "tool_result":
|
||||||
current = tool_calls.get(event.data["tool_use_id"]) or {}
|
tool_id = event.data.get("id") or event.data.get("tool_use_id")
|
||||||
tool_calls[event.data["tool_use_id"]] = {
|
if not tool_id:
|
||||||
|
logger.warning(f"tool_result event missing id: {event.data}")
|
||||||
|
continue
|
||||||
|
current = tool_calls.get(tool_id) or {}
|
||||||
|
tool_calls[tool_id] = {
|
||||||
"name": event.data.get("name") or current.get("name"),
|
"name": event.data.get("name") or current.get("name"),
|
||||||
"input": event.data.get("input") or current.get("input"),
|
"input": event.data.get("input") or current.get("input"),
|
||||||
"output": event.data.get("content"),
|
"output": event.data.get("content") or event.data.get("result"),
|
||||||
}
|
}
|
||||||
return Turn(
|
return Turn(
|
||||||
thinking=thinking or None,
|
thinking=thinking or None,
|
||||||
|
|||||||
@ -32,6 +32,18 @@ class MCPServer:
|
|||||||
token: str
|
token: str
|
||||||
allowed_tools: list[str] | None = None
|
allowed_tools: list[str] | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model(cls, model) -> "MCPServer":
|
||||||
|
allowed_tools = None
|
||||||
|
if model.disabled_tools:
|
||||||
|
allowed_tools = list(set(model.available_tools) - set(model.disabled_tools))
|
||||||
|
return cls(
|
||||||
|
name=model.name,
|
||||||
|
url=model.mcp_server_url,
|
||||||
|
token=model.access_token,
|
||||||
|
allowed_tools=allowed_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolDefinition:
|
class ToolDefinition:
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from memory.common.db.models import (
|
|||||||
DiscordServer,
|
DiscordServer,
|
||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
BotUser,
|
DiscordBotUser,
|
||||||
)
|
)
|
||||||
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
|
from memory.common.llms.tools import ToolDefinition, ToolInput, ToolHandler
|
||||||
from memory.common.discord import add_reaction
|
from memory.common.discord import add_reaction
|
||||||
@ -83,9 +83,9 @@ def make_summary_tool(type: UpdateSummaryType, item_id: BigInteger) -> ToolDefin
|
|||||||
|
|
||||||
|
|
||||||
def schedule_message(
|
def schedule_message(
|
||||||
user_id: int,
|
bot_id: int,
|
||||||
user: int | None,
|
recipient_id: int | None,
|
||||||
channel: int | None,
|
channel_id: int | None,
|
||||||
model: str,
|
model: str,
|
||||||
message: str,
|
message: str,
|
||||||
date_time: datetime,
|
date_time: datetime,
|
||||||
@ -95,43 +95,62 @@ def schedule_message(
|
|||||||
session,
|
session,
|
||||||
scheduled_time=date_time,
|
scheduled_time=date_time,
|
||||||
message=message,
|
message=message,
|
||||||
user_id=user_id,
|
user_id=bot_id,
|
||||||
model=model,
|
model=model,
|
||||||
discord_user=user,
|
discord_user=recipient_id,
|
||||||
discord_channel=channel,
|
discord_channel=channel_id,
|
||||||
system_prompt=comm_channel_prompt(session, user, channel),
|
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.error(f"Scheduled message: {call}")
|
||||||
|
logger.error(f"Scheduled message: {call.id}")
|
||||||
|
logger.error(f"Scheduled message time: {call.scheduled_time}")
|
||||||
|
logger.error(f"Scheduled message message: {call.message}")
|
||||||
|
logger.error(f"Scheduled message model: {call.model}")
|
||||||
|
logger.error(f"Scheduled message user id: {call.user_id}")
|
||||||
|
logger.error(f"Scheduled message discord user id: {call.discord_user_id}")
|
||||||
|
logger.error(f"Scheduled message discord channel id: {call.discord_channel_id}")
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
return cast(str, call.id)
|
return cast(str, call.id)
|
||||||
|
|
||||||
|
|
||||||
def make_message_scheduler(
|
def make_message_scheduler(
|
||||||
bot: BotUser, user: int | None, channel: int | None, model: str
|
bot: DiscordBotUser, user_id: int | None, channel_id: int | None, model: str
|
||||||
) -> ToolDefinition:
|
) -> ToolDefinition:
|
||||||
bot_id = cast(int, bot.id)
|
bot_id = cast(int, bot.id)
|
||||||
if user:
|
if user_id:
|
||||||
channel_type = "from your chat with this user"
|
channel_type = "from your chat with this user"
|
||||||
elif channel:
|
elif channel_id:
|
||||||
channel_type = "in this channel"
|
channel_type = "in this channel"
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either user or channel must be provided")
|
raise ValueError("Either user or channel must be provided")
|
||||||
|
|
||||||
def handler(input: ToolInput) -> str:
|
def handler(input: ToolInput) -> str:
|
||||||
if not isinstance(input, dict):
|
|
||||||
raise ValueError("Input must be a dictionary")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
time = datetime.fromisoformat(input["date_time"])
|
if not isinstance(input, dict):
|
||||||
except ValueError:
|
raise ValueError("Input must be a dictionary")
|
||||||
raise ValueError("Invalid date time format")
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError("Date time is required")
|
|
||||||
|
|
||||||
return schedule_message(bot_id, user, channel, model, input["message"], time)
|
try:
|
||||||
|
time = datetime.fromisoformat(input["date_time"])
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Invalid date time format")
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError("Date time is required")
|
||||||
|
|
||||||
|
return schedule_message(
|
||||||
|
bot_id, user_id, channel_id, model, input["message"], time
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.error(f"Error scheduling message: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
return ToolDefinition(
|
return ToolDefinition(
|
||||||
name="schedule_message",
|
name="schedule_discord_message",
|
||||||
description=textwrap.dedent("""
|
description=textwrap.dedent("""
|
||||||
Use this to schedule a message to be sent to yourself.
|
Use this to schedule a message to be sent to yourself.
|
||||||
|
|
||||||
@ -162,10 +181,13 @@ def make_message_scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefinition:
|
def make_prev_messages_tool(
|
||||||
if user:
|
bot: DiscordBotUser, user_id: int | None, channel_id: int | None
|
||||||
|
) -> ToolDefinition:
|
||||||
|
bot_id = bot.discord_id
|
||||||
|
if user_id:
|
||||||
channel_type = "from your chat with this user"
|
channel_type = "from your chat with this user"
|
||||||
elif channel:
|
elif channel_id:
|
||||||
channel_type = "in this channel"
|
channel_type = "in this channel"
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either user or channel must be provided")
|
raise ValueError("Either user or channel must be provided")
|
||||||
@ -185,7 +207,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
|
|||||||
raise ValueError("Offset must be greater than or equal to 0")
|
raise ValueError("Offset must be greater than or equal to 0")
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
messages = previous_messages(session, user, channel, max_messages, offset)
|
messages = previous_messages(
|
||||||
|
session, bot_id, user_id, channel_id, max_messages, offset
|
||||||
|
)
|
||||||
return "\n\n".join([msg.title for msg in messages])
|
return "\n\n".join([msg.title for msg in messages])
|
||||||
|
|
||||||
return ToolDefinition(
|
return ToolDefinition(
|
||||||
@ -210,7 +234,9 @@ def make_prev_messages_tool(user: int | None, channel: int | None) -> ToolDefini
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinition:
|
def make_add_reaction_tool(
|
||||||
|
bot: DiscordBotUser, channel: DiscordChannel
|
||||||
|
) -> ToolDefinition:
|
||||||
bot_id = cast(int, bot.id)
|
bot_id = cast(int, bot.id)
|
||||||
channel_id = channel and channel.id
|
channel_id = channel and channel.id
|
||||||
|
|
||||||
@ -255,7 +281,7 @@ def make_add_reaction_tool(bot: BotUser, channel: DiscordChannel) -> ToolDefinit
|
|||||||
|
|
||||||
|
|
||||||
def make_discord_tools(
|
def make_discord_tools(
|
||||||
bot: BotUser,
|
bot: DiscordBotUser,
|
||||||
author: DiscordUser | None,
|
author: DiscordUser | None,
|
||||||
channel: DiscordChannel | None,
|
channel: DiscordChannel | None,
|
||||||
model: str,
|
model: str,
|
||||||
@ -264,7 +290,7 @@ def make_discord_tools(
|
|||||||
channel_id = channel and channel.id
|
channel_id = channel and channel.id
|
||||||
tools = [
|
tools = [
|
||||||
make_message_scheduler(bot, author_id, channel_id, model),
|
make_message_scheduler(bot, author_id, channel_id, model),
|
||||||
make_prev_messages_tool(author_id, channel_id),
|
make_prev_messages_tool(bot, author_id, channel_id),
|
||||||
make_summary_tool("channel", channel_id),
|
make_summary_tool("channel", channel_id),
|
||||||
]
|
]
|
||||||
if author:
|
if author:
|
||||||
|
|||||||
58
src/memory/common/mcp.py
Normal file
58
src/memory/common/mcp.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def mcp_call(
|
||||||
|
url: str, access_token: str, method: str, params: dict = {}
|
||||||
|
) -> AsyncGenerator[Any, None]:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": int(time.time() * 1000),
|
||||||
|
"method": method,
|
||||||
|
"params": params,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as http_session:
|
||||||
|
async with http_session.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=10),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
logger.error(f"Tools list failed: {resp.status} - {error_text}")
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to call MCP server: {resp.status} - {error_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse SSE stream
|
||||||
|
async for line in resp.content:
|
||||||
|
line_str = line.decode("utf-8").strip()
|
||||||
|
|
||||||
|
# SSE format: "data: {json}"
|
||||||
|
if line_str.startswith("data: "):
|
||||||
|
json_str = line_str[6:] # Remove "data: " prefix
|
||||||
|
try:
|
||||||
|
yield json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue # Skip invalid JSON lines
|
||||||
|
|
||||||
|
|
||||||
|
async def mcp_tools_list(url: str, access_token: str) -> list[dict]:
|
||||||
|
async for data in mcp_call(url, access_token, "tools/list"):
|
||||||
|
if "result" in data and "tools" in data["result"]:
|
||||||
|
return data["result"]["tools"]
|
||||||
|
return []
|
||||||
@ -108,6 +108,7 @@ async def send_dm_endpoint(request: SendDMRequest):
|
|||||||
"""Send a DM via the collector's Discord client"""
|
"""Send a DM via the collector's Discord client"""
|
||||||
collector = app.bots.get(request.bot_id)
|
collector = app.bots.get(request.bot_id)
|
||||||
if not collector:
|
if not collector:
|
||||||
|
logger.error(f"Bot not found: {request.bot_id}")
|
||||||
raise HTTPException(status_code=404, detail="Bot not found")
|
raise HTTPException(status_code=404, detail="Bot not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,9 +1,7 @@
|
|||||||
"""Lightweight slash-command helpers for the Discord collector."""
|
"""Lightweight slash-command helpers for the Discord collector."""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
from typing import Literal, cast
|
||||||
from typing import Any, AsyncGenerator, Literal, cast
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import discord
|
import discord
|
||||||
@ -12,6 +10,7 @@ from sqlalchemy.orm import Session, scoped_session
|
|||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import MCPServer, MCPServerAssignment
|
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||||
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||||
|
from memory.common.mcp import mcp_tools_list
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -33,49 +32,6 @@ def find_mcp_server(
|
|||||||
return assignment and assignment.mcp_server
|
return assignment and assignment.mcp_server
|
||||||
|
|
||||||
|
|
||||||
async def call_mcp_server(
|
|
||||||
url: str, access_token: str, method: str, params: dict = {}
|
|
||||||
) -> AsyncGenerator[Any, None]:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json, text/event-stream",
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": int(time.time() * 1000),
|
|
||||||
"method": method,
|
|
||||||
"params": params,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as http_session:
|
|
||||||
async with http_session.post(
|
|
||||||
url,
|
|
||||||
json=payload,
|
|
||||||
headers=headers,
|
|
||||||
timeout=aiohttp.ClientTimeout(total=10),
|
|
||||||
) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
error_text = await resp.text()
|
|
||||||
logger.error(f"Tools list failed: {resp.status} - {error_text}")
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to call MCP server: {resp.status} - {error_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse SSE stream
|
|
||||||
async for line in resp.content:
|
|
||||||
line_str = line.decode("utf-8").strip()
|
|
||||||
|
|
||||||
# SSE format: "data: {json}"
|
|
||||||
if line_str.startswith("data: "):
|
|
||||||
json_str = line_str[6:] # Remove "data: " prefix
|
|
||||||
try:
|
|
||||||
yield json.loads(json_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue # Skip invalid JSON lines
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
|
async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
|
||||||
"""List all MCP servers for the user."""
|
"""List all MCP servers for the user."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
@ -258,10 +214,7 @@ async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
|
|||||||
# Make JSON-RPC request to MCP server
|
# Make JSON-RPC request to MCP server
|
||||||
tools = None
|
tools = None
|
||||||
try:
|
try:
|
||||||
async for data in call_mcp_server(url, access_token, "tools/list"):
|
tools = await mcp_tools_list(url, access_token)
|
||||||
if "result" in data and "tools" in data["result"]:
|
|
||||||
tools = data["result"]["tools"]
|
|
||||||
break
|
|
||||||
except aiohttp.ClientError as exc:
|
except aiohttp.ClientError as exc:
|
||||||
logger.exception(f"Failed to connect to MCP server: {exc}")
|
logger.exception(f"Failed to connect to MCP server: {exc}")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -132,14 +132,17 @@ def upsert_scheduled_message(
|
|||||||
|
|
||||||
def previous_messages(
|
def previous_messages(
|
||||||
session: Session | scoped_session,
|
session: Session | scoped_session,
|
||||||
|
bot_id: int,
|
||||||
user_id: int | None,
|
user_id: int | None,
|
||||||
channel_id: int | None,
|
channel_id: int | None,
|
||||||
max_messages: int = 10,
|
max_messages: int = 10,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[DiscordMessage]:
|
) -> list[DiscordMessage]:
|
||||||
messages = session.query(DiscordMessage)
|
messages = session.query(DiscordMessage).filter(
|
||||||
|
DiscordMessage.recipient_id == bot_id
|
||||||
|
)
|
||||||
if user_id:
|
if user_id:
|
||||||
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
messages = messages.filter(DiscordMessage.from_id == user_id)
|
||||||
if channel_id:
|
if channel_id:
|
||||||
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
||||||
return list(
|
return list(
|
||||||
@ -154,15 +157,23 @@ def previous_messages(
|
|||||||
|
|
||||||
def comm_channel_prompt(
|
def comm_channel_prompt(
|
||||||
session: Session | scoped_session,
|
session: Session | scoped_session,
|
||||||
|
bot: DiscordEntity,
|
||||||
user: DiscordEntity,
|
user: DiscordEntity,
|
||||||
channel: DiscordEntity,
|
channel: DiscordEntity,
|
||||||
max_messages: int = 10,
|
max_messages: int = 10,
|
||||||
) -> str:
|
) -> str:
|
||||||
user = resolve_discord_user(session, user)
|
user = resolve_discord_user(session, user)
|
||||||
channel = resolve_discord_channel(session, channel)
|
channel = resolve_discord_channel(session, channel)
|
||||||
|
bot = resolve_discord_user(session, bot)
|
||||||
|
if not bot:
|
||||||
|
raise ValueError("Bot not found")
|
||||||
|
|
||||||
messages = previous_messages(
|
messages = previous_messages(
|
||||||
session, user and user.id, channel and channel.id, max_messages
|
session,
|
||||||
|
cast(int, bot.id),
|
||||||
|
user and user.id,
|
||||||
|
channel and channel.id,
|
||||||
|
max_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
server_context = ""
|
server_context = ""
|
||||||
@ -244,6 +255,7 @@ def call_llm(
|
|||||||
user_id = cast(int, from_user.id)
|
user_id = cast(int, from_user.id)
|
||||||
prev_messages = previous_messages(
|
prev_messages = previous_messages(
|
||||||
session,
|
session,
|
||||||
|
bot_user.system_user.discord_id,
|
||||||
user_id,
|
user_id,
|
||||||
channel and channel.id,
|
channel and channel.id,
|
||||||
max_messages=num_previous_messages,
|
max_messages=num_previous_messages,
|
||||||
@ -263,18 +275,12 @@ def call_llm(
|
|||||||
if bot_user.system_prompt:
|
if bot_user.system_prompt:
|
||||||
system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "")
|
system_prompt = bot_user.system_prompt + "\n\n" + (system_prompt or "")
|
||||||
message_content = [m.as_content() for m in prev_messages] + messages
|
message_content = [m.as_content() for m in prev_messages] + messages
|
||||||
|
|
||||||
return provider.run_with_tools(
|
return provider.run_with_tools(
|
||||||
messages=provider.as_messages(message_content),
|
messages=provider.as_messages(message_content),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
|
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
|
||||||
mcp_servers=[
|
mcp_servers=[MCPServer.from_model(server) for server in mcp_servers]
|
||||||
MCPServer(
|
|
||||||
name=str(server.name),
|
|
||||||
url=str(server.mcp_server_url),
|
|
||||||
token=str(server.access_token),
|
|
||||||
)
|
|
||||||
for server in mcp_servers
|
|
||||||
]
|
|
||||||
if mcp_servers
|
if mcp_servers
|
||||||
else None,
|
else None,
|
||||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||||
|
|||||||
@ -113,7 +113,7 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
|
|
||||||
system_prompt = message.system_prompt or ""
|
system_prompt = message.system_prompt or ""
|
||||||
system_prompt += comm_channel_prompt(
|
system_prompt += comm_channel_prompt(
|
||||||
session, message.recipient_user, message.channel
|
session, message.recipient_user, message.from_user, message.channel
|
||||||
)
|
)
|
||||||
allowed_tools = [
|
allowed_tools = [
|
||||||
"update_channel_summary",
|
"update_channel_summary",
|
||||||
@ -199,7 +199,10 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
mcp_servers = discord_message.get_mcp_servers(session)
|
mcp_servers = discord_message.get_mcp_servers(session)
|
||||||
system_prompt = discord_message.system_prompt or ""
|
system_prompt = discord_message.system_prompt or ""
|
||||||
system_prompt += comm_channel_prompt(
|
system_prompt += comm_channel_prompt(
|
||||||
session, discord_message.recipient_user, discord_message.channel
|
session,
|
||||||
|
discord_message.recipient_user,
|
||||||
|
discord_message.from_user,
|
||||||
|
discord_message.channel,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -9,14 +9,14 @@ from memory.common.celery_app import (
|
|||||||
app,
|
app,
|
||||||
)
|
)
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import ScheduledLLMCall
|
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
|
||||||
from memory.discord.messages import call_llm, send_discord_response
|
from memory.discord.messages import call_llm, send_discord_response
|
||||||
from memory.workers.tasks.content_processing import safe_task_execution
|
from memory.workers.tasks.content_processing import safe_task_execution
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
|
def call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
|
||||||
"""Call LLM with tools support for scheduled calls."""
|
"""Call LLM with tools support for scheduled calls."""
|
||||||
if not scheduled_call.discord_user:
|
if not scheduled_call.discord_user:
|
||||||
logger.warning("No discord_user for scheduled call - cannot execute")
|
logger.warning("No discord_user for scheduled call - cannot execute")
|
||||||
@ -30,6 +30,7 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str |
|
|||||||
bot_user = (
|
bot_user = (
|
||||||
scheduled_call.user.discord_users and scheduled_call.user.discord_users[0]
|
scheduled_call.user.discord_users and scheduled_call.user.discord_users[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
return call_llm(
|
return call_llm(
|
||||||
session=session,
|
session=session,
|
||||||
bot_user=bot_user,
|
bot_user=bot_user,
|
||||||
@ -42,18 +43,8 @@ def _call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str |
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
def send_to_discord(bot_id: int, scheduled_call: ScheduledLLMCall, response: str):
|
||||||
"""Send the LLM response to Discord user or channel."""
|
"""Send the LLM response to Discord user or channel."""
|
||||||
bot_id_value = scheduled_call.user_id
|
|
||||||
if bot_id_value is None:
|
|
||||||
logger.warning(
|
|
||||||
"Scheduled call %s has no associated bot user; skipping Discord send",
|
|
||||||
scheduled_call.id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
bot_id = cast(int, bot_id_value)
|
|
||||||
|
|
||||||
send_discord_response(
|
send_discord_response(
|
||||||
bot_id=bot_id,
|
bot_id=bot_id,
|
||||||
response=response,
|
response=response,
|
||||||
@ -101,7 +92,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
|||||||
|
|
||||||
# Make the LLM call with tools support
|
# Make the LLM call with tools support
|
||||||
try:
|
try:
|
||||||
response = _call_llm_for_scheduled(session, scheduled_call)
|
response = call_llm_for_scheduled(session, scheduled_call)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to generate LLM response")
|
logger.exception("Failed to generate LLM response")
|
||||||
scheduled_call.status = "failed"
|
scheduled_call.status = "failed"
|
||||||
@ -132,7 +123,7 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
|||||||
|
|
||||||
# Send to Discord
|
# Send to Discord
|
||||||
try:
|
try:
|
||||||
_send_to_discord(scheduled_call, response)
|
send_to_discord(cast(int, scheduled_call.user_id), scheduled_call, response)
|
||||||
logger.info(f"Response sent to Discord for {scheduled_call_id}")
|
logger.info(f"Response sent to Discord for {scheduled_call_id}")
|
||||||
except Exception as discord_error:
|
except Exception as discord_error:
|
||||||
logger.error(f"Failed to send to Discord: {discord_error}")
|
logger.error(f"Failed to send to Discord: {discord_error}")
|
||||||
|
|||||||
@ -127,7 +127,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
|||||||
"""Test sending to Discord user."""
|
"""Test sending to Discord user."""
|
||||||
response = "This is a test response."
|
response = "This is a test response."
|
||||||
|
|
||||||
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
scheduled_calls.send_to_discord(pending_scheduled_call, response)
|
||||||
|
|
||||||
mock_send_dm.assert_called_once_with(
|
mock_send_dm.assert_called_once_with(
|
||||||
pending_scheduled_call.user_id,
|
pending_scheduled_call.user_id,
|
||||||
@ -141,7 +141,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
|||||||
"""Test sending to Discord channel."""
|
"""Test sending to Discord channel."""
|
||||||
response = "This is a channel response."
|
response = "This is a channel response."
|
||||||
|
|
||||||
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
scheduled_calls.send_to_discord(completed_scheduled_call, response)
|
||||||
|
|
||||||
mock_broadcast.assert_called_once_with(
|
mock_broadcast.assert_called_once_with(
|
||||||
completed_scheduled_call.user_id,
|
completed_scheduled_call.user_id,
|
||||||
@ -155,7 +155,7 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
|
|||||||
"""Test message truncation for long responses."""
|
"""Test message truncation for long responses."""
|
||||||
long_response = "A" * 2500 # Very long response
|
long_response = "A" * 2500 # Very long response
|
||||||
|
|
||||||
scheduled_calls._send_to_discord(pending_scheduled_call, long_response)
|
scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
|
||||||
|
|
||||||
# Verify the message was truncated
|
# Verify the message was truncated
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
@ -170,7 +170,7 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
|
|||||||
"""Test that normal length messages are not truncated."""
|
"""Test that normal length messages are not truncated."""
|
||||||
normal_response = "This is a normal length response."
|
normal_response = "This is a normal length response."
|
||||||
|
|
||||||
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
|
||||||
|
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
assert args[0] == pending_scheduled_call.user_id
|
assert args[0] == pending_scheduled_call.user_id
|
||||||
@ -535,7 +535,7 @@ def test_discord_destination_priority(
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
response = "Test response"
|
response = "Test response"
|
||||||
scheduled_calls._send_to_discord(call, response)
|
scheduled_calls.send_to_discord(call, response)
|
||||||
|
|
||||||
if expected_method == "send_dm":
|
if expected_method == "send_dm":
|
||||||
mock_send_dm.assert_called_once()
|
mock_send_dm.assert_called_once()
|
||||||
@ -590,7 +590,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
|||||||
mock_call.discord_user = mock_discord_user
|
mock_call.discord_user = mock_discord_user
|
||||||
mock_call.discord_channel = None
|
mock_call.discord_channel = None
|
||||||
|
|
||||||
scheduled_calls._send_to_discord(mock_call, response)
|
scheduled_calls.send_to_discord(mock_call, response)
|
||||||
|
|
||||||
# Get the actual message that was sent
|
# Get the actual message that was sent
|
||||||
args, kwargs = mock_send_dm.call_args
|
args, kwargs = mock_send_dm.call_args
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user