mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
multiple mcp servers
This commit is contained in:
parent
2d3dc06fdf
commit
8893018af1
@ -1,5 +1,6 @@
|
|||||||
# Agent Guidance
|
# Agent Guidance
|
||||||
|
|
||||||
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
|
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
|
||||||
- Treat LLM model identifiers as `<provider>/<model_name>` strings throughout the codebase.
|
|
||||||
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
|
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
|
||||||
|
- Tests should be written with @pytest.mark.parametrize where applicable and should avoid test classes
|
||||||
|
- Make sure linting errors get fixed
|
||||||
|
|||||||
@ -1,67 +0,0 @@
|
|||||||
"""discord mcp servers
|
|
||||||
|
|
||||||
Revision ID: 9b887449ea92
|
|
||||||
Revises: 1954477b25f4
|
|
||||||
Create Date: 2025-11-02 22:04:26.259323
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "9b887449ea92"
|
|
||||||
down_revision: Union[str, None] = "1954477b25f4"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"discord_mcp_servers",
|
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
|
||||||
sa.Column("client_id", sa.Text(), nullable=False),
|
|
||||||
sa.Column("state", sa.Text(), nullable=True),
|
|
||||||
sa.Column("code_verifier", sa.Text(), nullable=True),
|
|
||||||
sa.Column("access_token", sa.Text(), nullable=True),
|
|
||||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
|
||||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column(
|
|
||||||
"created_at",
|
|
||||||
sa.DateTime(timezone=True),
|
|
||||||
server_default=sa.text("now()"),
|
|
||||||
nullable=True,
|
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
"updated_at",
|
|
||||||
sa.DateTime(timezone=True),
|
|
||||||
server_default=sa.text("now()"),
|
|
||||||
nullable=True,
|
|
||||||
),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["discord_bot_user_id"],
|
|
||||||
["discord_users.id"],
|
|
||||||
),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.UniqueConstraint("state"),
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"discord_mcp_user_url_idx",
|
|
||||||
"discord_mcp_servers",
|
|
||||||
["discord_bot_user_id", "mcp_server_url"],
|
|
||||||
unique=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
|
|
||||||
op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
|
|
||||||
op.drop_table("discord_mcp_servers")
|
|
||||||
103
db/migrations/versions/20251103_154126_mcp_servers.py
Normal file
103
db/migrations/versions/20251103_154126_mcp_servers.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
"""mcp servers
|
||||||
|
|
||||||
|
Revision ID: 89861d5f1102
|
||||||
|
Revises: 1954477b25f4
|
||||||
|
Create Date: 2025-11-03 15:41:26.254854
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "89861d5f1102"
|
||||||
|
down_revision: Union[str, None] = "1954477b25f4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"mcp_servers",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
|
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
||||||
|
sa.Column("client_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"available_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("state", sa.Text(), nullable=True),
|
||||||
|
sa.Column("code_verifier", sa.Text(), nullable=True),
|
||||||
|
sa.Column("access_token", sa.Text(), nullable=True),
|
||||||
|
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||||
|
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("state"),
|
||||||
|
)
|
||||||
|
op.create_index("mcp_state_idx", "mcp_servers", ["state"], unique=False)
|
||||||
|
op.create_table(
|
||||||
|
"mcp_server_assignments",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("mcp_server_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("entity_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column("entity_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["mcp_server_id"],
|
||||||
|
["mcp_servers.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"mcp_assignment_entity_idx",
|
||||||
|
"mcp_server_assignments",
|
||||||
|
["entity_type", "entity_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"mcp_assignment_server_idx",
|
||||||
|
"mcp_server_assignments",
|
||||||
|
["mcp_server_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"mcp_assignment_unique_idx",
|
||||||
|
"mcp_server_assignments",
|
||||||
|
["mcp_server_id", "entity_type", "entity_id"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("mcp_assignment_unique_idx", table_name="mcp_server_assignments")
|
||||||
|
op.drop_index("mcp_assignment_server_idx", table_name="mcp_server_assignments")
|
||||||
|
op.drop_index("mcp_assignment_entity_idx", table_name="mcp_server_assignments")
|
||||||
|
op.drop_table("mcp_server_assignments")
|
||||||
|
op.drop_index("mcp_state_idx", table_name="mcp_servers")
|
||||||
|
op.drop_table("mcp_servers")
|
||||||
@ -14,7 +14,7 @@ from memory.common.db.models import (
|
|||||||
BookSection,
|
BookSection,
|
||||||
Chunk,
|
Chunk,
|
||||||
Comic,
|
Comic,
|
||||||
DiscordMCPServer,
|
MCPServer,
|
||||||
DiscordMessage,
|
DiscordMessage,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
@ -167,17 +167,17 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage):
|
|||||||
column_sortable_list = ["sent_at"]
|
column_sortable_list = ["sent_at"]
|
||||||
|
|
||||||
|
|
||||||
class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
|
class MCPServerAdmin(ModelView, model=MCPServer):
|
||||||
column_list = [
|
column_list = [
|
||||||
"id",
|
"id",
|
||||||
"mcp_server_url",
|
"mcp_server_url",
|
||||||
"client_id",
|
"client_id",
|
||||||
"discord_bot_user_id",
|
|
||||||
"state",
|
"state",
|
||||||
"code_verifier",
|
"code_verifier",
|
||||||
"access_token",
|
"access_token",
|
||||||
"refresh_token",
|
"refresh_token",
|
||||||
"token_expires_at",
|
"token_expires_at",
|
||||||
|
"available_tools",
|
||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
@ -186,7 +186,6 @@ class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
|
|||||||
"client_id",
|
"client_id",
|
||||||
"state",
|
"state",
|
||||||
"id",
|
"id",
|
||||||
"discord_bot_user_id",
|
|
||||||
]
|
]
|
||||||
column_sortable_list = [
|
column_sortable_list = [
|
||||||
"created_at",
|
"created_at",
|
||||||
@ -360,5 +359,5 @@ def setup_admin(admin: Admin):
|
|||||||
admin.add_view(DiscordUserAdmin)
|
admin.add_view(DiscordUserAdmin)
|
||||||
admin.add_view(DiscordServerAdmin)
|
admin.add_view(DiscordServerAdmin)
|
||||||
admin.add_view(DiscordChannelAdmin)
|
admin.add_view(DiscordChannelAdmin)
|
||||||
admin.add_view(DiscordMCPServerAdmin)
|
admin.add_view(MCPServerAdmin)
|
||||||
admin.add_view(ScheduledLLMCallAdmin)
|
admin.add_view(ScheduledLLMCallAdmin)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from memory.common import settings
|
|||||||
from memory.common.db.connection import get_session, make_session
|
from memory.common.db.connection import get_session, make_session
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
BotUser,
|
BotUser,
|
||||||
DiscordMCPServer,
|
MCPServer,
|
||||||
HumanUser,
|
HumanUser,
|
||||||
User,
|
User,
|
||||||
UserSession,
|
UserSession,
|
||||||
@ -169,9 +169,7 @@ async def oauth_callback_discord(request: Request):
|
|||||||
# Complete the OAuth flow (exchange code for token)
|
# Complete the OAuth flow (exchange code for token)
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = (
|
mcp_server = (
|
||||||
session.query(DiscordMCPServer)
|
session.query(MCPServer).filter(MCPServer.state == state).first()
|
||||||
.filter(DiscordMCPServer.state == state)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|||||||
@ -34,7 +34,8 @@ from memory.common.db.models.discord import (
|
|||||||
DiscordServer,
|
DiscordServer,
|
||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
DiscordMCPServer,
|
MCPServer,
|
||||||
|
MCPServerAssignment,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.observations import (
|
from memory.common.db.models.observations import (
|
||||||
ObservationContradiction,
|
ObservationContradiction,
|
||||||
@ -107,7 +108,8 @@ __all__ = [
|
|||||||
"DiscordServer",
|
"DiscordServer",
|
||||||
"DiscordChannel",
|
"DiscordChannel",
|
||||||
"DiscordUser",
|
"DiscordUser",
|
||||||
"DiscordMCPServer",
|
"MCPServer",
|
||||||
|
"MCPServerAssignment",
|
||||||
# Users
|
# Users
|
||||||
"User",
|
"User",
|
||||||
"HumanUser",
|
"HumanUser",
|
||||||
|
|||||||
@ -127,26 +127,22 @@ class DiscordUser(Base, MessageProcessor):
|
|||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
system_user = relationship("User", back_populates="discord_users")
|
system_user = relationship("User", back_populates="discord_users")
|
||||||
mcp_servers = relationship(
|
|
||||||
"DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||||
|
|
||||||
|
|
||||||
class DiscordMCPServer(Base):
|
class MCPServer(Base):
|
||||||
"""MCP server configuration and OAuth state for Discord users."""
|
"""MCP server configuration and OAuth state."""
|
||||||
|
|
||||||
__tablename__ = "discord_mcp_servers"
|
__tablename__ = "mcp_servers"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
discord_bot_user_id = Column(
|
|
||||||
BigInteger, ForeignKey("discord_users.id"), nullable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# MCP server info
|
# MCP server info
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
mcp_server_url = Column(Text, nullable=False)
|
mcp_server_url = Column(Text, nullable=False)
|
||||||
client_id = Column(Text, nullable=False)
|
client_id = Column(Text, nullable=False)
|
||||||
|
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
# OAuth flow state (temporary, cleared after token exchange)
|
# OAuth flow state (temporary, cleared after token exchange)
|
||||||
state = Column(Text, nullable=True, unique=True)
|
state = Column(Text, nullable=True, unique=True)
|
||||||
@ -162,9 +158,42 @@ class DiscordMCPServer(Base):
|
|||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
discord_user = relationship("DiscordUser", back_populates="mcp_servers")
|
assignments = relationship(
|
||||||
|
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (Index("mcp_state_idx", "state"),)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerAssignment(Base):
|
||||||
|
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
|
||||||
|
|
||||||
|
__tablename__ = "mcp_server_assignments"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
|
||||||
|
|
||||||
|
# Polymorphic entity reference
|
||||||
|
entity_type = Column(
|
||||||
|
Text, nullable=False
|
||||||
|
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
|
||||||
|
entity_id = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
mcp_server = relationship("MCPServer", back_populates="assignments")
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("discord_mcp_state_idx", "state"),
|
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
|
||||||
Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"),
|
Index("mcp_assignment_server_idx", "mcp_server_id"),
|
||||||
|
Index(
|
||||||
|
"mcp_assignment_unique_idx",
|
||||||
|
"mcp_server_id",
|
||||||
|
"entity_type",
|
||||||
|
"entity_id",
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -69,7 +69,6 @@ def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print("Result", result)
|
|
||||||
return result.get("success", False)
|
return result.get("success", False)
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
|
|||||||
@ -55,7 +55,6 @@ CUSTOM_EXTENSIONS = {
|
|||||||
def get_mime_type(path: pathlib.Path) -> str:
|
def get_mime_type(path: pathlib.Path) -> str:
|
||||||
mime_type, _ = mimetypes.guess_type(str(path))
|
mime_type, _ = mimetypes.guess_type(str(path))
|
||||||
if mime_type:
|
if mime_type:
|
||||||
print(f"mime_type: {mime_type}")
|
|
||||||
return mime_type
|
return mime_type
|
||||||
ext = path.suffix.lower()
|
ext = path.suffix.lower()
|
||||||
return CUSTOM_EXTENSIONS.get(ext, "application/octet-stream")
|
return CUSTOM_EXTENSIONS.get(ext, "application/octet-stream")
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.models.discord import DiscordMCPServer
|
from memory.common.db.models.discord import MCPServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ async def register_oauth_client(
|
|||||||
|
|
||||||
|
|
||||||
async def issue_challenge(
|
async def issue_challenge(
|
||||||
mcp_server: DiscordMCPServer,
|
mcp_server: MCPServer,
|
||||||
endpoints: OAuthEndpoints,
|
endpoints: OAuthEndpoints,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate OAuth challenge and store state in mcp_server object."""
|
"""Generate OAuth challenge and store state in mcp_server object."""
|
||||||
@ -160,7 +160,7 @@ async def issue_challenge(
|
|||||||
mcp_server.code_verifier = code_verifier # type: ignore
|
mcp_server.code_verifier = code_verifier # type: ignore
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: "
|
f"Generated OAuth state for MCP server {mcp_server.mcp_server_url}: "
|
||||||
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
|
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -179,7 +179,7 @@ async def issue_challenge(
|
|||||||
|
|
||||||
|
|
||||||
async def complete_oauth_flow(
|
async def complete_oauth_flow(
|
||||||
mcp_server: DiscordMCPServer, code: str, state: str
|
mcp_server: MCPServer, code: str, state: str
|
||||||
) -> tuple[int, str]:
|
) -> tuple[int, str]:
|
||||||
"""Complete OAuth flow by exchanging code for token.
|
"""Complete OAuth flow by exchanging code for token.
|
||||||
|
|
||||||
@ -196,7 +196,7 @@ async def complete_oauth_flow(
|
|||||||
return 400, "Invalid or expired OAuth state"
|
return 400, "Invalid or expired OAuth state"
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Found MCP server config: user={mcp_server.discord_bot_user_id}, "
|
f"Found MCP server config: id={mcp_server.id}, "
|
||||||
f"url={mcp_server.mcp_server_url}"
|
f"url={mcp_server.mcp_server_url}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -247,8 +247,8 @@ async def complete_oauth_flow(
|
|||||||
mcp_server.code_verifier = None # type: ignore
|
mcp_server.code_verifier = None # type: ignore
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
|
f"Stored tokens for MCP server id={mcp_server.id}, "
|
||||||
f"server {mcp_server.mcp_server_url}"
|
f"url={mcp_server.mcp_server_url}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, "✅ Authorization successful! You can now use this MCP server."
|
return 200, "✅ Authorization successful! You can now use this MCP server."
|
||||||
|
|||||||
@ -12,14 +12,12 @@ from contextlib import asynccontextmanager
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import DiscordMCPServer, DiscordBotUser
|
from memory.common.db.models import DiscordBotUser
|
||||||
from memory.common.oauth import complete_oauth_flow
|
|
||||||
from memory.discord.collector import MessageCollector
|
from memory.discord.collector import MessageCollector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
"""Lightweight slash-command helpers for the Discord collector."""
|
"""Lightweight slash-command helpers for the Discord collector."""
|
||||||
|
|
||||||
from calendar import c
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Literal
|
from typing import Callable, Literal
|
||||||
@ -14,7 +13,7 @@ from memory.discord.mcp import run_mcp_server_command
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ScopeLiteral = Literal["server", "channel", "user"]
|
ScopeLiteral = Literal["bot", "server", "channel", "user"]
|
||||||
|
|
||||||
|
|
||||||
class CommandError(Exception):
|
class CommandError(Exception):
|
||||||
@ -173,18 +172,29 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
|
|
||||||
@tree.command(
|
@tree.command(
|
||||||
name=f"{name}_mcp_servers",
|
name=f"{name}_mcp_servers",
|
||||||
description="Manage MCP servers for your account",
|
description="Manage MCP servers for a scope",
|
||||||
)
|
)
|
||||||
@discord.app_commands.describe(
|
@discord.app_commands.describe(
|
||||||
|
scope="Which configuration to modify (server, channel, or user)",
|
||||||
action="Action to perform",
|
action="Action to perform",
|
||||||
url="MCP server URL (required for add, delete, connect, tools)",
|
url="MCP server URL (required for add, delete, connect, tools)",
|
||||||
|
user="Target user when the scope is 'user'",
|
||||||
)
|
)
|
||||||
async def mcp_servers_command(
|
async def mcp_servers_command(
|
||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
|
scope: ScopeLiteral,
|
||||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
|
user: discord.User | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await run_mcp_server_command(interaction, bot.user, action, url and url.strip())
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_mcp_servers,
|
||||||
|
target_user=user,
|
||||||
|
action=action,
|
||||||
|
url=url and url.strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _run_interaction_command(
|
async def _run_interaction_command(
|
||||||
@ -199,7 +209,7 @@ async def _run_interaction_command(
|
|||||||
try:
|
try:
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
context = _build_context(session, interaction, scope, target_user)
|
context = _build_context(session, interaction, scope, target_user)
|
||||||
response = handler(context, **handler_kwargs)
|
response = await handler(context, **handler_kwargs)
|
||||||
session.commit()
|
session.commit()
|
||||||
except CommandError as exc: # pragma: no cover - passthrough
|
except CommandError as exc: # pragma: no cover - passthrough
|
||||||
await interaction.response.send_message(str(exc), ephemeral=True)
|
await interaction.response.send_message(str(exc), ephemeral=True)
|
||||||
@ -435,3 +445,29 @@ def handle_summary(context: CommandContext) -> CommandResponse:
|
|||||||
return CommandResponse(
|
return CommandResponse(
|
||||||
content=f"No summary stored for {context.display_name}.",
|
content=f"No summary stored for {context.display_name}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_servers(
|
||||||
|
context: CommandContext,
|
||||||
|
*,
|
||||||
|
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||||
|
url: str | None,
|
||||||
|
) -> CommandResponse:
|
||||||
|
"""Handle MCP server commands for a specific scope."""
|
||||||
|
entity_type_map = {
|
||||||
|
"server": "DiscordServer",
|
||||||
|
"channel": "DiscordChannel",
|
||||||
|
"user": "DiscordUser",
|
||||||
|
}
|
||||||
|
entity_type = entity_type_map[context.scope]
|
||||||
|
entity_id = context.target.id
|
||||||
|
try:
|
||||||
|
res = await run_mcp_server_command(
|
||||||
|
context.interaction.user, action, url, entity_type, entity_id
|
||||||
|
)
|
||||||
|
return CommandResponse(content=res)
|
||||||
|
except Exception as exc:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"Error running MCP server command: {traceback.format_exc()}")
|
||||||
|
raise CommandError(f"Error: {exc}") from exc
|
||||||
|
|||||||
@ -10,23 +10,27 @@ import discord
|
|||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models.discord import DiscordMCPServer
|
from memory.common.db.models.discord import MCPServer, MCPServerAssignment
|
||||||
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def find_mcp_server(
|
def find_mcp_server(
|
||||||
session: Session | scoped_session, user_id: int, url: str
|
session: Session | scoped_session, entity_type: str, entity_id: int, url: str
|
||||||
) -> DiscordMCPServer | None:
|
) -> MCPServer | None:
|
||||||
return (
|
"""Find an MCP server assigned to an entity."""
|
||||||
session.query(DiscordMCPServer)
|
assignment = (
|
||||||
|
session.query(MCPServerAssignment)
|
||||||
|
.join(MCPServer)
|
||||||
.filter(
|
.filter(
|
||||||
DiscordMCPServer.discord_bot_user_id == user_id,
|
MCPServerAssignment.entity_type == entity_type,
|
||||||
DiscordMCPServer.mcp_server_url == url,
|
MCPServerAssignment.entity_id == entity_id,
|
||||||
|
MCPServer.mcp_server_url == url,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
return assignment and assignment.mcp_server
|
||||||
|
|
||||||
|
|
||||||
async def call_mcp_server(
|
async def call_mcp_server(
|
||||||
@ -72,35 +76,39 @@ async def call_mcp_server(
|
|||||||
continue # Skip invalid JSON lines
|
continue # Skip invalid JSON lines
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_list(interaction: discord.Interaction) -> str:
|
async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
|
||||||
"""List all MCP servers for the user."""
|
"""List all MCP servers for the user."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
servers = (
|
assignments = (
|
||||||
session.query(DiscordMCPServer)
|
session.query(MCPServerAssignment)
|
||||||
|
.join(MCPServer)
|
||||||
.filter(
|
.filter(
|
||||||
DiscordMCPServer.discord_bot_user_id == interaction.user.id,
|
MCPServerAssignment.entity_type == entity_type,
|
||||||
|
MCPServerAssignment.entity_id == entity_id,
|
||||||
)
|
)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not servers:
|
if not assignments:
|
||||||
return (
|
return (
|
||||||
"📋 **Your MCP Servers**\n\n"
|
"📋 **Your MCP Servers**\n\n"
|
||||||
"You don't have any MCP servers configured yet.\n"
|
"You don't have any MCP servers configured yet.\n"
|
||||||
"Use `/memory_mcp_servers add <url>` to add one."
|
"Use `/memory_mcp_servers add <url>` to add one."
|
||||||
)
|
)
|
||||||
|
|
||||||
def format_server(server: DiscordMCPServer) -> str:
|
def format_server(assignment: MCPServerAssignment) -> str:
|
||||||
|
server = assignment.mcp_server
|
||||||
con = "🟢" if cast(str | None, server.access_token) else "🔴"
|
con = "🟢" if cast(str | None, server.access_token) else "🔴"
|
||||||
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
|
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
|
||||||
|
|
||||||
server_list = "\n".join(format_server(s) for s in servers)
|
server_list = "\n".join(format_server(a) for a in assignments)
|
||||||
|
|
||||||
return f"📋 **Your MCP Servers**\n\n{server_list}"
|
return f"📋 **Your MCP Servers**\n\n{server_list}"
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_add(
|
async def handle_mcp_add(
|
||||||
interaction: discord.Interaction,
|
entity_type: str,
|
||||||
|
entity_id: int,
|
||||||
bot_user: discord.User | None,
|
bot_user: discord.User | None,
|
||||||
url: str,
|
url: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -108,7 +116,7 @@ async def handle_mcp_add(
|
|||||||
if not bot_user:
|
if not bot_user:
|
||||||
raise ValueError("Bot user is required")
|
raise ValueError("Bot user is required")
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
if find_mcp_server(session, bot_user.id, url):
|
if find_mcp_server(session, entity_type, entity_id, url):
|
||||||
return (
|
return (
|
||||||
f"**MCP Server Already Exists**\n\n"
|
f"**MCP Server Already Exists**\n\n"
|
||||||
f"You already have an MCP server configured at `{url}`.\n"
|
f"You already have an MCP server configured at `{url}`.\n"
|
||||||
@ -116,25 +124,32 @@ async def handle_mcp_add(
|
|||||||
)
|
)
|
||||||
|
|
||||||
endpoints = await get_endpoints(url)
|
endpoints = await get_endpoints(url)
|
||||||
client_id = await register_oauth_client(
|
name = f"Discord Bot - {bot_user.name} ({entity_type} {entity_id})"
|
||||||
endpoints,
|
client_id = await register_oauth_client(endpoints, url, name)
|
||||||
url,
|
|
||||||
f"Discord Bot - {bot_user.name} ({interaction.user.name})",
|
# Create MCP server
|
||||||
)
|
mcp_server = MCPServer(
|
||||||
mcp_server = DiscordMCPServer(
|
|
||||||
discord_bot_user_id=bot_user.id,
|
|
||||||
mcp_server_url=url,
|
mcp_server_url=url,
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
|
name=name,
|
||||||
)
|
)
|
||||||
session.add(mcp_server)
|
session.add(mcp_server)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
|
assignment = MCPServerAssignment(
|
||||||
|
mcp_server_id=mcp_server.id,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
)
|
||||||
|
session.add(assignment)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
auth_url = await issue_challenge(mcp_server, endpoints)
|
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created MCP server record: id={mcp_server.id}, "
|
f"Created MCP server record: id={mcp_server.id}, "
|
||||||
f"user={interaction.user.id}, url={url}"
|
f"{entity_type}={entity_id}, url={url}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -146,32 +161,54 @@ async def handle_mcp_add(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_delete(bot_user: discord.User, url: str) -> str:
|
async def handle_mcp_delete(entity_type: str, entity_id: int, url: str) -> str:
|
||||||
"""Delete an MCP server."""
|
"""Delete an MCP server assignment."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
# Find the assignment
|
||||||
if not mcp_server:
|
assignment = (
|
||||||
|
session.query(MCPServerAssignment)
|
||||||
|
.join(MCPServer)
|
||||||
|
.filter(
|
||||||
|
MCPServerAssignment.entity_type == entity_type,
|
||||||
|
MCPServerAssignment.entity_id == entity_id,
|
||||||
|
MCPServer.mcp_server_url == url,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not assignment:
|
||||||
return (
|
return (
|
||||||
f"**MCP Server Not Found**\n\n"
|
f"**MCP Server Not Found**\n\n"
|
||||||
f"You don't have an MCP server configured at `{url}`.\n"
|
f"You don't have an MCP server configured at `{url}`.\n"
|
||||||
)
|
)
|
||||||
session.delete(mcp_server)
|
|
||||||
|
# Delete the assignment (server will cascade delete if no other assignments exist)
|
||||||
|
session.delete(assignment)
|
||||||
|
|
||||||
|
# Check if server has other assignments
|
||||||
|
mcp_server = assignment.mcp_server
|
||||||
|
other_assignments = (
|
||||||
|
session.query(MCPServerAssignment)
|
||||||
|
.filter(
|
||||||
|
MCPServerAssignment.mcp_server_id == mcp_server.id,
|
||||||
|
MCPServerAssignment.id != assignment.id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no other assignments, delete the server too
|
||||||
|
if other_assignments == 0:
|
||||||
|
session.delete(mcp_server)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
async def handle_mcp_connect(entity_type: str, entity_id: int, url: str) -> str:
|
||||||
"""Reconnect to an existing MCP server (redo OAuth)."""
|
"""Reconnect to an existing MCP server (redo OAuth)."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
mcp_server = find_mcp_server(session, entity_type, entity_id, url)
|
||||||
if not mcp_server:
|
|
||||||
raise ValueError(
|
|
||||||
f"**MCP Server Not Found**\n\n"
|
|
||||||
f"You don't have an MCP server configured at `{url}`.\n"
|
|
||||||
f"Use `/memory_mcp_servers add {url}` to add it first."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"**MCP Server Not Found**\n\n"
|
f"**MCP Server Not Found**\n\n"
|
||||||
@ -184,7 +221,9 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}")
|
logger.info(
|
||||||
|
f"Regenerated OAuth challenge for {entity_type}={entity_id}, url={url}"
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"🔄 **Reconnect to MCP Server**\n\n"
|
f"🔄 **Reconnect to MCP Server**\n\n"
|
||||||
@ -195,10 +234,10 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
|
async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
|
||||||
"""List tools available on an MCP server."""
|
"""List tools available on an MCP server."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
mcp_server = find_mcp_server(session, entity_type, entity_id, url)
|
||||||
|
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -265,37 +304,28 @@ async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def run_mcp_server_command(
|
async def run_mcp_server_command(
|
||||||
interaction: discord.Interaction,
|
|
||||||
bot_user: discord.User | None,
|
bot_user: discord.User | None,
|
||||||
action: Literal["list", "add", "delete", "connect", "tools"],
|
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||||
url: str | None,
|
url: str | None,
|
||||||
|
entity_type: str,
|
||||||
|
entity_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle MCP server management commands."""
|
"""Handle MCP server management commands."""
|
||||||
if action not in ["list", "add", "delete", "connect", "tools"]:
|
if action not in ["list", "add", "delete", "connect", "tools"]:
|
||||||
await interaction.response.send_message("❌ Invalid action", ephemeral=True)
|
raise ValueError(f"Invalid action: {action}")
|
||||||
return
|
|
||||||
if action != "list" and not url:
|
if action != "list" and not url:
|
||||||
await interaction.response.send_message(
|
raise ValueError("URL is required for this action")
|
||||||
"❌ URL is required for this action", ephemeral=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if not bot_user:
|
if not bot_user:
|
||||||
await interaction.response.send_message(
|
raise ValueError("Bot user is required")
|
||||||
"❌ Bot user is required", ephemeral=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
if action == "list" or not url:
|
||||||
if action == "list" or not url:
|
return await handle_mcp_list(entity_type, entity_id)
|
||||||
result = await handle_mcp_list(interaction)
|
elif action == "add":
|
||||||
elif action == "add":
|
return await handle_mcp_add(entity_type, entity_id, bot_user, url)
|
||||||
result = await handle_mcp_add(interaction, bot_user, url)
|
elif action == "delete":
|
||||||
elif action == "delete":
|
return await handle_mcp_delete(entity_type, entity_id, url)
|
||||||
result = await handle_mcp_delete(bot_user, url)
|
elif action == "connect":
|
||||||
elif action == "connect":
|
return await handle_mcp_connect(entity_type, entity_id, url)
|
||||||
result = await handle_mcp_connect(bot_user, url)
|
elif action == "tools":
|
||||||
elif action == "tools":
|
return await handle_mcp_tools(entity_type, entity_id, url)
|
||||||
result = await handle_mcp_tools(bot_user, url)
|
raise ValueError(f"Invalid action: {action}")
|
||||||
except Exception as exc:
|
|
||||||
result = f"❌ Error: {exc}"
|
|
||||||
await interaction.response.send_message(result, ephemeral=True)
|
|
||||||
|
|||||||
@ -113,8 +113,6 @@ def upsert_scheduled_message(
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
naive_scheduled_time = scheduled_time.replace(tzinfo=None)
|
naive_scheduled_time = scheduled_time.replace(tzinfo=None)
|
||||||
print(f"naive_scheduled_time: {naive_scheduled_time}")
|
|
||||||
print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}")
|
|
||||||
if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
|
if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
|
||||||
prev_call.status = "cancelled" # type: ignore
|
prev_call.status = "cancelled" # type: ignore
|
||||||
|
|
||||||
@ -141,10 +139,8 @@ def previous_messages(
|
|||||||
) -> list[DiscordMessage]:
|
) -> list[DiscordMessage]:
|
||||||
messages = session.query(DiscordMessage)
|
messages = session.query(DiscordMessage)
|
||||||
if user_id:
|
if user_id:
|
||||||
print(f"user_id: {user_id}")
|
|
||||||
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
||||||
if channel_id:
|
if channel_id:
|
||||||
print(f"channel_id: {channel_id}")
|
|
||||||
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
||||||
return list(
|
return list(
|
||||||
reversed(
|
reversed(
|
||||||
@ -294,10 +290,8 @@ def send_discord_response(
|
|||||||
True if sent successfully
|
True if sent successfully
|
||||||
"""
|
"""
|
||||||
if channel_id is not None:
|
if channel_id is not None:
|
||||||
logger.info(f"Sending message to channel {channel_id}")
|
|
||||||
return discord.send_to_channel(bot_id, channel_id, response)
|
return discord.send_to_channel(bot_id, channel_id, response)
|
||||||
elif user_identifier is not None:
|
elif user_identifier is not None:
|
||||||
logger.info(f"Sending DM to {user_identifier}")
|
|
||||||
return discord.send_dm(bot_id, user_identifier, response)
|
return discord.send_dm(bot_id, user_identifier, response)
|
||||||
else:
|
else:
|
||||||
logger.error("Neither channel_id nor user_identifier provided")
|
logger.error("Neither channel_id nor user_identifier provided")
|
||||||
|
|||||||
@ -91,6 +91,10 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if f"<@{message.recipient_user.id}>" in message.content:
|
||||||
|
logger.info("Direct mention of the bot, processing message")
|
||||||
|
return True
|
||||||
|
|
||||||
if message.from_user == message.recipient_user:
|
if message.from_user == message.recipient_user:
|
||||||
logger.info("Skipping message because from_user == recipient_user")
|
logger.info("Skipping message because from_user == recipient_user")
|
||||||
return False
|
return False
|
||||||
@ -132,6 +136,8 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
if not (res := re.search(r"<number>(.*)</number>", response)):
|
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"chattiness_threshold: {message.chattiness_threshold}")
|
||||||
|
logger.info(f"haiku desire: {res.group(1)}")
|
||||||
if int(res.group(1)) < 100 - message.chattiness_threshold:
|
if int(res.group(1)) < 100 - message.chattiness_threshold:
|
||||||
return False
|
return False
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user