This commit is contained in:
mruwnik 2025-11-03 19:42:13 +00:00
parent b568222e88
commit 56c0df9761
9 changed files with 158 additions and 126 deletions

View File

@ -34,6 +34,8 @@ from memory.common.db.models.discord import (
DiscordServer, DiscordServer,
DiscordChannel, DiscordChannel,
DiscordUser, DiscordUser,
)
from memory.common.db.models.mcp import (
MCPServer, MCPServer,
MCPServerAssignment, MCPServerAssignment,
) )

View File

@ -55,7 +55,7 @@ class MessageProcessor:
def entity_type(self) -> str: def entity_type(self) -> str:
return self.__class__.__tablename__[8:-1] # type: ignore return self.__class__.__tablename__[8:-1] # type: ignore
def to_xml(self, fields: list[str]) -> str: def to_xml(self, *fields: str) -> str:
def indent(key: str, text: str) -> str: def indent(key: str, text: str) -> str:
res = textwrap.dedent(""" res = textwrap.dedent("""
<{key}> <{key}>
@ -78,13 +78,13 @@ class MessageProcessor:
return indent(self.entity_type, "\n".join(vals)) # type: ignore return indent(self.entity_type, "\n".join(vals)) # type: ignore
def xml_prompt(self) -> str: def xml_prompt(self) -> str:
return self.to_xml(["name", "system_prompt"]) if self.system_prompt else "" return self.to_xml("name", "system_prompt") if self.system_prompt else ""
def xml_summary(self) -> str: def xml_summary(self) -> str:
return self.to_xml(["name", "summary"]) return self.to_xml("name", "summary")
def xml_mcp_servers(self) -> str: def xml_mcp_servers(self) -> str:
return self.to_xml(["mcp_servers"]) return self.to_xml("mcp_servers")
class DiscordServer(Base, MessageProcessor): class DiscordServer(Base, MessageProcessor):
@ -152,95 +152,3 @@ class DiscordUser(Base, MessageProcessor):
@property @property
def name(self) -> str: def name(self) -> str:
return self.username return self.username
class MCPServer(Base):
"""MCP server configuration and OAuth state."""
__tablename__ = "mcp_servers"
id = Column(Integer, primary_key=True)
# MCP server info
name = Column(Text, nullable=False)
mcp_server_url = Column(Text, nullable=False)
client_id = Column(Text, nullable=False)
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
# OAuth flow state (temporary, cleared after token exchange)
state = Column(Text, nullable=True, unique=True)
code_verifier = Column(Text, nullable=True)
# OAuth tokens (set after successful authorization)
access_token = Column(Text, nullable=True)
refresh_token = Column(Text, nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
assignments = relationship(
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
)
__table_args__ = (Index("mcp_state_idx", "state"),)
def as_xml(self) -> str:
tools = "\n".join(f"{tool}" for tool in self.available_tools).strip()
return textwrap.dedent("""
<mcp_server>
<name>
{name}
</name>
<mcp_server_url>
{mcp_server_url}
</mcp_server_url>
<client_id>
{client_id}
</client_id>
<available_tools>
{available_tools}
</available_tools>
</mcp_server>
""").format(
name=self.name,
mcp_server_url=self.mcp_server_url,
client_id=self.client_id,
available_tools=tools,
)
class MCPServerAssignment(Base):
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
__tablename__ = "mcp_server_assignments"
id = Column(Integer, primary_key=True)
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
# Polymorphic entity reference
entity_type = Column(
Text, nullable=False
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
entity_id = Column(BigInteger, nullable=False)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
mcp_server = relationship("MCPServer", back_populates="assignments")
__table_args__ = (
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
Index("mcp_assignment_server_idx", "mcp_server_id"),
Index(
"mcp_assignment_unique_idx",
"mcp_server_id",
"entity_type",
"entity_id",
unique=True,
),
)

View File

@ -0,0 +1,108 @@
import textwrap
from sqlalchemy import (
ARRAY,
BigInteger,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Text,
func,
)
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
class MCPServer(Base):
"""MCP server configuration and OAuth state."""
__tablename__ = "mcp_servers"
id = Column(Integer, primary_key=True)
# MCP server info
name = Column(Text, nullable=False)
mcp_server_url = Column(Text, nullable=False)
client_id = Column(Text, nullable=False)
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
# OAuth flow state (temporary, cleared after token exchange)
state = Column(Text, nullable=True, unique=True)
code_verifier = Column(Text, nullable=True)
# OAuth tokens (set after successful authorization)
access_token = Column(Text, nullable=True)
refresh_token = Column(Text, nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
assignments = relationship(
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
)
__table_args__ = (Index("mcp_state_idx", "state"),)
def as_xml(self) -> str:
tools = "\n".join(f"{tool}" for tool in self.available_tools).strip()
return textwrap.dedent("""
<mcp_server>
<name>
{name}
</name>
<mcp_server_url>
{mcp_server_url}
</mcp_server_url>
<client_id>
{client_id}
</client_id>
<available_tools>
{available_tools}
</available_tools>
</mcp_server>
""").format(
name=self.name,
mcp_server_url=self.mcp_server_url,
client_id=self.client_id,
available_tools=tools,
)
class MCPServerAssignment(Base):
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
__tablename__ = "mcp_server_assignments"
id = Column(Integer, primary_key=True)
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
# Polymorphic entity reference
entity_type = Column(
Text, nullable=False
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
entity_id = Column(BigInteger, nullable=False)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
mcp_server = relationship("MCPServer", back_populates="assignments")
__table_args__ = (
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
Index("mcp_assignment_server_idx", "mcp_server_id"),
Index(
"mcp_assignment_unique_idx",
"mcp_server_id",
"entity_type",
"entity_id",
unique=True,
),
)

View File

@ -36,6 +36,10 @@ from memory.common.db.models.source_item import (
clean_filename, clean_filename,
chunk_mixed, chunk_mixed,
) )
from memory.common.db.models.mcp import (
MCPServer,
MCPServerAssignment,
)
class MailMessagePayload(SourceItemPayload): class MailMessagePayload(SourceItemPayload):
@ -392,6 +396,29 @@ class DiscordMessage(SourceItem):
return content return content
def get_mcp_servers(self, session) -> list[MCPServer]:
entity_ids = list(
filter(
None,
[
self.recipient_user.id,
self.from_user.id,
self.channel.id,
self.server.id,
],
)
)
if not entity_ids:
return None
return (
session.query(MCPServer)
.filter(
MCPServerAssignment.entity_id.in_(entity_ids),
)
.all()
)
__mapper_args__ = { __mapper_args__ = {
"polymorphic_identity": "discord_message", "polymorphic_identity": "discord_message",
} }

View File

@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
import aiohttp import aiohttp
from memory.common import settings from memory.common import settings
from memory.common.db.models.discord import MCPServer from memory.common.db.models import MCPServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -363,16 +363,6 @@ def create_list_group(
mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers] mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers]
res = "\n\n".join(mcp_servers) res = "\n\n".join(mcp_servers)
await respond(interaction, res) await respond(interaction, res)
# def handler(objects: DiscordObjects) -> str:
# servers = [s.as_xml() for obj in objects.items for s in obj.mcp_servers]
# return "\n\n".join(servers) if servers else "No MCP servers configured."
# try:
# res = with_object_context(bot, interaction, handler, user)
# except Exception as exc:
# logger.error(f"Error listing MCP servers: {exc}", exc_info=True)
# return CommandResponse(content="Error listing MCP servers.")
# await respond(interaction, res)
return group return group

View File

@ -10,7 +10,7 @@ import discord
from sqlalchemy.orm import Session, scoped_session from sqlalchemy.orm import Session, scoped_session
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models.discord import MCPServer, MCPServerAssignment from memory.common.db.models import MCPServer, MCPServerAssignment
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -12,9 +12,9 @@ from memory.common.db.models import (
DiscordMessage, DiscordMessage,
DiscordUser, DiscordUser,
ScheduledLLMCall, ScheduledLLMCall,
MCPServer as MCPServerModel,
) )
from memory.common.llms.base import create_provider from memory.common.llms.base import create_provider
from memory.common.llms.tools import MCPServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -187,7 +187,7 @@ def comm_channel_prompt(
{users} {users}
</user_notes> </user_notes>
""").format( """).format(
users="\n".join({msg.from_user.as_xml() for msg in messages}), users="\n".join({msg.from_user.xml_summary() for msg in messages}),
) )
return textwrap.dedent(""" return textwrap.dedent("""
@ -212,7 +212,7 @@ def call_llm(
system_prompt: str = "", system_prompt: str = "",
messages: list[str | dict[str, Any]] = [], messages: list[str | dict[str, Any]] = [],
allowed_tools: Collection[str] | None = None, allowed_tools: Collection[str] | None = None,
mcp_servers: list[MCPServer] | None = None, mcp_servers: list[MCPServerModel] | None = None,
num_previous_messages: int = 10, num_previous_messages: int = 10,
) -> str | None: ) -> str | None:
""" """
@ -251,6 +251,7 @@ def call_llm(
from memory.common.llms.tools.discord import make_discord_tools from memory.common.llms.tools.discord import make_discord_tools
from memory.common.llms.tools.base import WebSearchTool from memory.common.llms.tools.base import WebSearchTool
from memory.common.llms.tools import MCPServer
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model) tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
tools |= {"web_search": WebSearchTool()} tools |= {"web_search": WebSearchTool()}
@ -266,7 +267,16 @@ def call_llm(
messages=provider.as_messages(message_content), messages=provider.as_messages(message_content),
tools=tools, tools=tools,
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""), system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
mcp_servers=mcp_servers, mcp_servers=[
MCPServer(
name=str(server.name),
url=str(server.mcp_server_url),
token=str(server.access_token),
)
for server in mcp_servers
]
if mcp_servers
else None,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS, max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response ).response

View File

@ -196,20 +196,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id, "message_id": message_id,
} }
mcp_servers = None mcp_servers = discord_message.get_mcp_servers(session)
if (
discord_message.recipient_user
and discord_message.recipient_user.mcp_servers
):
mcp_servers = [
MCPServer(
name=server.mcp_server_url,
url=server.mcp_server_url,
token=server.access_token,
)
for server in discord_message.recipient_user.mcp_servers
]
system_prompt = discord_message.system_prompt or "" system_prompt = discord_message.system_prompt or ""
system_prompt += comm_channel_prompt( system_prompt += comm_channel_prompt(
session, discord_message.recipient_user, discord_message.channel session, discord_message.recipient_user, discord_message.channel