mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 08:14:05 +01:00
tweaks
This commit is contained in:
parent
b568222e88
commit
56c0df9761
@ -34,6 +34,8 @@ from memory.common.db.models.discord import (
|
|||||||
DiscordServer,
|
DiscordServer,
|
||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
|
)
|
||||||
|
from memory.common.db.models.mcp import (
|
||||||
MCPServer,
|
MCPServer,
|
||||||
MCPServerAssignment,
|
MCPServerAssignment,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class MessageProcessor:
|
|||||||
def entity_type(self) -> str:
|
def entity_type(self) -> str:
|
||||||
return self.__class__.__tablename__[8:-1] # type: ignore
|
return self.__class__.__tablename__[8:-1] # type: ignore
|
||||||
|
|
||||||
def to_xml(self, fields: list[str]) -> str:
|
def to_xml(self, *fields: str) -> str:
|
||||||
def indent(key: str, text: str) -> str:
|
def indent(key: str, text: str) -> str:
|
||||||
res = textwrap.dedent("""
|
res = textwrap.dedent("""
|
||||||
<{key}>
|
<{key}>
|
||||||
@ -78,13 +78,13 @@ class MessageProcessor:
|
|||||||
return indent(self.entity_type, "\n".join(vals)) # type: ignore
|
return indent(self.entity_type, "\n".join(vals)) # type: ignore
|
||||||
|
|
||||||
def xml_prompt(self) -> str:
|
def xml_prompt(self) -> str:
|
||||||
return self.to_xml(["name", "system_prompt"]) if self.system_prompt else ""
|
return self.to_xml("name", "system_prompt") if self.system_prompt else ""
|
||||||
|
|
||||||
def xml_summary(self) -> str:
|
def xml_summary(self) -> str:
|
||||||
return self.to_xml(["name", "summary"])
|
return self.to_xml("name", "summary")
|
||||||
|
|
||||||
def xml_mcp_servers(self) -> str:
|
def xml_mcp_servers(self) -> str:
|
||||||
return self.to_xml(["mcp_servers"])
|
return self.to_xml("mcp_servers")
|
||||||
|
|
||||||
|
|
||||||
class DiscordServer(Base, MessageProcessor):
|
class DiscordServer(Base, MessageProcessor):
|
||||||
@ -152,95 +152,3 @@ class DiscordUser(Base, MessageProcessor):
|
|||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self.username
|
return self.username
|
||||||
|
|
||||||
|
|
||||||
class MCPServer(Base):
|
|
||||||
"""MCP server configuration and OAuth state."""
|
|
||||||
|
|
||||||
__tablename__ = "mcp_servers"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
|
|
||||||
# MCP server info
|
|
||||||
name = Column(Text, nullable=False)
|
|
||||||
mcp_server_url = Column(Text, nullable=False)
|
|
||||||
client_id = Column(Text, nullable=False)
|
|
||||||
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
|
||||||
|
|
||||||
# OAuth flow state (temporary, cleared after token exchange)
|
|
||||||
state = Column(Text, nullable=True, unique=True)
|
|
||||||
code_verifier = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
# OAuth tokens (set after successful authorization)
|
|
||||||
access_token = Column(Text, nullable=True)
|
|
||||||
refresh_token = Column(Text, nullable=True)
|
|
||||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
|
||||||
|
|
||||||
# Timestamps
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
assignments = relationship(
|
|
||||||
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (Index("mcp_state_idx", "state"),)
|
|
||||||
|
|
||||||
def as_xml(self) -> str:
|
|
||||||
tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip()
|
|
||||||
return textwrap.dedent("""
|
|
||||||
<mcp_server>
|
|
||||||
<name>
|
|
||||||
{name}
|
|
||||||
</name>
|
|
||||||
<mcp_server_url>
|
|
||||||
{mcp_server_url}
|
|
||||||
</mcp_server_url>
|
|
||||||
<client_id>
|
|
||||||
{client_id}
|
|
||||||
</client_id>
|
|
||||||
<available_tools>
|
|
||||||
{available_tools}
|
|
||||||
</available_tools>
|
|
||||||
</mcp_server>
|
|
||||||
""").format(
|
|
||||||
name=self.name,
|
|
||||||
mcp_server_url=self.mcp_server_url,
|
|
||||||
client_id=self.client_id,
|
|
||||||
available_tools=tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPServerAssignment(Base):
|
|
||||||
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
|
|
||||||
|
|
||||||
__tablename__ = "mcp_server_assignments"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
|
|
||||||
|
|
||||||
# Polymorphic entity reference
|
|
||||||
entity_type = Column(
|
|
||||||
Text, nullable=False
|
|
||||||
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
|
|
||||||
entity_id = Column(BigInteger, nullable=False)
|
|
||||||
|
|
||||||
# Timestamps
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
mcp_server = relationship("MCPServer", back_populates="assignments")
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
|
|
||||||
Index("mcp_assignment_server_idx", "mcp_server_id"),
|
|
||||||
Index(
|
|
||||||
"mcp_assignment_unique_idx",
|
|
||||||
"mcp_server_id",
|
|
||||||
"entity_type",
|
|
||||||
"entity_id",
|
|
||||||
unique=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|||||||
108
src/memory/common/db/models/mcp.py
Normal file
108
src/memory/common/db/models/mcp.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import textwrap
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
ARRAY,
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
Text,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServer(Base):
|
||||||
|
"""MCP server configuration and OAuth state."""
|
||||||
|
|
||||||
|
__tablename__ = "mcp_servers"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
|
||||||
|
# MCP server info
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
mcp_server_url = Column(Text, nullable=False)
|
||||||
|
client_id = Column(Text, nullable=False)
|
||||||
|
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
|
# OAuth flow state (temporary, cleared after token exchange)
|
||||||
|
state = Column(Text, nullable=True, unique=True)
|
||||||
|
code_verifier = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# OAuth tokens (set after successful authorization)
|
||||||
|
access_token = Column(Text, nullable=True)
|
||||||
|
refresh_token = Column(Text, nullable=True)
|
||||||
|
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
assignments = relationship(
|
||||||
|
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (Index("mcp_state_idx", "state"),)
|
||||||
|
|
||||||
|
def as_xml(self) -> str:
|
||||||
|
tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip()
|
||||||
|
return textwrap.dedent("""
|
||||||
|
<mcp_server>
|
||||||
|
<name>
|
||||||
|
{name}
|
||||||
|
</name>
|
||||||
|
<mcp_server_url>
|
||||||
|
{mcp_server_url}
|
||||||
|
</mcp_server_url>
|
||||||
|
<client_id>
|
||||||
|
{client_id}
|
||||||
|
</client_id>
|
||||||
|
<available_tools>
|
||||||
|
{available_tools}
|
||||||
|
</available_tools>
|
||||||
|
</mcp_server>
|
||||||
|
""").format(
|
||||||
|
name=self.name,
|
||||||
|
mcp_server_url=self.mcp_server_url,
|
||||||
|
client_id=self.client_id,
|
||||||
|
available_tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerAssignment(Base):
|
||||||
|
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
|
||||||
|
|
||||||
|
__tablename__ = "mcp_server_assignments"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
|
||||||
|
|
||||||
|
# Polymorphic entity reference
|
||||||
|
entity_type = Column(
|
||||||
|
Text, nullable=False
|
||||||
|
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
|
||||||
|
entity_id = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
mcp_server = relationship("MCPServer", back_populates="assignments")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
|
||||||
|
Index("mcp_assignment_server_idx", "mcp_server_id"),
|
||||||
|
Index(
|
||||||
|
"mcp_assignment_unique_idx",
|
||||||
|
"mcp_server_id",
|
||||||
|
"entity_type",
|
||||||
|
"entity_id",
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
@ -36,6 +36,10 @@ from memory.common.db.models.source_item import (
|
|||||||
clean_filename,
|
clean_filename,
|
||||||
chunk_mixed,
|
chunk_mixed,
|
||||||
)
|
)
|
||||||
|
from memory.common.db.models.mcp import (
|
||||||
|
MCPServer,
|
||||||
|
MCPServerAssignment,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MailMessagePayload(SourceItemPayload):
|
class MailMessagePayload(SourceItemPayload):
|
||||||
@ -392,6 +396,29 @@ class DiscordMessage(SourceItem):
|
|||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
def get_mcp_servers(self, session) -> list[MCPServer]:
|
||||||
|
entity_ids = list(
|
||||||
|
filter(
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
self.recipient_user.id,
|
||||||
|
self.from_user.id,
|
||||||
|
self.channel.id,
|
||||||
|
self.server.id,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not entity_ids:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (
|
||||||
|
session.query(MCPServer)
|
||||||
|
.filter(
|
||||||
|
MCPServerAssignment.entity_id.in_(entity_ids),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
"polymorphic_identity": "discord_message",
|
"polymorphic_identity": "discord_message",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.models.discord import MCPServer
|
from memory.common.db.models import MCPServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -363,16 +363,6 @@ def create_list_group(
|
|||||||
mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers]
|
mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers]
|
||||||
res = "\n\n".join(mcp_servers)
|
res = "\n\n".join(mcp_servers)
|
||||||
await respond(interaction, res)
|
await respond(interaction, res)
|
||||||
# def handler(objects: DiscordObjects) -> str:
|
|
||||||
# servers = [s.as_xml() for obj in objects.items for s in obj.mcp_servers]
|
|
||||||
# return "\n\n".join(servers) if servers else "No MCP servers configured."
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# res = with_object_context(bot, interaction, handler, user)
|
|
||||||
# except Exception as exc:
|
|
||||||
# logger.error(f"Error listing MCP servers: {exc}", exc_info=True)
|
|
||||||
# return CommandResponse(content="Error listing MCP servers.")
|
|
||||||
# await respond(interaction, res)
|
|
||||||
|
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import discord
|
|||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models.discord import MCPServer, MCPServerAssignment
|
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||||
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -12,9 +12,9 @@ from memory.common.db.models import (
|
|||||||
DiscordMessage,
|
DiscordMessage,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
ScheduledLLMCall,
|
ScheduledLLMCall,
|
||||||
|
MCPServer as MCPServerModel,
|
||||||
)
|
)
|
||||||
from memory.common.llms.base import create_provider
|
from memory.common.llms.base import create_provider
|
||||||
from memory.common.llms.tools import MCPServer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -187,7 +187,7 @@ def comm_channel_prompt(
|
|||||||
{users}
|
{users}
|
||||||
</user_notes>
|
</user_notes>
|
||||||
""").format(
|
""").format(
|
||||||
users="\n".join({msg.from_user.as_xml() for msg in messages}),
|
users="\n".join({msg.from_user.xml_summary() for msg in messages}),
|
||||||
)
|
)
|
||||||
|
|
||||||
return textwrap.dedent("""
|
return textwrap.dedent("""
|
||||||
@ -212,7 +212,7 @@ def call_llm(
|
|||||||
system_prompt: str = "",
|
system_prompt: str = "",
|
||||||
messages: list[str | dict[str, Any]] = [],
|
messages: list[str | dict[str, Any]] = [],
|
||||||
allowed_tools: Collection[str] | None = None,
|
allowed_tools: Collection[str] | None = None,
|
||||||
mcp_servers: list[MCPServer] | None = None,
|
mcp_servers: list[MCPServerModel] | None = None,
|
||||||
num_previous_messages: int = 10,
|
num_previous_messages: int = 10,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""
|
"""
|
||||||
@ -251,6 +251,7 @@ def call_llm(
|
|||||||
|
|
||||||
from memory.common.llms.tools.discord import make_discord_tools
|
from memory.common.llms.tools.discord import make_discord_tools
|
||||||
from memory.common.llms.tools.base import WebSearchTool
|
from memory.common.llms.tools.base import WebSearchTool
|
||||||
|
from memory.common.llms.tools import MCPServer
|
||||||
|
|
||||||
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
|
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
|
||||||
tools |= {"web_search": WebSearchTool()}
|
tools |= {"web_search": WebSearchTool()}
|
||||||
@ -266,7 +267,16 @@ def call_llm(
|
|||||||
messages=provider.as_messages(message_content),
|
messages=provider.as_messages(message_content),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
|
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
|
||||||
mcp_servers=mcp_servers,
|
mcp_servers=[
|
||||||
|
MCPServer(
|
||||||
|
name=str(server.name),
|
||||||
|
url=str(server.mcp_server_url),
|
||||||
|
token=str(server.access_token),
|
||||||
|
)
|
||||||
|
for server in mcp_servers
|
||||||
|
]
|
||||||
|
if mcp_servers
|
||||||
|
else None,
|
||||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||||
).response
|
).response
|
||||||
|
|
||||||
|
|||||||
@ -196,20 +196,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
mcp_servers = None
|
mcp_servers = discord_message.get_mcp_servers(session)
|
||||||
if (
|
|
||||||
discord_message.recipient_user
|
|
||||||
and discord_message.recipient_user.mcp_servers
|
|
||||||
):
|
|
||||||
mcp_servers = [
|
|
||||||
MCPServer(
|
|
||||||
name=server.mcp_server_url,
|
|
||||||
url=server.mcp_server_url,
|
|
||||||
token=server.access_token,
|
|
||||||
)
|
|
||||||
for server in discord_message.recipient_user.mcp_servers
|
|
||||||
]
|
|
||||||
|
|
||||||
system_prompt = discord_message.system_prompt or ""
|
system_prompt = discord_message.system_prompt or ""
|
||||||
system_prompt += comm_channel_prompt(
|
system_prompt += comm_channel_prompt(
|
||||||
session, discord_message.recipient_user, discord_message.channel
|
session, discord_message.recipient_user, discord_message.channel
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user