mirror of
https://github.com/mruwnik/memory.git
synced 2025-12-16 17:11:19 +01:00
Compare commits
5 Commits
2d3dc06fdf
...
470061bd43
| Author | SHA1 | Date | |
|---|---|---|---|
| 470061bd43 | |||
|
|
ad6510bd17 | ||
| 56c0df9761 | |||
| b568222e88 | |||
| 8893018af1 |
@ -1,5 +1,6 @@
|
|||||||
# Agent Guidance
|
# Agent Guidance
|
||||||
|
|
||||||
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
|
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
|
||||||
- Treat LLM model identifiers as `<provider>/<model_name>` strings throughout the codebase.
|
|
||||||
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
|
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
|
||||||
|
- Tests should be written with @pytest.mark.parametrize where applicable and should avoid test classes
|
||||||
|
- Make sure linting errors get fixed
|
||||||
|
|||||||
@ -1,67 +0,0 @@
|
|||||||
"""discord mcp servers
|
|
||||||
|
|
||||||
Revision ID: 9b887449ea92
|
|
||||||
Revises: 1954477b25f4
|
|
||||||
Create Date: 2025-11-02 22:04:26.259323
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "9b887449ea92"
|
|
||||||
down_revision: Union[str, None] = "1954477b25f4"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"discord_mcp_servers",
|
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
|
||||||
sa.Column("client_id", sa.Text(), nullable=False),
|
|
||||||
sa.Column("state", sa.Text(), nullable=True),
|
|
||||||
sa.Column("code_verifier", sa.Text(), nullable=True),
|
|
||||||
sa.Column("access_token", sa.Text(), nullable=True),
|
|
||||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
|
||||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column(
|
|
||||||
"created_at",
|
|
||||||
sa.DateTime(timezone=True),
|
|
||||||
server_default=sa.text("now()"),
|
|
||||||
nullable=True,
|
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
"updated_at",
|
|
||||||
sa.DateTime(timezone=True),
|
|
||||||
server_default=sa.text("now()"),
|
|
||||||
nullable=True,
|
|
||||||
),
|
|
||||||
sa.ForeignKeyConstraint(
|
|
||||||
["discord_bot_user_id"],
|
|
||||||
["discord_users.id"],
|
|
||||||
),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.UniqueConstraint("state"),
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"discord_mcp_user_url_idx",
|
|
||||||
"discord_mcp_servers",
|
|
||||||
["discord_bot_user_id", "mcp_server_url"],
|
|
||||||
unique=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
|
|
||||||
op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
|
|
||||||
op.drop_table("discord_mcp_servers")
|
|
||||||
103
db/migrations/versions/20251103_154126_mcp_servers.py
Normal file
103
db/migrations/versions/20251103_154126_mcp_servers.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
"""mcp servers
|
||||||
|
|
||||||
|
Revision ID: 89861d5f1102
|
||||||
|
Revises: 1954477b25f4
|
||||||
|
Create Date: 2025-11-03 15:41:26.254854
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "89861d5f1102"
|
||||||
|
down_revision: Union[str, None] = "1954477b25f4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"mcp_servers",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
|
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
||||||
|
sa.Column("client_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"available_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("state", sa.Text(), nullable=True),
|
||||||
|
sa.Column("code_verifier", sa.Text(), nullable=True),
|
||||||
|
sa.Column("access_token", sa.Text(), nullable=True),
|
||||||
|
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||||
|
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("state"),
|
||||||
|
)
|
||||||
|
op.create_index("mcp_state_idx", "mcp_servers", ["state"], unique=False)
|
||||||
|
op.create_table(
|
||||||
|
"mcp_server_assignments",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("mcp_server_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("entity_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column("entity_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["mcp_server_id"],
|
||||||
|
["mcp_servers.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"mcp_assignment_entity_idx",
|
||||||
|
"mcp_server_assignments",
|
||||||
|
["entity_type", "entity_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"mcp_assignment_server_idx",
|
||||||
|
"mcp_server_assignments",
|
||||||
|
["mcp_server_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"mcp_assignment_unique_idx",
|
||||||
|
"mcp_server_assignments",
|
||||||
|
["mcp_server_id", "entity_type", "entity_id"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("mcp_assignment_unique_idx", table_name="mcp_server_assignments")
|
||||||
|
op.drop_index("mcp_assignment_server_idx", table_name="mcp_server_assignments")
|
||||||
|
op.drop_index("mcp_assignment_entity_idx", table_name="mcp_server_assignments")
|
||||||
|
op.drop_table("mcp_server_assignments")
|
||||||
|
op.drop_index("mcp_state_idx", table_name="mcp_servers")
|
||||||
|
op.drop_table("mcp_servers")
|
||||||
@ -128,7 +128,7 @@ services:
|
|||||||
- /var/tmp
|
- /var/tmp
|
||||||
- /qdrant/snapshots:rw
|
- /qdrant/snapshots:rw
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: [ "CMD", "wget", "-q", "-T", "2", "-O", "-", "localhost:6333/ready" ]
|
test: ["CMD", "bash", "-c", "exec 3<>/dev/tcp/localhost/6333 && echo -e 'GET /readyz HTTP/1.0\\r\\n\\r\\n' >&3 && timeout 2 cat <&3 | grep -q ready"]
|
||||||
interval: 15s
|
interval: 15s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
|
|||||||
@ -8,6 +8,7 @@ RUN apt-get update && apt-get install -y \
|
|||||||
gcc \
|
gcc \
|
||||||
g++ \
|
g++ \
|
||||||
python3-dev \
|
python3-dev \
|
||||||
|
curl \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy and install Python requirements
|
# Copy and install Python requirements
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from memory.common.db.models import (
|
|||||||
BookSection,
|
BookSection,
|
||||||
Chunk,
|
Chunk,
|
||||||
Comic,
|
Comic,
|
||||||
DiscordMCPServer,
|
MCPServer,
|
||||||
DiscordMessage,
|
DiscordMessage,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
@ -167,17 +167,17 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage):
|
|||||||
column_sortable_list = ["sent_at"]
|
column_sortable_list = ["sent_at"]
|
||||||
|
|
||||||
|
|
||||||
class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
|
class MCPServerAdmin(ModelView, model=MCPServer):
|
||||||
column_list = [
|
column_list = [
|
||||||
"id",
|
"id",
|
||||||
"mcp_server_url",
|
"mcp_server_url",
|
||||||
"client_id",
|
"client_id",
|
||||||
"discord_bot_user_id",
|
|
||||||
"state",
|
"state",
|
||||||
"code_verifier",
|
"code_verifier",
|
||||||
"access_token",
|
"access_token",
|
||||||
"refresh_token",
|
"refresh_token",
|
||||||
"token_expires_at",
|
"token_expires_at",
|
||||||
|
"available_tools",
|
||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
]
|
]
|
||||||
@ -186,7 +186,6 @@ class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
|
|||||||
"client_id",
|
"client_id",
|
||||||
"state",
|
"state",
|
||||||
"id",
|
"id",
|
||||||
"discord_bot_user_id",
|
|
||||||
]
|
]
|
||||||
column_sortable_list = [
|
column_sortable_list = [
|
||||||
"created_at",
|
"created_at",
|
||||||
@ -360,5 +359,5 @@ def setup_admin(admin: Admin):
|
|||||||
admin.add_view(DiscordUserAdmin)
|
admin.add_view(DiscordUserAdmin)
|
||||||
admin.add_view(DiscordServerAdmin)
|
admin.add_view(DiscordServerAdmin)
|
||||||
admin.add_view(DiscordChannelAdmin)
|
admin.add_view(DiscordChannelAdmin)
|
||||||
admin.add_view(DiscordMCPServerAdmin)
|
admin.add_view(MCPServerAdmin)
|
||||||
admin.add_view(ScheduledLLMCallAdmin)
|
admin.add_view(ScheduledLLMCallAdmin)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from memory.common import settings
|
|||||||
from memory.common.db.connection import get_session, make_session
|
from memory.common.db.connection import get_session, make_session
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
BotUser,
|
BotUser,
|
||||||
DiscordMCPServer,
|
MCPServer,
|
||||||
HumanUser,
|
HumanUser,
|
||||||
User,
|
User,
|
||||||
UserSession,
|
UserSession,
|
||||||
@ -169,9 +169,7 @@ async def oauth_callback_discord(request: Request):
|
|||||||
# Complete the OAuth flow (exchange code for token)
|
# Complete the OAuth flow (exchange code for token)
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = (
|
mcp_server = (
|
||||||
session.query(DiscordMCPServer)
|
session.query(MCPServer).filter(MCPServer.state == state).first()
|
||||||
.filter(DiscordMCPServer.state == state)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|||||||
@ -34,7 +34,10 @@ from memory.common.db.models.discord import (
|
|||||||
DiscordServer,
|
DiscordServer,
|
||||||
DiscordChannel,
|
DiscordChannel,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
DiscordMCPServer,
|
)
|
||||||
|
from memory.common.db.models.mcp import (
|
||||||
|
MCPServer,
|
||||||
|
MCPServerAssignment,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.observations import (
|
from memory.common.db.models.observations import (
|
||||||
ObservationContradiction,
|
ObservationContradiction,
|
||||||
@ -107,7 +110,8 @@ __all__ = [
|
|||||||
"DiscordServer",
|
"DiscordServer",
|
||||||
"DiscordChannel",
|
"DiscordChannel",
|
||||||
"DiscordUser",
|
"DiscordUser",
|
||||||
"DiscordMCPServer",
|
"MCPServer",
|
||||||
|
"MCPServerAssignment",
|
||||||
# Users
|
# Users
|
||||||
"User",
|
"User",
|
||||||
"HumanUser",
|
"HumanUser",
|
||||||
|
|||||||
@ -51,21 +51,40 @@ class MessageProcessor:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_xml(self) -> str:
|
@property
|
||||||
return (
|
def entity_type(self) -> str:
|
||||||
textwrap.dedent("""
|
return self.__class__.__tablename__[8:-1] # type: ignore
|
||||||
<{type}>
|
|
||||||
<name>{name}</name>
|
def to_xml(self, *fields: str) -> str:
|
||||||
<summary>{summary}</summary>
|
def indent(key: str, text: str) -> str:
|
||||||
</{type}>
|
res = textwrap.dedent("""
|
||||||
""")
|
<{key}>
|
||||||
.format(
|
{text}
|
||||||
type=self.__class__.__tablename__[8:], # type: ignore
|
</{key}>
|
||||||
name=getattr(self, "name", None) or getattr(self, "username", None),
|
""").format(key=key, text=textwrap.indent(text, " "))
|
||||||
summary=self.summary,
|
return res.strip()
|
||||||
)
|
|
||||||
.strip()
|
vals = []
|
||||||
)
|
if "name" in fields:
|
||||||
|
vals.append(indent("name", self.name))
|
||||||
|
if "system_prompt" in fields:
|
||||||
|
vals.append(indent("system_prompt", self.system_prompt or ""))
|
||||||
|
if "summary" in fields:
|
||||||
|
vals.append(indent("summary", self.summary or ""))
|
||||||
|
if "mcp_servers" in fields:
|
||||||
|
servers = [s.as_xml() for s in self.mcp_servers]
|
||||||
|
vals.append(indent("mcp_servers", "\n".join(servers)))
|
||||||
|
|
||||||
|
return indent(self.entity_type, "\n".join(vals)) # type: ignore
|
||||||
|
|
||||||
|
def xml_prompt(self) -> str:
|
||||||
|
return self.to_xml("name", "system_prompt") if self.system_prompt else ""
|
||||||
|
|
||||||
|
def xml_summary(self) -> str:
|
||||||
|
return self.to_xml("name", "summary")
|
||||||
|
|
||||||
|
def xml_mcp_servers(self) -> str:
|
||||||
|
return self.to_xml("mcp_servers")
|
||||||
|
|
||||||
|
|
||||||
class DiscordServer(Base, MessageProcessor):
|
class DiscordServer(Base, MessageProcessor):
|
||||||
@ -127,44 +146,9 @@ class DiscordUser(Base, MessageProcessor):
|
|||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
system_user = relationship("User", back_populates="discord_users")
|
system_user = relationship("User", back_populates="discord_users")
|
||||||
mcp_servers = relationship(
|
|
||||||
"DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||||
|
|
||||||
|
@property
|
||||||
class DiscordMCPServer(Base):
|
def name(self) -> str:
|
||||||
"""MCP server configuration and OAuth state for Discord users."""
|
return self.username
|
||||||
|
|
||||||
__tablename__ = "discord_mcp_servers"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
discord_bot_user_id = Column(
|
|
||||||
BigInteger, ForeignKey("discord_users.id"), nullable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# MCP server info
|
|
||||||
mcp_server_url = Column(Text, nullable=False)
|
|
||||||
client_id = Column(Text, nullable=False)
|
|
||||||
|
|
||||||
# OAuth flow state (temporary, cleared after token exchange)
|
|
||||||
state = Column(Text, nullable=True, unique=True)
|
|
||||||
code_verifier = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
# OAuth tokens (set after successful authorization)
|
|
||||||
access_token = Column(Text, nullable=True)
|
|
||||||
refresh_token = Column(Text, nullable=True)
|
|
||||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
|
||||||
|
|
||||||
# Timestamps
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
discord_user = relationship("DiscordUser", back_populates="mcp_servers")
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("discord_mcp_state_idx", "state"),
|
|
||||||
Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"),
|
|
||||||
)
|
|
||||||
|
|||||||
108
src/memory/common/db/models/mcp.py
Normal file
108
src/memory/common/db/models/mcp.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import textwrap
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
ARRAY,
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
Text,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServer(Base):
|
||||||
|
"""MCP server configuration and OAuth state."""
|
||||||
|
|
||||||
|
__tablename__ = "mcp_servers"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
|
||||||
|
# MCP server info
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
mcp_server_url = Column(Text, nullable=False)
|
||||||
|
client_id = Column(Text, nullable=False)
|
||||||
|
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
|
# OAuth flow state (temporary, cleared after token exchange)
|
||||||
|
state = Column(Text, nullable=True, unique=True)
|
||||||
|
code_verifier = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# OAuth tokens (set after successful authorization)
|
||||||
|
access_token = Column(Text, nullable=True)
|
||||||
|
refresh_token = Column(Text, nullable=True)
|
||||||
|
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
assignments = relationship(
|
||||||
|
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (Index("mcp_state_idx", "state"),)
|
||||||
|
|
||||||
|
def as_xml(self) -> str:
|
||||||
|
tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip()
|
||||||
|
return textwrap.dedent("""
|
||||||
|
<mcp_server>
|
||||||
|
<name>
|
||||||
|
{name}
|
||||||
|
</name>
|
||||||
|
<mcp_server_url>
|
||||||
|
{mcp_server_url}
|
||||||
|
</mcp_server_url>
|
||||||
|
<client_id>
|
||||||
|
{client_id}
|
||||||
|
</client_id>
|
||||||
|
<available_tools>
|
||||||
|
{available_tools}
|
||||||
|
</available_tools>
|
||||||
|
</mcp_server>
|
||||||
|
""").format(
|
||||||
|
name=self.name,
|
||||||
|
mcp_server_url=self.mcp_server_url,
|
||||||
|
client_id=self.client_id,
|
||||||
|
available_tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerAssignment(Base):
|
||||||
|
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
|
||||||
|
|
||||||
|
__tablename__ = "mcp_server_assignments"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
|
||||||
|
|
||||||
|
# Polymorphic entity reference
|
||||||
|
entity_type = Column(
|
||||||
|
Text, nullable=False
|
||||||
|
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
|
||||||
|
entity_id = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
mcp_server = relationship("MCPServer", back_populates="assignments")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
|
||||||
|
Index("mcp_assignment_server_idx", "mcp_server_id"),
|
||||||
|
Index(
|
||||||
|
"mcp_assignment_unique_idx",
|
||||||
|
"mcp_server_id",
|
||||||
|
"entity_type",
|
||||||
|
"entity_id",
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
@ -36,6 +36,10 @@ from memory.common.db.models.source_item import (
|
|||||||
clean_filename,
|
clean_filename,
|
||||||
chunk_mixed,
|
chunk_mixed,
|
||||||
)
|
)
|
||||||
|
from memory.common.db.models.mcp import (
|
||||||
|
MCPServer,
|
||||||
|
MCPServerAssignment,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MailMessagePayload(SourceItemPayload):
|
class MailMessagePayload(SourceItemPayload):
|
||||||
@ -345,11 +349,12 @@ class DiscordMessage(SourceItem):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def system_prompt(self) -> str:
|
def system_prompt(self) -> str:
|
||||||
return (
|
prompts = [
|
||||||
(self.from_user and self.from_user.system_prompt)
|
(self.from_user and self.from_user.system_prompt),
|
||||||
or (self.channel and self.channel.system_prompt)
|
(self.channel and self.channel.system_prompt),
|
||||||
or (self.server and self.server.system_prompt)
|
(self.server and self.server.system_prompt),
|
||||||
)
|
]
|
||||||
|
return "\n\n".join(p for p in prompts if p)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chattiness_threshold(self) -> int:
|
def chattiness_threshold(self) -> int:
|
||||||
@ -391,6 +396,29 @@ class DiscordMessage(SourceItem):
|
|||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
def get_mcp_servers(self, session) -> list[MCPServer]:
|
||||||
|
entity_ids = list(
|
||||||
|
filter(
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
self.recipient_user.id,
|
||||||
|
self.from_user.id,
|
||||||
|
self.channel.id,
|
||||||
|
self.server.id,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not entity_ids:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (
|
||||||
|
session.query(MCPServer)
|
||||||
|
.filter(
|
||||||
|
MCPServerAssignment.entity_id.in_(entity_ids),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
"polymorphic_identity": "discord_message",
|
"polymorphic_identity": "discord_message",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -69,7 +69,6 @@ def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print("Result", result)
|
|
||||||
return result.get("success", False)
|
return result.get("success", False)
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
|
|||||||
@ -55,7 +55,6 @@ CUSTOM_EXTENSIONS = {
|
|||||||
def get_mime_type(path: pathlib.Path) -> str:
|
def get_mime_type(path: pathlib.Path) -> str:
|
||||||
mime_type, _ = mimetypes.guess_type(str(path))
|
mime_type, _ = mimetypes.guess_type(str(path))
|
||||||
if mime_type:
|
if mime_type:
|
||||||
print(f"mime_type: {mime_type}")
|
|
||||||
return mime_type
|
return mime_type
|
||||||
ext = path.suffix.lower()
|
ext = path.suffix.lower()
|
||||||
return CUSTOM_EXTENSIONS.get(ext, "application/octet-stream")
|
return CUSTOM_EXTENSIONS.get(ext, "application/octet-stream")
|
||||||
|
|||||||
@ -134,11 +134,15 @@ class UsageTracker:
|
|||||||
default_config: RateLimitConfig | None = None,
|
default_config: RateLimitConfig | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._configs = configs or {}
|
self._configs = configs or {}
|
||||||
self._default_config = default_config or RateLimitConfig(
|
if default_config is None:
|
||||||
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
|
default_config = RateLimitConfig(
|
||||||
|
window=timedelta(
|
||||||
|
minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES
|
||||||
|
),
|
||||||
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
||||||
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
||||||
)
|
)
|
||||||
|
self._default_config = default_config
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -260,8 +264,8 @@ class UsageTracker:
|
|||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||||
for model, state in self.iter_state_items():
|
for model_key, state in self.iter_state_items():
|
||||||
prov, model_name = split_model_key(model)
|
prov, model_name = split_model_key(model_key)
|
||||||
if provider and provider != prov:
|
if provider and provider != prov:
|
||||||
continue
|
continue
|
||||||
if model and model != model_name:
|
if model and model != model_name:
|
||||||
@ -304,7 +308,10 @@ class UsageTracker:
|
|||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def _get_config(self, model: str) -> RateLimitConfig | None:
|
def _get_config(self, model: str) -> RateLimitConfig | None:
|
||||||
return self._configs.get(model) or self._default_config
|
config = self._configs.get(model)
|
||||||
|
if config is not None:
|
||||||
|
return config
|
||||||
|
return self._default_config
|
||||||
|
|
||||||
def _prune_expired_events(
|
def _prune_expired_events(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.models.discord import DiscordMCPServer
|
from memory.common.db.models import MCPServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ async def register_oauth_client(
|
|||||||
|
|
||||||
|
|
||||||
async def issue_challenge(
|
async def issue_challenge(
|
||||||
mcp_server: DiscordMCPServer,
|
mcp_server: MCPServer,
|
||||||
endpoints: OAuthEndpoints,
|
endpoints: OAuthEndpoints,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate OAuth challenge and store state in mcp_server object."""
|
"""Generate OAuth challenge and store state in mcp_server object."""
|
||||||
@ -160,7 +160,7 @@ async def issue_challenge(
|
|||||||
mcp_server.code_verifier = code_verifier # type: ignore
|
mcp_server.code_verifier = code_verifier # type: ignore
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: "
|
f"Generated OAuth state for MCP server {mcp_server.mcp_server_url}: "
|
||||||
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
|
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -179,7 +179,7 @@ async def issue_challenge(
|
|||||||
|
|
||||||
|
|
||||||
async def complete_oauth_flow(
|
async def complete_oauth_flow(
|
||||||
mcp_server: DiscordMCPServer, code: str, state: str
|
mcp_server: MCPServer, code: str, state: str
|
||||||
) -> tuple[int, str]:
|
) -> tuple[int, str]:
|
||||||
"""Complete OAuth flow by exchanging code for token.
|
"""Complete OAuth flow by exchanging code for token.
|
||||||
|
|
||||||
@ -196,7 +196,7 @@ async def complete_oauth_flow(
|
|||||||
return 400, "Invalid or expired OAuth state"
|
return 400, "Invalid or expired OAuth state"
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Found MCP server config: user={mcp_server.discord_bot_user_id}, "
|
f"Found MCP server config: id={mcp_server.id}, "
|
||||||
f"url={mcp_server.mcp_server_url}"
|
f"url={mcp_server.mcp_server_url}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -247,8 +247,8 @@ async def complete_oauth_flow(
|
|||||||
mcp_server.code_verifier = None # type: ignore
|
mcp_server.code_verifier = None # type: ignore
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
|
f"Stored tokens for MCP server id={mcp_server.id}, "
|
||||||
f"server {mcp_server.mcp_server_url}"
|
f"url={mcp_server.mcp_server_url}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, "✅ Authorization successful! You can now use this MCP server."
|
return 200, "✅ Authorization successful! You can now use this MCP server."
|
||||||
|
|||||||
@ -12,14 +12,12 @@ from contextlib import asynccontextmanager
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import DiscordMCPServer, DiscordBotUser
|
from memory.common.db.models import DiscordBotUser
|
||||||
from memory.common.oauth import complete_oauth_flow
|
|
||||||
from memory.discord.collector import MessageCollector
|
from memory.discord.collector import MessageCollector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -179,17 +179,14 @@ def should_track_message(
|
|||||||
channel: DiscordChannel,
|
channel: DiscordChannel,
|
||||||
user: DiscordUser,
|
user: DiscordUser,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Pure function to determine if we should track this message"""
|
if server and server.ignore_messages:
|
||||||
if server and not server.track_messages: # type: ignore
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not channel.track_messages:
|
if channel.ignore_messages:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if channel.channel_type in ("dm", "group_dm"):
|
if channel.channel_type in ("dm", "group_dm"):
|
||||||
return bool(user.track_messages)
|
return not user.ignore_messages
|
||||||
|
|
||||||
# Default: track the message
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Lightweight slash-command helpers for the Discord collector."""
|
"""Lightweight slash-command helpers for the Discord collector."""
|
||||||
|
|
||||||
from calendar import c
|
import io
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Literal
|
from typing import Callable, Literal
|
||||||
@ -9,12 +9,34 @@ import discord
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
from memory.common.db.models import (
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordServer,
|
||||||
|
DiscordUser,
|
||||||
|
MCPServer,
|
||||||
|
MCPServerAssignment,
|
||||||
|
)
|
||||||
from memory.discord.mcp import run_mcp_server_command
|
from memory.discord.mcp import run_mcp_server_command
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ScopeLiteral = Literal["server", "channel", "user"]
|
ScopeLiteral = Literal["bot", "me", "server", "channel", "user"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiscordObjects:
|
||||||
|
bot: DiscordUser
|
||||||
|
server: DiscordServer | None
|
||||||
|
channel: DiscordChannel | None
|
||||||
|
user: DiscordUser | None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def items(self):
|
||||||
|
items = [self.bot, self.server, self.channel, self.user]
|
||||||
|
return [item for item in items if item is not None]
|
||||||
|
|
||||||
|
|
||||||
|
ListHandler = Callable[[DiscordObjects], str]
|
||||||
|
|
||||||
|
|
||||||
class CommandError(Exception):
|
class CommandError(Exception):
|
||||||
@ -44,12 +66,312 @@ class CommandContext:
|
|||||||
CommandHandler = Callable[..., CommandResponse]
|
CommandHandler = Callable[..., CommandResponse]
|
||||||
|
|
||||||
|
|
||||||
|
async def respond(
|
||||||
|
interaction: discord.Interaction, content: str, ephemeral: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""Send a response to the interaction, as file if too large."""
|
||||||
|
max_length = 1900
|
||||||
|
if len(content) <= max_length:
|
||||||
|
await interaction.response.send_message(content, ephemeral=ephemeral)
|
||||||
|
return
|
||||||
|
|
||||||
|
file = discord.File(io.BytesIO(content.encode("utf-8")), filename="response.txt")
|
||||||
|
await interaction.response.send_message(
|
||||||
|
"Response too large, sending as file:", file=file, ephemeral=ephemeral
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def with_object_context(
|
||||||
|
bot: discord.Client,
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
handler: ListHandler,
|
||||||
|
user: discord.User | None,
|
||||||
|
) -> str:
|
||||||
|
"""Execute handler with Discord objects context."""
|
||||||
|
server = interaction.guild
|
||||||
|
channel = interaction.channel
|
||||||
|
target_user = user or interaction.user
|
||||||
|
with make_session() as session:
|
||||||
|
objects = DiscordObjects(
|
||||||
|
bot=ensure_user(session, bot.user),
|
||||||
|
server=server and ensure_server(session, server),
|
||||||
|
channel=channel and _ensure_channel(session, channel, server and server.id),
|
||||||
|
user=ensure_user(session, target_user),
|
||||||
|
)
|
||||||
|
return handler(objects)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_scope_group(
|
||||||
|
parent: discord.app_commands.Group,
|
||||||
|
scope: ScopeLiteral,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
) -> discord.app_commands.Group:
|
||||||
|
"""Create a command group for a scope (bot/me/server/channel).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent: Parent command group
|
||||||
|
scope: Scope literal (bot, me, server, channel)
|
||||||
|
name: Group name
|
||||||
|
description: Group description
|
||||||
|
"""
|
||||||
|
group = discord.app_commands.Group(
|
||||||
|
name=name, description=description, parent=parent
|
||||||
|
)
|
||||||
|
|
||||||
|
@group.command(name="prompt", description=f"Manage {name}'s system prompt")
|
||||||
|
@discord.app_commands.describe(prompt="The system prompt to set")
|
||||||
|
async def prompt_cmd(interaction: discord.Interaction, prompt: str | None = None):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction, scope=scope, handler=handle_prompt, prompt=prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
@group.command(name="chattiness", description=f"Show/set {name}'s chattiness")
|
||||||
|
@discord.app_commands.describe(value="Optional new chattiness value (0-100)")
|
||||||
|
async def chattiness_cmd(
|
||||||
|
interaction: discord.Interaction, value: int | None = None
|
||||||
|
):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction, scope=scope, handler=handle_chattiness, value=value
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ignore command
|
||||||
|
@group.command(name="ignore", description=f"Toggle bot ignoring {name} messages")
|
||||||
|
@discord.app_commands.describe(enabled="Whether to ignore messages")
|
||||||
|
async def ignore_cmd(interaction: discord.Interaction, enabled: bool | None = None):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction, scope=scope, handler=handle_ignore, ignore_enabled=enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
# Summary command
|
||||||
|
@group.command(name="summary", description=f"Show {name}'s summary")
|
||||||
|
async def summary_cmd(interaction: discord.Interaction):
|
||||||
|
await _run_interaction_command(interaction, scope=scope, handler=handle_summary)
|
||||||
|
|
||||||
|
# MCP command
|
||||||
|
@group.command(name="mcp", description=f"Manage {name}'s MCP servers")
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
action="Action to perform",
|
||||||
|
url="MCP server URL (required for add, delete, connect, tools)",
|
||||||
|
)
|
||||||
|
async def mcp_cmd(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||||
|
url: str | None = None,
|
||||||
|
):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_mcp_servers,
|
||||||
|
action=action,
|
||||||
|
url=url and url.strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return group
|
||||||
|
|
||||||
|
|
||||||
|
def _create_user_scope_group(
|
||||||
|
parent: discord.app_commands.Group,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
) -> discord.app_commands.Group:
|
||||||
|
"""Create command group for user scope (requires user parameter).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent: Parent command group
|
||||||
|
name: Group name
|
||||||
|
description: Group description
|
||||||
|
"""
|
||||||
|
group = discord.app_commands.Group(
|
||||||
|
name=name, description=description, parent=parent
|
||||||
|
)
|
||||||
|
scope = "user"
|
||||||
|
|
||||||
|
@group.command(name="prompt", description=f"Manage {name}'s system prompt")
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
user="Target user", prompt="The system prompt to set"
|
||||||
|
)
|
||||||
|
async def prompt_cmd(
|
||||||
|
interaction: discord.Interaction, user: discord.User, prompt: str | None = None
|
||||||
|
):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_prompt,
|
||||||
|
target_user=user,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
@group.command(name="chattiness", description=f"Show/set {name}'s chattiness")
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
user="Target user", value="Optional new chattiness value (0-100)"
|
||||||
|
)
|
||||||
|
async def chattiness_cmd(
|
||||||
|
interaction: discord.Interaction, user: discord.User, value: int | None = None
|
||||||
|
):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_chattiness,
|
||||||
|
target_user=user,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ignore command
|
||||||
|
@group.command(name="ignore", description=f"Toggle bot ignoring {name} messages")
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
user="Target user", enabled="Whether to ignore messages"
|
||||||
|
)
|
||||||
|
async def ignore_cmd(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
user: discord.User,
|
||||||
|
enabled: bool | None = None,
|
||||||
|
):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_ignore,
|
||||||
|
target_user=user,
|
||||||
|
ignore_enabled=enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Summary command
|
||||||
|
@group.command(name="summary", description=f"Show {name}'s summary")
|
||||||
|
@discord.app_commands.describe(user="Target user")
|
||||||
|
async def summary_cmd(interaction: discord.Interaction, user: discord.User):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction, scope=scope, handler=handle_summary, target_user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
# MCP command
|
||||||
|
@group.command(name="mcp", description=f"Manage {name}'s MCP servers")
|
||||||
|
@discord.app_commands.describe(
|
||||||
|
user="Target user",
|
||||||
|
action="Action to perform",
|
||||||
|
url="MCP server URL (required for add, delete, connect, tools)",
|
||||||
|
)
|
||||||
|
async def mcp_cmd(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
user: discord.User,
|
||||||
|
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||||
|
url: str | None = None,
|
||||||
|
):
|
||||||
|
await _run_interaction_command(
|
||||||
|
interaction,
|
||||||
|
scope=scope,
|
||||||
|
handler=handle_mcp_servers,
|
||||||
|
target_user=user,
|
||||||
|
action=action,
|
||||||
|
url=url and url.strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return group
|
||||||
|
|
||||||
|
|
||||||
|
def create_list_group(
|
||||||
|
bot: discord.Client, parent: discord.app_commands.Group
|
||||||
|
) -> discord.app_commands.Group:
|
||||||
|
"""Create command group for listing settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent: Parent command group
|
||||||
|
"""
|
||||||
|
group = discord.app_commands.Group(
|
||||||
|
name="list", description="List settings", parent=parent
|
||||||
|
)
|
||||||
|
|
||||||
|
@group.command(name="prompt", description="List full system prompt")
|
||||||
|
@discord.app_commands.describe(user="Target user")
|
||||||
|
async def prompt_cmd(
|
||||||
|
interaction: discord.Interaction, user: discord.User | None = None
|
||||||
|
):
|
||||||
|
def handler(objects: DiscordObjects) -> str:
|
||||||
|
prompts = [o.xml_prompt() for o in objects.items if o.system_prompt]
|
||||||
|
return "\n\n".join(prompts)
|
||||||
|
|
||||||
|
res = with_object_context(bot, interaction, handler, user)
|
||||||
|
await respond(interaction, res)
|
||||||
|
|
||||||
|
@group.command(name="chattiness", description="Show {name}'s chattiness")
|
||||||
|
@discord.app_commands.describe(user="Target user")
|
||||||
|
async def chattiness_cmd(
|
||||||
|
interaction: discord.Interaction, user: discord.User | None = None
|
||||||
|
):
|
||||||
|
def handler(objects: DiscordObjects) -> str:
|
||||||
|
values = [
|
||||||
|
o.chattiness_threshold
|
||||||
|
for o in objects.items
|
||||||
|
if o.chattiness_threshold is not None
|
||||||
|
]
|
||||||
|
val = min(values) if values else 50
|
||||||
|
if objects.user:
|
||||||
|
return f"Total current chattiness for {objects.user.username}: {val}"
|
||||||
|
return f"Total current chattiness: {val}"
|
||||||
|
|
||||||
|
res = with_object_context(bot, interaction, handler, user)
|
||||||
|
await respond(interaction, res)
|
||||||
|
|
||||||
|
@group.command(
|
||||||
|
name="ignore", description="Does this bot ignore messages for this user?"
|
||||||
|
)
|
||||||
|
@discord.app_commands.describe(user="Target user")
|
||||||
|
async def ignore_cmd(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
user: discord.User | None = None,
|
||||||
|
):
|
||||||
|
def handler(objects: DiscordObjects) -> str:
|
||||||
|
should_ignore = any(o.ignore_messages for o in objects.items)
|
||||||
|
if should_ignore:
|
||||||
|
return f"The bot ignores messages for {objects.user}."
|
||||||
|
return f"The bot does not ignore messages for {objects.user}."
|
||||||
|
|
||||||
|
res = with_object_context(bot, interaction, handler, user)
|
||||||
|
await respond(interaction, res)
|
||||||
|
|
||||||
|
@group.command(name="summary", description="Show the full summary")
|
||||||
|
@discord.app_commands.describe(user="Target user")
|
||||||
|
async def summary_cmd(
|
||||||
|
interaction: discord.Interaction, user: discord.User | None = None
|
||||||
|
):
|
||||||
|
def handler(objects: DiscordObjects) -> str:
|
||||||
|
summaries = [o.xml_summary() for o in objects.items if o.summary]
|
||||||
|
return "\n\n".join(summaries)
|
||||||
|
|
||||||
|
res = with_object_context(bot, interaction, handler, user)
|
||||||
|
await respond(interaction, res)
|
||||||
|
|
||||||
|
@group.command(name="mcp", description="All used MCP servers")
|
||||||
|
@discord.app_commands.describe(user="Target user")
|
||||||
|
async def mcp_cmd(
|
||||||
|
interaction: discord.Interaction, user: discord.User | None = None
|
||||||
|
):
|
||||||
|
logger.error(f"Listing MCP servers for {user}")
|
||||||
|
ids = [
|
||||||
|
interaction.guild_id,
|
||||||
|
interaction.channel_id,
|
||||||
|
(user or interaction.user).id,
|
||||||
|
bot.user.id,
|
||||||
|
]
|
||||||
|
with make_session() as session:
|
||||||
|
mcp_servers = (
|
||||||
|
session.query(MCPServer)
|
||||||
|
.filter(
|
||||||
|
MCPServerAssignment.entity_id.in_(i for i in ids if i is not None)
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers]
|
||||||
|
res = "\n\n".join(mcp_servers)
|
||||||
|
await respond(interaction, res)
|
||||||
|
|
||||||
|
return group
|
||||||
|
|
||||||
|
|
||||||
def register_slash_commands(bot: discord.Client) -> None:
|
def register_slash_commands(bot: discord.Client) -> None:
|
||||||
"""Register the collector slash commands on the provided bot.
|
"""Register the collector slash commands on the provided bot.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bot: Discord bot client
|
bot: Discord bot client
|
||||||
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if getattr(bot, "_memory_commands_registered", False):
|
if getattr(bot, "_memory_commands_registered", False):
|
||||||
@ -63,128 +385,21 @@ def register_slash_commands(bot: discord.Client) -> None:
|
|||||||
tree = bot.tree
|
tree = bot.tree
|
||||||
name = bot.user and bot.user.name.replace("-", "_").lower()
|
name = bot.user and bot.user.name.replace("-", "_").lower()
|
||||||
|
|
||||||
@tree.command(
|
# Create main command group
|
||||||
name=f"{name}_show_prompt", description="Show the current system prompt"
|
memory_group = discord.app_commands.Group(
|
||||||
)
|
name=name or "memory", description=f"{name} bot configuration and management"
|
||||||
@discord.app_commands.describe(
|
|
||||||
scope="Which configuration to inspect",
|
|
||||||
user="Target user when the scope is 'user'",
|
|
||||||
)
|
|
||||||
async def show_prompt_command(
|
|
||||||
interaction: discord.Interaction,
|
|
||||||
scope: ScopeLiteral,
|
|
||||||
user: discord.User | None = None,
|
|
||||||
) -> None:
|
|
||||||
await _run_interaction_command(
|
|
||||||
interaction,
|
|
||||||
scope=scope,
|
|
||||||
handler=handle_prompt,
|
|
||||||
target_user=user,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@tree.command(
|
# Create scope groups
|
||||||
name=f"{name}_set_prompt",
|
_create_scope_group(memory_group, "bot", "bot", "Bot-wide settings")
|
||||||
description="Set the system prompt for the target",
|
_create_scope_group(memory_group, "me", "me", "Your personal settings")
|
||||||
)
|
_create_scope_group(memory_group, "server", "server", "Server-wide settings")
|
||||||
@discord.app_commands.describe(
|
_create_scope_group(memory_group, "channel", "channel", "Channel-specific settings")
|
||||||
scope="Which configuration to modify",
|
_create_user_scope_group(memory_group, "user", "Manage other users' settings")
|
||||||
prompt="The system prompt to set",
|
create_list_group(bot, memory_group)
|
||||||
user="Target user when the scope is 'user'",
|
|
||||||
)
|
|
||||||
async def set_prompt_command(
|
|
||||||
interaction: discord.Interaction,
|
|
||||||
scope: ScopeLiteral,
|
|
||||||
prompt: str,
|
|
||||||
user: discord.User | None = None,
|
|
||||||
) -> None:
|
|
||||||
await _run_interaction_command(
|
|
||||||
interaction,
|
|
||||||
scope=scope,
|
|
||||||
handler=handle_set_prompt,
|
|
||||||
target_user=user,
|
|
||||||
prompt=prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
@tree.command(
|
# Register main group
|
||||||
name=f"{name}_chattiness",
|
tree.add_command(memory_group)
|
||||||
description="Show or update the chattiness for the target",
|
|
||||||
)
|
|
||||||
@discord.app_commands.describe(
|
|
||||||
scope="Which configuration to inspect",
|
|
||||||
value="Optional new chattiness value between 0 and 100",
|
|
||||||
user="Target user when the scope is 'user'",
|
|
||||||
)
|
|
||||||
async def chattiness_command(
|
|
||||||
interaction: discord.Interaction,
|
|
||||||
scope: ScopeLiteral,
|
|
||||||
value: int | None = None,
|
|
||||||
user: discord.User | None = None,
|
|
||||||
) -> None:
|
|
||||||
await _run_interaction_command(
|
|
||||||
interaction,
|
|
||||||
scope=scope,
|
|
||||||
handler=handle_chattiness,
|
|
||||||
target_user=user,
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
|
|
||||||
@tree.command(
|
|
||||||
name=f"{name}_ignore",
|
|
||||||
description="Toggle whether the bot should ignore messages for the target",
|
|
||||||
)
|
|
||||||
@discord.app_commands.describe(
|
|
||||||
scope="Which configuration to modify",
|
|
||||||
enabled="Optional flag. Leave empty to enable ignoring.",
|
|
||||||
user="Target user when the scope is 'user'",
|
|
||||||
)
|
|
||||||
async def ignore_command(
|
|
||||||
interaction: discord.Interaction,
|
|
||||||
scope: ScopeLiteral,
|
|
||||||
enabled: bool | None = None,
|
|
||||||
user: discord.User | None = None,
|
|
||||||
) -> None:
|
|
||||||
await _run_interaction_command(
|
|
||||||
interaction,
|
|
||||||
scope=scope,
|
|
||||||
handler=handle_ignore,
|
|
||||||
target_user=user,
|
|
||||||
ignore_enabled=enabled,
|
|
||||||
)
|
|
||||||
|
|
||||||
@tree.command(
|
|
||||||
name=f"{name}_show_summary",
|
|
||||||
description="Show the stored summary for the target",
|
|
||||||
)
|
|
||||||
@discord.app_commands.describe(
|
|
||||||
scope="Which configuration to inspect",
|
|
||||||
user="Target user when the scope is 'user'",
|
|
||||||
)
|
|
||||||
async def summary_command(
|
|
||||||
interaction: discord.Interaction,
|
|
||||||
scope: ScopeLiteral,
|
|
||||||
user: discord.User | None = None,
|
|
||||||
) -> None:
|
|
||||||
await _run_interaction_command(
|
|
||||||
interaction,
|
|
||||||
scope=scope,
|
|
||||||
handler=handle_summary,
|
|
||||||
target_user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
@tree.command(
|
|
||||||
name=f"{name}_mcp_servers",
|
|
||||||
description="Manage MCP servers for your account",
|
|
||||||
)
|
|
||||||
@discord.app_commands.describe(
|
|
||||||
action="Action to perform",
|
|
||||||
url="MCP server URL (required for add, delete, connect, tools)",
|
|
||||||
)
|
|
||||||
async def mcp_servers_command(
|
|
||||||
interaction: discord.Interaction,
|
|
||||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
|
||||||
url: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
await run_mcp_server_command(interaction, bot.user, action, url and url.strip())
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_interaction_command(
|
async def _run_interaction_command(
|
||||||
@ -198,17 +413,16 @@ async def _run_interaction_command(
|
|||||||
"""Shared coroutine used by the registered slash commands."""
|
"""Shared coroutine used by the registered slash commands."""
|
||||||
try:
|
try:
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
context = _build_context(session, interaction, scope, target_user)
|
# Get bot from interaction client if needed for bot scope
|
||||||
response = handler(context, **handler_kwargs)
|
bot = getattr(interaction, "client", None)
|
||||||
|
context = _build_context(session, interaction, scope, target_user, bot)
|
||||||
|
response = await handler(context, **handler_kwargs)
|
||||||
session.commit()
|
session.commit()
|
||||||
except CommandError as exc: # pragma: no cover - passthrough
|
except CommandError as exc: # pragma: no cover - passthrough
|
||||||
await interaction.response.send_message(str(exc), ephemeral=True)
|
await respond(interaction, str(exc))
|
||||||
return
|
return
|
||||||
|
|
||||||
await interaction.response.send_message(
|
await respond(interaction, response.content, response.ephemeral)
|
||||||
response.content,
|
|
||||||
ephemeral=response.ephemeral,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_context(
|
def _build_context(
|
||||||
@ -216,60 +430,55 @@ def _build_context(
|
|||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
scope: ScopeLiteral,
|
scope: ScopeLiteral,
|
||||||
target_user: discord.User | None,
|
target_user: discord.User | None,
|
||||||
|
bot: discord.Client | None = None,
|
||||||
) -> CommandContext:
|
) -> CommandContext:
|
||||||
actor = _ensure_user(session, interaction.user)
|
actor = ensure_user(session, interaction.user)
|
||||||
|
|
||||||
if scope == "server":
|
# Determine target and display name based on scope
|
||||||
|
if scope == "bot":
|
||||||
|
if not bot or not bot.user:
|
||||||
|
raise CommandError("Bot user is not available.")
|
||||||
|
target = ensure_user(session, bot.user)
|
||||||
|
display_name = f"bot **{bot.user.name}**"
|
||||||
|
|
||||||
|
elif scope == "me":
|
||||||
|
target = ensure_user(session, interaction.user)
|
||||||
|
name = target.display_name or target.username
|
||||||
|
display_name = f"you (**{name}**)"
|
||||||
|
|
||||||
|
elif scope == "server":
|
||||||
if interaction.guild is None:
|
if interaction.guild is None:
|
||||||
raise CommandError("This command can only be used inside a server.")
|
raise CommandError("This command can only be used inside a server.")
|
||||||
|
target = ensure_server(session, interaction.guild)
|
||||||
target = _ensure_server(session, interaction.guild)
|
|
||||||
display_name = f"server **{target.name}**"
|
display_name = f"server **{target.name}**"
|
||||||
return CommandContext(
|
|
||||||
session=session,
|
|
||||||
interaction=interaction,
|
|
||||||
actor=actor,
|
|
||||||
scope=scope,
|
|
||||||
target=target,
|
|
||||||
display_name=display_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
if scope == "channel":
|
elif scope == "channel":
|
||||||
channel_obj = interaction.channel
|
if interaction.channel is None or not hasattr(interaction.channel, "id"):
|
||||||
if channel_obj is None or not hasattr(channel_obj, "id"):
|
|
||||||
raise CommandError("Unable to determine channel for this interaction.")
|
raise CommandError("Unable to determine channel for this interaction.")
|
||||||
|
target = _ensure_channel(session, interaction.channel, interaction.guild_id)
|
||||||
target = _ensure_channel(session, channel_obj, interaction.guild_id)
|
|
||||||
display_name = f"channel **#{target.name}**"
|
display_name = f"channel **#{target.name}**"
|
||||||
return CommandContext(
|
|
||||||
session=session,
|
|
||||||
interaction=interaction,
|
|
||||||
actor=actor,
|
|
||||||
scope=scope,
|
|
||||||
target=target,
|
|
||||||
display_name=display_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
if scope == "user":
|
elif scope == "user":
|
||||||
discord_user = target_user or interaction.user
|
if target_user is None:
|
||||||
if discord_user is None:
|
|
||||||
raise CommandError("A target user is required for this command.")
|
raise CommandError("A target user is required for this command.")
|
||||||
|
target = ensure_user(session, target_user)
|
||||||
|
name = target.display_name or target.username
|
||||||
|
display_name = f"user **{name}**"
|
||||||
|
|
||||||
target = _ensure_user(session, discord_user)
|
else:
|
||||||
display_name = target.display_name or target.username
|
|
||||||
return CommandContext(
|
|
||||||
session=session,
|
|
||||||
interaction=interaction,
|
|
||||||
actor=actor,
|
|
||||||
scope=scope,
|
|
||||||
target=target,
|
|
||||||
display_name=f"user **{display_name}**",
|
|
||||||
)
|
|
||||||
|
|
||||||
raise CommandError(f"Unsupported scope '{scope}'.")
|
raise CommandError(f"Unsupported scope '{scope}'.")
|
||||||
|
|
||||||
|
return CommandContext(
|
||||||
|
session=session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=actor,
|
||||||
|
scope=scope,
|
||||||
|
target=target,
|
||||||
|
display_name=display_name,
|
||||||
|
)
|
||||||
|
|
||||||
def _ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
|
|
||||||
|
def ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
|
||||||
server = session.get(DiscordServer, guild.id)
|
server = session.get(DiscordServer, guild.id)
|
||||||
if server is None:
|
if server is None:
|
||||||
server = DiscordServer(
|
server = DiscordServer(
|
||||||
@ -320,7 +529,7 @@ def _ensure_channel(
|
|||||||
return channel_model
|
return channel_model
|
||||||
|
|
||||||
|
|
||||||
def _ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
|
def ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
|
||||||
user = session.get(DiscordUser, discord_user.id)
|
user = session.get(DiscordUser, discord_user.id)
|
||||||
display_name = getattr(discord_user, "display_name", discord_user.name)
|
display_name = getattr(discord_user, "display_name", discord_user.name)
|
||||||
if user is None:
|
if user is None:
|
||||||
@ -354,32 +563,23 @@ def _resolve_channel_type(channel: discord.abc.Messageable) -> str:
|
|||||||
return getattr(getattr(channel, "type", None), "name", "unknown")
|
return getattr(getattr(channel, "type", None), "name", "unknown")
|
||||||
|
|
||||||
|
|
||||||
def handle_prompt(context: CommandContext) -> CommandResponse:
|
async def handle_prompt(
|
||||||
|
context: CommandContext, *, prompt: str | None = None
|
||||||
|
) -> CommandResponse:
|
||||||
|
if prompt is not None:
|
||||||
|
prompt = prompt or None
|
||||||
|
setattr(context.target, "system_prompt", prompt)
|
||||||
|
else:
|
||||||
prompt = getattr(context.target, "system_prompt", None)
|
prompt = getattr(context.target, "system_prompt", None)
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
return CommandResponse(
|
content = f"Current prompt for {context.display_name}:\n\n{prompt}"
|
||||||
content=f"Current prompt for {context.display_name}:\n\n{prompt}",
|
else:
|
||||||
)
|
content = f"No prompt configured for {context.display_name}."
|
||||||
|
return CommandResponse(content=content)
|
||||||
return CommandResponse(
|
|
||||||
content=f"No prompt configured for {context.display_name}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_set_prompt(
|
async def handle_chattiness(
|
||||||
context: CommandContext,
|
|
||||||
*,
|
|
||||||
prompt: str,
|
|
||||||
) -> CommandResponse:
|
|
||||||
setattr(context.target, "system_prompt", prompt)
|
|
||||||
|
|
||||||
return CommandResponse(
|
|
||||||
content=f"Updated system prompt for {context.display_name}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_chattiness(
|
|
||||||
context: CommandContext,
|
context: CommandContext,
|
||||||
*,
|
*,
|
||||||
value: int | None,
|
value: int | None,
|
||||||
@ -409,7 +609,7 @@ def handle_chattiness(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_ignore(
|
async def handle_ignore(
|
||||||
context: CommandContext,
|
context: CommandContext,
|
||||||
*,
|
*,
|
||||||
ignore_enabled: bool | None,
|
ignore_enabled: bool | None,
|
||||||
@ -424,7 +624,7 @@ def handle_ignore(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_summary(context: CommandContext) -> CommandResponse:
|
async def handle_summary(context: CommandContext) -> CommandResponse:
|
||||||
summary = getattr(context.target, "summary", None)
|
summary = getattr(context.target, "summary", None)
|
||||||
|
|
||||||
if summary:
|
if summary:
|
||||||
@ -435,3 +635,31 @@ def handle_summary(context: CommandContext) -> CommandResponse:
|
|||||||
return CommandResponse(
|
return CommandResponse(
|
||||||
content=f"No summary stored for {context.display_name}.",
|
content=f"No summary stored for {context.display_name}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_servers(
|
||||||
|
context: CommandContext,
|
||||||
|
*,
|
||||||
|
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||||
|
url: str | None,
|
||||||
|
) -> CommandResponse:
|
||||||
|
"""Handle MCP server commands for a specific scope."""
|
||||||
|
# Map scope to database entity type
|
||||||
|
entity_type = {
|
||||||
|
"bot": "DiscordUser",
|
||||||
|
"me": "DiscordUser",
|
||||||
|
"user": "DiscordUser",
|
||||||
|
"server": "DiscordServer",
|
||||||
|
"channel": "DiscordChannel",
|
||||||
|
}[context.scope]
|
||||||
|
|
||||||
|
bot_user = getattr(getattr(context.interaction, "client", None), "user", None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = await run_mcp_server_command(
|
||||||
|
bot_user, action, url, entity_type, context.target.id
|
||||||
|
)
|
||||||
|
return CommandResponse(content=res)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
|
||||||
|
raise CommandError(f"Error: {exc}") from exc
|
||||||
|
|||||||
@ -10,23 +10,27 @@ import discord
|
|||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models.discord import DiscordMCPServer
|
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||||
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def find_mcp_server(
|
def find_mcp_server(
|
||||||
session: Session | scoped_session, user_id: int, url: str
|
session: Session | scoped_session, entity_type: str, entity_id: int, url: str
|
||||||
) -> DiscordMCPServer | None:
|
) -> MCPServer | None:
|
||||||
return (
|
"""Find an MCP server assigned to an entity."""
|
||||||
session.query(DiscordMCPServer)
|
assignment = (
|
||||||
|
session.query(MCPServerAssignment)
|
||||||
|
.join(MCPServer)
|
||||||
.filter(
|
.filter(
|
||||||
DiscordMCPServer.discord_bot_user_id == user_id,
|
MCPServerAssignment.entity_type == entity_type,
|
||||||
DiscordMCPServer.mcp_server_url == url,
|
MCPServerAssignment.entity_id == entity_id,
|
||||||
|
MCPServer.mcp_server_url == url,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
return assignment and assignment.mcp_server
|
||||||
|
|
||||||
|
|
||||||
async def call_mcp_server(
|
async def call_mcp_server(
|
||||||
@ -72,35 +76,39 @@ async def call_mcp_server(
|
|||||||
continue # Skip invalid JSON lines
|
continue # Skip invalid JSON lines
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_list(interaction: discord.Interaction) -> str:
|
async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
|
||||||
"""List all MCP servers for the user."""
|
"""List all MCP servers for the user."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
servers = (
|
assignments = (
|
||||||
session.query(DiscordMCPServer)
|
session.query(MCPServerAssignment)
|
||||||
|
.join(MCPServer)
|
||||||
.filter(
|
.filter(
|
||||||
DiscordMCPServer.discord_bot_user_id == interaction.user.id,
|
MCPServerAssignment.entity_type == entity_type,
|
||||||
|
MCPServerAssignment.entity_id == entity_id,
|
||||||
)
|
)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not servers:
|
if not assignments:
|
||||||
return (
|
return (
|
||||||
"📋 **Your MCP Servers**\n\n"
|
"📋 **Your MCP Servers**\n\n"
|
||||||
"You don't have any MCP servers configured yet.\n"
|
"You don't have any MCP servers configured yet.\n"
|
||||||
"Use `/memory_mcp_servers add <url>` to add one."
|
"Use `/memory_mcp_servers add <url>` to add one."
|
||||||
)
|
)
|
||||||
|
|
||||||
def format_server(server: DiscordMCPServer) -> str:
|
def format_server(assignment: MCPServerAssignment) -> str:
|
||||||
|
server = assignment.mcp_server
|
||||||
con = "🟢" if cast(str | None, server.access_token) else "🔴"
|
con = "🟢" if cast(str | None, server.access_token) else "🔴"
|
||||||
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
|
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
|
||||||
|
|
||||||
server_list = "\n".join(format_server(s) for s in servers)
|
server_list = "\n".join(format_server(a) for a in assignments)
|
||||||
|
|
||||||
return f"📋 **Your MCP Servers**\n\n{server_list}"
|
return f"📋 **Your MCP Servers**\n\n{server_list}"
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_add(
|
async def handle_mcp_add(
|
||||||
interaction: discord.Interaction,
|
entity_type: str,
|
||||||
|
entity_id: int,
|
||||||
bot_user: discord.User | None,
|
bot_user: discord.User | None,
|
||||||
url: str,
|
url: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -108,7 +116,7 @@ async def handle_mcp_add(
|
|||||||
if not bot_user:
|
if not bot_user:
|
||||||
raise ValueError("Bot user is required")
|
raise ValueError("Bot user is required")
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
if find_mcp_server(session, bot_user.id, url):
|
if find_mcp_server(session, entity_type, entity_id, url):
|
||||||
return (
|
return (
|
||||||
f"**MCP Server Already Exists**\n\n"
|
f"**MCP Server Already Exists**\n\n"
|
||||||
f"You already have an MCP server configured at `{url}`.\n"
|
f"You already have an MCP server configured at `{url}`.\n"
|
||||||
@ -116,25 +124,32 @@ async def handle_mcp_add(
|
|||||||
)
|
)
|
||||||
|
|
||||||
endpoints = await get_endpoints(url)
|
endpoints = await get_endpoints(url)
|
||||||
client_id = await register_oauth_client(
|
name = f"Discord Bot - {bot_user.name} ({entity_type} {entity_id})"
|
||||||
endpoints,
|
client_id = await register_oauth_client(endpoints, url, name)
|
||||||
url,
|
|
||||||
f"Discord Bot - {bot_user.name} ({interaction.user.name})",
|
# Create MCP server
|
||||||
)
|
mcp_server = MCPServer(
|
||||||
mcp_server = DiscordMCPServer(
|
|
||||||
discord_bot_user_id=bot_user.id,
|
|
||||||
mcp_server_url=url,
|
mcp_server_url=url,
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
|
name=name,
|
||||||
)
|
)
|
||||||
session.add(mcp_server)
|
session.add(mcp_server)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
|
assignment = MCPServerAssignment(
|
||||||
|
mcp_server_id=mcp_server.id,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
)
|
||||||
|
session.add(assignment)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
auth_url = await issue_challenge(mcp_server, endpoints)
|
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created MCP server record: id={mcp_server.id}, "
|
f"Created MCP server record: id={mcp_server.id}, "
|
||||||
f"user={interaction.user.id}, url={url}"
|
f"{entity_type}={entity_id}, url={url}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -146,32 +161,54 @@ async def handle_mcp_add(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_delete(bot_user: discord.User, url: str) -> str:
|
async def handle_mcp_delete(entity_type: str, entity_id: int, url: str) -> str:
|
||||||
"""Delete an MCP server."""
|
"""Delete an MCP server assignment."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
# Find the assignment
|
||||||
if not mcp_server:
|
assignment = (
|
||||||
|
session.query(MCPServerAssignment)
|
||||||
|
.join(MCPServer)
|
||||||
|
.filter(
|
||||||
|
MCPServerAssignment.entity_type == entity_type,
|
||||||
|
MCPServerAssignment.entity_id == entity_id,
|
||||||
|
MCPServer.mcp_server_url == url,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not assignment:
|
||||||
return (
|
return (
|
||||||
f"**MCP Server Not Found**\n\n"
|
f"**MCP Server Not Found**\n\n"
|
||||||
f"You don't have an MCP server configured at `{url}`.\n"
|
f"You don't have an MCP server configured at `{url}`.\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Delete the assignment (server will cascade delete if no other assignments exist)
|
||||||
|
session.delete(assignment)
|
||||||
|
|
||||||
|
# Check if server has other assignments
|
||||||
|
mcp_server = assignment.mcp_server
|
||||||
|
other_assignments = (
|
||||||
|
session.query(MCPServerAssignment)
|
||||||
|
.filter(
|
||||||
|
MCPServerAssignment.mcp_server_id == mcp_server.id,
|
||||||
|
MCPServerAssignment.id != assignment.id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no other assignments, delete the server too
|
||||||
|
if other_assignments == 0:
|
||||||
session.delete(mcp_server)
|
session.delete(mcp_server)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
async def handle_mcp_connect(entity_type: str, entity_id: int, url: str) -> str:
|
||||||
"""Reconnect to an existing MCP server (redo OAuth)."""
|
"""Reconnect to an existing MCP server (redo OAuth)."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
mcp_server = find_mcp_server(session, entity_type, entity_id, url)
|
||||||
if not mcp_server:
|
|
||||||
raise ValueError(
|
|
||||||
f"**MCP Server Not Found**\n\n"
|
|
||||||
f"You don't have an MCP server configured at `{url}`.\n"
|
|
||||||
f"Use `/memory_mcp_servers add {url}` to add it first."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"**MCP Server Not Found**\n\n"
|
f"**MCP Server Not Found**\n\n"
|
||||||
@ -184,7 +221,9 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}")
|
logger.info(
|
||||||
|
f"Regenerated OAuth challenge for {entity_type}={entity_id}, url={url}"
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"🔄 **Reconnect to MCP Server**\n\n"
|
f"🔄 **Reconnect to MCP Server**\n\n"
|
||||||
@ -195,10 +234,10 @@ async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
|
async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
|
||||||
"""List tools available on an MCP server."""
|
"""List tools available on an MCP server."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
mcp_server = find_mcp_server(session, entity_type, entity_id, url)
|
||||||
|
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -265,37 +304,28 @@ async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def run_mcp_server_command(
|
async def run_mcp_server_command(
|
||||||
interaction: discord.Interaction,
|
|
||||||
bot_user: discord.User | None,
|
bot_user: discord.User | None,
|
||||||
action: Literal["list", "add", "delete", "connect", "tools"],
|
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||||
url: str | None,
|
url: str | None,
|
||||||
|
entity_type: str,
|
||||||
|
entity_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle MCP server management commands."""
|
"""Handle MCP server management commands."""
|
||||||
if action not in ["list", "add", "delete", "connect", "tools"]:
|
if action not in ["list", "add", "delete", "connect", "tools"]:
|
||||||
await interaction.response.send_message("❌ Invalid action", ephemeral=True)
|
raise ValueError(f"Invalid action: {action}")
|
||||||
return
|
|
||||||
if action != "list" and not url:
|
if action != "list" and not url:
|
||||||
await interaction.response.send_message(
|
raise ValueError("URL is required for this action")
|
||||||
"❌ URL is required for this action", ephemeral=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if not bot_user:
|
if not bot_user:
|
||||||
await interaction.response.send_message(
|
raise ValueError("Bot user is required")
|
||||||
"❌ Bot user is required", ephemeral=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if action == "list" or not url:
|
if action == "list" or not url:
|
||||||
result = await handle_mcp_list(interaction)
|
return await handle_mcp_list(entity_type, entity_id)
|
||||||
elif action == "add":
|
elif action == "add":
|
||||||
result = await handle_mcp_add(interaction, bot_user, url)
|
return await handle_mcp_add(entity_type, entity_id, bot_user, url)
|
||||||
elif action == "delete":
|
elif action == "delete":
|
||||||
result = await handle_mcp_delete(bot_user, url)
|
return await handle_mcp_delete(entity_type, entity_id, url)
|
||||||
elif action == "connect":
|
elif action == "connect":
|
||||||
result = await handle_mcp_connect(bot_user, url)
|
return await handle_mcp_connect(entity_type, entity_id, url)
|
||||||
elif action == "tools":
|
elif action == "tools":
|
||||||
result = await handle_mcp_tools(bot_user, url)
|
return await handle_mcp_tools(entity_type, entity_id, url)
|
||||||
except Exception as exc:
|
raise ValueError(f"Invalid action: {action}")
|
||||||
result = f"❌ Error: {exc}"
|
|
||||||
await interaction.response.send_message(result, ephemeral=True)
|
|
||||||
|
|||||||
@ -12,9 +12,9 @@ from memory.common.db.models import (
|
|||||||
DiscordMessage,
|
DiscordMessage,
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
ScheduledLLMCall,
|
ScheduledLLMCall,
|
||||||
|
MCPServer as MCPServerModel,
|
||||||
)
|
)
|
||||||
from memory.common.llms.base import create_provider
|
from memory.common.llms.base import create_provider
|
||||||
from memory.common.llms.tools import MCPServer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -113,8 +113,6 @@ def upsert_scheduled_message(
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
naive_scheduled_time = scheduled_time.replace(tzinfo=None)
|
naive_scheduled_time = scheduled_time.replace(tzinfo=None)
|
||||||
print(f"naive_scheduled_time: {naive_scheduled_time}")
|
|
||||||
print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}")
|
|
||||||
if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
|
if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
|
||||||
prev_call.status = "cancelled" # type: ignore
|
prev_call.status = "cancelled" # type: ignore
|
||||||
|
|
||||||
@ -141,10 +139,8 @@ def previous_messages(
|
|||||||
) -> list[DiscordMessage]:
|
) -> list[DiscordMessage]:
|
||||||
messages = session.query(DiscordMessage)
|
messages = session.query(DiscordMessage)
|
||||||
if user_id:
|
if user_id:
|
||||||
print(f"user_id: {user_id}")
|
|
||||||
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
||||||
if channel_id:
|
if channel_id:
|
||||||
print(f"channel_id: {channel_id}")
|
|
||||||
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
||||||
return list(
|
return list(
|
||||||
reversed(
|
reversed(
|
||||||
@ -191,7 +187,7 @@ def comm_channel_prompt(
|
|||||||
{users}
|
{users}
|
||||||
</user_notes>
|
</user_notes>
|
||||||
""").format(
|
""").format(
|
||||||
users="\n".join({msg.from_user.as_xml() for msg in messages}),
|
users="\n".join({msg.from_user.xml_summary() for msg in messages}),
|
||||||
)
|
)
|
||||||
|
|
||||||
return textwrap.dedent("""
|
return textwrap.dedent("""
|
||||||
@ -216,7 +212,7 @@ def call_llm(
|
|||||||
system_prompt: str = "",
|
system_prompt: str = "",
|
||||||
messages: list[str | dict[str, Any]] = [],
|
messages: list[str | dict[str, Any]] = [],
|
||||||
allowed_tools: Collection[str] | None = None,
|
allowed_tools: Collection[str] | None = None,
|
||||||
mcp_servers: list[MCPServer] | None = None,
|
mcp_servers: list[MCPServerModel] | None = None,
|
||||||
num_previous_messages: int = 10,
|
num_previous_messages: int = 10,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""
|
"""
|
||||||
@ -255,6 +251,7 @@ def call_llm(
|
|||||||
|
|
||||||
from memory.common.llms.tools.discord import make_discord_tools
|
from memory.common.llms.tools.discord import make_discord_tools
|
||||||
from memory.common.llms.tools.base import WebSearchTool
|
from memory.common.llms.tools.base import WebSearchTool
|
||||||
|
from memory.common.llms.tools import MCPServer
|
||||||
|
|
||||||
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
|
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
|
||||||
tools |= {"web_search": WebSearchTool()}
|
tools |= {"web_search": WebSearchTool()}
|
||||||
@ -270,7 +267,16 @@ def call_llm(
|
|||||||
messages=provider.as_messages(message_content),
|
messages=provider.as_messages(message_content),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
|
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
|
||||||
mcp_servers=mcp_servers,
|
mcp_servers=[
|
||||||
|
MCPServer(
|
||||||
|
name=str(server.name),
|
||||||
|
url=str(server.mcp_server_url),
|
||||||
|
token=str(server.access_token),
|
||||||
|
)
|
||||||
|
for server in mcp_servers
|
||||||
|
]
|
||||||
|
if mcp_servers
|
||||||
|
else None,
|
||||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||||
).response
|
).response
|
||||||
|
|
||||||
@ -294,10 +300,8 @@ def send_discord_response(
|
|||||||
True if sent successfully
|
True if sent successfully
|
||||||
"""
|
"""
|
||||||
if channel_id is not None:
|
if channel_id is not None:
|
||||||
logger.info(f"Sending message to channel {channel_id}")
|
|
||||||
return discord.send_to_channel(bot_id, channel_id, response)
|
return discord.send_to_channel(bot_id, channel_id, response)
|
||||||
elif user_identifier is not None:
|
elif user_identifier is not None:
|
||||||
logger.info(f"Sending DM to {user_identifier}")
|
|
||||||
return discord.send_dm(bot_id, user_identifier, response)
|
return discord.send_dm(bot_id, user_identifier, response)
|
||||||
else:
|
else:
|
||||||
logger.error("Neither channel_id nor user_identifier provided")
|
logger.error("Neither channel_id nor user_identifier provided")
|
||||||
|
|||||||
@ -91,6 +91,10 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if f"<@{message.recipient_user.id}>" in message.content:
|
||||||
|
logger.info("Direct mention of the bot, processing message")
|
||||||
|
return True
|
||||||
|
|
||||||
if message.from_user == message.recipient_user:
|
if message.from_user == message.recipient_user:
|
||||||
logger.info("Skipping message because from_user == recipient_user")
|
logger.info("Skipping message because from_user == recipient_user")
|
||||||
return False
|
return False
|
||||||
@ -132,6 +136,8 @@ def should_process(message: DiscordMessage) -> bool:
|
|||||||
if not (res := re.search(r"<number>(.*)</number>", response)):
|
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"chattiness_threshold: {message.chattiness_threshold}")
|
||||||
|
logger.info(f"haiku desire: {res.group(1)}")
|
||||||
if int(res.group(1)) < 100 - message.chattiness_threshold:
|
if int(res.group(1)) < 100 - message.chattiness_threshold:
|
||||||
return False
|
return False
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -190,19 +196,11 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
mcp_servers = None
|
mcp_servers = discord_message.get_mcp_servers(session)
|
||||||
if (
|
system_prompt = discord_message.system_prompt or ""
|
||||||
discord_message.recipient_user
|
system_prompt += comm_channel_prompt(
|
||||||
and discord_message.recipient_user.mcp_servers
|
session, discord_message.recipient_user, discord_message.channel
|
||||||
):
|
|
||||||
mcp_servers = [
|
|
||||||
MCPServer(
|
|
||||||
name=server.mcp_server_url,
|
|
||||||
url=server.mcp_server_url,
|
|
||||||
token=server.access_token,
|
|
||||||
)
|
)
|
||||||
for server in discord_message.recipient_user.mcp_servers
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = call_llm(
|
response = call_llm(
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -20,6 +21,31 @@ from memory.common.qdrant import initialize_collections
|
|||||||
from tests.providers.email_provider import MockEmailProvider
|
from tests.providers.email_provider import MockEmailProvider
|
||||||
|
|
||||||
|
|
||||||
|
class MockRedis:
|
||||||
|
"""In-memory mock of Redis for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._data = {}
|
||||||
|
|
||||||
|
def get(self, key: str):
|
||||||
|
return self._data.get(key)
|
||||||
|
|
||||||
|
def set(self, key: str, value):
|
||||||
|
self._data[key] = value
|
||||||
|
|
||||||
|
def scan_iter(self, match: str):
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
|
pattern = match.replace("*", "**")
|
||||||
|
for key in self._data.keys():
|
||||||
|
if fnmatch.fnmatch(key, pattern):
|
||||||
|
yield key
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_url(cls, url: str):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
def get_test_db_name() -> str:
|
def get_test_db_name() -> str:
|
||||||
return f"test_db_{uuid.uuid4().hex[:8]}"
|
return f"test_db_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
@ -83,7 +109,7 @@ def run_alembic_migrations(db_name: str) -> None:
|
|||||||
alembic_ini = project_root / "db" / "migrations" / "alembic.ini"
|
alembic_ini = project_root / "db" / "migrations" / "alembic.ini"
|
||||||
|
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
["alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
[sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||||
env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)},
|
env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)},
|
||||||
check=True,
|
check=True,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
@ -265,7 +291,8 @@ def mock_openai_client():
|
|||||||
),
|
),
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=20),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -281,7 +308,8 @@ def mock_openai_client():
|
|||||||
delta=Mock(content="test", tool_calls=None),
|
delta=Mock(content="test", tool_calls=None),
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||||
),
|
),
|
||||||
Mock(
|
Mock(
|
||||||
choices=[
|
choices=[
|
||||||
@ -289,7 +317,8 @@ def mock_openai_client():
|
|||||||
delta=Mock(content=" response", tool_calls=None),
|
delta=Mock(content=" response", tool_calls=None),
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=15),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -303,7 +332,8 @@ def mock_openai_client():
|
|||||||
),
|
),
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=20),
|
||||||
)
|
)
|
||||||
|
|
||||||
client.chat.completions.create.side_effect = streaming_response
|
client.chat.completions.create.side_effect = streaming_response
|
||||||
@ -312,6 +342,8 @@ def mock_openai_client():
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_anthropic_client():
|
def mock_anthropic_client():
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
|
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
|
||||||
client = mock_client()
|
client = mock_client()
|
||||||
client.messages = Mock()
|
client.messages = Mock()
|
||||||
@ -345,9 +377,59 @@ def mock_anthropic_client():
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mock async client
|
||||||
|
async_client = Mock()
|
||||||
|
async_client.messages = Mock()
|
||||||
|
async_client.messages.create = AsyncMock(
|
||||||
|
return_value=Mock(
|
||||||
|
content=[
|
||||||
|
Mock(
|
||||||
|
type="text",
|
||||||
|
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock async streaming
|
||||||
|
def async_stream_ctx(*args, **kwargs):
|
||||||
|
async def async_iter():
|
||||||
|
yield Mock(
|
||||||
|
type="content_block_delta",
|
||||||
|
delta=Mock(
|
||||||
|
type="text_delta",
|
||||||
|
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
class AsyncStreamMock:
|
||||||
|
async def __aenter__(self):
|
||||||
|
return async_iter()
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return AsyncStreamMock()
|
||||||
|
|
||||||
|
async_client.messages.stream = Mock(side_effect=async_stream_ctx)
|
||||||
|
|
||||||
|
# Add async_client property to mock
|
||||||
|
mock_client.return_value._async_client = None
|
||||||
|
|
||||||
|
with patch.object(anthropic, "AsyncAnthropic", return_value=async_client):
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_redis():
|
||||||
|
"""Mock Redis client for all tests."""
|
||||||
|
import redis
|
||||||
|
|
||||||
|
with patch.object(redis, "Redis", MockRedis):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_discord_client():
|
def mock_discord_client():
|
||||||
with patch.object(settings, "DISCORD_NOTIFICATIONS_ENABLED", False):
|
with patch.object(settings, "DISCORD_NOTIFICATIONS_ENABLED", False):
|
||||||
|
|||||||
151
tests/memory/api/test_auth.py
Normal file
151
tests/memory/api/test_auth.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
"""Tests for authentication helpers and OAuth callback."""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from memory.api import auth
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
|
|
||||||
|
def make_request(query: str) -> Request:
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"method": "GET",
|
||||||
|
"path": "/auth/callback/discord",
|
||||||
|
"headers": [],
|
||||||
|
"query_string": query.encode(),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def receive():
|
||||||
|
return {"type": "http.request", "body": b"", "more_body": False}
|
||||||
|
|
||||||
|
return Request(scope, receive)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_bearer_token_parses_header():
|
||||||
|
request = SimpleNamespace(headers={"Authorization": "Bearer token123"})
|
||||||
|
|
||||||
|
assert auth.get_bearer_token(request) == "token123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_bearer_token_handles_missing_header():
|
||||||
|
request = SimpleNamespace(headers={})
|
||||||
|
|
||||||
|
assert auth.get_bearer_token(request) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_token_prefers_header_over_cookie():
|
||||||
|
request = SimpleNamespace(
|
||||||
|
headers={"Authorization": "Bearer header-token"},
|
||||||
|
cookies={"session": "cookie-token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert auth.get_token(request) == "header-token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_token_falls_back_to_cookie():
|
||||||
|
request = SimpleNamespace(
|
||||||
|
headers={},
|
||||||
|
cookies={settings.SESSION_COOKIE_NAME: "cookie-token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert auth.get_token(request) == "cookie-token"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.api.auth.get_user_session")
|
||||||
|
def test_logout_removes_session(mock_get_user_session):
|
||||||
|
db = MagicMock()
|
||||||
|
session = MagicMock()
|
||||||
|
mock_get_user_session.return_value = session
|
||||||
|
request = SimpleNamespace()
|
||||||
|
|
||||||
|
result = auth.logout(request, db)
|
||||||
|
|
||||||
|
assert result == {"message": "Logged out successfully"}
|
||||||
|
db.delete.assert_called_once_with(session)
|
||||||
|
db.commit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.api.auth.get_user_session", return_value=None)
|
||||||
|
def test_logout_handles_missing_session(mock_get_user_session):
|
||||||
|
db = MagicMock()
|
||||||
|
request = SimpleNamespace()
|
||||||
|
|
||||||
|
result = auth.logout(request, db)
|
||||||
|
|
||||||
|
assert result == {"message": "Logged out successfully"}
|
||||||
|
db.delete.assert_not_called()
|
||||||
|
db.commit.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||||
|
@patch("memory.api.auth.make_session")
|
||||||
|
async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
||||||
|
mock_session = MagicMock()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session_cm():
|
||||||
|
yield mock_session
|
||||||
|
|
||||||
|
mock_make_session.return_value = session_cm()
|
||||||
|
|
||||||
|
mcp_server = MagicMock()
|
||||||
|
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||||
|
|
||||||
|
mock_complete.return_value = (200, "Authorized")
|
||||||
|
|
||||||
|
request = make_request("code=abc123&state=state456")
|
||||||
|
response = await auth.oauth_callback_discord(request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.body.decode()
|
||||||
|
assert "Authorization Successful" in body
|
||||||
|
assert "Authorized" in body
|
||||||
|
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||||
|
@patch("memory.api.auth.make_session")
|
||||||
|
async def test_oauth_callback_discord_handles_failures(
|
||||||
|
mock_make_session, mock_complete
|
||||||
|
):
|
||||||
|
mock_session = MagicMock()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session_cm():
|
||||||
|
yield mock_session
|
||||||
|
|
||||||
|
mock_make_session.return_value = session_cm()
|
||||||
|
|
||||||
|
mcp_server = MagicMock()
|
||||||
|
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||||
|
|
||||||
|
mock_complete.return_value = (500, "Failure")
|
||||||
|
|
||||||
|
request = make_request("code=abc123&state=state456")
|
||||||
|
response = await auth.oauth_callback_discord(request)
|
||||||
|
|
||||||
|
assert response.status_code == 500
|
||||||
|
body = response.body.decode()
|
||||||
|
assert "Authorization Failed" in body
|
||||||
|
assert "Failure" in body
|
||||||
|
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_discord_validates_query_params():
|
||||||
|
request = make_request("code=&state=")
|
||||||
|
|
||||||
|
response = await auth.oauth_callback_discord(request)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
body = response.body.decode()
|
||||||
|
assert "Missing authorization code" in body
|
||||||
@ -1,5 +1,7 @@
|
|||||||
"""Tests for Discord database models."""
|
"""Tests for Discord database models."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from memory.common.db.models import DiscordServer, DiscordChannel, DiscordUser
|
from memory.common.db.models import DiscordServer, DiscordChannel, DiscordUser
|
||||||
|
|
||||||
@ -19,12 +21,11 @@ def test_create_discord_server(db_session):
|
|||||||
assert server.name == "Test Server"
|
assert server.name == "Test Server"
|
||||||
assert server.description == "A test Discord server"
|
assert server.description == "A test Discord server"
|
||||||
assert server.member_count == 100
|
assert server.member_count == 100
|
||||||
assert server.track_messages is True # default value
|
assert server.ignore_messages is False # default value
|
||||||
assert server.ignore_messages is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_discord_server_as_xml(db_session):
|
def test_discord_server_as_xml(db_session):
|
||||||
"""Test DiscordServer.as_xml() method."""
|
"""Test DiscordServer.to_xml() method."""
|
||||||
server = DiscordServer(
|
server = DiscordServer(
|
||||||
id=123456789,
|
id=123456789,
|
||||||
name="Test Server",
|
name="Test Server",
|
||||||
@ -33,11 +34,11 @@ def test_discord_server_as_xml(db_session):
|
|||||||
db_session.add(server)
|
db_session.add(server)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
xml = server.as_xml()
|
xml = server.to_xml("name", "summary")
|
||||||
assert "<servers>" in xml # tablename is discord_servers, strips to "servers"
|
assert "<server>" in xml # tablename is discord_servers, strips to "server"
|
||||||
assert "<name>Test Server</name>" in xml
|
assert "<name>" in xml and "Test Server" in xml
|
||||||
assert "<summary>This is a test server for gaming</summary>" in xml
|
assert "<summary>" in xml and "This is a test server for gaming" in xml
|
||||||
assert "</servers>" in xml
|
assert "</server>" in xml
|
||||||
|
|
||||||
|
|
||||||
def test_discord_server_message_tracking(db_session):
|
def test_discord_server_message_tracking(db_session):
|
||||||
@ -45,13 +46,11 @@ def test_discord_server_message_tracking(db_session):
|
|||||||
server = DiscordServer(
|
server = DiscordServer(
|
||||||
id=123456789,
|
id=123456789,
|
||||||
name="Test Server",
|
name="Test Server",
|
||||||
track_messages=False,
|
|
||||||
ignore_messages=True,
|
ignore_messages=True,
|
||||||
)
|
)
|
||||||
db_session.add(server)
|
db_session.add(server)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
assert server.track_messages is False
|
|
||||||
assert server.ignore_messages is True
|
assert server.ignore_messages is True
|
||||||
|
|
||||||
|
|
||||||
@ -111,7 +110,7 @@ def test_discord_channel_without_server(db_session):
|
|||||||
|
|
||||||
|
|
||||||
def test_discord_channel_as_xml(db_session):
|
def test_discord_channel_as_xml(db_session):
|
||||||
"""Test DiscordChannel.as_xml() method."""
|
"""Test DiscordChannel.to_xml() method."""
|
||||||
channel = DiscordChannel(
|
channel = DiscordChannel(
|
||||||
id=111222333,
|
id=111222333,
|
||||||
name="general",
|
name="general",
|
||||||
@ -121,30 +120,28 @@ def test_discord_channel_as_xml(db_session):
|
|||||||
db_session.add(channel)
|
db_session.add(channel)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
xml = channel.as_xml()
|
xml = channel.to_xml("name", "summary")
|
||||||
assert "<channels>" in xml # tablename is discord_channels, strips to "channels"
|
assert "<channel>" in xml # tablename is discord_channels, strips to "channel"
|
||||||
assert "<name>general</name>" in xml
|
assert "<name>" in xml and "general" in xml
|
||||||
assert "<summary>Main discussion channel</summary>" in xml
|
assert "<summary>" in xml and "Main discussion channel" in xml
|
||||||
assert "</channels>" in xml
|
assert "</channel>" in xml
|
||||||
|
|
||||||
|
|
||||||
def test_discord_channel_inherits_server_settings(db_session):
|
def test_discord_channel_inherits_server_settings(db_session):
|
||||||
"""Test that channels can have their own or inherit server settings."""
|
"""Test that channels can have their own or inherit server settings."""
|
||||||
server = DiscordServer(
|
server = DiscordServer(id=987654321, name="Server", ignore_messages=False)
|
||||||
id=987654321, name="Server", track_messages=True, ignore_messages=False
|
|
||||||
)
|
|
||||||
channel = DiscordChannel(
|
channel = DiscordChannel(
|
||||||
id=111222333,
|
id=111222333,
|
||||||
server_id=server.id,
|
server_id=server.id,
|
||||||
name="announcements",
|
name="announcements",
|
||||||
channel_type="text",
|
channel_type="text",
|
||||||
track_messages=False, # Override server setting
|
ignore_messages=True, # Override server setting
|
||||||
)
|
)
|
||||||
db_session.add_all([server, channel])
|
db_session.add_all([server, channel])
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
assert server.track_messages is True
|
assert server.ignore_messages is False
|
||||||
assert channel.track_messages is False
|
assert channel.ignore_messages is True
|
||||||
|
|
||||||
|
|
||||||
def test_create_discord_user(db_session):
|
def test_create_discord_user(db_session):
|
||||||
@ -186,7 +183,7 @@ def test_discord_user_with_system_user(db_session):
|
|||||||
|
|
||||||
|
|
||||||
def test_discord_user_as_xml(db_session):
|
def test_discord_user_as_xml(db_session):
|
||||||
"""Test DiscordUser.as_xml() method."""
|
"""Test DiscordUser.to_xml() method."""
|
||||||
user = DiscordUser(
|
user = DiscordUser(
|
||||||
id=555666777,
|
id=555666777,
|
||||||
username="testuser",
|
username="testuser",
|
||||||
@ -195,11 +192,10 @@ def test_discord_user_as_xml(db_session):
|
|||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
xml = user.as_xml()
|
xml = user.to_xml("summary")
|
||||||
assert "<users>" in xml # tablename is discord_users, strips to "users"
|
assert "<user>" in xml # tablename is discord_users, strips to "user"
|
||||||
assert "<name>testuser</name>" in xml
|
assert "<summary>" in xml and "Friendly and helpful community member" in xml
|
||||||
assert "<summary>Friendly and helpful community member</summary>" in xml
|
assert "</user>" in xml
|
||||||
assert "</users>" in xml
|
|
||||||
|
|
||||||
|
|
||||||
def test_discord_user_message_preferences(db_session):
|
def test_discord_user_message_preferences(db_session):
|
||||||
@ -207,13 +203,11 @@ def test_discord_user_message_preferences(db_session):
|
|||||||
user = DiscordUser(
|
user = DiscordUser(
|
||||||
id=555666777,
|
id=555666777,
|
||||||
username="testuser",
|
username="testuser",
|
||||||
track_messages=True,
|
|
||||||
ignore_messages=False,
|
ignore_messages=False,
|
||||||
)
|
)
|
||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
assert user.track_messages is True
|
|
||||||
assert user.ignore_messages is False
|
assert user.ignore_messages is False
|
||||||
|
|
||||||
|
|
||||||
@ -234,6 +228,21 @@ def test_discord_server_channel_relationship(db_session):
|
|||||||
assert channel2 in server.channels
|
assert channel2 in server.channels
|
||||||
|
|
||||||
|
|
||||||
|
def test_discord_processor_xml_mcp_servers():
|
||||||
|
"""Test xml_mcp_servers includes assigned MCP server XML."""
|
||||||
|
server = DiscordServer(id=111, name="Server")
|
||||||
|
mcp_stub = SimpleNamespace(
|
||||||
|
as_xml=lambda: "<mcp_server><name>Example</name></mcp_server>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationship is optional for test purposes; assign directly
|
||||||
|
server.mcp_servers = [mcp_stub]
|
||||||
|
|
||||||
|
xml_output = server.xml_mcp_servers()
|
||||||
|
assert "<mcp_server>" in xml_output
|
||||||
|
assert "Example" in xml_output
|
||||||
|
|
||||||
|
|
||||||
def test_discord_server_cascade_delete(db_session):
|
def test_discord_server_cascade_delete(db_session):
|
||||||
"""Test that deleting a server cascades to channels."""
|
"""Test that deleting a server cascades to channels."""
|
||||||
server = DiscordServer(id=987654321, name="Test Server")
|
server = DiscordServer(id=987654321, name="Test Server")
|
||||||
|
|||||||
155
tests/memory/common/db/models/test_mcp_models.py
Normal file
155
tests/memory/common/db/models/test_mcp_models.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from memory.common.db.models.mcp import MCPServer, MCPServerAssignment
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"available_tools,expected_tools",
|
||||||
|
[
|
||||||
|
(["search", "summarize"], ["• search", "• summarize"]),
|
||||||
|
([], []),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mcp_server_as_xml_formats_available_tools(available_tools, expected_tools):
|
||||||
|
server = MCPServer(
|
||||||
|
name="Example Server",
|
||||||
|
mcp_server_url="https://example.com/mcp",
|
||||||
|
client_id="client-123",
|
||||||
|
available_tools=available_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
xml_output = server.as_xml()
|
||||||
|
root = ET.fromstring(xml_output)
|
||||||
|
|
||||||
|
assert root.find("name").text.strip() == "Example Server"
|
||||||
|
assert root.find("mcp_server_url").text.strip() == "https://example.com/mcp"
|
||||||
|
assert root.find("client_id").text.strip() == "client-123"
|
||||||
|
|
||||||
|
tools_element = root.find("available_tools")
|
||||||
|
assert tools_element is not None
|
||||||
|
|
||||||
|
tools_text = tools_element.text.strip() if tools_element.text else ""
|
||||||
|
if expected_tools:
|
||||||
|
assert tools_text.splitlines() == expected_tools
|
||||||
|
else:
|
||||||
|
assert tools_text == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_server_crud_and_token_expiration(db_session):
|
||||||
|
initial_expiry = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||||
|
server = MCPServer(
|
||||||
|
name="Initial Server",
|
||||||
|
mcp_server_url="https://initial.example.com/mcp",
|
||||||
|
client_id="client-initial",
|
||||||
|
available_tools=["search"],
|
||||||
|
access_token="access-123",
|
||||||
|
refresh_token="refresh-123",
|
||||||
|
token_expires_at=initial_expiry,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
fetched = db_session.get(MCPServer, server.id)
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.access_token == "access-123"
|
||||||
|
assert fetched.refresh_token == "refresh-123"
|
||||||
|
assert fetched.token_expires_at == initial_expiry
|
||||||
|
assert fetched.token_expires_at.tzinfo is not None
|
||||||
|
|
||||||
|
new_expiry = initial_expiry + timedelta(minutes=15)
|
||||||
|
fetched.name = "Updated Server"
|
||||||
|
fetched.available_tools = [*fetched.available_tools, "summarize"]
|
||||||
|
fetched.access_token = "access-456"
|
||||||
|
fetched.refresh_token = "refresh-456"
|
||||||
|
fetched.token_expires_at = new_expiry
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
updated = db_session.get(MCPServer, server.id)
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.name == "Updated Server"
|
||||||
|
assert updated.available_tools == ["search", "summarize"]
|
||||||
|
assert updated.access_token == "access-456"
|
||||||
|
assert updated.refresh_token == "refresh-456"
|
||||||
|
assert updated.token_expires_at == new_expiry
|
||||||
|
|
||||||
|
db_session.delete(updated)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
assert db_session.get(MCPServer, server.id) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_server_assignments_relationship_and_cascade(db_session):
|
||||||
|
server = MCPServer(
|
||||||
|
name="Cascade Server",
|
||||||
|
mcp_server_url="https://cascade.example.com/mcp",
|
||||||
|
client_id="client-cascade",
|
||||||
|
available_tools=["search"],
|
||||||
|
)
|
||||||
|
server.assignments.extend(
|
||||||
|
[
|
||||||
|
MCPServerAssignment(entity_type="DiscordUser", entity_id=101),
|
||||||
|
MCPServerAssignment(entity_type="DiscordChannel", entity_id=202),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
persisted_server = db_session.get(MCPServer, server.id)
|
||||||
|
assert persisted_server is not None
|
||||||
|
assert len(persisted_server.assignments) == 2
|
||||||
|
assert {assignment.entity_type for assignment in persisted_server.assignments} == {
|
||||||
|
"DiscordUser",
|
||||||
|
"DiscordChannel",
|
||||||
|
}
|
||||||
|
assert all(
|
||||||
|
assignment.mcp_server_id == persisted_server.id
|
||||||
|
for assignment in persisted_server.assignments
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.delete(persisted_server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
remaining_assignments = db_session.query(MCPServerAssignment).all()
|
||||||
|
assert remaining_assignments == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_server_assignment_unique_constraint(db_session):
|
||||||
|
server = MCPServer(
|
||||||
|
name="Unique Server",
|
||||||
|
mcp_server_url="https://unique.example.com/mcp",
|
||||||
|
client_id="client-unique",
|
||||||
|
available_tools=["search"],
|
||||||
|
)
|
||||||
|
assignment = MCPServerAssignment(
|
||||||
|
entity_type="DiscordUser",
|
||||||
|
entity_id=12345,
|
||||||
|
)
|
||||||
|
server.assignments.append(assignment)
|
||||||
|
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
duplicate_assignment = MCPServerAssignment(
|
||||||
|
mcp_server_id=server.id,
|
||||||
|
entity_type="DiscordUser",
|
||||||
|
entity_id=12345,
|
||||||
|
)
|
||||||
|
db_session.add(duplicate_assignment)
|
||||||
|
|
||||||
|
with pytest.raises(IntegrityError):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
db_session.rollback()
|
||||||
|
|
||||||
|
assignments = (
|
||||||
|
db_session.query(MCPServerAssignment)
|
||||||
|
.filter(MCPServerAssignment.mcp_server_id == server.id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(assignments) == 1
|
||||||
@ -162,8 +162,18 @@ def test_create_bot_user_auto_api_key(db_session):
|
|||||||
|
|
||||||
def test_create_discord_bot_user(db_session):
|
def test_create_discord_bot_user(db_session):
|
||||||
"""Test creating a DiscordBotUser"""
|
"""Test creating a DiscordBotUser"""
|
||||||
|
from memory.common.db.models import DiscordUser
|
||||||
|
|
||||||
|
# Create a Discord user for the bot
|
||||||
|
discord_user = DiscordUser(
|
||||||
|
id=123456789,
|
||||||
|
username="botuser",
|
||||||
|
)
|
||||||
|
db_session.add(discord_user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
user = DiscordBotUser.create_with_api_key(
|
user = DiscordBotUser.create_with_api_key(
|
||||||
discord_users=[],
|
discord_users=[discord_user],
|
||||||
name="Discord Bot",
|
name="Discord Bot",
|
||||||
email="discordbot@example.com",
|
email="discordbot@example.com",
|
||||||
api_key="discord_key_123",
|
api_key="discord_key_123",
|
||||||
@ -176,6 +186,7 @@ def test_create_discord_bot_user(db_session):
|
|||||||
assert user.name == "Discord Bot"
|
assert user.name == "Discord Bot"
|
||||||
assert user.user_type == "discord_bot"
|
assert user.user_type == "discord_bot"
|
||||||
assert user.api_key == "discord_key_123"
|
assert user.api_key == "discord_key_123"
|
||||||
|
assert len(user.discord_users) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_user_serialization_human(db_session):
|
def test_user_serialization_human(db_session):
|
||||||
|
|||||||
@ -131,7 +131,7 @@ def test_build_request_kwargs_basic(provider):
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
||||||
|
|
||||||
kwargs = provider._build_request_kwargs(messages, None, None, settings)
|
kwargs = provider._build_request_kwargs(messages, None, None, None, settings)
|
||||||
|
|
||||||
assert kwargs["model"] == "claude-3-opus-20240229"
|
assert kwargs["model"] == "claude-3-opus-20240229"
|
||||||
assert kwargs["temperature"] == 0.5
|
assert kwargs["temperature"] == 0.5
|
||||||
@ -143,7 +143,9 @@ def test_build_request_kwargs_with_system_prompt(provider):
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
settings = LLMSettings()
|
settings = LLMSettings()
|
||||||
|
|
||||||
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
|
kwargs = provider._build_request_kwargs(
|
||||||
|
messages, "system prompt", None, None, settings
|
||||||
|
)
|
||||||
|
|
||||||
assert kwargs["system"] == "system prompt"
|
assert kwargs["system"] == "system prompt"
|
||||||
|
|
||||||
@ -160,7 +162,7 @@ def test_build_request_kwargs_with_tools(provider):
|
|||||||
]
|
]
|
||||||
settings = LLMSettings()
|
settings = LLMSettings()
|
||||||
|
|
||||||
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
|
kwargs = provider._build_request_kwargs(messages, None, tools, None, settings)
|
||||||
|
|
||||||
assert "tools" in kwargs
|
assert "tools" in kwargs
|
||||||
assert len(kwargs["tools"]) == 1
|
assert len(kwargs["tools"]) == 1
|
||||||
@ -170,7 +172,9 @@ def test_build_request_kwargs_with_thinking(thinking_provider):
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
settings = LLMSettings(max_tokens=5000)
|
settings = LLMSettings(max_tokens=5000)
|
||||||
|
|
||||||
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
|
kwargs = thinking_provider._build_request_kwargs(
|
||||||
|
messages, None, None, None, settings
|
||||||
|
)
|
||||||
|
|
||||||
assert "thinking" in kwargs
|
assert "thinking" in kwargs
|
||||||
assert kwargs["thinking"]["type"] == "enabled"
|
assert kwargs["thinking"]["type"] == "enabled"
|
||||||
@ -183,7 +187,9 @@ def test_build_request_kwargs_thinking_insufficient_tokens(thinking_provider):
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
settings = LLMSettings(max_tokens=1000)
|
settings = LLMSettings(max_tokens=1000)
|
||||||
|
|
||||||
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
|
kwargs = thinking_provider._build_request_kwargs(
|
||||||
|
messages, None, None, None, settings
|
||||||
|
)
|
||||||
|
|
||||||
# Shouldn't enable thinking if not enough tokens
|
# Shouldn't enable thinking if not enough tokens
|
||||||
assert "thinking" not in kwargs
|
assert "thinking" not in kwargs
|
||||||
@ -326,7 +332,7 @@ async def test_agenerate_basic(provider, mock_anthropic_client):
|
|||||||
|
|
||||||
result = await provider.agenerate(messages)
|
result = await provider.agenerate(messages)
|
||||||
|
|
||||||
assert result == "test summary"
|
assert "<summary>test summary</summary>" in result
|
||||||
provider.async_client.messages.create.assert_called_once()
|
provider.async_client.messages.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock, AsyncMock
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common.llms.openai_provider import OpenAIProvider
|
from memory.common.llms.openai_provider import OpenAIProvider
|
||||||
@ -192,7 +192,7 @@ def test_build_request_kwargs_basic(provider):
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
||||||
|
|
||||||
kwargs = provider._build_request_kwargs(messages, None, None, settings)
|
kwargs = provider._build_request_kwargs(messages, None, None, None, settings)
|
||||||
|
|
||||||
assert kwargs["model"] == "gpt-4o"
|
assert kwargs["model"] == "gpt-4o"
|
||||||
assert kwargs["temperature"] == 0.5
|
assert kwargs["temperature"] == 0.5
|
||||||
@ -204,7 +204,9 @@ def test_build_request_kwargs_with_system_prompt_standard_model(provider):
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
settings = LLMSettings()
|
settings = LLMSettings()
|
||||||
|
|
||||||
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
|
kwargs = provider._build_request_kwargs(
|
||||||
|
messages, "system prompt", None, None, settings
|
||||||
|
)
|
||||||
|
|
||||||
# For gpt-4o, system prompt becomes system message
|
# For gpt-4o, system prompt becomes system message
|
||||||
assert kwargs["messages"][0]["role"] == "system"
|
assert kwargs["messages"][0]["role"] == "system"
|
||||||
@ -218,7 +220,7 @@ def test_build_request_kwargs_with_system_prompt_reasoning_model(
|
|||||||
settings = LLMSettings()
|
settings = LLMSettings()
|
||||||
|
|
||||||
kwargs = reasoning_provider._build_request_kwargs(
|
kwargs = reasoning_provider._build_request_kwargs(
|
||||||
messages, "system prompt", None, settings
|
messages, "system prompt", None, None, settings
|
||||||
)
|
)
|
||||||
|
|
||||||
# For o1 models, system prompt becomes developer message
|
# For o1 models, system prompt becomes developer message
|
||||||
@ -232,7 +234,9 @@ def test_build_request_kwargs_reasoning_model_uses_max_completion_tokens(
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
settings = LLMSettings(max_tokens=2000)
|
settings = LLMSettings(max_tokens=2000)
|
||||||
|
|
||||||
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
|
kwargs = reasoning_provider._build_request_kwargs(
|
||||||
|
messages, None, None, None, settings
|
||||||
|
)
|
||||||
|
|
||||||
# Reasoning models use max_completion_tokens
|
# Reasoning models use max_completion_tokens
|
||||||
assert "max_completion_tokens" in kwargs
|
assert "max_completion_tokens" in kwargs
|
||||||
@ -244,7 +248,9 @@ def test_build_request_kwargs_reasoning_model_no_temperature(reasoning_provider)
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
settings = LLMSettings(temperature=0.7)
|
settings = LLMSettings(temperature=0.7)
|
||||||
|
|
||||||
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
|
kwargs = reasoning_provider._build_request_kwargs(
|
||||||
|
messages, None, None, None, settings
|
||||||
|
)
|
||||||
|
|
||||||
# Reasoning models don't support temperature
|
# Reasoning models don't support temperature
|
||||||
assert "temperature" not in kwargs
|
assert "temperature" not in kwargs
|
||||||
@ -263,7 +269,7 @@ def test_build_request_kwargs_with_tools(provider):
|
|||||||
]
|
]
|
||||||
settings = LLMSettings()
|
settings = LLMSettings()
|
||||||
|
|
||||||
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
|
kwargs = provider._build_request_kwargs(messages, None, tools, None, settings)
|
||||||
|
|
||||||
assert "tools" in kwargs
|
assert "tools" in kwargs
|
||||||
assert len(kwargs["tools"]) == 1
|
assert len(kwargs["tools"]) == 1
|
||||||
@ -274,7 +280,9 @@ def test_build_request_kwargs_with_stream(provider):
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
settings = LLMSettings()
|
settings = LLMSettings()
|
||||||
|
|
||||||
kwargs = provider._build_request_kwargs(messages, None, None, settings, stream=True)
|
kwargs = provider._build_request_kwargs(
|
||||||
|
messages, None, None, None, settings, stream=True
|
||||||
|
)
|
||||||
|
|
||||||
assert kwargs["stream"] is True
|
assert kwargs["stream"] is True
|
||||||
|
|
||||||
@ -314,7 +322,8 @@ def test_handle_stream_chunk_text_content(provider):
|
|||||||
delta=Mock(content="hello", tool_calls=None),
|
delta=Mock(content="hello", tool_calls=None),
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||||
)
|
)
|
||||||
|
|
||||||
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
@ -342,8 +351,9 @@ def test_handle_stream_chunk_tool_call_start(provider):
|
|||||||
choice.delta = delta
|
choice.delta = delta
|
||||||
choice.finish_reason = None
|
choice.finish_reason = None
|
||||||
|
|
||||||
chunk = Mock(spec=["choices"])
|
chunk = Mock(spec=["choices", "usage"])
|
||||||
chunk.choices = [choice]
|
chunk.choices = [choice]
|
||||||
|
chunk.usage = Mock(prompt_tokens=10, completion_tokens=5)
|
||||||
|
|
||||||
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
@ -369,7 +379,8 @@ def test_handle_stream_chunk_tool_call_arguments(provider):
|
|||||||
),
|
),
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||||
)
|
)
|
||||||
|
|
||||||
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
||||||
@ -386,7 +397,8 @@ def test_handle_stream_chunk_finish_with_tool_call(provider):
|
|||||||
delta=Mock(content=None, tool_calls=None),
|
delta=Mock(content=None, tool_calls=None),
|
||||||
finish_reason="tool_calls",
|
finish_reason="tool_calls",
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||||
)
|
)
|
||||||
|
|
||||||
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
||||||
@ -399,7 +411,7 @@ def test_handle_stream_chunk_finish_with_tool_call(provider):
|
|||||||
|
|
||||||
|
|
||||||
def test_handle_stream_chunk_empty_choices(provider):
|
def test_handle_stream_chunk_empty_choices(provider):
|
||||||
chunk = Mock(choices=[])
|
chunk = Mock(choices=[], usage=Mock(prompt_tokens=10, completion_tokens=5))
|
||||||
|
|
||||||
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||||
|
|
||||||
@ -435,8 +447,13 @@ async def test_agenerate_basic(provider, mock_openai_client):
|
|||||||
messages = [Message(role=MessageRole.USER, content="test")]
|
messages = [Message(role=MessageRole.USER, content="test")]
|
||||||
|
|
||||||
# Mock the async client
|
# Mock the async client
|
||||||
mock_response = Mock(choices=[Mock(message=Mock(content="async response"))])
|
mock_response = Mock(
|
||||||
provider.async_client.chat.completions.create = Mock(return_value=mock_response)
|
choices=[Mock(message=Mock(content="async response"))],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=20),
|
||||||
|
)
|
||||||
|
provider.async_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
result = await provider.agenerate(messages)
|
result = await provider.agenerate(messages)
|
||||||
|
|
||||||
@ -452,15 +469,19 @@ async def test_astream_basic(provider, mock_openai_client):
|
|||||||
yield Mock(
|
yield Mock(
|
||||||
choices=[
|
choices=[
|
||||||
Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None)
|
Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None)
|
||||||
]
|
],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||||
)
|
)
|
||||||
yield Mock(
|
yield Mock(
|
||||||
choices=[
|
choices=[
|
||||||
Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop")
|
Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop")
|
||||||
]
|
],
|
||||||
|
usage=Mock(prompt_tokens=10, completion_tokens=10),
|
||||||
)
|
)
|
||||||
|
|
||||||
provider.async_client.chat.completions.create = Mock(return_value=async_stream())
|
provider.async_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=async_stream()
|
||||||
|
)
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
async for event in provider.astream(messages):
|
async for event in provider.astream(messages):
|
||||||
|
|||||||
@ -18,10 +18,10 @@ except ModuleNotFoundError: # pragma: no cover - import guard for test envs
|
|||||||
|
|
||||||
sys.modules.setdefault("redis", _RedisStub())
|
sys.modules.setdefault("redis", _RedisStub())
|
||||||
|
|
||||||
from memory.common.llms.redis_usage_tracker import RedisUsageTracker
|
from memory.common.llms.usage import (
|
||||||
from memory.common.llms.usage_tracker import (
|
|
||||||
InMemoryUsageTracker,
|
InMemoryUsageTracker,
|
||||||
RateLimitConfig,
|
RateLimitConfig,
|
||||||
|
RedisUsageTracker,
|
||||||
UsageTracker,
|
UsageTracker,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,7 +84,9 @@ def redis_tracker() -> RedisUsageTracker:
|
|||||||
(timedelta(seconds=0), {"max_total_tokens": 1}),
|
(timedelta(seconds=0), {"max_total_tokens": 1}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_rate_limit_config_validation(window: timedelta, kwargs: dict[str, int]) -> None:
|
def test_rate_limit_config_validation(
|
||||||
|
window: timedelta, kwargs: dict[str, int]
|
||||||
|
) -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
RateLimitConfig(window=window, **kwargs)
|
RateLimitConfig(window=window, **kwargs)
|
||||||
|
|
||||||
@ -93,9 +95,7 @@ def test_allows_usage_within_limits(tracker: InMemoryUsageTracker) -> None:
|
|||||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||||
|
|
||||||
allowance = tracker.get_available_tokens(
|
allowance = tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
|
||||||
"anthropic/claude-3", timestamp=now
|
|
||||||
)
|
|
||||||
assert allowance is not None
|
assert allowance is not None
|
||||||
assert allowance.input_tokens == 900
|
assert allowance.input_tokens == 900
|
||||||
assert allowance.output_tokens == 1_800
|
assert allowance.output_tokens == 1_800
|
||||||
@ -114,9 +114,7 @@ def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
|
|||||||
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
||||||
|
|
||||||
later = now + timedelta(minutes=2)
|
later = now + timedelta(minutes=2)
|
||||||
allowance = tracker.get_available_tokens(
|
allowance = tracker.get_available_tokens("anthropic/claude-3", timestamp=later)
|
||||||
"anthropic/claude-3", timestamp=later
|
|
||||||
)
|
|
||||||
assert allowance is not None
|
assert allowance is not None
|
||||||
assert allowance.input_tokens == 1_000
|
assert allowance.input_tokens == 1_000
|
||||||
assert allowance.output_tokens == 2_000
|
assert allowance.output_tokens == 2_000
|
||||||
@ -126,6 +124,7 @@ def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
|
|||||||
|
|
||||||
def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None:
|
def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None:
|
||||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
# Use the configured models from the fixture
|
||||||
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||||
tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||||
|
|
||||||
@ -144,6 +143,7 @@ def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> N
|
|||||||
|
|
||||||
def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
||||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
# Use configured models from the fixture
|
||||||
tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now)
|
tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now)
|
||||||
tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now)
|
tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now)
|
||||||
|
|
||||||
@ -156,15 +156,19 @@ def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
|||||||
assert set(filtered_model["openai"].keys()) == {"gpt-4o"}
|
assert set(filtered_model["openai"].keys()) == {"gpt-4o"}
|
||||||
|
|
||||||
|
|
||||||
def test_missing_configuration_records_lifetime_only() -> None:
|
def test_missing_configuration_uses_default() -> None:
|
||||||
|
# With no specific config, falls back to default config (from settings)
|
||||||
tracker = InMemoryUsageTracker(configs={})
|
tracker = InMemoryUsageTracker(configs={})
|
||||||
tracker.record_usage("openai/gpt-4o", 10, 20)
|
tracker.record_usage("openai/gpt-4o", 10, 20)
|
||||||
|
|
||||||
assert tracker.get_available_tokens("openai/gpt-4o") is None
|
# Uses default config, so get_available_tokens returns allowance
|
||||||
|
allowance = tracker.get_available_tokens("openai/gpt-4o")
|
||||||
|
assert allowance is not None
|
||||||
|
|
||||||
|
# Lifetime stats are tracked
|
||||||
breakdown = tracker.get_usage_breakdown()
|
breakdown = tracker.get_usage_breakdown()
|
||||||
usage = breakdown["openai"]["gpt-4o"]
|
usage = breakdown["openai"]["gpt-4o"]
|
||||||
assert usage.window_input_tokens == 0
|
assert usage.window_input_tokens == 10
|
||||||
assert usage.lifetime_input_tokens == 10
|
assert usage.lifetime_input_tokens == 10
|
||||||
|
|
||||||
|
|
||||||
@ -193,6 +197,7 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
|
|||||||
|
|
||||||
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
|
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
|
||||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||||
|
# Use configured models from the fixture
|
||||||
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||||
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||||
|
|
||||||
@ -201,6 +206,8 @@ def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) ->
|
|||||||
assert allowance.input_tokens == 900
|
assert allowance.input_tokens == 900
|
||||||
|
|
||||||
breakdown = redis_tracker.get_usage_breakdown()
|
breakdown = redis_tracker.get_usage_breakdown()
|
||||||
|
assert "anthropic" in breakdown
|
||||||
|
assert "claude-3" in breakdown["anthropic"]
|
||||||
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
|
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
|
||||||
|
|
||||||
items = dict(redis_tracker.iter_state_items())
|
items = dict(redis_tracker.iter_state_items())
|
||||||
|
|||||||
26
tests/memory/common/llms/tools/test_base_tools.py
Normal file
26
tests/memory/common/llms/tools/test_base_tools.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
"""Tests for base web tool definitions."""
|
||||||
|
|
||||||
|
from memory.common.llms.tools.base import WebFetchTool, WebSearchTool
|
||||||
|
|
||||||
|
|
||||||
|
def test_web_search_tool_provider_formats():
|
||||||
|
tool = WebSearchTool()
|
||||||
|
|
||||||
|
assert tool.provider_format("openai") == {"type": "web_search"}
|
||||||
|
assert tool.provider_format("anthropic") == {
|
||||||
|
"type": "web_search_20250305",
|
||||||
|
"name": "web_search",
|
||||||
|
"max_uses": 10,
|
||||||
|
}
|
||||||
|
assert tool.provider_format("unknown") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_web_fetch_tool_provider_formats():
|
||||||
|
tool = WebFetchTool()
|
||||||
|
|
||||||
|
assert tool.provider_format("anthropic") == {
|
||||||
|
"type": "web_fetch_20250910",
|
||||||
|
"name": "web_fetch",
|
||||||
|
"max_uses": 10,
|
||||||
|
}
|
||||||
|
assert tool.provider_format("openai") is None
|
||||||
@ -497,13 +497,14 @@ def test_make_discord_tools_with_user_and_channel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||||
# update_user_summary, update_server_summary
|
# update_user_summary, update_server_summary, add_reaction
|
||||||
assert len(tools) == 5
|
assert len(tools) == 6
|
||||||
assert "schedule_message" in tools
|
assert "schedule_message" in tools
|
||||||
assert "previous_messages" in tools
|
assert "previous_messages" in tools
|
||||||
assert "update_channel_summary" in tools
|
assert "update_channel_summary" in tools
|
||||||
assert "update_user_summary" in tools
|
assert "update_user_summary" in tools
|
||||||
assert "update_server_summary" in tools
|
assert "update_server_summary" in tools
|
||||||
|
assert "add_reaction" in tools
|
||||||
|
|
||||||
|
|
||||||
def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user):
|
def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user):
|
||||||
@ -533,12 +534,13 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||||
# update_server_summary (no user summary without author)
|
# update_server_summary, add_reaction (no user summary without author)
|
||||||
assert len(tools) == 4
|
assert len(tools) == 5
|
||||||
assert "schedule_message" in tools
|
assert "schedule_message" in tools
|
||||||
assert "previous_messages" in tools
|
assert "previous_messages" in tools
|
||||||
assert "update_channel_summary" in tools
|
assert "update_channel_summary" in tools
|
||||||
assert "update_server_summary" in tools
|
assert "update_server_summary" in tools
|
||||||
|
assert "add_reaction" in tools
|
||||||
assert "update_user_summary" not in tools
|
assert "update_user_summary" not in tools
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
539
tests/memory/common/test_oauth.py
Normal file
539
tests/memory/common/test_oauth.py
Normal file
@ -0,0 +1,539 @@
|
|||||||
|
"""Tests for OAuth 2.0 flow handling."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from memory.common.oauth import (
|
||||||
|
OAuthEndpoints,
|
||||||
|
generate_pkce_pair,
|
||||||
|
discover_oauth_metadata,
|
||||||
|
get_endpoints,
|
||||||
|
register_oauth_client,
|
||||||
|
issue_challenge,
|
||||||
|
complete_oauth_flow,
|
||||||
|
)
|
||||||
|
from memory.common.db.models import MCPServer
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeneratePkcePair:
|
||||||
|
"""Tests for generate_pkce_pair function."""
|
||||||
|
|
||||||
|
def test_generates_valid_verifier_and_challenge(self):
|
||||||
|
"""Test that PKCE pair is generated correctly."""
|
||||||
|
verifier, challenge = generate_pkce_pair()
|
||||||
|
|
||||||
|
# Verifier should be base64url encoded (no padding)
|
||||||
|
assert len(verifier) > 0
|
||||||
|
assert "=" not in verifier
|
||||||
|
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" for c in verifier)
|
||||||
|
|
||||||
|
# Challenge should be base64url encoded (no padding)
|
||||||
|
assert len(challenge) > 0
|
||||||
|
assert "=" not in challenge
|
||||||
|
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" for c in challenge)
|
||||||
|
|
||||||
|
# They should be different
|
||||||
|
assert verifier != challenge
|
||||||
|
|
||||||
|
def test_generates_unique_pairs(self):
|
||||||
|
"""Test that each call generates a unique pair."""
|
||||||
|
verifier1, challenge1 = generate_pkce_pair()
|
||||||
|
verifier2, challenge2 = generate_pkce_pair()
|
||||||
|
|
||||||
|
assert verifier1 != verifier2
|
||||||
|
assert challenge1 != challenge2
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscoverOauthMetadata:
|
||||||
|
"""Tests for discover_oauth_metadata function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_metadata_success(self):
|
||||||
|
"""Test successful OAuth metadata discovery."""
|
||||||
|
metadata = {
|
||||||
|
"authorization_endpoint": "https://example.com/auth",
|
||||||
|
"registration_endpoint": "https://example.com/register",
|
||||||
|
"token_endpoint": "https://example.com/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value=metadata)
|
||||||
|
|
||||||
|
mock_get = AsyncMock()
|
||||||
|
mock_get.__aenter__.return_value = mock_response
|
||||||
|
mock_get.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.get = Mock(return_value=mock_get)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
result = await discover_oauth_metadata("https://example.com")
|
||||||
|
|
||||||
|
assert result == metadata
|
||||||
|
assert result["authorization_endpoint"] == "https://example.com/auth"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_metadata_not_found(self):
|
||||||
|
"""Test OAuth metadata discovery when endpoint not found."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 404
|
||||||
|
|
||||||
|
mock_get = AsyncMock()
|
||||||
|
mock_get.__aenter__.return_value = mock_response
|
||||||
|
mock_get.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.get = Mock(return_value=mock_get)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
result = await discover_oauth_metadata("https://example.com")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_metadata_connection_error(self):
|
||||||
|
"""Test OAuth metadata discovery with connection error."""
|
||||||
|
mock_get = AsyncMock()
|
||||||
|
mock_get.__aenter__.side_effect = aiohttp.ClientError("Connection failed")
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.get = Mock(return_value=mock_get)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
result = await discover_oauth_metadata("https://example.com")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetEndpoints:
|
||||||
|
"""Tests for get_endpoints function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_endpoints_success(self):
|
||||||
|
"""Test successful endpoint retrieval."""
|
||||||
|
metadata = {
|
||||||
|
"authorization_endpoint": "https://example.com/auth",
|
||||||
|
"registration_endpoint": "https://example.com/register",
|
||||||
|
"token_endpoint": "https://example.com/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||||
|
result = await get_endpoints("https://example.com")
|
||||||
|
|
||||||
|
assert isinstance(result, OAuthEndpoints)
|
||||||
|
assert result.authorization_endpoint == "https://example.com/auth"
|
||||||
|
assert result.registration_endpoint == "https://example.com/register"
|
||||||
|
assert result.token_endpoint == "https://example.com/token"
|
||||||
|
assert "/auth/callback/discord" in result.redirect_uri
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_endpoints_no_metadata(self):
|
||||||
|
"""Test when OAuth metadata cannot be discovered."""
|
||||||
|
with patch("memory.common.oauth.discover_oauth_metadata", return_value=None):
|
||||||
|
with pytest.raises(ValueError, match="Failed to connect to MCP server"):
|
||||||
|
await get_endpoints("https://example.com")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_endpoints_missing_authorization(self):
|
||||||
|
"""Test when authorization endpoint is missing."""
|
||||||
|
metadata = {
|
||||||
|
"registration_endpoint": "https://example.com/register",
|
||||||
|
"token_endpoint": "https://example.com/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||||
|
with pytest.raises(ValueError, match="authorization endpoint"):
|
||||||
|
await get_endpoints("https://example.com")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_endpoints_missing_registration(self):
|
||||||
|
"""Test when registration endpoint is missing."""
|
||||||
|
metadata = {
|
||||||
|
"authorization_endpoint": "https://example.com/auth",
|
||||||
|
"token_endpoint": "https://example.com/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||||
|
with pytest.raises(ValueError, match="dynamic client registration"):
|
||||||
|
await get_endpoints("https://example.com")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_endpoints_missing_token(self):
|
||||||
|
"""Test when token endpoint is missing."""
|
||||||
|
metadata = {
|
||||||
|
"authorization_endpoint": "https://example.com/auth",
|
||||||
|
"registration_endpoint": "https://example.com/register",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||||
|
with pytest.raises(ValueError, match="token endpoint"):
|
||||||
|
await get_endpoints("https://example.com")
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegisterOauthClient:
|
||||||
|
"""Tests for register_oauth_client function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_client_success(self):
|
||||||
|
"""Test successful OAuth client registration."""
|
||||||
|
endpoints = OAuthEndpoints(
|
||||||
|
authorization_endpoint="https://example.com/auth",
|
||||||
|
registration_endpoint="https://example.com/register",
|
||||||
|
token_endpoint="https://example.com/token",
|
||||||
|
redirect_uri="https://myapp.com/callback",
|
||||||
|
)
|
||||||
|
|
||||||
|
client_info = {"client_id": "test-client-123"}
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.text = AsyncMock(return_value="Success")
|
||||||
|
mock_response.json = AsyncMock(return_value=client_info)
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
client_id = await register_oauth_client(
|
||||||
|
endpoints,
|
||||||
|
"https://example.com",
|
||||||
|
"Test Client",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert client_id == "test-client-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_client_http_error(self):
|
||||||
|
"""Test OAuth client registration with HTTP error."""
|
||||||
|
endpoints = OAuthEndpoints(
|
||||||
|
authorization_endpoint="https://example.com/auth",
|
||||||
|
registration_endpoint="https://example.com/register",
|
||||||
|
token_endpoint="https://example.com/token",
|
||||||
|
redirect_uri="https://myapp.com/callback",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.raise_for_status = Mock(side_effect=aiohttp.ClientResponseError(
|
||||||
|
request_info=Mock(),
|
||||||
|
history=(),
|
||||||
|
status=400,
|
||||||
|
message="Bad Request",
|
||||||
|
))
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
with pytest.raises(ValueError, match="Failed to register OAuth client"):
|
||||||
|
await register_oauth_client(
|
||||||
|
endpoints,
|
||||||
|
"https://example.com",
|
||||||
|
"Test Client",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_client_missing_client_id(self):
|
||||||
|
"""Test OAuth client registration when response lacks client_id."""
|
||||||
|
endpoints = OAuthEndpoints(
|
||||||
|
authorization_endpoint="https://example.com/auth",
|
||||||
|
registration_endpoint="https://example.com/register",
|
||||||
|
token_endpoint="https://example.com/token",
|
||||||
|
redirect_uri="https://myapp.com/callback",
|
||||||
|
)
|
||||||
|
|
||||||
|
client_info = {} # Missing client_id
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value=client_info)
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
with pytest.raises(ValueError, match="Failed to register OAuth client"):
|
||||||
|
await register_oauth_client(
|
||||||
|
endpoints,
|
||||||
|
"https://example.com",
|
||||||
|
"Test Client",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueChallenge:
|
||||||
|
"""Tests for issue_challenge function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_issue_challenge_success(self, db_session):
|
||||||
|
"""Test successful OAuth challenge issuance."""
|
||||||
|
mcp_server = MCPServer(
|
||||||
|
name="Test Server",
|
||||||
|
mcp_server_url="https://example.com",
|
||||||
|
client_id="test-client-123",
|
||||||
|
)
|
||||||
|
db_session.add(mcp_server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
endpoints = OAuthEndpoints(
|
||||||
|
authorization_endpoint="https://example.com/auth",
|
||||||
|
registration_endpoint="https://example.com/register",
|
||||||
|
token_endpoint="https://example.com/token",
|
||||||
|
redirect_uri="https://myapp.com/callback",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("memory.common.oauth.generate_pkce_pair", return_value=("verifier123", "challenge123")):
|
||||||
|
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||||
|
|
||||||
|
# Verify the auth URL contains expected parameters
|
||||||
|
assert "https://example.com/auth?" in auth_url
|
||||||
|
assert "client_id=test-client-123" in auth_url
|
||||||
|
# redirect_uri will be URL encoded
|
||||||
|
assert "redirect_uri=" in auth_url
|
||||||
|
assert "myapp.com" in auth_url
|
||||||
|
assert "callback" in auth_url
|
||||||
|
assert "response_type=code" in auth_url
|
||||||
|
assert "code_challenge=challenge123" in auth_url
|
||||||
|
assert "code_challenge_method=S256" in auth_url
|
||||||
|
assert "state=" in auth_url
|
||||||
|
|
||||||
|
# Verify state and code_verifier were stored
|
||||||
|
assert mcp_server.state is not None
|
||||||
|
assert mcp_server.code_verifier == "verifier123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompleteOauthFlow:
|
||||||
|
"""Tests for complete_oauth_flow function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_oauth_flow_success(self, db_session):
|
||||||
|
"""Test successful OAuth flow completion."""
|
||||||
|
mcp_server = MCPServer(
|
||||||
|
name="Test Server",
|
||||||
|
mcp_server_url="https://example.com",
|
||||||
|
client_id="test-client-123",
|
||||||
|
state="test-state",
|
||||||
|
code_verifier="test-verifier",
|
||||||
|
)
|
||||||
|
db_session.add(mcp_server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"authorization_endpoint": "https://example.com/auth",
|
||||||
|
"registration_endpoint": "https://example.com/register",
|
||||||
|
"token_endpoint": "https://example.com/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
token_response = {
|
||||||
|
"access_token": "access-token-123",
|
||||||
|
"refresh_token": "refresh-token-123",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_token_response = Mock()
|
||||||
|
mock_token_response.status = 200
|
||||||
|
mock_token_response.json = AsyncMock(return_value=token_response)
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_token_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
|
||||||
|
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
|
||||||
|
):
|
||||||
|
status, message = await complete_oauth_flow(
|
||||||
|
mcp_server,
|
||||||
|
"auth-code-123",
|
||||||
|
"test-state",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status == 200
|
||||||
|
assert "successful" in message
|
||||||
|
|
||||||
|
# Verify tokens were stored
|
||||||
|
assert mcp_server.access_token == "access-token-123"
|
||||||
|
assert mcp_server.refresh_token == "refresh-token-123"
|
||||||
|
assert mcp_server.token_expires_at is not None
|
||||||
|
|
||||||
|
# Verify temporary state was cleared
|
||||||
|
assert mcp_server.state is None
|
||||||
|
assert mcp_server.code_verifier is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_oauth_flow_invalid_state(self):
|
||||||
|
"""Test OAuth flow completion with invalid state."""
|
||||||
|
status, message = await complete_oauth_flow(
|
||||||
|
None,
|
||||||
|
"auth-code-123",
|
||||||
|
"invalid-state",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status == 400
|
||||||
|
assert "Invalid or expired" in message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_oauth_flow_token_error(self, db_session):
|
||||||
|
"""Test OAuth flow completion when token exchange fails."""
|
||||||
|
mcp_server = MCPServer(
|
||||||
|
name="Test Server",
|
||||||
|
mcp_server_url="https://example.com",
|
||||||
|
client_id="test-client-123",
|
||||||
|
state="test-state",
|
||||||
|
code_verifier="test-verifier",
|
||||||
|
)
|
||||||
|
db_session.add(mcp_server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"authorization_endpoint": "https://example.com/auth",
|
||||||
|
"registration_endpoint": "https://example.com/register",
|
||||||
|
"token_endpoint": "https://example.com/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_token_response = Mock()
|
||||||
|
mock_token_response.status = 400
|
||||||
|
mock_token_response.text = AsyncMock(return_value="Invalid grant")
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_token_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
|
||||||
|
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
|
||||||
|
):
|
||||||
|
status, message = await complete_oauth_flow(
|
||||||
|
mcp_server,
|
||||||
|
"invalid-code",
|
||||||
|
"test-state",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status == 500
|
||||||
|
assert "Token exchange failed" in message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_oauth_flow_missing_access_token(self, db_session):
|
||||||
|
"""Test OAuth flow completion when access token is missing from response."""
|
||||||
|
mcp_server = MCPServer(
|
||||||
|
name="Test Server",
|
||||||
|
mcp_server_url="https://example.com",
|
||||||
|
client_id="test-client-123",
|
||||||
|
state="test-state",
|
||||||
|
code_verifier="test-verifier",
|
||||||
|
)
|
||||||
|
db_session.add(mcp_server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"authorization_endpoint": "https://example.com/auth",
|
||||||
|
"registration_endpoint": "https://example.com/register",
|
||||||
|
"token_endpoint": "https://example.com/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
token_response = {} # Missing access_token
|
||||||
|
|
||||||
|
mock_token_response = Mock()
|
||||||
|
mock_token_response.status = 200
|
||||||
|
mock_token_response.json = AsyncMock(return_value=token_response)
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_token_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
|
||||||
|
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
|
||||||
|
):
|
||||||
|
status, message = await complete_oauth_flow(
|
||||||
|
mcp_server,
|
||||||
|
"auth-code-123",
|
||||||
|
"test-state",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status == 500
|
||||||
|
assert "did not include access_token" in message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_oauth_flow_get_endpoints_error(self, db_session):
|
||||||
|
"""Test OAuth flow completion when getting endpoints fails."""
|
||||||
|
mcp_server = MCPServer(
|
||||||
|
name="Test Server",
|
||||||
|
mcp_server_url="https://example.com",
|
||||||
|
client_id="test-client-123",
|
||||||
|
state="test-state",
|
||||||
|
code_verifier="test-verifier",
|
||||||
|
)
|
||||||
|
db_session.add(mcp_server)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
with patch("memory.common.oauth.discover_oauth_metadata", return_value=None):
|
||||||
|
status, message = await complete_oauth_flow(
|
||||||
|
mcp_server,
|
||||||
|
"auth-code-123",
|
||||||
|
"test-state",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status == 500
|
||||||
|
assert "Failed to get OAuth endpoints" in message
|
||||||
253
tests/memory/discord_tests/test_api.py
Normal file
253
tests/memory/discord_tests/test_api.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from memory.discord import api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_app_bots():
|
||||||
|
existing = getattr(api.app, "bots", None)
|
||||||
|
api.app.bots = {}
|
||||||
|
yield
|
||||||
|
if existing is None:
|
||||||
|
delattr(api.app, "bots")
|
||||||
|
else:
|
||||||
|
api.app.bots = existing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def active_bot():
|
||||||
|
collector = SimpleNamespace(
|
||||||
|
send_dm=AsyncMock(return_value=True),
|
||||||
|
trigger_typing_dm=AsyncMock(return_value=True),
|
||||||
|
send_to_channel=AsyncMock(return_value=True),
|
||||||
|
trigger_typing_channel=AsyncMock(return_value=True),
|
||||||
|
add_reaction=AsyncMock(return_value=True),
|
||||||
|
refresh_metadata=AsyncMock(return_value={"refreshed": True}),
|
||||||
|
is_closed=Mock(return_value=False),
|
||||||
|
user="CollectorUser#1234",
|
||||||
|
guilds=[101, 202],
|
||||||
|
)
|
||||||
|
bot = SimpleNamespace(
|
||||||
|
collector=collector,
|
||||||
|
collector_task=None,
|
||||||
|
bot_id=1,
|
||||||
|
bot_token="token-123",
|
||||||
|
bot_name="Test Bot",
|
||||||
|
)
|
||||||
|
api.app.bots[bot.bot_id] = bot
|
||||||
|
return bot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_dm_success(active_bot):
|
||||||
|
request = api.SendDMRequest(bot_id=active_bot.bot_id, user="user123", message="Hello")
|
||||||
|
|
||||||
|
response = await api.send_dm_endpoint(request)
|
||||||
|
|
||||||
|
assert response == {
|
||||||
|
"success": True,
|
||||||
|
"message": "DM sent to user123",
|
||||||
|
"user": "user123",
|
||||||
|
}
|
||||||
|
active_bot.collector.send_dm.assert_awaited_once_with("user123", "Hello")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"endpoint,payload",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
api.send_dm_endpoint,
|
||||||
|
api.SendDMRequest(bot_id=99, user="ghost", message="hi"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.trigger_dm_typing,
|
||||||
|
api.TypingDMRequest(bot_id=99, user="ghost"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.send_channel_endpoint,
|
||||||
|
api.SendChannelRequest(bot_id=99, channel="general", message="hello"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.trigger_channel_typing,
|
||||||
|
api.TypingChannelRequest(bot_id=99, channel="general"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.add_reaction_endpoint,
|
||||||
|
api.AddReactionRequest(
|
||||||
|
bot_id=99,
|
||||||
|
channel="general",
|
||||||
|
message_id=42,
|
||||||
|
emoji=":thumbsup:",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_endpoint_returns_404_when_bot_missing(endpoint, payload):
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await endpoint(payload)
|
||||||
|
|
||||||
|
assert exc.value.status_code == 404
|
||||||
|
assert exc.value.detail == "Bot not found"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"endpoint,request_cls,request_kwargs,attr_name,detail_template",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
api.send_dm_endpoint,
|
||||||
|
api.SendDMRequest,
|
||||||
|
{"bot_id": 1, "user": "user123", "message": "Hi"},
|
||||||
|
"send_dm",
|
||||||
|
"Failed to send DM to {user}",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.trigger_dm_typing,
|
||||||
|
api.TypingDMRequest,
|
||||||
|
{"bot_id": 1, "user": "user123"},
|
||||||
|
"trigger_typing_dm",
|
||||||
|
"Failed to trigger typing for user123",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.send_channel_endpoint,
|
||||||
|
api.SendChannelRequest,
|
||||||
|
{"bot_id": 1, "channel": "general", "message": "Hello"},
|
||||||
|
"send_to_channel",
|
||||||
|
"Failed to send message to channel general",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.trigger_channel_typing,
|
||||||
|
api.TypingChannelRequest,
|
||||||
|
{"bot_id": 1, "channel": "general"},
|
||||||
|
"trigger_typing_channel",
|
||||||
|
"Failed to trigger typing for channel general",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.add_reaction_endpoint,
|
||||||
|
api.AddReactionRequest,
|
||||||
|
{"bot_id": 1, "channel": "general", "message_id": 55, "emoji": ":fire:"},
|
||||||
|
"add_reaction",
|
||||||
|
"Failed to add reaction to message 55",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_endpoint_returns_400_on_collector_failure(
|
||||||
|
active_bot, endpoint, request_cls, request_kwargs, attr_name, detail_template
|
||||||
|
):
|
||||||
|
request = request_cls(**request_kwargs)
|
||||||
|
getattr(active_bot.collector, attr_name).return_value = False
|
||||||
|
expected_detail = detail_template.format(**request_kwargs)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await endpoint(request)
|
||||||
|
|
||||||
|
assert exc.value.status_code == 400
|
||||||
|
assert exc.value.detail == expected_detail
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"endpoint,request_cls,request_kwargs,attr_name",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
api.send_dm_endpoint,
|
||||||
|
api.SendDMRequest,
|
||||||
|
{"bot_id": 1, "user": "user123", "message": "Hi"},
|
||||||
|
"send_dm",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.trigger_dm_typing,
|
||||||
|
api.TypingDMRequest,
|
||||||
|
{"bot_id": 1, "user": "user123"},
|
||||||
|
"trigger_typing_dm",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.send_channel_endpoint,
|
||||||
|
api.SendChannelRequest,
|
||||||
|
{"bot_id": 1, "channel": "general", "message": "Hello"},
|
||||||
|
"send_to_channel",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.trigger_channel_typing,
|
||||||
|
api.TypingChannelRequest,
|
||||||
|
{"bot_id": 1, "channel": "general"},
|
||||||
|
"trigger_typing_channel",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
api.add_reaction_endpoint,
|
||||||
|
api.AddReactionRequest,
|
||||||
|
{"bot_id": 1, "channel": "general", "message_id": 55, "emoji": ":fire:"},
|
||||||
|
"add_reaction",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_endpoint_returns_500_on_collector_exception(
|
||||||
|
active_bot, endpoint, request_cls, request_kwargs, attr_name
|
||||||
|
):
|
||||||
|
request = request_cls(**request_kwargs)
|
||||||
|
getattr(active_bot.collector, attr_name).side_effect = RuntimeError("boom")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await endpoint(request)
|
||||||
|
|
||||||
|
assert exc.value.status_code == 500
|
||||||
|
assert "boom" in exc.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_success(active_bot):
|
||||||
|
response = await api.health_check()
|
||||||
|
|
||||||
|
assert response["Test Bot"] == {
|
||||||
|
"status": "healthy",
|
||||||
|
"connected": True,
|
||||||
|
"user": "CollectorUser#1234",
|
||||||
|
"guilds": 2,
|
||||||
|
}
|
||||||
|
active_bot.collector.is_closed.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_without_bots():
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await api.health_check()
|
||||||
|
|
||||||
|
assert exc.value.status_code == 503
|
||||||
|
assert exc.value.detail == "Discord collector not running"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_metadata_success(active_bot):
|
||||||
|
active_bot.collector.refresh_metadata.return_value = {"channels": 3}
|
||||||
|
|
||||||
|
response = await api.refresh_metadata()
|
||||||
|
|
||||||
|
assert response["success"] is True
|
||||||
|
assert response["message"] == "Metadata refreshed successfully for 1 bots"
|
||||||
|
assert response["results"]["Test Bot"] == {"channels": 3}
|
||||||
|
active_bot.collector.refresh_metadata.assert_awaited_once_with()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_metadata_without_bots():
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await api.refresh_metadata()
|
||||||
|
|
||||||
|
assert exc.value.status_code == 503
|
||||||
|
assert exc.value.detail == "Discord collector not running"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_metadata_failure(active_bot):
|
||||||
|
active_bot.collector.refresh_metadata.side_effect = RuntimeError("sync failed")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await api.refresh_metadata()
|
||||||
|
|
||||||
|
assert exc.value.status_code == 500
|
||||||
|
assert "sync failed" in exc.value.detail
|
||||||
@ -79,6 +79,7 @@ def mock_message(mock_text_channel, mock_user):
|
|||||||
message.content = "Test message"
|
message.content = "Test message"
|
||||||
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||||
message.reference = None
|
message.reference = None
|
||||||
|
message.attachments = []
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
@ -351,7 +352,7 @@ def test_determine_message_metadata_thread():
|
|||||||
# Tests for should_track_message
|
# Tests for should_track_message
|
||||||
def test_should_track_message_server_disabled(db_session):
|
def test_should_track_message_server_disabled(db_session):
|
||||||
"""Test when server has tracking disabled"""
|
"""Test when server has tracking disabled"""
|
||||||
server = DiscordServer(id=1, name="Server", track_messages=False)
|
server = DiscordServer(id=1, name="Server", ignore_messages=True)
|
||||||
channel = DiscordChannel(id=2, name="Channel", channel_type="text")
|
channel = DiscordChannel(id=2, name="Channel", channel_type="text")
|
||||||
user = DiscordUser(id=3, username="User")
|
user = DiscordUser(id=3, username="User")
|
||||||
|
|
||||||
@ -362,9 +363,9 @@ def test_should_track_message_server_disabled(db_session):
|
|||||||
|
|
||||||
def test_should_track_message_channel_disabled(db_session):
|
def test_should_track_message_channel_disabled(db_session):
|
||||||
"""Test when channel has tracking disabled"""
|
"""Test when channel has tracking disabled"""
|
||||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
server = DiscordServer(id=1, name="Server", ignore_messages=False)
|
||||||
channel = DiscordChannel(
|
channel = DiscordChannel(
|
||||||
id=2, name="Channel", channel_type="text", track_messages=False
|
id=2, name="Channel", channel_type="text", ignore_messages=True
|
||||||
)
|
)
|
||||||
user = DiscordUser(id=3, username="User")
|
user = DiscordUser(id=3, username="User")
|
||||||
|
|
||||||
@ -375,8 +376,8 @@ def test_should_track_message_channel_disabled(db_session):
|
|||||||
|
|
||||||
def test_should_track_message_dm_allowed(db_session):
|
def test_should_track_message_dm_allowed(db_session):
|
||||||
"""Test DM tracking when user allows it"""
|
"""Test DM tracking when user allows it"""
|
||||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
channel = DiscordChannel(id=2, name="DM", channel_type="dm", ignore_messages=False)
|
||||||
user = DiscordUser(id=3, username="User", track_messages=True)
|
user = DiscordUser(id=3, username="User", ignore_messages=False)
|
||||||
|
|
||||||
result = should_track_message(None, channel, user)
|
result = should_track_message(None, channel, user)
|
||||||
|
|
||||||
@ -385,8 +386,8 @@ def test_should_track_message_dm_allowed(db_session):
|
|||||||
|
|
||||||
def test_should_track_message_dm_not_allowed(db_session):
|
def test_should_track_message_dm_not_allowed(db_session):
|
||||||
"""Test DM tracking when user doesn't allow it"""
|
"""Test DM tracking when user doesn't allow it"""
|
||||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
channel = DiscordChannel(id=2, name="DM", channel_type="dm", ignore_messages=False)
|
||||||
user = DiscordUser(id=3, username="User", track_messages=False)
|
user = DiscordUser(id=3, username="User", ignore_messages=True)
|
||||||
|
|
||||||
result = should_track_message(None, channel, user)
|
result = should_track_message(None, channel, user)
|
||||||
|
|
||||||
@ -395,9 +396,9 @@ def test_should_track_message_dm_not_allowed(db_session):
|
|||||||
|
|
||||||
def test_should_track_message_default_true(db_session):
|
def test_should_track_message_default_true(db_session):
|
||||||
"""Test default tracking behavior"""
|
"""Test default tracking behavior"""
|
||||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
server = DiscordServer(id=1, name="Server", ignore_messages=False)
|
||||||
channel = DiscordChannel(
|
channel = DiscordChannel(
|
||||||
id=2, name="Channel", channel_type="text", track_messages=True
|
id=2, name="Channel", channel_type="text", ignore_messages=False
|
||||||
)
|
)
|
||||||
user = DiscordUser(id=3, username="User")
|
user = DiscordUser(id=3, username="User")
|
||||||
|
|
||||||
@ -465,6 +466,7 @@ def test_sync_guild_metadata(mock_make_session, mock_guild):
|
|||||||
voice_channel.guild = mock_guild
|
voice_channel.guild = mock_guild
|
||||||
|
|
||||||
mock_guild.channels = [text_channel, voice_channel]
|
mock_guild.channels = [text_channel, voice_channel]
|
||||||
|
mock_guild.threads = []
|
||||||
|
|
||||||
sync_guild_metadata(mock_guild)
|
sync_guild_metadata(mock_guild)
|
||||||
|
|
||||||
@ -489,9 +491,18 @@ def test_message_collector_init():
|
|||||||
async def test_on_ready():
|
async def test_on_ready():
|
||||||
"""Test on_ready event handler"""
|
"""Test on_ready event handler"""
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
collector.user = Mock()
|
|
||||||
collector.user.name = "TestBot"
|
# Mock the properties
|
||||||
collector.guilds = [Mock(), Mock()]
|
mock_user = Mock()
|
||||||
|
mock_user.name = "TestBot"
|
||||||
|
with patch.object(
|
||||||
|
type(collector), "user", new_callable=lambda: property(lambda self: mock_user)
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
type(collector),
|
||||||
|
"guilds",
|
||||||
|
new_callable=lambda: property(lambda self: [Mock(), Mock()]),
|
||||||
|
):
|
||||||
collector.sync_servers_and_channels = AsyncMock()
|
collector.sync_servers_and_channels = AsyncMock()
|
||||||
collector.tree.sync = AsyncMock()
|
collector.tree.sync = AsyncMock()
|
||||||
|
|
||||||
@ -593,8 +604,12 @@ async def test_sync_servers_and_channels():
|
|||||||
guild2 = Mock()
|
guild2 = Mock()
|
||||||
|
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
collector.guilds = [guild1, guild2]
|
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
type(collector),
|
||||||
|
"guilds",
|
||||||
|
new_callable=lambda: property(lambda self: [guild1, guild2]),
|
||||||
|
):
|
||||||
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
|
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
|
||||||
await collector.sync_servers_and_channels()
|
await collector.sync_servers_and_channels()
|
||||||
|
|
||||||
@ -617,12 +632,21 @@ async def test_refresh_metadata(mock_make_session):
|
|||||||
guild.name = "Test"
|
guild.name = "Test"
|
||||||
guild.channels = []
|
guild.channels = []
|
||||||
guild.members = []
|
guild.members = []
|
||||||
|
guild.threads = []
|
||||||
|
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
collector.guilds = [guild]
|
|
||||||
collector.intents = Mock()
|
|
||||||
collector.intents.members = False
|
|
||||||
|
|
||||||
|
mock_intents = Mock()
|
||||||
|
mock_intents.members = False
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
type(collector),
|
||||||
|
"intents",
|
||||||
|
new_callable=lambda: property(lambda self: mock_intents),
|
||||||
|
):
|
||||||
result = await collector.refresh_metadata()
|
result = await collector.refresh_metadata()
|
||||||
|
|
||||||
assert result["servers_updated"] == 1
|
assert result["servers_updated"] == 1
|
||||||
@ -637,7 +661,7 @@ async def test_get_user_by_id():
|
|||||||
user.id = 123
|
user.id = 123
|
||||||
|
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
collector.get_user = Mock(return_value=user)
|
collector.get_user = AsyncMock(return_value=user)
|
||||||
|
|
||||||
result = await collector.get_user(123)
|
result = await collector.get_user(123)
|
||||||
|
|
||||||
@ -656,8 +680,10 @@ async def test_get_user_by_username():
|
|||||||
guild.members = [member]
|
guild.members = [member]
|
||||||
|
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
collector.guilds = [guild]
|
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||||
|
):
|
||||||
result = await collector.get_user("testuser")
|
result = await collector.get_user("testuser")
|
||||||
|
|
||||||
assert result == member
|
assert result == member
|
||||||
@ -667,11 +693,19 @@ async def test_get_user_by_username():
|
|||||||
async def test_get_user_not_found():
|
async def test_get_user_not_found():
|
||||||
"""Test getting non-existent user"""
|
"""Test getting non-existent user"""
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
collector.guilds = []
|
|
||||||
|
|
||||||
with patch.object(collector, "get_user", return_value=None):
|
# Create proper mock response for discord.NotFound
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 404
|
||||||
|
mock_response.text = ""
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock())
|
type(collector), "guilds", new_callable=lambda: property(lambda self: [])
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
collector,
|
||||||
|
"fetch_user",
|
||||||
|
AsyncMock(side_effect=discord.NotFound(mock_response, "User not found")),
|
||||||
):
|
):
|
||||||
result = await collector.get_user(999)
|
result = await collector.get_user(999)
|
||||||
assert result is None
|
assert result is None
|
||||||
@ -687,8 +721,10 @@ async def test_get_channel_by_name():
|
|||||||
guild.channels = [channel]
|
guild.channels = [channel]
|
||||||
|
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
collector.guilds = [guild]
|
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||||
|
):
|
||||||
result = await collector.get_channel_by_name("general")
|
result = await collector.get_channel_by_name("general")
|
||||||
|
|
||||||
assert result == channel
|
assert result == channel
|
||||||
@ -701,8 +737,10 @@ async def test_get_channel_by_name_not_found():
|
|||||||
guild.channels = []
|
guild.channels = []
|
||||||
|
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
collector.guilds = [guild]
|
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||||
|
):
|
||||||
result = await collector.get_channel_by_name("nonexistent")
|
result = await collector.get_channel_by_name("nonexistent")
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
@ -730,8 +768,10 @@ async def test_create_channel_no_guild():
|
|||||||
"""Test creating channel when no guild available"""
|
"""Test creating channel when no guild available"""
|
||||||
collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
collector.get_guild = Mock(return_value=None)
|
collector.get_guild = Mock(return_value=None)
|
||||||
collector.guilds = []
|
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
type(collector), "guilds", new_callable=lambda: property(lambda self: [])
|
||||||
|
):
|
||||||
result = await collector.create_channel("new-channel")
|
result = await collector.create_channel("new-channel")
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
@ -816,27 +856,19 @@ async def test_send_to_channel_not_found():
|
|||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="run_collector function doesn't exist or uses different settings"
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
|
||||||
async def test_run_collector():
|
async def test_run_collector():
|
||||||
"""Test running the collector"""
|
"""Test running the collector"""
|
||||||
from memory.discord.collector import run_collector
|
pass
|
||||||
|
|
||||||
with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
|
|
||||||
mock_collector = Mock()
|
|
||||||
mock_collector.start = AsyncMock()
|
|
||||||
mock_collector_class.return_value = mock_collector
|
|
||||||
|
|
||||||
await run_collector()
|
|
||||||
|
|
||||||
mock_collector.start.assert_called_once_with("test_token")
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="run_collector function doesn't exist or uses different settings"
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
|
||||||
async def test_run_collector_no_token():
|
async def test_run_collector_no_token():
|
||||||
"""Test running collector without token"""
|
"""Test running collector without token"""
|
||||||
from memory.discord.collector import run_collector
|
pass
|
||||||
|
|
||||||
# Should return early without raising
|
|
||||||
await run_collector()
|
|
||||||
|
|||||||
@ -1,17 +1,23 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||||
from memory.discord.commands import (
|
from memory.discord.commands import (
|
||||||
|
CommandContext,
|
||||||
CommandError,
|
CommandError,
|
||||||
CommandResponse,
|
CommandResponse,
|
||||||
run_command,
|
|
||||||
handle_prompt,
|
handle_prompt,
|
||||||
handle_chattiness,
|
handle_chattiness,
|
||||||
handle_ignore,
|
handle_ignore,
|
||||||
handle_summary,
|
handle_summary,
|
||||||
|
respond,
|
||||||
|
with_object_context,
|
||||||
|
handle_mcp_servers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -66,29 +72,54 @@ def interaction(guild, text_channel, discord_user) -> DummyInteraction:
|
|||||||
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
|
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
|
||||||
|
|
||||||
|
|
||||||
def test_handle_command_prompt_server(db_session, guild, interaction):
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_command_prompt_server(db_session, guild, interaction):
|
||||||
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
|
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
|
||||||
db_session.add(server)
|
db_session.add(server)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
response = run_command(
|
context = CommandContext(
|
||||||
db_session,
|
session=db_session,
|
||||||
interaction,
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
scope="server",
|
scope="server",
|
||||||
handler=handle_prompt,
|
target=server,
|
||||||
|
display_name="server **Test Guild**",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response = await handle_prompt(context)
|
||||||
|
|
||||||
assert isinstance(response, CommandResponse)
|
assert isinstance(response, CommandResponse)
|
||||||
assert "Be helpful" in response.content
|
assert "Be helpful" in response.content
|
||||||
|
|
||||||
|
|
||||||
def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel):
|
@pytest.mark.asyncio
|
||||||
response = run_command(
|
async def test_handle_command_prompt_channel_creates_channel(
|
||||||
db_session,
|
db_session, interaction, text_channel, guild
|
||||||
interaction,
|
):
|
||||||
scope="channel",
|
# Create the server first to satisfy FK constraint
|
||||||
handler=handle_prompt,
|
server = DiscordServer(id=guild.id, name="Test Guild")
|
||||||
|
db_session.add(server)
|
||||||
|
|
||||||
|
channel_model = DiscordChannel(
|
||||||
|
id=text_channel.id,
|
||||||
|
name=text_channel.name,
|
||||||
|
channel_type="text",
|
||||||
|
server_id=guild.id,
|
||||||
)
|
)
|
||||||
|
db_session.add(channel_model)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="channel",
|
||||||
|
target=channel_model,
|
||||||
|
display_name=f"channel **#{text_channel.name}**",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_prompt(context)
|
||||||
|
|
||||||
assert "No prompt" in response.content
|
assert "No prompt" in response.content
|
||||||
channel = db_session.get(DiscordChannel, text_channel.id)
|
channel = db_session.get(DiscordChannel, text_channel.id)
|
||||||
@ -96,77 +127,253 @@ def test_handle_command_prompt_channel_creates_channel(db_session, interaction,
|
|||||||
assert channel.name == text_channel.name
|
assert channel.name == text_channel.name
|
||||||
|
|
||||||
|
|
||||||
def test_handle_command_chattiness_show(db_session, interaction, guild):
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_command_chattiness_show(db_session, interaction, guild):
|
||||||
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
|
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
|
||||||
db_session.add(server)
|
db_session.add(server)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
response = run_command(
|
context = CommandContext(
|
||||||
db_session,
|
session=db_session,
|
||||||
interaction,
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
scope="server",
|
scope="server",
|
||||||
handler=handle_chattiness,
|
target=server,
|
||||||
|
display_name="server **Guild**",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response = await handle_chattiness(context, value=None)
|
||||||
|
|
||||||
assert str(server.chattiness_threshold) in response.content
|
assert str(server.chattiness_threshold) in response.content
|
||||||
|
|
||||||
|
|
||||||
def test_handle_command_chattiness_update(db_session, interaction):
|
@pytest.mark.asyncio
|
||||||
user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15)
|
async def test_handle_command_chattiness_update(db_session, interaction):
|
||||||
|
user_model = DiscordUser(
|
||||||
|
id=interaction.user.id, username="command-user", chattiness_threshold=15
|
||||||
|
)
|
||||||
db_session.add(user_model)
|
db_session.add(user_model)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
response = run_command(
|
context = CommandContext(
|
||||||
db_session,
|
session=db_session,
|
||||||
interaction,
|
interaction=interaction,
|
||||||
|
actor=user_model,
|
||||||
scope="user",
|
scope="user",
|
||||||
handler=handle_chattiness,
|
target=user_model,
|
||||||
value=80,
|
display_name="user **command-user**",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response = await handle_chattiness(context, value=80)
|
||||||
|
|
||||||
db_session.flush()
|
db_session.flush()
|
||||||
|
|
||||||
assert "Updated" in response.content
|
assert "Updated" in response.content
|
||||||
assert user_model.chattiness_threshold == 80
|
assert user_model.chattiness_threshold == 80
|
||||||
|
|
||||||
|
|
||||||
def test_handle_command_chattiness_invalid_value(db_session, interaction):
|
@pytest.mark.asyncio
|
||||||
with pytest.raises(CommandError):
|
async def test_handle_command_chattiness_invalid_value(db_session, interaction):
|
||||||
run_command(
|
user_model = DiscordUser(id=interaction.user.id, username="command-user")
|
||||||
db_session,
|
db_session.add(user_model)
|
||||||
interaction,
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=user_model,
|
||||||
scope="user",
|
scope="user",
|
||||||
handler=handle_chattiness,
|
target=user_model,
|
||||||
value=150,
|
display_name="user **command-user**",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with pytest.raises(CommandError):
|
||||||
|
await handle_chattiness(context, value=150)
|
||||||
|
|
||||||
def test_handle_command_ignore_toggle(db_session, interaction, guild):
|
|
||||||
channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id)
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_command_ignore_toggle(db_session, interaction, guild):
|
||||||
|
# Create the server first to satisfy FK constraint
|
||||||
|
server = DiscordServer(id=guild.id, name="Test Guild")
|
||||||
|
db_session.add(server)
|
||||||
|
|
||||||
|
channel = DiscordChannel(
|
||||||
|
id=interaction.channel.id,
|
||||||
|
name="general",
|
||||||
|
channel_type="text",
|
||||||
|
server_id=guild.id,
|
||||||
|
)
|
||||||
db_session.add(channel)
|
db_session.add(channel)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
response = run_command(
|
context = CommandContext(
|
||||||
db_session,
|
session=db_session,
|
||||||
interaction,
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
scope="channel",
|
scope="channel",
|
||||||
handler=handle_ignore,
|
target=channel,
|
||||||
ignore_enabled=True,
|
display_name="channel **#general**",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response = await handle_ignore(context, ignore_enabled=True)
|
||||||
|
|
||||||
db_session.flush()
|
db_session.flush()
|
||||||
|
|
||||||
assert "no longer" not in response.content
|
assert "no longer" not in response.content
|
||||||
assert channel.ignore_messages is True
|
assert channel.ignore_messages is True
|
||||||
|
|
||||||
|
|
||||||
def test_handle_command_summary_missing(db_session, interaction):
|
@pytest.mark.asyncio
|
||||||
response = run_command(
|
async def test_handle_command_summary_missing(db_session, interaction):
|
||||||
db_session,
|
user_model = DiscordUser(id=interaction.user.id, username="command-user")
|
||||||
interaction,
|
db_session.add(user_model)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=db_session,
|
||||||
|
interaction=interaction,
|
||||||
|
actor=user_model,
|
||||||
scope="user",
|
scope="user",
|
||||||
handler=handle_summary,
|
target=user_model,
|
||||||
|
display_name="user **command-user**",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response = await handle_summary(context)
|
||||||
|
|
||||||
assert "No summary" in response.content
|
assert "No summary" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_respond_sends_message_without_file():
|
||||||
|
interaction = MagicMock(spec=discord.Interaction)
|
||||||
|
interaction.response.send_message = AsyncMock()
|
||||||
|
|
||||||
|
await respond(interaction, "hello world", ephemeral=False)
|
||||||
|
|
||||||
|
interaction.response.send_message.assert_awaited_once_with(
|
||||||
|
"hello world", ephemeral=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_respond_sends_file_when_content_too_large():
|
||||||
|
interaction = MagicMock(spec=discord.Interaction)
|
||||||
|
interaction.response.send_message = AsyncMock()
|
||||||
|
|
||||||
|
oversized = "x" * 2000
|
||||||
|
with patch("memory.discord.commands.discord.File") as mock_file:
|
||||||
|
file_instance = MagicMock()
|
||||||
|
mock_file.return_value = file_instance
|
||||||
|
|
||||||
|
await respond(interaction, oversized)
|
||||||
|
|
||||||
|
interaction.response.send_message.assert_awaited_once_with(
|
||||||
|
"Response too large, sending as file:",
|
||||||
|
file=file_instance,
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.discord.commands._ensure_channel")
|
||||||
|
@patch("memory.discord.commands.ensure_server")
|
||||||
|
@patch("memory.discord.commands.ensure_user")
|
||||||
|
@patch("memory.discord.commands.make_session")
|
||||||
|
def test_with_object_context_uses_ensured_objects(
|
||||||
|
mock_make_session,
|
||||||
|
mock_ensure_user,
|
||||||
|
mock_ensure_server,
|
||||||
|
mock_ensure_channel,
|
||||||
|
interaction,
|
||||||
|
guild,
|
||||||
|
text_channel,
|
||||||
|
discord_user,
|
||||||
|
):
|
||||||
|
mock_session = MagicMock()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session_cm():
|
||||||
|
yield mock_session
|
||||||
|
|
||||||
|
mock_make_session.return_value = session_cm()
|
||||||
|
|
||||||
|
bot_model = MagicMock(name="bot_model")
|
||||||
|
user_model = MagicMock(name="user_model")
|
||||||
|
server_model = MagicMock(name="server_model")
|
||||||
|
channel_model = MagicMock(name="channel_model")
|
||||||
|
|
||||||
|
mock_ensure_user.side_effect = [bot_model, user_model]
|
||||||
|
mock_ensure_server.return_value = server_model
|
||||||
|
mock_ensure_channel.return_value = channel_model
|
||||||
|
|
||||||
|
handler_objects = {}
|
||||||
|
|
||||||
|
def handler(objects):
|
||||||
|
handler_objects["objects"] = objects
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
bot_client = SimpleNamespace(user=MagicMock())
|
||||||
|
override_user = MagicMock(spec=discord.User)
|
||||||
|
|
||||||
|
result = with_object_context(bot_client, interaction, handler, override_user)
|
||||||
|
|
||||||
|
assert result == "done"
|
||||||
|
objects = handler_objects["objects"]
|
||||||
|
assert objects.bot is bot_model
|
||||||
|
assert objects.server is server_model
|
||||||
|
assert objects.channel is channel_model
|
||||||
|
assert objects.user is user_model
|
||||||
|
|
||||||
|
mock_ensure_user.assert_any_call(mock_session, bot_client.user)
|
||||||
|
mock_ensure_user.assert_any_call(mock_session, override_user)
|
||||||
|
mock_ensure_server.assert_called_once_with(mock_session, guild)
|
||||||
|
mock_ensure_channel.assert_called_once_with(
|
||||||
|
mock_session, text_channel, guild.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.discord.commands.run_mcp_server_command", new_callable=AsyncMock)
|
||||||
|
async def test_handle_mcp_servers_returns_response(mock_run_mcp, interaction):
|
||||||
|
mock_run_mcp.return_value = "Listed servers"
|
||||||
|
server_model = DiscordServer(id=interaction.guild.id, name="Guild")
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=MagicMock(),
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="server",
|
||||||
|
target=server_model,
|
||||||
|
display_name="server **Guild**",
|
||||||
|
)
|
||||||
|
interaction.client = SimpleNamespace(user=MagicMock(spec=discord.User))
|
||||||
|
|
||||||
|
response = await handle_mcp_servers(
|
||||||
|
context, action="list", url=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "Listed servers"
|
||||||
|
mock_run_mcp.assert_awaited_once_with(
|
||||||
|
interaction.client.user, "list", None, "DiscordServer", server_model.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.discord.commands.run_mcp_server_command", new_callable=AsyncMock)
|
||||||
|
async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
|
||||||
|
mock_run_mcp.side_effect = RuntimeError("boom")
|
||||||
|
server_model = DiscordServer(id=interaction.guild.id, name="Guild")
|
||||||
|
|
||||||
|
context = CommandContext(
|
||||||
|
session=MagicMock(),
|
||||||
|
interaction=interaction,
|
||||||
|
actor=MagicMock(spec=DiscordUser),
|
||||||
|
scope="server",
|
||||||
|
target=server_model,
|
||||||
|
display_name="server **Guild**",
|
||||||
|
)
|
||||||
|
interaction.client = SimpleNamespace(user=MagicMock(spec=discord.User))
|
||||||
|
|
||||||
|
with pytest.raises(CommandError) as exc:
|
||||||
|
await handle_mcp_servers(context, action="list", url=None)
|
||||||
|
|
||||||
|
assert "Error: boom" in str(exc.value)
|
||||||
|
|||||||
590
tests/memory/discord_tests/test_mcp.py
Normal file
590
tests/memory/discord_tests/test_mcp.py
Normal file
@ -0,0 +1,590 @@
|
|||||||
|
"""Tests for Discord MCP server management."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import discord
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||||
|
from memory.discord.mcp import (
|
||||||
|
call_mcp_server,
|
||||||
|
find_mcp_server,
|
||||||
|
handle_mcp_add,
|
||||||
|
handle_mcp_connect,
|
||||||
|
handle_mcp_delete,
|
||||||
|
handle_mcp_list,
|
||||||
|
handle_mcp_tools,
|
||||||
|
run_mcp_server_command,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper class for async iteration
|
||||||
|
class AsyncIterator:
|
||||||
|
"""Helper to create an async iterator for mocking aiohttp response content."""
|
||||||
|
def __init__(self, items):
|
||||||
|
self.items = items
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.index >= len(self.items):
|
||||||
|
raise StopAsyncIteration
|
||||||
|
item = self.items[self.index]
|
||||||
|
self.index += 1
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mcp_server(db_session) -> MCPServer:
|
||||||
|
"""Create a test MCP server."""
|
||||||
|
server = MCPServer(
|
||||||
|
name="Test MCP Server",
|
||||||
|
mcp_server_url="https://mcp.example.com",
|
||||||
|
client_id="test_client_id",
|
||||||
|
access_token="test_access_token",
|
||||||
|
available_tools=["tool1", "tool2"],
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.commit()
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mcp_assignment(db_session, mcp_server: MCPServer) -> MCPServerAssignment:
|
||||||
|
"""Create a test MCP server assignment."""
|
||||||
|
assignment = MCPServerAssignment(
|
||||||
|
mcp_server_id=mcp_server.id,
|
||||||
|
entity_type="DiscordUser",
|
||||||
|
entity_id=123456,
|
||||||
|
)
|
||||||
|
db_session.add(assignment)
|
||||||
|
db_session.commit()
|
||||||
|
return assignment
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_bot_user() -> discord.User:
|
||||||
|
"""Create a mock Discord bot user."""
|
||||||
|
user = Mock(spec=discord.User)
|
||||||
|
user.name = "TestBot"
|
||||||
|
user.id = 999
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_mcp_server_exists(
|
||||||
|
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||||
|
):
|
||||||
|
"""Test finding an existing MCP server."""
|
||||||
|
result = find_mcp_server(
|
||||||
|
db_session,
|
||||||
|
entity_type="DiscordUser",
|
||||||
|
entity_id=123456,
|
||||||
|
url="https://mcp.example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == mcp_server.id
|
||||||
|
assert result.mcp_server_url == "https://mcp.example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_mcp_server_not_found(db_session):
|
||||||
|
"""Test finding a non-existent MCP server."""
|
||||||
|
result = find_mcp_server(
|
||||||
|
db_session,
|
||||||
|
entity_type="DiscordUser",
|
||||||
|
entity_id=999999,
|
||||||
|
url="https://nonexistent.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_mcp_server_wrong_entity(
|
||||||
|
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||||
|
):
|
||||||
|
"""Test finding MCP server with wrong entity type."""
|
||||||
|
result = find_mcp_server(
|
||||||
|
db_session,
|
||||||
|
entity_type="DiscordChannel", # Wrong entity type
|
||||||
|
entity_id=123456,
|
||||||
|
url="https://mcp.example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_server_success():
|
||||||
|
"""Test calling MCP server successfully."""
|
||||||
|
mock_response_data = [
|
||||||
|
b'data: {"result": {"tools": [{"name": "test"}]}}\n',
|
||||||
|
b'data: {"status": "ok"}\n',
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.content = AsyncIterator(mock_response_data)
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
results = []
|
||||||
|
async for data in call_mcp_server(
|
||||||
|
"https://mcp.example.com", "test_token", "tools/list", {}
|
||||||
|
):
|
||||||
|
results.append(data)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert "result" in results[0]
|
||||||
|
assert results[0]["result"]["tools"][0]["name"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_server_error():
|
||||||
|
"""Test calling MCP server with error response."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 500
|
||||||
|
mock_response.text = AsyncMock(return_value="Internal Server Error")
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
with pytest.raises(ValueError, match="Failed to call MCP server"):
|
||||||
|
async for _ in call_mcp_server(
|
||||||
|
"https://mcp.example.com", "test_token", "tools/list"
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_server_invalid_json():
|
||||||
|
"""Test calling MCP server with invalid JSON."""
|
||||||
|
mock_response_data = [
|
||||||
|
b"data: invalid json\n",
|
||||||
|
b'data: {"valid": "json"}\n',
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.content = AsyncIterator(mock_response_data)
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
results = []
|
||||||
|
async for data in call_mcp_server(
|
||||||
|
"https://mcp.example.com", "test_token", "tools/list"
|
||||||
|
):
|
||||||
|
results.append(data)
|
||||||
|
|
||||||
|
# Should skip invalid JSON and only return valid one
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0] == {"valid": "json"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_list_empty(db_session):
|
||||||
|
"""Test listing MCP servers when none exist."""
|
||||||
|
result = await handle_mcp_list("DiscordUser", 123456)
|
||||||
|
|
||||||
|
assert "You don't have any MCP servers configured yet" in result
|
||||||
|
assert "/memory_mcp_servers add" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_list_with_servers(
|
||||||
|
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||||
|
):
|
||||||
|
"""Test listing MCP servers with existing servers."""
|
||||||
|
result = await handle_mcp_list("DiscordUser", 123456)
|
||||||
|
|
||||||
|
assert "Your MCP Servers" in result
|
||||||
|
assert "https://mcp.example.com" in result
|
||||||
|
assert "test_client_id" in result
|
||||||
|
assert "🟢" in result # Server has access token
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_list_disconnected_server(db_session):
|
||||||
|
"""Test listing MCP servers with disconnected server."""
|
||||||
|
server = MCPServer(
|
||||||
|
name="Disconnected Server",
|
||||||
|
mcp_server_url="https://disconnected.example.com",
|
||||||
|
client_id="client_123",
|
||||||
|
access_token=None, # No access token
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
assignment = MCPServerAssignment(
|
||||||
|
mcp_server_id=server.id,
|
||||||
|
entity_type="DiscordUser",
|
||||||
|
entity_id=123456,
|
||||||
|
)
|
||||||
|
db_session.add(assignment)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
result = await handle_mcp_list("DiscordUser", 123456)
|
||||||
|
|
||||||
|
assert "🔴" in result # Server has no access token
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_add_new_server(db_session, mock_bot_user):
|
||||||
|
"""Test adding a new MCP server."""
|
||||||
|
with (
|
||||||
|
patch("memory.discord.mcp.get_endpoints") as mock_get_endpoints,
|
||||||
|
patch("memory.discord.mcp.register_oauth_client") as mock_register,
|
||||||
|
patch("memory.discord.mcp.issue_challenge") as mock_challenge,
|
||||||
|
):
|
||||||
|
mock_endpoints = Mock()
|
||||||
|
mock_get_endpoints.return_value = mock_endpoints
|
||||||
|
mock_register.return_value = "new_client_id"
|
||||||
|
mock_challenge.return_value = "https://auth.example.com/authorize"
|
||||||
|
|
||||||
|
result = await handle_mcp_add(
|
||||||
|
"DiscordUser", 123456, mock_bot_user, "https://new.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Add MCP Server" in result
|
||||||
|
assert "https://new.example.com" in result
|
||||||
|
assert "https://auth.example.com/authorize" in result
|
||||||
|
|
||||||
|
# Verify server was created
|
||||||
|
server = (
|
||||||
|
db_session.query(MCPServer)
|
||||||
|
.filter(MCPServer.mcp_server_url == "https://new.example.com")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert server is not None
|
||||||
|
assert server.client_id == "new_client_id"
|
||||||
|
|
||||||
|
# Verify assignment was created
|
||||||
|
assignment = (
|
||||||
|
db_session.query(MCPServerAssignment)
|
||||||
|
.filter(
|
||||||
|
MCPServerAssignment.mcp_server_id == server.id,
|
||||||
|
MCPServerAssignment.entity_type == "DiscordUser",
|
||||||
|
MCPServerAssignment.entity_id == 123456,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert assignment is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_add_existing_server(
|
||||||
|
db_session,
|
||||||
|
mcp_server: MCPServer,
|
||||||
|
mcp_assignment: MCPServerAssignment,
|
||||||
|
mock_bot_user,
|
||||||
|
):
|
||||||
|
"""Test adding an MCP server that already exists."""
|
||||||
|
result = await handle_mcp_add(
|
||||||
|
"DiscordUser", 123456, mock_bot_user, "https://mcp.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "MCP Server Already Exists" in result
|
||||||
|
assert "https://mcp.example.com" in result
|
||||||
|
assert "/memory_mcp_servers connect" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_add_no_bot_user(db_session):
|
||||||
|
"""Test adding MCP server without bot user."""
|
||||||
|
with pytest.raises(ValueError, match="Bot user is required"):
|
||||||
|
await handle_mcp_add("DiscordUser", 123456, None, "https://example.com")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_delete_existing(
|
||||||
|
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||||
|
):
|
||||||
|
"""Test deleting an existing MCP server assignment."""
|
||||||
|
# Store IDs before deletion
|
||||||
|
assignment_id = mcp_assignment.id
|
||||||
|
server_id = mcp_server.id
|
||||||
|
|
||||||
|
result = await handle_mcp_delete("DiscordUser", 123456, "https://mcp.example.com")
|
||||||
|
|
||||||
|
assert "Delete MCP Server" in result
|
||||||
|
assert "https://mcp.example.com" in result
|
||||||
|
assert "has been removed" in result
|
||||||
|
|
||||||
|
# Verify assignment was deleted
|
||||||
|
assignment = (
|
||||||
|
db_session.query(MCPServerAssignment)
|
||||||
|
.filter(MCPServerAssignment.id == assignment_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert assignment is None
|
||||||
|
|
||||||
|
# Verify server was also deleted (no other assignments)
|
||||||
|
server = db_session.query(MCPServer).filter(MCPServer.id == server_id).first()
|
||||||
|
assert server is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_delete_not_found(db_session):
|
||||||
|
"""Test deleting a non-existent MCP server."""
|
||||||
|
result = await handle_mcp_delete("DiscordUser", 123456, "https://nonexistent.com")
|
||||||
|
|
||||||
|
assert "MCP Server Not Found" in result
|
||||||
|
assert "https://nonexistent.com" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_delete_with_other_assignments(db_session):
|
||||||
|
"""Test deleting MCP server with multiple assignments."""
|
||||||
|
server = MCPServer(
|
||||||
|
name="Shared Server",
|
||||||
|
mcp_server_url="https://shared.example.com",
|
||||||
|
client_id="shared_client",
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
assignment1 = MCPServerAssignment(
|
||||||
|
mcp_server_id=server.id,
|
||||||
|
entity_type="DiscordUser",
|
||||||
|
entity_id=111,
|
||||||
|
)
|
||||||
|
assignment2 = MCPServerAssignment(
|
||||||
|
mcp_server_id=server.id,
|
||||||
|
entity_type="DiscordUser",
|
||||||
|
entity_id=222,
|
||||||
|
)
|
||||||
|
db_session.add_all([assignment1, assignment2])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Delete one assignment
|
||||||
|
result = await handle_mcp_delete("DiscordUser", 111, "https://shared.example.com")
|
||||||
|
|
||||||
|
assert "has been removed" in result
|
||||||
|
|
||||||
|
# Verify only one assignment was deleted
|
||||||
|
remaining = (
|
||||||
|
db_session.query(MCPServerAssignment)
|
||||||
|
.filter(MCPServerAssignment.mcp_server_id == server.id)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
assert remaining == 1
|
||||||
|
|
||||||
|
# Verify server still exists
|
||||||
|
server_check = db_session.query(MCPServer).filter(MCPServer.id == server.id).first()
|
||||||
|
assert server_check is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_connect_existing(
|
||||||
|
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||||
|
):
|
||||||
|
"""Test reconnecting to an existing MCP server."""
|
||||||
|
with (
|
||||||
|
patch("memory.discord.mcp.get_endpoints") as mock_get_endpoints,
|
||||||
|
patch("memory.discord.mcp.issue_challenge") as mock_challenge,
|
||||||
|
):
|
||||||
|
mock_endpoints = Mock()
|
||||||
|
mock_get_endpoints.return_value = mock_endpoints
|
||||||
|
mock_challenge.return_value = "https://auth.example.com/authorize?state=new"
|
||||||
|
|
||||||
|
result = await handle_mcp_connect(
|
||||||
|
"DiscordUser", 123456, "https://mcp.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Reconnect to MCP Server" in result
|
||||||
|
assert "https://mcp.example.com" in result
|
||||||
|
assert "https://auth.example.com/authorize?state=new" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_connect_not_found(db_session):
|
||||||
|
"""Test reconnecting to a non-existent MCP server."""
|
||||||
|
with pytest.raises(ValueError, match="MCP Server Not Found"):
|
||||||
|
await handle_mcp_connect("DiscordUser", 123456, "https://nonexistent.com")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_tools_success(
|
||||||
|
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||||
|
):
|
||||||
|
"""Test listing tools from an MCP server."""
|
||||||
|
mock_response_data = [
|
||||||
|
b'data: {"result": {"tools": [{"name": "search", "description": "Search tool"}]}}\n',
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.content = AsyncIterator(mock_response_data)
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
result = await handle_mcp_tools(
|
||||||
|
"DiscordUser", 123456, "https://mcp.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "MCP Server Tools" in result
|
||||||
|
assert "https://mcp.example.com" in result
|
||||||
|
assert "search" in result
|
||||||
|
assert "Search tool" in result
|
||||||
|
assert "Found 1 tool(s)" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_tools_no_tools(
|
||||||
|
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||||
|
):
|
||||||
|
"""Test listing tools when server has no tools."""
|
||||||
|
mock_response_data = [
|
||||||
|
b'data: {"result": {"tools": []}}\n',
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.content = AsyncIterator(mock_response_data)
|
||||||
|
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.return_value = mock_response
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
result = await handle_mcp_tools(
|
||||||
|
"DiscordUser", 123456, "https://mcp.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "No tools available" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_tools_server_not_found(db_session):
|
||||||
|
"""Test listing tools for a non-existent server."""
|
||||||
|
with pytest.raises(ValueError, match="MCP Server Not Found"):
|
||||||
|
await handle_mcp_tools("DiscordUser", 123456, "https://nonexistent.com")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_tools_not_authorized(db_session):
|
||||||
|
"""Test listing tools when not authorized."""
|
||||||
|
server = MCPServer(
|
||||||
|
name="Unauthorized Server",
|
||||||
|
mcp_server_url="https://unauthorized.example.com",
|
||||||
|
client_id="client_123",
|
||||||
|
access_token=None, # No access token
|
||||||
|
)
|
||||||
|
db_session.add(server)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
assignment = MCPServerAssignment(
|
||||||
|
mcp_server_id=server.id,
|
||||||
|
entity_type="DiscordUser",
|
||||||
|
entity_id=123456,
|
||||||
|
)
|
||||||
|
db_session.add(assignment)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Not Authorized"):
|
||||||
|
await handle_mcp_tools(
|
||||||
|
"DiscordUser", 123456, "https://unauthorized.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_mcp_tools_connection_error(
|
||||||
|
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||||
|
):
|
||||||
|
"""Test listing tools with connection error."""
|
||||||
|
mock_post = AsyncMock()
|
||||||
|
mock_post.__aenter__.side_effect = aiohttp.ClientError("Connection failed")
|
||||||
|
mock_post.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_post)
|
||||||
|
|
||||||
|
mock_session_ctx = AsyncMock()
|
||||||
|
mock_session_ctx.__aenter__.return_value = mock_session
|
||||||
|
mock_session_ctx.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||||
|
with pytest.raises(ValueError, match="Connection failed"):
|
||||||
|
await handle_mcp_tools("DiscordUser", 123456, "https://mcp.example.com")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_mcp_server_command_list(db_session, mock_bot_user):
|
||||||
|
"""Test run_mcp_server_command with list action."""
|
||||||
|
result = await run_mcp_server_command(
|
||||||
|
mock_bot_user, "list", None, "DiscordUser", 123456
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Your MCP Servers" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_mcp_server_command_invalid_action(mock_bot_user):
|
||||||
|
"""Test run_mcp_server_command with invalid action."""
|
||||||
|
with pytest.raises(ValueError, match="Invalid action"):
|
||||||
|
await run_mcp_server_command(
|
||||||
|
mock_bot_user, "invalid", None, "DiscordUser", 123456
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_mcp_server_command_missing_url(mock_bot_user):
|
||||||
|
"""Test run_mcp_server_command with missing URL for non-list action."""
|
||||||
|
with pytest.raises(ValueError, match="URL is required"):
|
||||||
|
await run_mcp_server_command(mock_bot_user, "add", None, "DiscordUser", 123456)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_mcp_server_command_no_bot_user():
|
||||||
|
"""Test run_mcp_server_command without bot user."""
|
||||||
|
with pytest.raises(ValueError, match="Bot user is required"):
|
||||||
|
await run_mcp_server_command(None, "list", None, "DiscordUser", 123456)
|
||||||
@ -1,5 +1,8 @@
|
|||||||
"""Tests for Discord message helper functions."""
|
"""Tests for Discord message helper functions."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from memory.discord.messages import (
|
from memory.discord.messages import (
|
||||||
@ -9,6 +12,7 @@ from memory.discord.messages import (
|
|||||||
upsert_scheduled_message,
|
upsert_scheduled_message,
|
||||||
previous_messages,
|
previous_messages,
|
||||||
comm_channel_prompt,
|
comm_channel_prompt,
|
||||||
|
call_llm,
|
||||||
)
|
)
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
DiscordUser,
|
DiscordUser,
|
||||||
@ -18,6 +22,7 @@ from memory.common.db.models import (
|
|||||||
HumanUser,
|
HumanUser,
|
||||||
ScheduledLLMCall,
|
ScheduledLLMCall,
|
||||||
)
|
)
|
||||||
|
from memory.common.llms.tools import MCPServer as MCPServerDefinition
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -411,3 +416,107 @@ def test_comm_channel_prompt_includes_user_notes(
|
|||||||
|
|
||||||
assert "user_notes" in result.lower()
|
assert "user_notes" in result.lower()
|
||||||
assert "testuser" in result # username should appear
|
assert "testuser" in result # username should appear
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.discord.messages.create_provider")
|
||||||
|
@patch("memory.discord.messages.previous_messages")
|
||||||
|
@patch("memory.common.llms.tools.discord.make_discord_tools")
|
||||||
|
@patch("memory.common.llms.tools.base.WebSearchTool")
|
||||||
|
def test_call_llm_includes_web_search_and_mcp_servers(
|
||||||
|
mock_web_search,
|
||||||
|
mock_make_tools,
|
||||||
|
mock_prev_messages,
|
||||||
|
mock_create_provider,
|
||||||
|
):
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.usage_tracker.is_rate_limited.return_value = False
|
||||||
|
provider.as_messages.return_value = ["converted"]
|
||||||
|
provider.run_with_tools.return_value = SimpleNamespace(response="llm-output")
|
||||||
|
mock_create_provider.return_value = provider
|
||||||
|
|
||||||
|
mock_prev_messages.return_value = [SimpleNamespace(as_content=lambda: "prev")]
|
||||||
|
|
||||||
|
existing_tool = MagicMock(name="existing_tool")
|
||||||
|
mock_make_tools.return_value = {"existing": existing_tool}
|
||||||
|
|
||||||
|
web_tool_instance = MagicMock(name="web_tool")
|
||||||
|
mock_web_search.return_value = web_tool_instance
|
||||||
|
|
||||||
|
bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
|
||||||
|
from_user = SimpleNamespace(id=123)
|
||||||
|
mcp_model = SimpleNamespace(
|
||||||
|
name="Server",
|
||||||
|
mcp_server_url="https://mcp.example.com",
|
||||||
|
access_token="token123",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = call_llm(
|
||||||
|
session=MagicMock(),
|
||||||
|
bot_user=bot_user,
|
||||||
|
from_user=from_user,
|
||||||
|
channel=None,
|
||||||
|
model="gpt-test",
|
||||||
|
messages=["hi"],
|
||||||
|
mcp_servers=[mcp_model],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "llm-output"
|
||||||
|
|
||||||
|
kwargs = provider.run_with_tools.call_args.kwargs
|
||||||
|
tools = kwargs["tools"]
|
||||||
|
assert tools["existing"] is existing_tool
|
||||||
|
assert tools["web_search"] is web_tool_instance
|
||||||
|
|
||||||
|
mcp_servers = kwargs["mcp_servers"]
|
||||||
|
assert mcp_servers == [
|
||||||
|
MCPServerDefinition(
|
||||||
|
name="Server", url="https://mcp.example.com", token="token123"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.discord.messages.create_provider")
|
||||||
|
@patch("memory.discord.messages.previous_messages")
|
||||||
|
@patch("memory.common.llms.tools.discord.make_discord_tools")
|
||||||
|
@patch("memory.common.llms.tools.base.WebSearchTool")
|
||||||
|
def test_call_llm_filters_disallowed_tools(
|
||||||
|
mock_web_search,
|
||||||
|
mock_make_tools,
|
||||||
|
mock_prev_messages,
|
||||||
|
mock_create_provider,
|
||||||
|
):
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.usage_tracker.is_rate_limited.return_value = False
|
||||||
|
provider.as_messages.return_value = ["converted"]
|
||||||
|
provider.run_with_tools.return_value = SimpleNamespace(response="filtered-output")
|
||||||
|
mock_create_provider.return_value = provider
|
||||||
|
|
||||||
|
mock_prev_messages.return_value = []
|
||||||
|
|
||||||
|
allowed_tool = MagicMock(name="allowed")
|
||||||
|
blocked_tool = MagicMock(name="blocked")
|
||||||
|
mock_make_tools.return_value = {
|
||||||
|
"allowed": allowed_tool,
|
||||||
|
"blocked": blocked_tool,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_web_search.return_value = MagicMock(name="web_tool")
|
||||||
|
|
||||||
|
bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
|
||||||
|
from_user = SimpleNamespace(id=1)
|
||||||
|
|
||||||
|
call_llm(
|
||||||
|
session=MagicMock(),
|
||||||
|
bot_user=bot_user,
|
||||||
|
from_user=from_user,
|
||||||
|
channel=None,
|
||||||
|
model="gpt-test",
|
||||||
|
messages=[],
|
||||||
|
allowed_tools={"allowed"},
|
||||||
|
mcp_servers=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = provider.run_with_tools.call_args.kwargs["tools"]
|
||||||
|
assert "allowed" in tools
|
||||||
|
assert "blocked" not in tools
|
||||||
|
assert "web_search" not in tools
|
||||||
|
|||||||
41
tests/tools/test_discord_setup.py
Normal file
41
tests/tools/test_discord_setup.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
"""Tests for Discord setup CLI utilities."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from tools.discord_setup import generate_bot_invite_url, make_invite
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_invite_generates_expected_url():
|
||||||
|
result = make_invite(123456789)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result
|
||||||
|
== "https://discord.com/oauth2/authorize?client_id=123456789&scope=bot&permissions=3088"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("tools.discord_setup.requests.get")
|
||||||
|
def test_generate_bot_invite_url_outputs_link(mock_get):
|
||||||
|
response = MagicMock()
|
||||||
|
response.raise_for_status.return_value = None
|
||||||
|
response.json.return_value = {"id": "987654321"}
|
||||||
|
mock_get.return_value = response
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(generate_bot_invite_url, ["--bot-token", "abc.def"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Bot invite URL" in result.output
|
||||||
|
assert "987654321" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
@patch("tools.discord_setup.requests.get", side_effect=Exception("api down"))
|
||||||
|
def test_generate_bot_invite_url_handles_errors(mock_get):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(generate_bot_invite_url, ["--bot-token", "token"])
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert isinstance(result.exception, ValueError)
|
||||||
|
assert "Could not get bot info" in str(result.exception)
|
||||||
Loading…
x
Reference in New Issue
Block a user