diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py
index c2e262b..7eb5887 100644
--- a/src/memory/common/db/models/__init__.py
+++ b/src/memory/common/db/models/__init__.py
@@ -34,6 +34,8 @@ from memory.common.db.models.discord import (
DiscordServer,
DiscordChannel,
DiscordUser,
+)
+from memory.common.db.models.mcp import (
MCPServer,
MCPServerAssignment,
)
diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py
index d6d54b4..752ce45 100644
--- a/src/memory/common/db/models/discord.py
+++ b/src/memory/common/db/models/discord.py
@@ -55,7 +55,7 @@ class MessageProcessor:
def entity_type(self) -> str:
return self.__class__.__tablename__[8:-1] # type: ignore
- def to_xml(self, fields: list[str]) -> str:
+ def to_xml(self, *fields: str) -> str:
def indent(key: str, text: str) -> str:
res = textwrap.dedent("""
<{key}>
@@ -78,13 +78,13 @@ class MessageProcessor:
return indent(self.entity_type, "\n".join(vals)) # type: ignore
def xml_prompt(self) -> str:
- return self.to_xml(["name", "system_prompt"]) if self.system_prompt else ""
+ return self.to_xml("name", "system_prompt") if self.system_prompt else ""
def xml_summary(self) -> str:
- return self.to_xml(["name", "summary"])
+ return self.to_xml("name", "summary")
def xml_mcp_servers(self) -> str:
- return self.to_xml(["mcp_servers"])
+ return self.to_xml("mcp_servers")
class DiscordServer(Base, MessageProcessor):
@@ -152,95 +152,3 @@ class DiscordUser(Base, MessageProcessor):
@property
def name(self) -> str:
return self.username
-
-
-class MCPServer(Base):
- """MCP server configuration and OAuth state."""
-
- __tablename__ = "mcp_servers"
-
- id = Column(Integer, primary_key=True)
-
- # MCP server info
- name = Column(Text, nullable=False)
- mcp_server_url = Column(Text, nullable=False)
- client_id = Column(Text, nullable=False)
- available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
-
- # OAuth flow state (temporary, cleared after token exchange)
- state = Column(Text, nullable=True, unique=True)
- code_verifier = Column(Text, nullable=True)
-
- # OAuth tokens (set after successful authorization)
- access_token = Column(Text, nullable=True)
- refresh_token = Column(Text, nullable=True)
- token_expires_at = Column(DateTime(timezone=True), nullable=True)
-
- # Timestamps
- created_at = Column(DateTime(timezone=True), server_default=func.now())
- updated_at = Column(DateTime(timezone=True), server_default=func.now())
-
- # Relationships
- assignments = relationship(
- "MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
- )
-
- __table_args__ = (Index("mcp_state_idx", "state"),)
-
- def as_xml(self) -> str:
- tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip()
- return textwrap.dedent("""
-
-
- {name}
-
-
- {mcp_server_url}
-
-
- {client_id}
-
-
- {available_tools}
-
-
- """).format(
- name=self.name,
- mcp_server_url=self.mcp_server_url,
- client_id=self.client_id,
- available_tools=tools,
- )
-
-
-class MCPServerAssignment(Base):
- """Assignment of MCP servers to entities (users, channels, servers, etc.)."""
-
- __tablename__ = "mcp_server_assignments"
-
- id = Column(Integer, primary_key=True)
- mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
-
- # Polymorphic entity reference
- entity_type = Column(
- Text, nullable=False
- ) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
- entity_id = Column(BigInteger, nullable=False)
-
- # Timestamps
- created_at = Column(DateTime(timezone=True), server_default=func.now())
- updated_at = Column(DateTime(timezone=True), server_default=func.now())
-
- # Relationships
- mcp_server = relationship("MCPServer", back_populates="assignments")
-
- __table_args__ = (
- Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
- Index("mcp_assignment_server_idx", "mcp_server_id"),
- Index(
- "mcp_assignment_unique_idx",
- "mcp_server_id",
- "entity_type",
- "entity_id",
- unique=True,
- ),
- )
diff --git a/src/memory/common/db/models/mcp.py b/src/memory/common/db/models/mcp.py
new file mode 100644
index 0000000..1bf6a8c
--- /dev/null
+++ b/src/memory/common/db/models/mcp.py
@@ -0,0 +1,108 @@
+import textwrap
+
+from sqlalchemy import (
+ ARRAY,
+ BigInteger,
+ Column,
+ DateTime,
+ ForeignKey,
+ Index,
+ Integer,
+ Text,
+ func,
+)
+from sqlalchemy.orm import relationship
+
+from memory.common.db.models.base import Base
+
+
+class MCPServer(Base):
+ """MCP server configuration and OAuth state."""
+
+ __tablename__ = "mcp_servers"
+
+ id = Column(Integer, primary_key=True)
+
+ # MCP server info
+ name = Column(Text, nullable=False)
+ mcp_server_url = Column(Text, nullable=False)
+ client_id = Column(Text, nullable=False)
+ available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
+
+ # OAuth flow state (temporary, cleared after token exchange)
+ state = Column(Text, nullable=True, unique=True)
+ code_verifier = Column(Text, nullable=True)
+
+ # OAuth tokens (set after successful authorization)
+ access_token = Column(Text, nullable=True)
+ refresh_token = Column(Text, nullable=True)
+ token_expires_at = Column(DateTime(timezone=True), nullable=True)
+
+ # Timestamps
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
+ updated_at = Column(DateTime(timezone=True), server_default=func.now())
+
+ # Relationships
+ assignments = relationship(
+ "MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
+ )
+
+ __table_args__ = (Index("mcp_state_idx", "state"),)
+
+ def as_xml(self) -> str:
+ tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip()
+ return textwrap.dedent("""
+
+
+ {name}
+
+
+ {mcp_server_url}
+
+
+ {client_id}
+
+
+ {available_tools}
+
+
+ """).format(
+ name=self.name,
+ mcp_server_url=self.mcp_server_url,
+ client_id=self.client_id,
+ available_tools=tools,
+ )
+
+
+class MCPServerAssignment(Base):
+ """Assignment of MCP servers to entities (users, channels, servers, etc.)."""
+
+ __tablename__ = "mcp_server_assignments"
+
+ id = Column(Integer, primary_key=True)
+ mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
+
+ # Polymorphic entity reference
+ entity_type = Column(
+ Text, nullable=False
+ ) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
+ entity_id = Column(BigInteger, nullable=False)
+
+ # Timestamps
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
+ updated_at = Column(DateTime(timezone=True), server_default=func.now())
+
+ # Relationships
+ mcp_server = relationship("MCPServer", back_populates="assignments")
+
+ __table_args__ = (
+ Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
+ Index("mcp_assignment_server_idx", "mcp_server_id"),
+ Index(
+ "mcp_assignment_unique_idx",
+ "mcp_server_id",
+ "entity_type",
+ "entity_id",
+ unique=True,
+ ),
+ )
diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py
index 8046ad6..fa35f17 100644
--- a/src/memory/common/db/models/source_items.py
+++ b/src/memory/common/db/models/source_items.py
@@ -36,6 +36,10 @@ from memory.common.db.models.source_item import (
clean_filename,
chunk_mixed,
)
+from memory.common.db.models.mcp import (
+ MCPServer,
+ MCPServerAssignment,
+)
class MailMessagePayload(SourceItemPayload):
@@ -392,6 +396,29 @@ class DiscordMessage(SourceItem):
return content
+ def get_mcp_servers(self, session) -> list[MCPServer]:
+ entity_ids = list(
+ filter(
+ None,
+ [
+ self.recipient_user.id,
+ self.from_user.id,
+ self.channel.id,
+ self.server.id,
+ ],
+ )
+ )
+ if not entity_ids:
+ return None
+
+ return (
+ session.query(MCPServer)
+ .filter(
+ MCPServerAssignment.entity_id.in_(entity_ids),
+ )
+ .all()
+ )
+
__mapper_args__ = {
"polymorphic_identity": "discord_message",
}
diff --git a/src/memory/common/oauth.py b/src/memory/common/oauth.py
index 82bd111..69dc8e6 100644
--- a/src/memory/common/oauth.py
+++ b/src/memory/common/oauth.py
@@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
import aiohttp
from memory.common import settings
-from memory.common.db.models.discord import MCPServer
+from memory.common.db.models import MCPServer
logger = logging.getLogger(__name__)
diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py
index 04eec54..0c432bf 100644
--- a/src/memory/discord/commands.py
+++ b/src/memory/discord/commands.py
@@ -363,16 +363,6 @@ def create_list_group(
mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers]
res = "\n\n".join(mcp_servers)
await respond(interaction, res)
- # def handler(objects: DiscordObjects) -> str:
- # servers = [s.as_xml() for obj in objects.items for s in obj.mcp_servers]
- # return "\n\n".join(servers) if servers else "No MCP servers configured."
-
- # try:
- # res = with_object_context(bot, interaction, handler, user)
- # except Exception as exc:
- # logger.error(f"Error listing MCP servers: {exc}", exc_info=True)
- # return CommandResponse(content="Error listing MCP servers.")
- # await respond(interaction, res)
return group
diff --git a/src/memory/discord/mcp.py b/src/memory/discord/mcp.py
index a0f42aa..bb72988 100644
--- a/src/memory/discord/mcp.py
+++ b/src/memory/discord/mcp.py
@@ -10,7 +10,7 @@ import discord
from sqlalchemy.orm import Session, scoped_session
from memory.common.db.connection import make_session
-from memory.common.db.models.discord import MCPServer, MCPServerAssignment
+from memory.common.db.models import MCPServer, MCPServerAssignment
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
logger = logging.getLogger(__name__)
diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py
index d815dd3..01c5b28 100644
--- a/src/memory/discord/messages.py
+++ b/src/memory/discord/messages.py
@@ -12,9 +12,9 @@ from memory.common.db.models import (
DiscordMessage,
DiscordUser,
ScheduledLLMCall,
+ MCPServer as MCPServerModel,
)
from memory.common.llms.base import create_provider
-from memory.common.llms.tools import MCPServer
logger = logging.getLogger(__name__)
@@ -187,7 +187,7 @@ def comm_channel_prompt(
{users}
""").format(
- users="\n".join({msg.from_user.as_xml() for msg in messages}),
+ users="\n".join({msg.from_user.xml_summary() for msg in messages}),
)
return textwrap.dedent("""
@@ -212,7 +212,7 @@ def call_llm(
system_prompt: str = "",
messages: list[str | dict[str, Any]] = [],
allowed_tools: Collection[str] | None = None,
- mcp_servers: list[MCPServer] | None = None,
+ mcp_servers: list[MCPServerModel] | None = None,
num_previous_messages: int = 10,
) -> str | None:
"""
@@ -251,6 +251,7 @@ def call_llm(
from memory.common.llms.tools.discord import make_discord_tools
from memory.common.llms.tools.base import WebSearchTool
+ from memory.common.llms.tools import MCPServer
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
tools |= {"web_search": WebSearchTool()}
@@ -266,7 +267,16 @@ def call_llm(
messages=provider.as_messages(message_content),
tools=tools,
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
- mcp_servers=mcp_servers,
+ mcp_servers=[
+ MCPServer(
+ name=str(server.name),
+ url=str(server.mcp_server_url),
+ token=str(server.access_token),
+ )
+ for server in mcp_servers
+ ]
+ if mcp_servers
+ else None,
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
).response
diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py
index 9c70973..05e85b3 100644
--- a/src/memory/workers/tasks/discord.py
+++ b/src/memory/workers/tasks/discord.py
@@ -196,20 +196,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id,
}
- mcp_servers = None
- if (
- discord_message.recipient_user
- and discord_message.recipient_user.mcp_servers
- ):
- mcp_servers = [
- MCPServer(
- name=server.mcp_server_url,
- url=server.mcp_server_url,
- token=server.access_token,
- )
- for server in discord_message.recipient_user.mcp_servers
- ]
-
+ mcp_servers = discord_message.get_mcp_servers(session)
system_prompt = discord_message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, discord_message.recipient_user, discord_message.channel