This commit is contained in:
mruwnik 2025-11-03 19:42:13 +00:00
parent b568222e88
commit 56c0df9761
9 changed files with 158 additions and 126 deletions

View File

@ -34,6 +34,8 @@ from memory.common.db.models.discord import (
DiscordServer,
DiscordChannel,
DiscordUser,
)
from memory.common.db.models.mcp import (
MCPServer,
MCPServerAssignment,
)

View File

@ -55,7 +55,7 @@ class MessageProcessor:
def entity_type(self) -> str:
return self.__class__.__tablename__[8:-1] # type: ignore
def to_xml(self, fields: list[str]) -> str:
def to_xml(self, *fields: str) -> str:
def indent(key: str, text: str) -> str:
res = textwrap.dedent("""
<{key}>
@ -78,13 +78,13 @@ class MessageProcessor:
return indent(self.entity_type, "\n".join(vals)) # type: ignore
def xml_prompt(self) -> str:
return self.to_xml(["name", "system_prompt"]) if self.system_prompt else ""
return self.to_xml("name", "system_prompt") if self.system_prompt else ""
def xml_summary(self) -> str:
return self.to_xml(["name", "summary"])
return self.to_xml("name", "summary")
def xml_mcp_servers(self) -> str:
return self.to_xml(["mcp_servers"])
return self.to_xml("mcp_servers")
class DiscordServer(Base, MessageProcessor):
@ -152,95 +152,3 @@ class DiscordUser(Base, MessageProcessor):
@property
def name(self) -> str:
return self.username
class MCPServer(Base):
"""MCP server configuration and OAuth state."""
__tablename__ = "mcp_servers"
id = Column(Integer, primary_key=True)
# MCP server info
name = Column(Text, nullable=False)
mcp_server_url = Column(Text, nullable=False)
client_id = Column(Text, nullable=False)
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
# OAuth flow state (temporary, cleared after token exchange)
state = Column(Text, nullable=True, unique=True)
code_verifier = Column(Text, nullable=True)
# OAuth tokens (set after successful authorization)
access_token = Column(Text, nullable=True)
refresh_token = Column(Text, nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
assignments = relationship(
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
)
__table_args__ = (Index("mcp_state_idx", "state"),)
def as_xml(self) -> str:
tools = "\n".join(f"{tool}" for tool in self.available_tools).strip()
return textwrap.dedent("""
<mcp_server>
<name>
{name}
</name>
<mcp_server_url>
{mcp_server_url}
</mcp_server_url>
<client_id>
{client_id}
</client_id>
<available_tools>
{available_tools}
</available_tools>
</mcp_server>
""").format(
name=self.name,
mcp_server_url=self.mcp_server_url,
client_id=self.client_id,
available_tools=tools,
)
class MCPServerAssignment(Base):
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
__tablename__ = "mcp_server_assignments"
id = Column(Integer, primary_key=True)
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
# Polymorphic entity reference
entity_type = Column(
Text, nullable=False
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
entity_id = Column(BigInteger, nullable=False)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
mcp_server = relationship("MCPServer", back_populates="assignments")
__table_args__ = (
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
Index("mcp_assignment_server_idx", "mcp_server_id"),
Index(
"mcp_assignment_unique_idx",
"mcp_server_id",
"entity_type",
"entity_id",
unique=True,
),
)

View File

@ -0,0 +1,108 @@
import textwrap
from sqlalchemy import (
ARRAY,
BigInteger,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Text,
func,
)
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
class MCPServer(Base):
"""MCP server configuration and OAuth state."""
__tablename__ = "mcp_servers"
id = Column(Integer, primary_key=True)
# MCP server info
name = Column(Text, nullable=False)
mcp_server_url = Column(Text, nullable=False)
client_id = Column(Text, nullable=False)
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
# OAuth flow state (temporary, cleared after token exchange)
state = Column(Text, nullable=True, unique=True)
code_verifier = Column(Text, nullable=True)
# OAuth tokens (set after successful authorization)
access_token = Column(Text, nullable=True)
refresh_token = Column(Text, nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
assignments = relationship(
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
)
__table_args__ = (Index("mcp_state_idx", "state"),)
def as_xml(self) -> str:
tools = "\n".join(f"{tool}" for tool in self.available_tools).strip()
return textwrap.dedent("""
<mcp_server>
<name>
{name}
</name>
<mcp_server_url>
{mcp_server_url}
</mcp_server_url>
<client_id>
{client_id}
</client_id>
<available_tools>
{available_tools}
</available_tools>
</mcp_server>
""").format(
name=self.name,
mcp_server_url=self.mcp_server_url,
client_id=self.client_id,
available_tools=tools,
)
class MCPServerAssignment(Base):
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
__tablename__ = "mcp_server_assignments"
id = Column(Integer, primary_key=True)
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
# Polymorphic entity reference
entity_type = Column(
Text, nullable=False
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
entity_id = Column(BigInteger, nullable=False)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships
mcp_server = relationship("MCPServer", back_populates="assignments")
__table_args__ = (
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
Index("mcp_assignment_server_idx", "mcp_server_id"),
Index(
"mcp_assignment_unique_idx",
"mcp_server_id",
"entity_type",
"entity_id",
unique=True,
),
)

View File

@ -36,6 +36,10 @@ from memory.common.db.models.source_item import (
clean_filename,
chunk_mixed,
)
from memory.common.db.models.mcp import (
MCPServer,
MCPServerAssignment,
)
class MailMessagePayload(SourceItemPayload):
@ -392,6 +396,29 @@ class DiscordMessage(SourceItem):
return content
def get_mcp_servers(self, session) -> list[MCPServer]:
entity_ids = list(
filter(
None,
[
self.recipient_user.id,
self.from_user.id,
self.channel.id,
self.server.id,
],
)
)
if not entity_ids:
return None
return (
session.query(MCPServer)
.filter(
MCPServerAssignment.entity_id.in_(entity_ids),
)
.all()
)
__mapper_args__ = {
"polymorphic_identity": "discord_message",
}

View File

@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
import aiohttp
from memory.common import settings
from memory.common.db.models.discord import MCPServer
from memory.common.db.models import MCPServer
logger = logging.getLogger(__name__)

View File

@ -363,16 +363,6 @@ def create_list_group(
mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers]
res = "\n\n".join(mcp_servers)
await respond(interaction, res)
# def handler(objects: DiscordObjects) -> str:
# servers = [s.as_xml() for obj in objects.items for s in obj.mcp_servers]
# return "\n\n".join(servers) if servers else "No MCP servers configured."
# try:
# res = with_object_context(bot, interaction, handler, user)
# except Exception as exc:
# logger.error(f"Error listing MCP servers: {exc}", exc_info=True)
# return CommandResponse(content="Error listing MCP servers.")
# await respond(interaction, res)
return group

View File

@ -10,7 +10,7 @@ import discord
from sqlalchemy.orm import Session, scoped_session
from memory.common.db.connection import make_session
from memory.common.db.models.discord import MCPServer, MCPServerAssignment
from memory.common.db.models import MCPServer, MCPServerAssignment
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
logger = logging.getLogger(__name__)

View File

@ -12,9 +12,9 @@ from memory.common.db.models import (
DiscordMessage,
DiscordUser,
ScheduledLLMCall,
MCPServer as MCPServerModel,
)
from memory.common.llms.base import create_provider
from memory.common.llms.tools import MCPServer
logger = logging.getLogger(__name__)
@ -187,7 +187,7 @@ def comm_channel_prompt(
{users}
</user_notes>
""").format(
users="\n".join({msg.from_user.as_xml() for msg in messages}),
users="\n".join({msg.from_user.xml_summary() for msg in messages}),
)
return textwrap.dedent("""
@ -212,7 +212,7 @@ def call_llm(
system_prompt: str = "",
messages: list[str | dict[str, Any]] = [],
allowed_tools: Collection[str] | None = None,
mcp_servers: list[MCPServer] | None = None,
mcp_servers: list[MCPServerModel] | None = None,
num_previous_messages: int = 10,
) -> str | None:
"""
@ -251,6 +251,7 @@ def call_llm(
from memory.common.llms.tools.discord import make_discord_tools
from memory.common.llms.tools.base import WebSearchTool
from memory.common.llms.tools import MCPServer
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
tools |= {"web_search": WebSearchTool()}
@ -266,7 +267,16 @@ def call_llm(
messages=provider.as_messages(message_content),
tools=tools,
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
mcp_servers=mcp_servers,
mcp_servers=[
MCPServer(
name=str(server.name),
url=str(server.mcp_server_url),
token=str(server.access_token),
)
for server in mcp_servers
]
if mcp_servers
else None,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response

View File

@ -196,20 +196,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id,
}
mcp_servers = None
if (
discord_message.recipient_user
and discord_message.recipient_user.mcp_servers
):
mcp_servers = [
MCPServer(
name=server.mcp_server_url,
url=server.mcp_server_url,
token=server.access_token,
)
for server in discord_message.recipient_user.mcp_servers
]
mcp_servers = discord_message.get_mcp_servers(session)
system_prompt = discord_message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, discord_message.recipient_user, discord_message.channel