diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py index c2e262b..7eb5887 100644 --- a/src/memory/common/db/models/__init__.py +++ b/src/memory/common/db/models/__init__.py @@ -34,6 +34,8 @@ from memory.common.db.models.discord import ( DiscordServer, DiscordChannel, DiscordUser, +) +from memory.common.db.models.mcp import ( MCPServer, MCPServerAssignment, ) diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py index d6d54b4..752ce45 100644 --- a/src/memory/common/db/models/discord.py +++ b/src/memory/common/db/models/discord.py @@ -55,7 +55,7 @@ class MessageProcessor: def entity_type(self) -> str: return self.__class__.__tablename__[8:-1] # type: ignore - def to_xml(self, fields: list[str]) -> str: + def to_xml(self, *fields: str) -> str: def indent(key: str, text: str) -> str: res = textwrap.dedent(""" <{key}> @@ -78,13 +78,13 @@ class MessageProcessor: return indent(self.entity_type, "\n".join(vals)) # type: ignore def xml_prompt(self) -> str: - return self.to_xml(["name", "system_prompt"]) if self.system_prompt else "" + return self.to_xml("name", "system_prompt") if self.system_prompt else "" def xml_summary(self) -> str: - return self.to_xml(["name", "summary"]) + return self.to_xml("name", "summary") def xml_mcp_servers(self) -> str: - return self.to_xml(["mcp_servers"]) + return self.to_xml("mcp_servers") class DiscordServer(Base, MessageProcessor): @@ -152,95 +152,3 @@ class DiscordUser(Base, MessageProcessor): @property def name(self) -> str: return self.username - - -class MCPServer(Base): - """MCP server configuration and OAuth state.""" - - __tablename__ = "mcp_servers" - - id = Column(Integer, primary_key=True) - - # MCP server info - name = Column(Text, nullable=False) - mcp_server_url = Column(Text, nullable=False) - client_id = Column(Text, nullable=False) - available_tools = Column(ARRAY(Text), nullable=False, server_default="{}") - - # OAuth flow state (temporary, cleared after token exchange) - state = Column(Text, nullable=True, unique=True) - code_verifier = Column(Text, nullable=True) - - # OAuth tokens (set after successful authorization) - access_token = Column(Text, nullable=True) - refresh_token = Column(Text, nullable=True) - token_expires_at = Column(DateTime(timezone=True), nullable=True) - - # Timestamps - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now()) - - # Relationships - assignments = relationship( - "MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan" - ) - - __table_args__ = (Index("mcp_state_idx", "state"),) - - def as_xml(self) -> str: - tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip() - return textwrap.dedent(""" - - - {name} - - - {mcp_server_url} - - - {client_id} - - - {available_tools} - - - """).format( - name=self.name, - mcp_server_url=self.mcp_server_url, - client_id=self.client_id, - available_tools=tools, - ) - - -class MCPServerAssignment(Base): - """Assignment of MCP servers to entities (users, channels, servers, etc.).""" - - __tablename__ = "mcp_server_assignments" - - id = Column(Integer, primary_key=True) - mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False) - - # Polymorphic entity reference - entity_type = Column( - Text, nullable=False - ) # "User", "DiscordUser", "DiscordServer", "DiscordChannel" - entity_id = Column(BigInteger, nullable=False) - - # Timestamps - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now()) - - # Relationships - mcp_server = relationship("MCPServer", back_populates="assignments") - - __table_args__ = ( - Index("mcp_assignment_entity_idx", "entity_type", "entity_id"), - Index("mcp_assignment_server_idx", "mcp_server_id"), - Index( - "mcp_assignment_unique_idx", - "mcp_server_id", - "entity_type", - "entity_id", - unique=True, - ), - ) diff --git a/src/memory/common/db/models/mcp.py b/src/memory/common/db/models/mcp.py new file mode 100644 index 0000000..1bf6a8c --- /dev/null +++ b/src/memory/common/db/models/mcp.py @@ -0,0 +1,108 @@ +import textwrap + +from sqlalchemy import ( + ARRAY, + BigInteger, + Column, + DateTime, + ForeignKey, + Index, + Integer, + Text, + func, +) +from sqlalchemy.orm import relationship + +from memory.common.db.models.base import Base + + +class MCPServer(Base): + """MCP server configuration and OAuth state.""" + + __tablename__ = "mcp_servers" + + id = Column(Integer, primary_key=True) + + # MCP server info + name = Column(Text, nullable=False) + mcp_server_url = Column(Text, nullable=False) + client_id = Column(Text, nullable=False) + available_tools = Column(ARRAY(Text), nullable=False, server_default="{}") + + # OAuth flow state (temporary, cleared after token exchange) + state = Column(Text, nullable=True, unique=True) + code_verifier = Column(Text, nullable=True) + + # OAuth tokens (set after successful authorization) + access_token = Column(Text, nullable=True) + refresh_token = Column(Text, nullable=True) + token_expires_at = Column(DateTime(timezone=True), nullable=True) + + # Timestamps + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + # Relationships + assignments = relationship( + "MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan" + ) + + __table_args__ = (Index("mcp_state_idx", "state"),) + + def as_xml(self) -> str: + tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip() + return textwrap.dedent(""" + + + {name} + + + {mcp_server_url} + + + {client_id} + + + {available_tools} + + + """).format( + name=self.name, + mcp_server_url=self.mcp_server_url, + client_id=self.client_id, + available_tools=tools, + ) + + +class MCPServerAssignment(Base): + """Assignment of MCP servers to entities (users, channels, servers, etc.).""" + + __tablename__ = "mcp_server_assignments" + + id = Column(Integer, primary_key=True) + mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False) + + # Polymorphic entity reference + entity_type = Column( + Text, nullable=False + ) # "User", "DiscordUser", "DiscordServer", "DiscordChannel" + entity_id = Column(BigInteger, nullable=False) + + # Timestamps + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + + # Relationships + mcp_server = relationship("MCPServer", back_populates="assignments") + + __table_args__ = ( + Index("mcp_assignment_entity_idx", "entity_type", "entity_id"), + Index("mcp_assignment_server_idx", "mcp_server_id"), + Index( + "mcp_assignment_unique_idx", + "mcp_server_id", + "entity_type", + "entity_id", + unique=True, + ), + ) diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index 8046ad6..fa35f17 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -36,6 +36,10 @@ from memory.common.db.models.source_item import ( clean_filename, chunk_mixed, ) +from memory.common.db.models.mcp import ( + MCPServer, + MCPServerAssignment, +) class MailMessagePayload(SourceItemPayload): @@ -392,6 +396,29 @@ class DiscordMessage(SourceItem): return content + def get_mcp_servers(self, session) -> list[MCPServer]: + entity_ids = list( + filter( + None, + [ + self.recipient_user.id, + self.from_user.id, + self.channel.id, + self.server.id, + ], + ) + ) + if not entity_ids: + return None + + return ( + session.query(MCPServer) + .filter( + MCPServerAssignment.entity_id.in_(entity_ids), + ) + .all() + ) + __mapper_args__ = { "polymorphic_identity": "discord_message", } diff --git a/src/memory/common/oauth.py b/src/memory/common/oauth.py index 82bd111..69dc8e6 100644 --- a/src/memory/common/oauth.py +++ b/src/memory/common/oauth.py @@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin import aiohttp from memory.common import settings -from memory.common.db.models.discord import MCPServer +from memory.common.db.models import MCPServer logger = logging.getLogger(__name__) diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py index 04eec54..0c432bf 100644 --- a/src/memory/discord/commands.py +++ b/src/memory/discord/commands.py @@ -363,16 +363,6 @@ def create_list_group( mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers] res = "\n\n".join(mcp_servers) await respond(interaction, res) - # def handler(objects: DiscordObjects) -> str: - # servers = [s.as_xml() for obj in objects.items for s in obj.mcp_servers] - # return "\n\n".join(servers) if servers else "No MCP servers configured." - - # try: - # res = with_object_context(bot, interaction, handler, user) - # except Exception as exc: - # logger.error(f"Error listing MCP servers: {exc}", exc_info=True) - # return CommandResponse(content="Error listing MCP servers.") - # await respond(interaction, res) return group diff --git a/src/memory/discord/mcp.py b/src/memory/discord/mcp.py index a0f42aa..bb72988 100644 --- a/src/memory/discord/mcp.py +++ b/src/memory/discord/mcp.py @@ -10,7 +10,7 @@ import discord from sqlalchemy.orm import Session, scoped_session from memory.common.db.connection import make_session -from memory.common.db.models.discord import MCPServer, MCPServerAssignment +from memory.common.db.models import MCPServer, MCPServerAssignment from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client logger = logging.getLogger(__name__) diff --git a/src/memory/discord/messages.py b/src/memory/discord/messages.py index d815dd3..01c5b28 100644 --- a/src/memory/discord/messages.py +++ b/src/memory/discord/messages.py @@ -12,9 +12,9 @@ from memory.common.db.models import ( DiscordMessage, DiscordUser, ScheduledLLMCall, + MCPServer as MCPServerModel, ) from memory.common.llms.base import create_provider -from memory.common.llms.tools import MCPServer logger = logging.getLogger(__name__) @@ -187,7 +187,7 @@ def comm_channel_prompt( {users} """).format( - users="\n".join({msg.from_user.as_xml() for msg in messages}), + users="\n".join({msg.from_user.xml_summary() for msg in messages}), ) return textwrap.dedent(""" @@ -212,7 +212,7 @@ def call_llm( system_prompt: str = "", messages: list[str | dict[str, Any]] = [], allowed_tools: Collection[str] | None = None, - mcp_servers: list[MCPServer] | None = None, + mcp_servers: list[MCPServerModel] | None = None, num_previous_messages: int = 10, ) -> str | None: """ @@ -251,6 +251,7 @@ def call_llm( from memory.common.llms.tools.discord import make_discord_tools from memory.common.llms.tools.base import WebSearchTool + from memory.common.llms.tools import MCPServer tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model) tools |= {"web_search": WebSearchTool()} @@ -266,7 +267,16 @@ def call_llm( messages=provider.as_messages(message_content), tools=tools, system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""), - mcp_servers=mcp_servers, + mcp_servers=[ + MCPServer( + name=str(server.name), + url=str(server.mcp_server_url), + token=str(server.access_token), + ) + for server in mcp_servers + ] + if mcp_servers + else None, max_iterations=settings.DISCORD_MAX_TOOL_CALLS, ).response diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index 9c70973..05e85b3 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -196,20 +196,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]: "message_id": message_id, } - mcp_servers = None - if ( - discord_message.recipient_user - and discord_message.recipient_user.mcp_servers - ): - mcp_servers = [ - MCPServer( - name=server.mcp_server_url, - url=server.mcp_server_url, - token=server.access_token, - ) - for server in discord_message.recipient_user.mcp_servers - ] - + mcp_servers = discord_message.get_mcp_servers(session) system_prompt = discord_message.system_prompt or "" system_prompt += comm_channel_prompt( session, discord_message.recipient_user, discord_message.channel