mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
tweaks
This commit is contained in:
parent
b568222e88
commit
56c0df9761
@ -34,6 +34,8 @@ from memory.common.db.models.discord import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
)
|
||||
from memory.common.db.models.mcp import (
|
||||
MCPServer,
|
||||
MCPServerAssignment,
|
||||
)
|
||||
|
||||
@ -55,7 +55,7 @@ class MessageProcessor:
|
||||
def entity_type(self) -> str:
|
||||
return self.__class__.__tablename__[8:-1] # type: ignore
|
||||
|
||||
def to_xml(self, fields: list[str]) -> str:
|
||||
def to_xml(self, *fields: str) -> str:
|
||||
def indent(key: str, text: str) -> str:
|
||||
res = textwrap.dedent("""
|
||||
<{key}>
|
||||
@ -78,13 +78,13 @@ class MessageProcessor:
|
||||
return indent(self.entity_type, "\n".join(vals)) # type: ignore
|
||||
|
||||
def xml_prompt(self) -> str:
|
||||
return self.to_xml(["name", "system_prompt"]) if self.system_prompt else ""
|
||||
return self.to_xml("name", "system_prompt") if self.system_prompt else ""
|
||||
|
||||
def xml_summary(self) -> str:
|
||||
return self.to_xml(["name", "summary"])
|
||||
return self.to_xml("name", "summary")
|
||||
|
||||
def xml_mcp_servers(self) -> str:
|
||||
return self.to_xml(["mcp_servers"])
|
||||
return self.to_xml("mcp_servers")
|
||||
|
||||
|
||||
class DiscordServer(Base, MessageProcessor):
|
||||
@ -152,95 +152,3 @@ class DiscordUser(Base, MessageProcessor):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.username
|
||||
|
||||
|
||||
class MCPServer(Base):
|
||||
"""MCP server configuration and OAuth state."""
|
||||
|
||||
__tablename__ = "mcp_servers"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
# MCP server info
|
||||
name = Column(Text, nullable=False)
|
||||
mcp_server_url = Column(Text, nullable=False)
|
||||
client_id = Column(Text, nullable=False)
|
||||
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
# OAuth flow state (temporary, cleared after token exchange)
|
||||
state = Column(Text, nullable=True, unique=True)
|
||||
code_verifier = Column(Text, nullable=True)
|
||||
|
||||
# OAuth tokens (set after successful authorization)
|
||||
access_token = Column(Text, nullable=True)
|
||||
refresh_token = Column(Text, nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
assignments = relationship(
|
||||
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (Index("mcp_state_idx", "state"),)
|
||||
|
||||
def as_xml(self) -> str:
|
||||
tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip()
|
||||
return textwrap.dedent("""
|
||||
<mcp_server>
|
||||
<name>
|
||||
{name}
|
||||
</name>
|
||||
<mcp_server_url>
|
||||
{mcp_server_url}
|
||||
</mcp_server_url>
|
||||
<client_id>
|
||||
{client_id}
|
||||
</client_id>
|
||||
<available_tools>
|
||||
{available_tools}
|
||||
</available_tools>
|
||||
</mcp_server>
|
||||
""").format(
|
||||
name=self.name,
|
||||
mcp_server_url=self.mcp_server_url,
|
||||
client_id=self.client_id,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
|
||||
class MCPServerAssignment(Base):
|
||||
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
|
||||
|
||||
__tablename__ = "mcp_server_assignments"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
|
||||
|
||||
# Polymorphic entity reference
|
||||
entity_type = Column(
|
||||
Text, nullable=False
|
||||
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
|
||||
entity_id = Column(BigInteger, nullable=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
mcp_server = relationship("MCPServer", back_populates="assignments")
|
||||
|
||||
__table_args__ = (
|
||||
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
|
||||
Index("mcp_assignment_server_idx", "mcp_server_id"),
|
||||
Index(
|
||||
"mcp_assignment_unique_idx",
|
||||
"mcp_server_id",
|
||||
"entity_type",
|
||||
"entity_id",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
108
src/memory/common/db/models/mcp.py
Normal file
108
src/memory/common/db/models/mcp.py
Normal file
@ -0,0 +1,108 @@
|
||||
import textwrap
|
||||
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class MCPServer(Base):
|
||||
"""MCP server configuration and OAuth state."""
|
||||
|
||||
__tablename__ = "mcp_servers"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
# MCP server info
|
||||
name = Column(Text, nullable=False)
|
||||
mcp_server_url = Column(Text, nullable=False)
|
||||
client_id = Column(Text, nullable=False)
|
||||
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
# OAuth flow state (temporary, cleared after token exchange)
|
||||
state = Column(Text, nullable=True, unique=True)
|
||||
code_verifier = Column(Text, nullable=True)
|
||||
|
||||
# OAuth tokens (set after successful authorization)
|
||||
access_token = Column(Text, nullable=True)
|
||||
refresh_token = Column(Text, nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
assignments = relationship(
|
||||
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (Index("mcp_state_idx", "state"),)
|
||||
|
||||
def as_xml(self) -> str:
|
||||
tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip()
|
||||
return textwrap.dedent("""
|
||||
<mcp_server>
|
||||
<name>
|
||||
{name}
|
||||
</name>
|
||||
<mcp_server_url>
|
||||
{mcp_server_url}
|
||||
</mcp_server_url>
|
||||
<client_id>
|
||||
{client_id}
|
||||
</client_id>
|
||||
<available_tools>
|
||||
{available_tools}
|
||||
</available_tools>
|
||||
</mcp_server>
|
||||
""").format(
|
||||
name=self.name,
|
||||
mcp_server_url=self.mcp_server_url,
|
||||
client_id=self.client_id,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
|
||||
class MCPServerAssignment(Base):
|
||||
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
|
||||
|
||||
__tablename__ = "mcp_server_assignments"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
|
||||
|
||||
# Polymorphic entity reference
|
||||
entity_type = Column(
|
||||
Text, nullable=False
|
||||
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
|
||||
entity_id = Column(BigInteger, nullable=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
mcp_server = relationship("MCPServer", back_populates="assignments")
|
||||
|
||||
__table_args__ = (
|
||||
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
|
||||
Index("mcp_assignment_server_idx", "mcp_server_id"),
|
||||
Index(
|
||||
"mcp_assignment_unique_idx",
|
||||
"mcp_server_id",
|
||||
"entity_type",
|
||||
"entity_id",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
@ -36,6 +36,10 @@ from memory.common.db.models.source_item import (
|
||||
clean_filename,
|
||||
chunk_mixed,
|
||||
)
|
||||
from memory.common.db.models.mcp import (
|
||||
MCPServer,
|
||||
MCPServerAssignment,
|
||||
)
|
||||
|
||||
|
||||
class MailMessagePayload(SourceItemPayload):
|
||||
@ -392,6 +396,29 @@ class DiscordMessage(SourceItem):
|
||||
|
||||
return content
|
||||
|
||||
def get_mcp_servers(self, session) -> list[MCPServer]:
|
||||
entity_ids = list(
|
||||
filter(
|
||||
None,
|
||||
[
|
||||
self.recipient_user.id,
|
||||
self.from_user.id,
|
||||
self.channel.id,
|
||||
self.server.id,
|
||||
],
|
||||
)
|
||||
)
|
||||
if not entity_ids:
|
||||
return None
|
||||
|
||||
return (
|
||||
session.query(MCPServer)
|
||||
.filter(
|
||||
MCPServerAssignment.entity_id.in_(entity_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "discord_message",
|
||||
}
|
||||
|
||||
@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
|
||||
|
||||
import aiohttp
|
||||
from memory.common import settings
|
||||
from memory.common.db.models.discord import MCPServer
|
||||
from memory.common.db.models import MCPServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -363,16 +363,6 @@ def create_list_group(
|
||||
mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers]
|
||||
res = "\n\n".join(mcp_servers)
|
||||
await respond(interaction, res)
|
||||
# def handler(objects: DiscordObjects) -> str:
|
||||
# servers = [s.as_xml() for obj in objects.items for s in obj.mcp_servers]
|
||||
# return "\n\n".join(servers) if servers else "No MCP servers configured."
|
||||
|
||||
# try:
|
||||
# res = with_object_context(bot, interaction, handler, user)
|
||||
# except Exception as exc:
|
||||
# logger.error(f"Error listing MCP servers: {exc}", exc_info=True)
|
||||
# return CommandResponse(content="Error listing MCP servers.")
|
||||
# await respond(interaction, res)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ import discord
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.discord import MCPServer, MCPServerAssignment
|
||||
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -12,9 +12,9 @@ from memory.common.db.models import (
|
||||
DiscordMessage,
|
||||
DiscordUser,
|
||||
ScheduledLLMCall,
|
||||
MCPServer as MCPServerModel,
|
||||
)
|
||||
from memory.common.llms.base import create_provider
|
||||
from memory.common.llms.tools import MCPServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -187,7 +187,7 @@ def comm_channel_prompt(
|
||||
{users}
|
||||
</user_notes>
|
||||
""").format(
|
||||
users="\n".join({msg.from_user.as_xml() for msg in messages}),
|
||||
users="\n".join({msg.from_user.xml_summary() for msg in messages}),
|
||||
)
|
||||
|
||||
return textwrap.dedent("""
|
||||
@ -212,7 +212,7 @@ def call_llm(
|
||||
system_prompt: str = "",
|
||||
messages: list[str | dict[str, Any]] = [],
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
mcp_servers: list[MCPServerModel] | None = None,
|
||||
num_previous_messages: int = 10,
|
||||
) -> str | None:
|
||||
"""
|
||||
@ -251,6 +251,7 @@ def call_llm(
|
||||
|
||||
from memory.common.llms.tools.discord import make_discord_tools
|
||||
from memory.common.llms.tools.base import WebSearchTool
|
||||
from memory.common.llms.tools import MCPServer
|
||||
|
||||
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
|
||||
tools |= {"web_search": WebSearchTool()}
|
||||
@ -266,7 +267,16 @@ def call_llm(
|
||||
messages=provider.as_messages(message_content),
|
||||
tools=tools,
|
||||
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
|
||||
mcp_servers=mcp_servers,
|
||||
mcp_servers=[
|
||||
MCPServer(
|
||||
name=str(server.name),
|
||||
url=str(server.mcp_server_url),
|
||||
token=str(server.access_token),
|
||||
)
|
||||
for server in mcp_servers
|
||||
]
|
||||
if mcp_servers
|
||||
else None,
|
||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||
).response
|
||||
|
||||
|
||||
@ -196,20 +196,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
mcp_servers = None
|
||||
if (
|
||||
discord_message.recipient_user
|
||||
and discord_message.recipient_user.mcp_servers
|
||||
):
|
||||
mcp_servers = [
|
||||
MCPServer(
|
||||
name=server.mcp_server_url,
|
||||
url=server.mcp_server_url,
|
||||
token=server.access_token,
|
||||
)
|
||||
for server in discord_message.recipient_user.mcp_servers
|
||||
]
|
||||
|
||||
mcp_servers = discord_message.get_mcp_servers(session)
|
||||
system_prompt = discord_message.system_prompt or ""
|
||||
system_prompt += comm_channel_prompt(
|
||||
session, discord_message.recipient_user, discord_message.channel
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user