mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
multiple mcp servers
This commit is contained in:
parent
2d3dc06fdf
commit
8893018af1
@ -1,5 +1,6 @@
|
||||
# Agent Guidance
|
||||
|
||||
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
|
||||
- Treat LLM model identifiers as `<provider>/<model_name>` strings throughout the codebase.
|
||||
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
|
||||
- Tests should be written with @pytest.mark.parametrize where applicable and should avoid test classes
|
||||
- Make sure linting errors get fixed
|
||||
|
||||
@ -1,67 +0,0 @@
|
||||
"""discord mcp servers
|
||||
|
||||
Revision ID: 9b887449ea92
|
||||
Revises: 1954477b25f4
|
||||
Create Date: 2025-11-02 22:04:26.259323
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9b887449ea92"
|
||||
down_revision: Union[str, None] = "1954477b25f4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"discord_mcp_servers",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
||||
sa.Column("client_id", sa.Text(), nullable=False),
|
||||
sa.Column("state", sa.Text(), nullable=True),
|
||||
sa.Column("code_verifier", sa.Text(), nullable=True),
|
||||
sa.Column("access_token", sa.Text(), nullable=True),
|
||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["discord_bot_user_id"],
|
||||
["discord_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("state"),
|
||||
)
|
||||
op.create_index(
|
||||
"discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"discord_mcp_user_url_idx",
|
||||
"discord_mcp_servers",
|
||||
["discord_bot_user_id", "mcp_server_url"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
|
||||
op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
|
||||
op.drop_table("discord_mcp_servers")
|
||||
103
db/migrations/versions/20251103_154126_mcp_servers.py
Normal file
103
db/migrations/versions/20251103_154126_mcp_servers.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""mcp servers
|
||||
|
||||
Revision ID: 89861d5f1102
|
||||
Revises: 1954477b25f4
|
||||
Create Date: 2025-11-03 15:41:26.254854
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "89861d5f1102"
|
||||
down_revision: Union[str, None] = "1954477b25f4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"mcp_servers",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
||||
sa.Column("client_id", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"available_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||
),
|
||||
sa.Column("state", sa.Text(), nullable=True),
|
||||
sa.Column("code_verifier", sa.Text(), nullable=True),
|
||||
sa.Column("access_token", sa.Text(), nullable=True),
|
||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("state"),
|
||||
)
|
||||
op.create_index("mcp_state_idx", "mcp_servers", ["state"], unique=False)
|
||||
op.create_table(
|
||||
"mcp_server_assignments",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("mcp_server_id", sa.Integer(), nullable=False),
|
||||
sa.Column("entity_type", sa.Text(), nullable=False),
|
||||
sa.Column("entity_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mcp_server_id"],
|
||||
["mcp_servers.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"mcp_assignment_entity_idx",
|
||||
"mcp_server_assignments",
|
||||
["entity_type", "entity_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"mcp_assignment_server_idx",
|
||||
"mcp_server_assignments",
|
||||
["mcp_server_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"mcp_assignment_unique_idx",
|
||||
"mcp_server_assignments",
|
||||
["mcp_server_id", "entity_type", "entity_id"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("mcp_assignment_unique_idx", table_name="mcp_server_assignments")
|
||||
op.drop_index("mcp_assignment_server_idx", table_name="mcp_server_assignments")
|
||||
op.drop_index("mcp_assignment_entity_idx", table_name="mcp_server_assignments")
|
||||
op.drop_table("mcp_server_assignments")
|
||||
op.drop_index("mcp_state_idx", table_name="mcp_servers")
|
||||
op.drop_table("mcp_servers")
|
||||
@ -14,7 +14,7 @@ from memory.common.db.models import (
|
||||
BookSection,
|
||||
Chunk,
|
||||
Comic,
|
||||
DiscordMCPServer,
|
||||
MCPServer,
|
||||
DiscordMessage,
|
||||
EmailAccount,
|
||||
EmailAttachment,
|
||||
@ -167,17 +167,17 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage):
|
||||
column_sortable_list = ["sent_at"]
|
||||
|
||||
|
||||
class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
|
||||
class MCPServerAdmin(ModelView, model=MCPServer):
|
||||
column_list = [
|
||||
"id",
|
||||
"mcp_server_url",
|
||||
"client_id",
|
||||
"discord_bot_user_id",
|
||||
"state",
|
||||
"code_verifier",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"token_expires_at",
|
||||
"available_tools",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
@ -186,7 +186,6 @@ class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
|
||||
"client_id",
|
||||
"state",
|
||||
"id",
|
||||
"discord_bot_user_id",
|
||||
]
|
||||
column_sortable_list = [
|
||||
"created_at",
|
||||
@ -360,5 +359,5 @@ def setup_admin(admin: Admin):
|
||||
admin.add_view(DiscordUserAdmin)
|
||||
admin.add_view(DiscordServerAdmin)
|
||||
admin.add_view(DiscordChannelAdmin)
|
||||
admin.add_view(DiscordMCPServerAdmin)
|
||||
admin.add_view(MCPServerAdmin)
|
||||
admin.add_view(ScheduledLLMCallAdmin)
|
||||
|
||||
@ -10,7 +10,7 @@ from memory.common import settings
|
||||
from memory.common.db.connection import get_session, make_session
|
||||
from memory.common.db.models import (
|
||||
BotUser,
|
||||
DiscordMCPServer,
|
||||
MCPServer,
|
||||
HumanUser,
|
||||
User,
|
||||
UserSession,
|
||||
@ -169,9 +169,7 @@ async def oauth_callback_discord(request: Request):
|
||||
# Complete the OAuth flow (exchange code for token)
|
||||
with make_session() as session:
|
||||
mcp_server = (
|
||||
session.query(DiscordMCPServer)
|
||||
.filter(DiscordMCPServer.state == state)
|
||||
.first()
|
||||
session.query(MCPServer).filter(MCPServer.state == state).first()
|
||||
)
|
||||
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
||||
session.commit()
|
||||
|
||||
@ -34,7 +34,8 @@ from memory.common.db.models.discord import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
DiscordMCPServer,
|
||||
MCPServer,
|
||||
MCPServerAssignment,
|
||||
)
|
||||
from memory.common.db.models.observations import (
|
||||
ObservationContradiction,
|
||||
@ -107,7 +108,8 @@ __all__ = [
|
||||
"DiscordServer",
|
||||
"DiscordChannel",
|
||||
"DiscordUser",
|
||||
"DiscordMCPServer",
|
||||
"MCPServer",
|
||||
"MCPServerAssignment",
|
||||
# Users
|
||||
"User",
|
||||
"HumanUser",
|
||||
|
||||
@ -127,26 +127,22 @@ class DiscordUser(Base, MessageProcessor):
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
system_user = relationship("User", back_populates="discord_users")
|
||||
mcp_servers = relationship(
|
||||
"DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||
|
||||
|
||||
class DiscordMCPServer(Base):
|
||||
"""MCP server configuration and OAuth state for Discord users."""
|
||||
class MCPServer(Base):
|
||||
"""MCP server configuration and OAuth state."""
|
||||
|
||||
__tablename__ = "discord_mcp_servers"
|
||||
__tablename__ = "mcp_servers"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
discord_bot_user_id = Column(
|
||||
BigInteger, ForeignKey("discord_users.id"), nullable=False
|
||||
)
|
||||
|
||||
# MCP server info
|
||||
name = Column(Text, nullable=False)
|
||||
mcp_server_url = Column(Text, nullable=False)
|
||||
client_id = Column(Text, nullable=False)
|
||||
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
# OAuth flow state (temporary, cleared after token exchange)
|
||||
state = Column(Text, nullable=True, unique=True)
|
||||
@ -162,9 +158,42 @@ class DiscordMCPServer(Base):
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
discord_user = relationship("DiscordUser", back_populates="mcp_servers")
|
||||
assignments = relationship(
|
||||
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (Index("mcp_state_idx", "state"),)
|
||||
|
||||
|
||||
class MCPServerAssignment(Base):
|
||||
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
|
||||
|
||||
__tablename__ = "mcp_server_assignments"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
|
||||
|
||||
# Polymorphic entity reference
|
||||
entity_type = Column(
|
||||
Text, nullable=False
|
||||
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
|
||||
entity_id = Column(BigInteger, nullable=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
mcp_server = relationship("MCPServer", back_populates="assignments")
|
||||
|
||||
__table_args__ = (
|
||||
Index("discord_mcp_state_idx", "state"),
|
||||
Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"),
|
||||
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
|
||||
Index("mcp_assignment_server_idx", "mcp_server_id"),
|
||||
Index(
|
||||
"mcp_assignment_unique_idx",
|
||||
"mcp_server_id",
|
||||
"entity_type",
|
||||
"entity_id",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
@ -69,7 +69,6 @@ def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
print("Result", result)
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
|
||||
@ -55,7 +55,6 @@ CUSTOM_EXTENSIONS = {
|
||||
def get_mime_type(path: pathlib.Path) -> str:
|
||||
mime_type, _ = mimetypes.guess_type(str(path))
|
||||
if mime_type:
|
||||
print(f"mime_type: {mime_type}")
|
||||
return mime_type
|
||||
ext = path.suffix.lower()
|
||||
return CUSTOM_EXTENSIONS.get(ext, "application/octet-stream")
|
||||
|
||||
@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
|
||||
|
||||
import aiohttp
|
||||
from memory.common import settings
|
||||
from memory.common.db.models.discord import DiscordMCPServer
|
||||
from memory.common.db.models.discord import MCPServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -148,7 +148,7 @@ async def register_oauth_client(
|
||||
|
||||
|
||||
async def issue_challenge(
|
||||
mcp_server: DiscordMCPServer,
|
||||
mcp_server: MCPServer,
|
||||
endpoints: OAuthEndpoints,
|
||||
) -> str:
|
||||
"""Generate OAuth challenge and store state in mcp_server object."""
|
||||
@ -160,7 +160,7 @@ async def issue_challenge(
|
||||
mcp_server.code_verifier = code_verifier # type: ignore
|
||||
|
||||
logger.info(
|
||||
f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: "
|
||||
f"Generated OAuth state for MCP server {mcp_server.mcp_server_url}: "
|
||||
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
|
||||
)
|
||||
|
||||
@ -179,7 +179,7 @@ async def issue_challenge(
|
||||
|
||||
|
||||
async def complete_oauth_flow(
|
||||
mcp_server: DiscordMCPServer, code: str, state: str
|
||||
mcp_server: MCPServer, code: str, state: str
|
||||
) -> tuple[int, str]:
|
||||
"""Complete OAuth flow by exchanging code for token.
|
||||
|
||||
@ -196,7 +196,7 @@ async def complete_oauth_flow(
|
||||
return 400, "Invalid or expired OAuth state"
|
||||
|
||||
logger.info(
|
||||
f"Found MCP server config: user={mcp_server.discord_bot_user_id}, "
|
||||
f"Found MCP server config: id={mcp_server.id}, "
|
||||
f"url={mcp_server.mcp_server_url}"
|
||||
)
|
||||
|
||||
@ -247,8 +247,8 @@ async def complete_oauth_flow(
|
||||
mcp_server.code_verifier = None # type: ignore
|
||||
|
||||
logger.info(
|
||||
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
|
||||
f"server {mcp_server.mcp_server_url}"
|
||||
f"Stored tokens for MCP server id={mcp_server.id}, "
|
||||
f"url={mcp_server.mcp_server_url}"
|
||||
)
|
||||
|
||||
return 200, "✅ Authorization successful! You can now use this MCP server."
|
||||
|
||||
@ -12,14 +12,12 @@ from contextlib import asynccontextmanager
|
||||
from typing import cast
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import DiscordMCPServer, DiscordBotUser
|
||||
from memory.common.oauth import complete_oauth_flow
|
||||
from memory.common.db.models import DiscordBotUser
|
||||
from memory.discord.collector import MessageCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Lightweight slash-command helpers for the Discord collector."""
|
||||
|
||||
from calendar import c
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Literal
|
||||
@ -14,7 +13,7 @@ from memory.discord.mcp import run_mcp_server_command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ScopeLiteral = Literal["server", "channel", "user"]
|
||||
ScopeLiteral = Literal["bot", "server", "channel", "user"]
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
@ -173,18 +172,29 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_mcp_servers",
|
||||
description="Manage MCP servers for your account",
|
||||
description="Manage MCP servers for a scope",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify (server, channel, or user)",
|
||||
action="Action to perform",
|
||||
url="MCP server URL (required for add, delete, connect, tools)",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def mcp_servers_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||
url: str | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await run_mcp_server_command(interaction, bot.user, action, url and url.strip())
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_mcp_servers,
|
||||
target_user=user,
|
||||
action=action,
|
||||
url=url and url.strip(),
|
||||
)
|
||||
|
||||
|
||||
async def _run_interaction_command(
|
||||
@ -199,7 +209,7 @@ async def _run_interaction_command(
|
||||
try:
|
||||
with make_session() as session:
|
||||
context = _build_context(session, interaction, scope, target_user)
|
||||
response = handler(context, **handler_kwargs)
|
||||
response = await handler(context, **handler_kwargs)
|
||||
session.commit()
|
||||
except CommandError as exc: # pragma: no cover - passthrough
|
||||
await interaction.response.send_message(str(exc), ephemeral=True)
|
||||
@ -435,3 +445,29 @@ def handle_summary(context: CommandContext) -> CommandResponse:
|
||||
return CommandResponse(
|
||||
content=f"No summary stored for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_servers(
|
||||
context: CommandContext,
|
||||
*,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||
url: str | None,
|
||||
) -> CommandResponse:
|
||||
"""Handle MCP server commands for a specific scope."""
|
||||
entity_type_map = {
|
||||
"server": "DiscordServer",
|
||||
"channel": "DiscordChannel",
|
||||
"user": "DiscordUser",
|
||||
}
|
||||
entity_type = entity_type_map[context.scope]
|
||||
entity_id = context.target.id
|
||||
try:
|
||||
res = await run_mcp_server_command(
|
||||
context.interaction.user, action, url, entity_type, entity_id
|
||||
)
|
||||
return CommandResponse(content=res)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Error running MCP server command: {traceback.format_exc()}")
|
||||
raise CommandError(f"Error: {exc}") from exc
|
||||
|
||||
@ -10,23 +10,27 @@ import discord
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.discord import DiscordMCPServer
|
||||
from memory.common.db.models.discord import MCPServer, MCPServerAssignment
|
||||
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_mcp_server(
|
||||
session: Session | scoped_session, user_id: int, url: str
|
||||
) -> DiscordMCPServer | None:
|
||||
return (
|
||||
session.query(DiscordMCPServer)
|
||||
session: Session | scoped_session, entity_type: str, entity_id: int, url: str
|
||||
) -> MCPServer | None:
|
||||
"""Find an MCP server assigned to an entity."""
|
||||
assignment = (
|
||||
session.query(MCPServerAssignment)
|
||||
.join(MCPServer)
|
||||
.filter(
|
||||
DiscordMCPServer.discord_bot_user_id == user_id,
|
||||
DiscordMCPServer.mcp_server_url == url,
|
||||
MCPServerAssignment.entity_type == entity_type,
|
||||
MCPServerAssignment.entity_id == entity_id,
|
||||
MCPServer.mcp_server_url == url,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return assignment and assignment.mcp_server
|
||||
|
||||
|
||||
async def call_mcp_server(
|
||||
@ -72,35 +76,39 @@ async def call_mcp_server(
|
||||
continue # Skip invalid JSON lines
|
||||
|
||||
|
||||
async def handle_mcp_list(interaction: discord.Interaction) -> str:
|
||||
async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
|
||||
"""List all MCP servers for the user."""
|
||||
with make_session() as session:
|
||||
servers = (
|
||||
session.query(DiscordMCPServer)
|
||||
assignments = (
|
||||
session.query(MCPServerAssignment)
|
||||
.join(MCPServer)
|
||||
.filter(
|
||||
DiscordMCPServer.discord_bot_user_id == interaction.user.id,
|
||||
MCPServerAssignment.entity_type == entity_type,
|
||||
MCPServerAssignment.entity_id == entity_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not servers:
|
||||
if not assignments:
|
||||
return (
|
||||
"📋 **Your MCP Servers**\n\n"
|
||||
"You don't have any MCP servers configured yet.\n"
|
||||
"Use `/memory_mcp_servers add <url>` to add one."
|
||||
)
|
||||
|
||||
def format_server(server: DiscordMCPServer) -> str:
|
||||
def format_server(assignment: MCPServerAssignment) -> str:
|
||||
server = assignment.mcp_server
|
||||
con = "🟢" if cast(str | None, server.access_token) else "🔴"
|
||||
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
|
||||
|
||||
server_list = "\n".join(format_server(s) for s in servers)
|
||||
server_list = "\n".join(format_server(a) for a in assignments)
|
||||
|
||||
return f"📋 **Your MCP Servers**\n\n{server_list}"
|
||||
|
||||
|
||||
async def handle_mcp_add(
|
||||
interaction: discord.Interaction,
|
||||
entity_type: str,
|
||||
entity_id: int,
|
||||
bot_user: discord.User | None,
|
||||
url: str,
|
||||
) -> str:
|
||||
@ -108,7 +116,7 @@ async def handle_mcp_add(
|
||||
if not bot_user:
|
||||
raise ValueError("Bot user is required")
|
||||
with make_session() as session:
|
||||
if find_mcp_server(session, bot_user.id, url):
|
||||
if find_mcp_server(session, entity_type, entity_id, url):
|
||||
return (
|
||||
f"**MCP Server Already Exists**\n\n"
|
||||
f"You already have an MCP server configured at `{url}`.\n"
|
||||
@ -116,25 +124,32 @@ async def handle_mcp_add(
|
||||
)
|
||||
|
||||
endpoints = await get_endpoints(url)
|
||||
client_id = await register_oauth_client(
|
||||
endpoints,
|
||||
url,
|
||||
f"Discord Bot - {bot_user.name} ({interaction.user.name})",
|
||||
)
|
||||
mcp_server = DiscordMCPServer(
|
||||
discord_bot_user_id=bot_user.id,
|
||||
name = f"Discord Bot - {bot_user.name} ({entity_type} {entity_id})"
|
||||
client_id = await register_oauth_client(endpoints, url, name)
|
||||
|
||||
# Create MCP server
|
||||
mcp_server = MCPServer(
|
||||
mcp_server_url=url,
|
||||
client_id=client_id,
|
||||
name=name,
|
||||
)
|
||||
session.add(mcp_server)
|
||||
session.flush()
|
||||
|
||||
assignment = MCPServerAssignment(
|
||||
mcp_server_id=mcp_server.id,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
session.add(assignment)
|
||||
session.flush()
|
||||
|
||||
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Created MCP server record: id={mcp_server.id}, "
|
||||
f"user={interaction.user.id}, url={url}"
|
||||
f"{entity_type}={entity_id}, url={url}"
|
||||
)
|
||||
|
||||
return (
|
||||
@ -146,32 +161,54 @@ async def handle_mcp_add(
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_delete(bot_user: discord.User, url: str) -> str:
|
||||
"""Delete an MCP server."""
|
||||
async def handle_mcp_delete(entity_type: str, entity_id: int, url: str) -> str:
|
||||
"""Delete an MCP server assignment."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||
if not mcp_server:
|
||||
# Find the assignment
|
||||
assignment = (
|
||||
session.query(MCPServerAssignment)
|
||||
.join(MCPServer)
|
||||
.filter(
|
||||
MCPServerAssignment.entity_type == entity_type,
|
||||
MCPServerAssignment.entity_id == entity_id,
|
||||
MCPServer.mcp_server_url == url,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not assignment:
|
||||
return (
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
f"You don't have an MCP server configured at `{url}`.\n"
|
||||
)
|
||||
session.delete(mcp_server)
|
||||
|
||||
# Delete the assignment (server will cascade delete if no other assignments exist)
|
||||
session.delete(assignment)
|
||||
|
||||
# Check if server has other assignments
|
||||
mcp_server = assignment.mcp_server
|
||||
other_assignments = (
|
||||
session.query(MCPServerAssignment)
|
||||
.filter(
|
||||
MCPServerAssignment.mcp_server_id == mcp_server.id,
|
||||
MCPServerAssignment.id != assignment.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# If no other assignments, delete the server too
|
||||
if other_assignments == 0:
|
||||
session.delete(mcp_server)
|
||||
|
||||
session.commit()
|
||||
|
||||
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
||||
|
||||
|
||||
async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
||||
async def handle_mcp_connect(entity_type: str, entity_id: int, url: str) -> str:
|
||||
"""Reconnect to an existing MCP server (redo OAuth)."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
f"You don't have an MCP server configured at `{url}`.\n"
|
||||
f"Use `/memory_mcp_servers add {url}` to add it first."
|
||||
)
|
||||
|
||||
mcp_server = find_mcp_server(session, entity_type, entity_id, url)
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
@ -184,7 +221,9 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}")
|
||||
logger.info(
|
||||
f"Regenerated OAuth challenge for {entity_type}={entity_id}, url={url}"
|
||||
)
|
||||
|
||||
return (
|
||||
f"🔄 **Reconnect to MCP Server**\n\n"
|
||||
@ -195,10 +234,10 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
|
||||
async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
|
||||
"""List tools available on an MCP server."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||
mcp_server = find_mcp_server(session, entity_type, entity_id, url)
|
||||
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
@ -265,37 +304,28 @@ async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
|
||||
|
||||
|
||||
async def run_mcp_server_command(
|
||||
interaction: discord.Interaction,
|
||||
bot_user: discord.User | None,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||
url: str | None,
|
||||
entity_type: str,
|
||||
entity_id: int,
|
||||
) -> None:
|
||||
"""Handle MCP server management commands."""
|
||||
if action not in ["list", "add", "delete", "connect", "tools"]:
|
||||
await interaction.response.send_message("❌ Invalid action", ephemeral=True)
|
||||
return
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
if action != "list" and not url:
|
||||
await interaction.response.send_message(
|
||||
"❌ URL is required for this action", ephemeral=True
|
||||
)
|
||||
return
|
||||
raise ValueError("URL is required for this action")
|
||||
if not bot_user:
|
||||
await interaction.response.send_message(
|
||||
"❌ Bot user is required", ephemeral=True
|
||||
)
|
||||
return
|
||||
raise ValueError("Bot user is required")
|
||||
|
||||
try:
|
||||
if action == "list" or not url:
|
||||
result = await handle_mcp_list(interaction)
|
||||
elif action == "add":
|
||||
result = await handle_mcp_add(interaction, bot_user, url)
|
||||
elif action == "delete":
|
||||
result = await handle_mcp_delete(bot_user, url)
|
||||
elif action == "connect":
|
||||
result = await handle_mcp_connect(bot_user, url)
|
||||
elif action == "tools":
|
||||
result = await handle_mcp_tools(bot_user, url)
|
||||
except Exception as exc:
|
||||
result = f"❌ Error: {exc}"
|
||||
await interaction.response.send_message(result, ephemeral=True)
|
||||
if action == "list" or not url:
|
||||
return await handle_mcp_list(entity_type, entity_id)
|
||||
elif action == "add":
|
||||
return await handle_mcp_add(entity_type, entity_id, bot_user, url)
|
||||
elif action == "delete":
|
||||
return await handle_mcp_delete(entity_type, entity_id, url)
|
||||
elif action == "connect":
|
||||
return await handle_mcp_connect(entity_type, entity_id, url)
|
||||
elif action == "tools":
|
||||
return await handle_mcp_tools(entity_type, entity_id, url)
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
|
||||
@ -113,8 +113,6 @@ def upsert_scheduled_message(
|
||||
.first()
|
||||
)
|
||||
naive_scheduled_time = scheduled_time.replace(tzinfo=None)
|
||||
print(f"naive_scheduled_time: {naive_scheduled_time}")
|
||||
print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}")
|
||||
if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
|
||||
prev_call.status = "cancelled" # type: ignore
|
||||
|
||||
@ -141,10 +139,8 @@ def previous_messages(
|
||||
) -> list[DiscordMessage]:
|
||||
messages = session.query(DiscordMessage)
|
||||
if user_id:
|
||||
print(f"user_id: {user_id}")
|
||||
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
||||
if channel_id:
|
||||
print(f"channel_id: {channel_id}")
|
||||
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
||||
return list(
|
||||
reversed(
|
||||
@ -294,10 +290,8 @@ def send_discord_response(
|
||||
True if sent successfully
|
||||
"""
|
||||
if channel_id is not None:
|
||||
logger.info(f"Sending message to channel {channel_id}")
|
||||
return discord.send_to_channel(bot_id, channel_id, response)
|
||||
elif user_identifier is not None:
|
||||
logger.info(f"Sending DM to {user_identifier}")
|
||||
return discord.send_dm(bot_id, user_identifier, response)
|
||||
else:
|
||||
logger.error("Neither channel_id nor user_identifier provided")
|
||||
|
||||
@ -91,6 +91,10 @@ def should_process(message: DiscordMessage) -> bool:
|
||||
):
|
||||
return False
|
||||
|
||||
if f"<@{message.recipient_user.id}>" in message.content:
|
||||
logger.info("Direct mention of the bot, processing message")
|
||||
return True
|
||||
|
||||
if message.from_user == message.recipient_user:
|
||||
logger.info("Skipping message because from_user == recipient_user")
|
||||
return False
|
||||
@ -132,6 +136,8 @@ def should_process(message: DiscordMessage) -> bool:
|
||||
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||
return False
|
||||
try:
|
||||
logger.info(f"chattiness_threshold: {message.chattiness_threshold}")
|
||||
logger.info(f"haiku desire: {res.group(1)}")
|
||||
if int(res.group(1)) < 100 - message.chattiness_threshold:
|
||||
return False
|
||||
except ValueError:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user