mirror of
https://github.com/mruwnik/memory.git
synced 2025-12-16 17:11:19 +01:00
Compare commits
No commits in common. "470061bd436a7270ff54fc56a0604d44191f98ac" and "2d3dc06fdf992a9dff34d719561fe1d427c265ff" have entirely different histories.
470061bd43
...
2d3dc06fdf
@ -1,6 +1,5 @@
|
||||
# Agent Guidance
|
||||
|
||||
- Assume Python 3.10+ features are available; avoid `from __future__ import annotations` unless necessary.
|
||||
- Treat LLM model identifiers as `<provider>/<model_name>` strings throughout the codebase.
|
||||
- Prefer straightforward control flow (`if`/`else`) instead of nested ternaries when clarity is improved.
|
||||
- Tests should be written with @pytest.mark.parametrize where applicable and should avoid test classes
|
||||
- Make sure linting errors get fixed
|
||||
|
||||
@ -0,0 +1,67 @@
|
||||
"""discord mcp servers
|
||||
|
||||
Revision ID: 9b887449ea92
|
||||
Revises: 1954477b25f4
|
||||
Create Date: 2025-11-02 22:04:26.259323
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9b887449ea92"
|
||||
down_revision: Union[str, None] = "1954477b25f4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"discord_mcp_servers",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("discord_bot_user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
||||
sa.Column("client_id", sa.Text(), nullable=False),
|
||||
sa.Column("state", sa.Text(), nullable=True),
|
||||
sa.Column("code_verifier", sa.Text(), nullable=True),
|
||||
sa.Column("access_token", sa.Text(), nullable=True),
|
||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["discord_bot_user_id"],
|
||||
["discord_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("state"),
|
||||
)
|
||||
op.create_index(
|
||||
"discord_mcp_state_idx", "discord_mcp_servers", ["state"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"discord_mcp_user_url_idx",
|
||||
"discord_mcp_servers",
|
||||
["discord_bot_user_id", "mcp_server_url"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("discord_mcp_user_url_idx", table_name="discord_mcp_servers")
|
||||
op.drop_index("discord_mcp_state_idx", table_name="discord_mcp_servers")
|
||||
op.drop_table("discord_mcp_servers")
|
||||
@ -1,103 +0,0 @@
|
||||
"""mcp servers
|
||||
|
||||
Revision ID: 89861d5f1102
|
||||
Revises: 1954477b25f4
|
||||
Create Date: 2025-11-03 15:41:26.254854
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "89861d5f1102"
|
||||
down_revision: Union[str, None] = "1954477b25f4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"mcp_servers",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("mcp_server_url", sa.Text(), nullable=False),
|
||||
sa.Column("client_id", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"available_tools", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||
),
|
||||
sa.Column("state", sa.Text(), nullable=True),
|
||||
sa.Column("code_verifier", sa.Text(), nullable=True),
|
||||
sa.Column("access_token", sa.Text(), nullable=True),
|
||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("state"),
|
||||
)
|
||||
op.create_index("mcp_state_idx", "mcp_servers", ["state"], unique=False)
|
||||
op.create_table(
|
||||
"mcp_server_assignments",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("mcp_server_id", sa.Integer(), nullable=False),
|
||||
sa.Column("entity_type", sa.Text(), nullable=False),
|
||||
sa.Column("entity_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mcp_server_id"],
|
||||
["mcp_servers.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"mcp_assignment_entity_idx",
|
||||
"mcp_server_assignments",
|
||||
["entity_type", "entity_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"mcp_assignment_server_idx",
|
||||
"mcp_server_assignments",
|
||||
["mcp_server_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"mcp_assignment_unique_idx",
|
||||
"mcp_server_assignments",
|
||||
["mcp_server_id", "entity_type", "entity_id"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("mcp_assignment_unique_idx", table_name="mcp_server_assignments")
|
||||
op.drop_index("mcp_assignment_server_idx", table_name="mcp_server_assignments")
|
||||
op.drop_index("mcp_assignment_entity_idx", table_name="mcp_server_assignments")
|
||||
op.drop_table("mcp_server_assignments")
|
||||
op.drop_index("mcp_state_idx", table_name="mcp_servers")
|
||||
op.drop_table("mcp_servers")
|
||||
@ -128,7 +128,7 @@ services:
|
||||
- /var/tmp
|
||||
- /qdrant/snapshots:rw
|
||||
healthcheck:
|
||||
test: ["CMD", "bash", "-c", "exec 3<>/dev/tcp/localhost/6333 && echo -e 'GET /readyz HTTP/1.0\\r\\n\\r\\n' >&3 && timeout 2 cat <&3 | grep -q ready"]
|
||||
test: [ "CMD", "wget", "-q", "-T", "2", "-O", "-", "localhost:6333/ready" ]
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
@ -8,7 +8,6 @@ RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
python3-dev \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy and install Python requirements
|
||||
|
||||
@ -14,7 +14,7 @@ from memory.common.db.models import (
|
||||
BookSection,
|
||||
Chunk,
|
||||
Comic,
|
||||
MCPServer,
|
||||
DiscordMCPServer,
|
||||
DiscordMessage,
|
||||
EmailAccount,
|
||||
EmailAttachment,
|
||||
@ -167,17 +167,17 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage):
|
||||
column_sortable_list = ["sent_at"]
|
||||
|
||||
|
||||
class MCPServerAdmin(ModelView, model=MCPServer):
|
||||
class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
|
||||
column_list = [
|
||||
"id",
|
||||
"mcp_server_url",
|
||||
"client_id",
|
||||
"discord_bot_user_id",
|
||||
"state",
|
||||
"code_verifier",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"token_expires_at",
|
||||
"available_tools",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
@ -186,6 +186,7 @@ class MCPServerAdmin(ModelView, model=MCPServer):
|
||||
"client_id",
|
||||
"state",
|
||||
"id",
|
||||
"discord_bot_user_id",
|
||||
]
|
||||
column_sortable_list = [
|
||||
"created_at",
|
||||
@ -359,5 +360,5 @@ def setup_admin(admin: Admin):
|
||||
admin.add_view(DiscordUserAdmin)
|
||||
admin.add_view(DiscordServerAdmin)
|
||||
admin.add_view(DiscordChannelAdmin)
|
||||
admin.add_view(MCPServerAdmin)
|
||||
admin.add_view(DiscordMCPServerAdmin)
|
||||
admin.add_view(ScheduledLLMCallAdmin)
|
||||
|
||||
@ -10,7 +10,7 @@ from memory.common import settings
|
||||
from memory.common.db.connection import get_session, make_session
|
||||
from memory.common.db.models import (
|
||||
BotUser,
|
||||
MCPServer,
|
||||
DiscordMCPServer,
|
||||
HumanUser,
|
||||
User,
|
||||
UserSession,
|
||||
@ -169,7 +169,9 @@ async def oauth_callback_discord(request: Request):
|
||||
# Complete the OAuth flow (exchange code for token)
|
||||
with make_session() as session:
|
||||
mcp_server = (
|
||||
session.query(MCPServer).filter(MCPServer.state == state).first()
|
||||
session.query(DiscordMCPServer)
|
||||
.filter(DiscordMCPServer.state == state)
|
||||
.first()
|
||||
)
|
||||
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
||||
session.commit()
|
||||
|
||||
@ -34,10 +34,7 @@ from memory.common.db.models.discord import (
|
||||
DiscordServer,
|
||||
DiscordChannel,
|
||||
DiscordUser,
|
||||
)
|
||||
from memory.common.db.models.mcp import (
|
||||
MCPServer,
|
||||
MCPServerAssignment,
|
||||
DiscordMCPServer,
|
||||
)
|
||||
from memory.common.db.models.observations import (
|
||||
ObservationContradiction,
|
||||
@ -110,8 +107,7 @@ __all__ = [
|
||||
"DiscordServer",
|
||||
"DiscordChannel",
|
||||
"DiscordUser",
|
||||
"MCPServer",
|
||||
"MCPServerAssignment",
|
||||
"DiscordMCPServer",
|
||||
# Users
|
||||
"User",
|
||||
"HumanUser",
|
||||
|
||||
@ -51,40 +51,21 @@ class MessageProcessor:
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def entity_type(self) -> str:
|
||||
return self.__class__.__tablename__[8:-1] # type: ignore
|
||||
|
||||
def to_xml(self, *fields: str) -> str:
|
||||
def indent(key: str, text: str) -> str:
|
||||
res = textwrap.dedent("""
|
||||
<{key}>
|
||||
{text}
|
||||
</{key}>
|
||||
""").format(key=key, text=textwrap.indent(text, " "))
|
||||
return res.strip()
|
||||
|
||||
vals = []
|
||||
if "name" in fields:
|
||||
vals.append(indent("name", self.name))
|
||||
if "system_prompt" in fields:
|
||||
vals.append(indent("system_prompt", self.system_prompt or ""))
|
||||
if "summary" in fields:
|
||||
vals.append(indent("summary", self.summary or ""))
|
||||
if "mcp_servers" in fields:
|
||||
servers = [s.as_xml() for s in self.mcp_servers]
|
||||
vals.append(indent("mcp_servers", "\n".join(servers)))
|
||||
|
||||
return indent(self.entity_type, "\n".join(vals)) # type: ignore
|
||||
|
||||
def xml_prompt(self) -> str:
|
||||
return self.to_xml("name", "system_prompt") if self.system_prompt else ""
|
||||
|
||||
def xml_summary(self) -> str:
|
||||
return self.to_xml("name", "summary")
|
||||
|
||||
def xml_mcp_servers(self) -> str:
|
||||
return self.to_xml("mcp_servers")
|
||||
def as_xml(self) -> str:
|
||||
return (
|
||||
textwrap.dedent("""
|
||||
<{type}>
|
||||
<name>{name}</name>
|
||||
<summary>{summary}</summary>
|
||||
</{type}>
|
||||
""")
|
||||
.format(
|
||||
type=self.__class__.__tablename__[8:], # type: ignore
|
||||
name=getattr(self, "name", None) or getattr(self, "username", None),
|
||||
summary=self.summary,
|
||||
)
|
||||
.strip()
|
||||
)
|
||||
|
||||
|
||||
class DiscordServer(Base, MessageProcessor):
|
||||
@ -146,9 +127,44 @@ class DiscordUser(Base, MessageProcessor):
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
system_user = relationship("User", back_populates="discord_users")
|
||||
mcp_servers = relationship(
|
||||
"DiscordMCPServer", back_populates="discord_user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (Index("discord_users_system_user_idx", "system_user_id"),)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.username
|
||||
|
||||
class DiscordMCPServer(Base):
|
||||
"""MCP server configuration and OAuth state for Discord users."""
|
||||
|
||||
__tablename__ = "discord_mcp_servers"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
discord_bot_user_id = Column(
|
||||
BigInteger, ForeignKey("discord_users.id"), nullable=False
|
||||
)
|
||||
|
||||
# MCP server info
|
||||
mcp_server_url = Column(Text, nullable=False)
|
||||
client_id = Column(Text, nullable=False)
|
||||
|
||||
# OAuth flow state (temporary, cleared after token exchange)
|
||||
state = Column(Text, nullable=True, unique=True)
|
||||
code_verifier = Column(Text, nullable=True)
|
||||
|
||||
# OAuth tokens (set after successful authorization)
|
||||
access_token = Column(Text, nullable=True)
|
||||
refresh_token = Column(Text, nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
discord_user = relationship("DiscordUser", back_populates="mcp_servers")
|
||||
|
||||
__table_args__ = (
|
||||
Index("discord_mcp_state_idx", "state"),
|
||||
Index("discord_mcp_user_url_idx", "discord_bot_user_id", "mcp_server_url"),
|
||||
)
|
||||
|
||||
@ -1,108 +0,0 @@
|
||||
import textwrap
|
||||
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class MCPServer(Base):
|
||||
"""MCP server configuration and OAuth state."""
|
||||
|
||||
__tablename__ = "mcp_servers"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
# MCP server info
|
||||
name = Column(Text, nullable=False)
|
||||
mcp_server_url = Column(Text, nullable=False)
|
||||
client_id = Column(Text, nullable=False)
|
||||
available_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
# OAuth flow state (temporary, cleared after token exchange)
|
||||
state = Column(Text, nullable=True, unique=True)
|
||||
code_verifier = Column(Text, nullable=True)
|
||||
|
||||
# OAuth tokens (set after successful authorization)
|
||||
access_token = Column(Text, nullable=True)
|
||||
refresh_token = Column(Text, nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
assignments = relationship(
|
||||
"MCPServerAssignment", back_populates="mcp_server", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (Index("mcp_state_idx", "state"),)
|
||||
|
||||
def as_xml(self) -> str:
|
||||
tools = "\n".join(f"• {tool}" for tool in self.available_tools).strip()
|
||||
return textwrap.dedent("""
|
||||
<mcp_server>
|
||||
<name>
|
||||
{name}
|
||||
</name>
|
||||
<mcp_server_url>
|
||||
{mcp_server_url}
|
||||
</mcp_server_url>
|
||||
<client_id>
|
||||
{client_id}
|
||||
</client_id>
|
||||
<available_tools>
|
||||
{available_tools}
|
||||
</available_tools>
|
||||
</mcp_server>
|
||||
""").format(
|
||||
name=self.name,
|
||||
mcp_server_url=self.mcp_server_url,
|
||||
client_id=self.client_id,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
|
||||
class MCPServerAssignment(Base):
|
||||
"""Assignment of MCP servers to entities (users, channels, servers, etc.)."""
|
||||
|
||||
__tablename__ = "mcp_server_assignments"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
mcp_server_id = Column(Integer, ForeignKey("mcp_servers.id"), nullable=False)
|
||||
|
||||
# Polymorphic entity reference
|
||||
entity_type = Column(
|
||||
Text, nullable=False
|
||||
) # "User", "DiscordUser", "DiscordServer", "DiscordChannel"
|
||||
entity_id = Column(BigInteger, nullable=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
mcp_server = relationship("MCPServer", back_populates="assignments")
|
||||
|
||||
__table_args__ = (
|
||||
Index("mcp_assignment_entity_idx", "entity_type", "entity_id"),
|
||||
Index("mcp_assignment_server_idx", "mcp_server_id"),
|
||||
Index(
|
||||
"mcp_assignment_unique_idx",
|
||||
"mcp_server_id",
|
||||
"entity_type",
|
||||
"entity_id",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
@ -36,10 +36,6 @@ from memory.common.db.models.source_item import (
|
||||
clean_filename,
|
||||
chunk_mixed,
|
||||
)
|
||||
from memory.common.db.models.mcp import (
|
||||
MCPServer,
|
||||
MCPServerAssignment,
|
||||
)
|
||||
|
||||
|
||||
class MailMessagePayload(SourceItemPayload):
|
||||
@ -349,12 +345,11 @@ class DiscordMessage(SourceItem):
|
||||
|
||||
@property
|
||||
def system_prompt(self) -> str:
|
||||
prompts = [
|
||||
(self.from_user and self.from_user.system_prompt),
|
||||
(self.channel and self.channel.system_prompt),
|
||||
(self.server and self.server.system_prompt),
|
||||
]
|
||||
return "\n\n".join(p for p in prompts if p)
|
||||
return (
|
||||
(self.from_user and self.from_user.system_prompt)
|
||||
or (self.channel and self.channel.system_prompt)
|
||||
or (self.server and self.server.system_prompt)
|
||||
)
|
||||
|
||||
@property
|
||||
def chattiness_threshold(self) -> int:
|
||||
@ -396,29 +391,6 @@ class DiscordMessage(SourceItem):
|
||||
|
||||
return content
|
||||
|
||||
def get_mcp_servers(self, session) -> list[MCPServer]:
|
||||
entity_ids = list(
|
||||
filter(
|
||||
None,
|
||||
[
|
||||
self.recipient_user.id,
|
||||
self.from_user.id,
|
||||
self.channel.id,
|
||||
self.server.id,
|
||||
],
|
||||
)
|
||||
)
|
||||
if not entity_ids:
|
||||
return None
|
||||
|
||||
return (
|
||||
session.query(MCPServer)
|
||||
.filter(
|
||||
MCPServerAssignment.entity_id.in_(entity_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "discord_message",
|
||||
}
|
||||
|
||||
@ -69,6 +69,7 @@ def send_to_channel(bot_id: int, channel: int | str, message: str) -> bool:
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
print("Result", result)
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
|
||||
@ -55,6 +55,7 @@ CUSTOM_EXTENSIONS = {
|
||||
def get_mime_type(path: pathlib.Path) -> str:
|
||||
mime_type, _ = mimetypes.guess_type(str(path))
|
||||
if mime_type:
|
||||
print(f"mime_type: {mime_type}")
|
||||
return mime_type
|
||||
ext = path.suffix.lower()
|
||||
return CUSTOM_EXTENSIONS.get(ext, "application/octet-stream")
|
||||
|
||||
@ -134,15 +134,11 @@ class UsageTracker:
|
||||
default_config: RateLimitConfig | None = None,
|
||||
) -> None:
|
||||
self._configs = configs or {}
|
||||
if default_config is None:
|
||||
default_config = RateLimitConfig(
|
||||
window=timedelta(
|
||||
minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES
|
||||
),
|
||||
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
||||
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
||||
)
|
||||
self._default_config = default_config
|
||||
self._default_config = default_config or RateLimitConfig(
|
||||
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
|
||||
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
||||
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
||||
)
|
||||
self._lock = Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@ -264,8 +260,8 @@ class UsageTracker:
|
||||
|
||||
with self._lock:
|
||||
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||
for model_key, state in self.iter_state_items():
|
||||
prov, model_name = split_model_key(model_key)
|
||||
for model, state in self.iter_state_items():
|
||||
prov, model_name = split_model_key(model)
|
||||
if provider and provider != prov:
|
||||
continue
|
||||
if model and model != model_name:
|
||||
@ -308,10 +304,7 @@ class UsageTracker:
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _get_config(self, model: str) -> RateLimitConfig | None:
|
||||
config = self._configs.get(model)
|
||||
if config is not None:
|
||||
return config
|
||||
return self._default_config
|
||||
return self._configs.get(model) or self._default_config
|
||||
|
||||
def _prune_expired_events(
|
||||
self,
|
||||
|
||||
@ -8,7 +8,7 @@ from urllib.parse import urlencode, urljoin
|
||||
|
||||
import aiohttp
|
||||
from memory.common import settings
|
||||
from memory.common.db.models import MCPServer
|
||||
from memory.common.db.models.discord import DiscordMCPServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -148,7 +148,7 @@ async def register_oauth_client(
|
||||
|
||||
|
||||
async def issue_challenge(
|
||||
mcp_server: MCPServer,
|
||||
mcp_server: DiscordMCPServer,
|
||||
endpoints: OAuthEndpoints,
|
||||
) -> str:
|
||||
"""Generate OAuth challenge and store state in mcp_server object."""
|
||||
@ -160,7 +160,7 @@ async def issue_challenge(
|
||||
mcp_server.code_verifier = code_verifier # type: ignore
|
||||
|
||||
logger.info(
|
||||
f"Generated OAuth state for MCP server {mcp_server.mcp_server_url}: "
|
||||
f"Generated OAuth state for user {mcp_server.discord_bot_user_id}: "
|
||||
f"state={state[:20]}..., verifier={code_verifier[:20]}..."
|
||||
)
|
||||
|
||||
@ -179,7 +179,7 @@ async def issue_challenge(
|
||||
|
||||
|
||||
async def complete_oauth_flow(
|
||||
mcp_server: MCPServer, code: str, state: str
|
||||
mcp_server: DiscordMCPServer, code: str, state: str
|
||||
) -> tuple[int, str]:
|
||||
"""Complete OAuth flow by exchanging code for token.
|
||||
|
||||
@ -196,7 +196,7 @@ async def complete_oauth_flow(
|
||||
return 400, "Invalid or expired OAuth state"
|
||||
|
||||
logger.info(
|
||||
f"Found MCP server config: id={mcp_server.id}, "
|
||||
f"Found MCP server config: user={mcp_server.discord_bot_user_id}, "
|
||||
f"url={mcp_server.mcp_server_url}"
|
||||
)
|
||||
|
||||
@ -247,8 +247,8 @@ async def complete_oauth_flow(
|
||||
mcp_server.code_verifier = None # type: ignore
|
||||
|
||||
logger.info(
|
||||
f"Stored tokens for MCP server id={mcp_server.id}, "
|
||||
f"url={mcp_server.mcp_server_url}"
|
||||
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
|
||||
f"server {mcp_server.mcp_server_url}"
|
||||
)
|
||||
|
||||
return 200, "✅ Authorization successful! You can now use this MCP server."
|
||||
|
||||
@ -12,12 +12,14 @@ from contextlib import asynccontextmanager
|
||||
from typing import cast
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import DiscordBotUser
|
||||
from memory.common.db.models import DiscordMCPServer, DiscordBotUser
|
||||
from memory.common.oauth import complete_oauth_flow
|
||||
from memory.discord.collector import MessageCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -179,14 +179,17 @@ def should_track_message(
|
||||
channel: DiscordChannel,
|
||||
user: DiscordUser,
|
||||
) -> bool:
|
||||
if server and server.ignore_messages:
|
||||
"""Pure function to determine if we should track this message"""
|
||||
if server and not server.track_messages: # type: ignore
|
||||
return False
|
||||
|
||||
if channel.ignore_messages:
|
||||
if not channel.track_messages:
|
||||
return False
|
||||
|
||||
if channel.channel_type in ("dm", "group_dm"):
|
||||
return not user.ignore_messages
|
||||
return bool(user.track_messages)
|
||||
|
||||
# Default: track the message
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Lightweight slash-command helpers for the Discord collector."""
|
||||
|
||||
import io
|
||||
from calendar import c
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Literal
|
||||
@ -9,34 +9,12 @@ import discord
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import (
|
||||
DiscordChannel,
|
||||
DiscordServer,
|
||||
DiscordUser,
|
||||
MCPServer,
|
||||
MCPServerAssignment,
|
||||
)
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
from memory.discord.mcp import run_mcp_server_command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ScopeLiteral = Literal["bot", "me", "server", "channel", "user"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscordObjects:
|
||||
bot: DiscordUser
|
||||
server: DiscordServer | None
|
||||
channel: DiscordChannel | None
|
||||
user: DiscordUser | None
|
||||
|
||||
@property
|
||||
def items(self):
|
||||
items = [self.bot, self.server, self.channel, self.user]
|
||||
return [item for item in items if item is not None]
|
||||
|
||||
|
||||
ListHandler = Callable[[DiscordObjects], str]
|
||||
ScopeLiteral = Literal["server", "channel", "user"]
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
@ -66,312 +44,12 @@ class CommandContext:
|
||||
CommandHandler = Callable[..., CommandResponse]
|
||||
|
||||
|
||||
async def respond(
|
||||
interaction: discord.Interaction, content: str, ephemeral: bool = True
|
||||
) -> None:
|
||||
"""Send a response to the interaction, as file if too large."""
|
||||
max_length = 1900
|
||||
if len(content) <= max_length:
|
||||
await interaction.response.send_message(content, ephemeral=ephemeral)
|
||||
return
|
||||
|
||||
file = discord.File(io.BytesIO(content.encode("utf-8")), filename="response.txt")
|
||||
await interaction.response.send_message(
|
||||
"Response too large, sending as file:", file=file, ephemeral=ephemeral
|
||||
)
|
||||
|
||||
|
||||
def with_object_context(
|
||||
bot: discord.Client,
|
||||
interaction: discord.Interaction,
|
||||
handler: ListHandler,
|
||||
user: discord.User | None,
|
||||
) -> str:
|
||||
"""Execute handler with Discord objects context."""
|
||||
server = interaction.guild
|
||||
channel = interaction.channel
|
||||
target_user = user or interaction.user
|
||||
with make_session() as session:
|
||||
objects = DiscordObjects(
|
||||
bot=ensure_user(session, bot.user),
|
||||
server=server and ensure_server(session, server),
|
||||
channel=channel and _ensure_channel(session, channel, server and server.id),
|
||||
user=ensure_user(session, target_user),
|
||||
)
|
||||
return handler(objects)
|
||||
|
||||
|
||||
def _create_scope_group(
|
||||
parent: discord.app_commands.Group,
|
||||
scope: ScopeLiteral,
|
||||
name: str,
|
||||
description: str,
|
||||
) -> discord.app_commands.Group:
|
||||
"""Create a command group for a scope (bot/me/server/channel).
|
||||
|
||||
Args:
|
||||
parent: Parent command group
|
||||
scope: Scope literal (bot, me, server, channel)
|
||||
name: Group name
|
||||
description: Group description
|
||||
"""
|
||||
group = discord.app_commands.Group(
|
||||
name=name, description=description, parent=parent
|
||||
)
|
||||
|
||||
@group.command(name="prompt", description=f"Manage {name}'s system prompt")
|
||||
@discord.app_commands.describe(prompt="The system prompt to set")
|
||||
async def prompt_cmd(interaction: discord.Interaction, prompt: str | None = None):
|
||||
await _run_interaction_command(
|
||||
interaction, scope=scope, handler=handle_prompt, prompt=prompt
|
||||
)
|
||||
|
||||
@group.command(name="chattiness", description=f"Show/set {name}'s chattiness")
|
||||
@discord.app_commands.describe(value="Optional new chattiness value (0-100)")
|
||||
async def chattiness_cmd(
|
||||
interaction: discord.Interaction, value: int | None = None
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction, scope=scope, handler=handle_chattiness, value=value
|
||||
)
|
||||
|
||||
# Ignore command
|
||||
@group.command(name="ignore", description=f"Toggle bot ignoring {name} messages")
|
||||
@discord.app_commands.describe(enabled="Whether to ignore messages")
|
||||
async def ignore_cmd(interaction: discord.Interaction, enabled: bool | None = None):
|
||||
await _run_interaction_command(
|
||||
interaction, scope=scope, handler=handle_ignore, ignore_enabled=enabled
|
||||
)
|
||||
|
||||
# Summary command
|
||||
@group.command(name="summary", description=f"Show {name}'s summary")
|
||||
async def summary_cmd(interaction: discord.Interaction):
|
||||
await _run_interaction_command(interaction, scope=scope, handler=handle_summary)
|
||||
|
||||
# MCP command
|
||||
@group.command(name="mcp", description=f"Manage {name}'s MCP servers")
|
||||
@discord.app_commands.describe(
|
||||
action="Action to perform",
|
||||
url="MCP server URL (required for add, delete, connect, tools)",
|
||||
)
|
||||
async def mcp_cmd(
|
||||
interaction: discord.Interaction,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||
url: str | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_mcp_servers,
|
||||
action=action,
|
||||
url=url and url.strip(),
|
||||
)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
def _create_user_scope_group(
|
||||
parent: discord.app_commands.Group,
|
||||
name: str,
|
||||
description: str,
|
||||
) -> discord.app_commands.Group:
|
||||
"""Create command group for user scope (requires user parameter).
|
||||
|
||||
Args:
|
||||
parent: Parent command group
|
||||
name: Group name
|
||||
description: Group description
|
||||
"""
|
||||
group = discord.app_commands.Group(
|
||||
name=name, description=description, parent=parent
|
||||
)
|
||||
scope = "user"
|
||||
|
||||
@group.command(name="prompt", description=f"Manage {name}'s system prompt")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user", prompt="The system prompt to set"
|
||||
)
|
||||
async def prompt_cmd(
|
||||
interaction: discord.Interaction, user: discord.User, prompt: str | None = None
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_prompt,
|
||||
target_user=user,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
@group.command(name="chattiness", description=f"Show/set {name}'s chattiness")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user", value="Optional new chattiness value (0-100)"
|
||||
)
|
||||
async def chattiness_cmd(
|
||||
interaction: discord.Interaction, user: discord.User, value: int | None = None
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_chattiness,
|
||||
target_user=user,
|
||||
value=value,
|
||||
)
|
||||
|
||||
# Ignore command
|
||||
@group.command(name="ignore", description=f"Toggle bot ignoring {name} messages")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user", enabled="Whether to ignore messages"
|
||||
)
|
||||
async def ignore_cmd(
|
||||
interaction: discord.Interaction,
|
||||
user: discord.User,
|
||||
enabled: bool | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_ignore,
|
||||
target_user=user,
|
||||
ignore_enabled=enabled,
|
||||
)
|
||||
|
||||
# Summary command
|
||||
@group.command(name="summary", description=f"Show {name}'s summary")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def summary_cmd(interaction: discord.Interaction, user: discord.User):
|
||||
await _run_interaction_command(
|
||||
interaction, scope=scope, handler=handle_summary, target_user=user
|
||||
)
|
||||
|
||||
# MCP command
|
||||
@group.command(name="mcp", description=f"Manage {name}'s MCP servers")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user",
|
||||
action="Action to perform",
|
||||
url="MCP server URL (required for add, delete, connect, tools)",
|
||||
)
|
||||
async def mcp_cmd(
|
||||
interaction: discord.Interaction,
|
||||
user: discord.User,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||
url: str | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_mcp_servers,
|
||||
target_user=user,
|
||||
action=action,
|
||||
url=url and url.strip(),
|
||||
)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
def create_list_group(
|
||||
bot: discord.Client, parent: discord.app_commands.Group
|
||||
) -> discord.app_commands.Group:
|
||||
"""Create command group for listing settings.
|
||||
|
||||
Args:
|
||||
parent: Parent command group
|
||||
"""
|
||||
group = discord.app_commands.Group(
|
||||
name="list", description="List settings", parent=parent
|
||||
)
|
||||
|
||||
@group.command(name="prompt", description="List full system prompt")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def prompt_cmd(
|
||||
interaction: discord.Interaction, user: discord.User | None = None
|
||||
):
|
||||
def handler(objects: DiscordObjects) -> str:
|
||||
prompts = [o.xml_prompt() for o in objects.items if o.system_prompt]
|
||||
return "\n\n".join(prompts)
|
||||
|
||||
res = with_object_context(bot, interaction, handler, user)
|
||||
await respond(interaction, res)
|
||||
|
||||
@group.command(name="chattiness", description="Show {name}'s chattiness")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def chattiness_cmd(
|
||||
interaction: discord.Interaction, user: discord.User | None = None
|
||||
):
|
||||
def handler(objects: DiscordObjects) -> str:
|
||||
values = [
|
||||
o.chattiness_threshold
|
||||
for o in objects.items
|
||||
if o.chattiness_threshold is not None
|
||||
]
|
||||
val = min(values) if values else 50
|
||||
if objects.user:
|
||||
return f"Total current chattiness for {objects.user.username}: {val}"
|
||||
return f"Total current chattiness: {val}"
|
||||
|
||||
res = with_object_context(bot, interaction, handler, user)
|
||||
await respond(interaction, res)
|
||||
|
||||
@group.command(
|
||||
name="ignore", description="Does this bot ignore messages for this user?"
|
||||
)
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def ignore_cmd(
|
||||
interaction: discord.Interaction,
|
||||
user: discord.User | None = None,
|
||||
):
|
||||
def handler(objects: DiscordObjects) -> str:
|
||||
should_ignore = any(o.ignore_messages for o in objects.items)
|
||||
if should_ignore:
|
||||
return f"The bot ignores messages for {objects.user}."
|
||||
return f"The bot does not ignore messages for {objects.user}."
|
||||
|
||||
res = with_object_context(bot, interaction, handler, user)
|
||||
await respond(interaction, res)
|
||||
|
||||
@group.command(name="summary", description="Show the full summary")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def summary_cmd(
|
||||
interaction: discord.Interaction, user: discord.User | None = None
|
||||
):
|
||||
def handler(objects: DiscordObjects) -> str:
|
||||
summaries = [o.xml_summary() for o in objects.items if o.summary]
|
||||
return "\n\n".join(summaries)
|
||||
|
||||
res = with_object_context(bot, interaction, handler, user)
|
||||
await respond(interaction, res)
|
||||
|
||||
@group.command(name="mcp", description="All used MCP servers")
|
||||
@discord.app_commands.describe(user="Target user")
|
||||
async def mcp_cmd(
|
||||
interaction: discord.Interaction, user: discord.User | None = None
|
||||
):
|
||||
logger.error(f"Listing MCP servers for {user}")
|
||||
ids = [
|
||||
interaction.guild_id,
|
||||
interaction.channel_id,
|
||||
(user or interaction.user).id,
|
||||
bot.user.id,
|
||||
]
|
||||
with make_session() as session:
|
||||
mcp_servers = (
|
||||
session.query(MCPServer)
|
||||
.filter(
|
||||
MCPServerAssignment.entity_id.in_(i for i in ids if i is not None)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
mcp_servers = [mcp_server.as_xml() for mcp_server in mcp_servers]
|
||||
res = "\n\n".join(mcp_servers)
|
||||
await respond(interaction, res)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
def register_slash_commands(bot: discord.Client) -> None:
|
||||
"""Register the collector slash commands on the provided bot.
|
||||
|
||||
Args:
|
||||
bot: Discord bot client
|
||||
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
|
||||
"""
|
||||
|
||||
if getattr(bot, "_memory_commands_registered", False):
|
||||
@ -385,21 +63,128 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
tree = bot.tree
|
||||
name = bot.user and bot.user.name.replace("-", "_").lower()
|
||||
|
||||
# Create main command group
|
||||
memory_group = discord.app_commands.Group(
|
||||
name=name or "memory", description=f"{name} bot configuration and management"
|
||||
@tree.command(
|
||||
name=f"{name}_show_prompt", description="Show the current system prompt"
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def show_prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_prompt,
|
||||
target_user=user,
|
||||
)
|
||||
|
||||
# Create scope groups
|
||||
_create_scope_group(memory_group, "bot", "bot", "Bot-wide settings")
|
||||
_create_scope_group(memory_group, "me", "me", "Your personal settings")
|
||||
_create_scope_group(memory_group, "server", "server", "Server-wide settings")
|
||||
_create_scope_group(memory_group, "channel", "channel", "Channel-specific settings")
|
||||
_create_user_scope_group(memory_group, "user", "Manage other users' settings")
|
||||
create_list_group(bot, memory_group)
|
||||
@tree.command(
|
||||
name=f"{name}_set_prompt",
|
||||
description="Set the system prompt for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify",
|
||||
prompt="The system prompt to set",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def set_prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
prompt: str,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_set_prompt,
|
||||
target_user=user,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
# Register main group
|
||||
tree.add_command(memory_group)
|
||||
@tree.command(
|
||||
name=f"{name}_chattiness",
|
||||
description="Show or update the chattiness for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
value="Optional new chattiness value between 0 and 100",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def chattiness_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
value: int | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_chattiness,
|
||||
target_user=user,
|
||||
value=value,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_ignore",
|
||||
description="Toggle whether the bot should ignore messages for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify",
|
||||
enabled="Optional flag. Leave empty to enable ignoring.",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def ignore_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
enabled: bool | None = None,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_ignore,
|
||||
target_user=user,
|
||||
ignore_enabled=enabled,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_show_summary",
|
||||
description="Show the stored summary for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def summary_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_summary,
|
||||
target_user=user,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_mcp_servers",
|
||||
description="Manage MCP servers for your account",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
action="Action to perform",
|
||||
url="MCP server URL (required for add, delete, connect, tools)",
|
||||
)
|
||||
async def mcp_servers_command(
|
||||
interaction: discord.Interaction,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"] = "list",
|
||||
url: str | None = None,
|
||||
) -> None:
|
||||
await run_mcp_server_command(interaction, bot.user, action, url and url.strip())
|
||||
|
||||
|
||||
async def _run_interaction_command(
|
||||
@ -413,16 +198,17 @@ async def _run_interaction_command(
|
||||
"""Shared coroutine used by the registered slash commands."""
|
||||
try:
|
||||
with make_session() as session:
|
||||
# Get bot from interaction client if needed for bot scope
|
||||
bot = getattr(interaction, "client", None)
|
||||
context = _build_context(session, interaction, scope, target_user, bot)
|
||||
response = await handler(context, **handler_kwargs)
|
||||
context = _build_context(session, interaction, scope, target_user)
|
||||
response = handler(context, **handler_kwargs)
|
||||
session.commit()
|
||||
except CommandError as exc: # pragma: no cover - passthrough
|
||||
await respond(interaction, str(exc))
|
||||
await interaction.response.send_message(str(exc), ephemeral=True)
|
||||
return
|
||||
|
||||
await respond(interaction, response.content, response.ephemeral)
|
||||
await interaction.response.send_message(
|
||||
response.content,
|
||||
ephemeral=response.ephemeral,
|
||||
)
|
||||
|
||||
|
||||
def _build_context(
|
||||
@ -430,55 +216,60 @@ def _build_context(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
target_user: discord.User | None,
|
||||
bot: discord.Client | None = None,
|
||||
) -> CommandContext:
|
||||
actor = ensure_user(session, interaction.user)
|
||||
actor = _ensure_user(session, interaction.user)
|
||||
|
||||
# Determine target and display name based on scope
|
||||
if scope == "bot":
|
||||
if not bot or not bot.user:
|
||||
raise CommandError("Bot user is not available.")
|
||||
target = ensure_user(session, bot.user)
|
||||
display_name = f"bot **{bot.user.name}**"
|
||||
|
||||
elif scope == "me":
|
||||
target = ensure_user(session, interaction.user)
|
||||
name = target.display_name or target.username
|
||||
display_name = f"you (**{name}**)"
|
||||
|
||||
elif scope == "server":
|
||||
if scope == "server":
|
||||
if interaction.guild is None:
|
||||
raise CommandError("This command can only be used inside a server.")
|
||||
target = ensure_server(session, interaction.guild)
|
||||
|
||||
target = _ensure_server(session, interaction.guild)
|
||||
display_name = f"server **{target.name}**"
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
elif scope == "channel":
|
||||
if interaction.channel is None or not hasattr(interaction.channel, "id"):
|
||||
if scope == "channel":
|
||||
channel_obj = interaction.channel
|
||||
if channel_obj is None or not hasattr(channel_obj, "id"):
|
||||
raise CommandError("Unable to determine channel for this interaction.")
|
||||
target = _ensure_channel(session, interaction.channel, interaction.guild_id)
|
||||
|
||||
target = _ensure_channel(session, channel_obj, interaction.guild_id)
|
||||
display_name = f"channel **#{target.name}**"
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
elif scope == "user":
|
||||
if target_user is None:
|
||||
if scope == "user":
|
||||
discord_user = target_user or interaction.user
|
||||
if discord_user is None:
|
||||
raise CommandError("A target user is required for this command.")
|
||||
target = ensure_user(session, target_user)
|
||||
name = target.display_name or target.username
|
||||
display_name = f"user **{name}**"
|
||||
|
||||
else:
|
||||
raise CommandError(f"Unsupported scope '{scope}'.")
|
||||
target = _ensure_user(session, discord_user)
|
||||
display_name = target.display_name or target.username
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=f"user **{display_name}**",
|
||||
)
|
||||
|
||||
return CommandContext(
|
||||
session=session,
|
||||
interaction=interaction,
|
||||
actor=actor,
|
||||
scope=scope,
|
||||
target=target,
|
||||
display_name=display_name,
|
||||
)
|
||||
raise CommandError(f"Unsupported scope '{scope}'.")
|
||||
|
||||
|
||||
def ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
|
||||
def _ensure_server(session: Session, guild: discord.Guild) -> DiscordServer:
|
||||
server = session.get(DiscordServer, guild.id)
|
||||
if server is None:
|
||||
server = DiscordServer(
|
||||
@ -529,7 +320,7 @@ def _ensure_channel(
|
||||
return channel_model
|
||||
|
||||
|
||||
def ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
|
||||
def _ensure_user(session: Session, discord_user: discord.abc.User) -> DiscordUser:
|
||||
user = session.get(DiscordUser, discord_user.id)
|
||||
display_name = getattr(discord_user, "display_name", discord_user.name)
|
||||
if user is None:
|
||||
@ -563,23 +354,32 @@ def _resolve_channel_type(channel: discord.abc.Messageable) -> str:
|
||||
return getattr(getattr(channel, "type", None), "name", "unknown")
|
||||
|
||||
|
||||
async def handle_prompt(
|
||||
context: CommandContext, *, prompt: str | None = None
|
||||
) -> CommandResponse:
|
||||
if prompt is not None:
|
||||
prompt = prompt or None
|
||||
setattr(context.target, "system_prompt", prompt)
|
||||
else:
|
||||
prompt = getattr(context.target, "system_prompt", None)
|
||||
def handle_prompt(context: CommandContext) -> CommandResponse:
|
||||
prompt = getattr(context.target, "system_prompt", None)
|
||||
|
||||
if prompt:
|
||||
content = f"Current prompt for {context.display_name}:\n\n{prompt}"
|
||||
else:
|
||||
content = f"No prompt configured for {context.display_name}."
|
||||
return CommandResponse(content=content)
|
||||
return CommandResponse(
|
||||
content=f"Current prompt for {context.display_name}:\n\n{prompt}",
|
||||
)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"No prompt configured for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
async def handle_chattiness(
|
||||
def handle_set_prompt(
|
||||
context: CommandContext,
|
||||
*,
|
||||
prompt: str,
|
||||
) -> CommandResponse:
|
||||
setattr(context.target, "system_prompt", prompt)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"Updated system prompt for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
def handle_chattiness(
|
||||
context: CommandContext,
|
||||
*,
|
||||
value: int | None,
|
||||
@ -609,7 +409,7 @@ async def handle_chattiness(
|
||||
)
|
||||
|
||||
|
||||
async def handle_ignore(
|
||||
def handle_ignore(
|
||||
context: CommandContext,
|
||||
*,
|
||||
ignore_enabled: bool | None,
|
||||
@ -624,7 +424,7 @@ async def handle_ignore(
|
||||
)
|
||||
|
||||
|
||||
async def handle_summary(context: CommandContext) -> CommandResponse:
|
||||
def handle_summary(context: CommandContext) -> CommandResponse:
|
||||
summary = getattr(context.target, "summary", None)
|
||||
|
||||
if summary:
|
||||
@ -635,31 +435,3 @@ async def handle_summary(context: CommandContext) -> CommandResponse:
|
||||
return CommandResponse(
|
||||
content=f"No summary stored for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_servers(
|
||||
context: CommandContext,
|
||||
*,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||
url: str | None,
|
||||
) -> CommandResponse:
|
||||
"""Handle MCP server commands for a specific scope."""
|
||||
# Map scope to database entity type
|
||||
entity_type = {
|
||||
"bot": "DiscordUser",
|
||||
"me": "DiscordUser",
|
||||
"user": "DiscordUser",
|
||||
"server": "DiscordServer",
|
||||
"channel": "DiscordChannel",
|
||||
}[context.scope]
|
||||
|
||||
bot_user = getattr(getattr(context.interaction, "client", None), "user", None)
|
||||
|
||||
try:
|
||||
res = await run_mcp_server_command(
|
||||
bot_user, action, url, entity_type, context.target.id
|
||||
)
|
||||
return CommandResponse(content=res)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
|
||||
raise CommandError(f"Error: {exc}") from exc
|
||||
|
||||
@ -10,27 +10,23 @@ import discord
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||
from memory.common.db.models.discord import DiscordMCPServer
|
||||
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_mcp_server(
|
||||
session: Session | scoped_session, entity_type: str, entity_id: int, url: str
|
||||
) -> MCPServer | None:
|
||||
"""Find an MCP server assigned to an entity."""
|
||||
assignment = (
|
||||
session.query(MCPServerAssignment)
|
||||
.join(MCPServer)
|
||||
session: Session | scoped_session, user_id: int, url: str
|
||||
) -> DiscordMCPServer | None:
|
||||
return (
|
||||
session.query(DiscordMCPServer)
|
||||
.filter(
|
||||
MCPServerAssignment.entity_type == entity_type,
|
||||
MCPServerAssignment.entity_id == entity_id,
|
||||
MCPServer.mcp_server_url == url,
|
||||
DiscordMCPServer.discord_bot_user_id == user_id,
|
||||
DiscordMCPServer.mcp_server_url == url,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return assignment and assignment.mcp_server
|
||||
|
||||
|
||||
async def call_mcp_server(
|
||||
@ -76,39 +72,35 @@ async def call_mcp_server(
|
||||
continue # Skip invalid JSON lines
|
||||
|
||||
|
||||
async def handle_mcp_list(entity_type: str, entity_id: int) -> str:
|
||||
async def handle_mcp_list(interaction: discord.Interaction) -> str:
|
||||
"""List all MCP servers for the user."""
|
||||
with make_session() as session:
|
||||
assignments = (
|
||||
session.query(MCPServerAssignment)
|
||||
.join(MCPServer)
|
||||
servers = (
|
||||
session.query(DiscordMCPServer)
|
||||
.filter(
|
||||
MCPServerAssignment.entity_type == entity_type,
|
||||
MCPServerAssignment.entity_id == entity_id,
|
||||
DiscordMCPServer.discord_bot_user_id == interaction.user.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not assignments:
|
||||
if not servers:
|
||||
return (
|
||||
"📋 **Your MCP Servers**\n\n"
|
||||
"You don't have any MCP servers configured yet.\n"
|
||||
"Use `/memory_mcp_servers add <url>` to add one."
|
||||
)
|
||||
|
||||
def format_server(assignment: MCPServerAssignment) -> str:
|
||||
server = assignment.mcp_server
|
||||
def format_server(server: DiscordMCPServer) -> str:
|
||||
con = "🟢" if cast(str | None, server.access_token) else "🔴"
|
||||
return f"{con} **{server.mcp_server_url}**\n`{server.client_id}`"
|
||||
|
||||
server_list = "\n".join(format_server(a) for a in assignments)
|
||||
server_list = "\n".join(format_server(s) for s in servers)
|
||||
|
||||
return f"📋 **Your MCP Servers**\n\n{server_list}"
|
||||
|
||||
|
||||
async def handle_mcp_add(
|
||||
entity_type: str,
|
||||
entity_id: int,
|
||||
interaction: discord.Interaction,
|
||||
bot_user: discord.User | None,
|
||||
url: str,
|
||||
) -> str:
|
||||
@ -116,7 +108,7 @@ async def handle_mcp_add(
|
||||
if not bot_user:
|
||||
raise ValueError("Bot user is required")
|
||||
with make_session() as session:
|
||||
if find_mcp_server(session, entity_type, entity_id, url):
|
||||
if find_mcp_server(session, bot_user.id, url):
|
||||
return (
|
||||
f"**MCP Server Already Exists**\n\n"
|
||||
f"You already have an MCP server configured at `{url}`.\n"
|
||||
@ -124,32 +116,25 @@ async def handle_mcp_add(
|
||||
)
|
||||
|
||||
endpoints = await get_endpoints(url)
|
||||
name = f"Discord Bot - {bot_user.name} ({entity_type} {entity_id})"
|
||||
client_id = await register_oauth_client(endpoints, url, name)
|
||||
|
||||
# Create MCP server
|
||||
mcp_server = MCPServer(
|
||||
client_id = await register_oauth_client(
|
||||
endpoints,
|
||||
url,
|
||||
f"Discord Bot - {bot_user.name} ({interaction.user.name})",
|
||||
)
|
||||
mcp_server = DiscordMCPServer(
|
||||
discord_bot_user_id=bot_user.id,
|
||||
mcp_server_url=url,
|
||||
client_id=client_id,
|
||||
name=name,
|
||||
)
|
||||
session.add(mcp_server)
|
||||
session.flush()
|
||||
|
||||
assignment = MCPServerAssignment(
|
||||
mcp_server_id=mcp_server.id,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
session.add(assignment)
|
||||
session.flush()
|
||||
|
||||
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Created MCP server record: id={mcp_server.id}, "
|
||||
f"{entity_type}={entity_id}, url={url}"
|
||||
f"user={interaction.user.id}, url={url}"
|
||||
)
|
||||
|
||||
return (
|
||||
@ -161,54 +146,32 @@ async def handle_mcp_add(
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_delete(entity_type: str, entity_id: int, url: str) -> str:
|
||||
"""Delete an MCP server assignment."""
|
||||
async def handle_mcp_delete(bot_user: discord.User, url: str) -> str:
|
||||
"""Delete an MCP server."""
|
||||
with make_session() as session:
|
||||
# Find the assignment
|
||||
assignment = (
|
||||
session.query(MCPServerAssignment)
|
||||
.join(MCPServer)
|
||||
.filter(
|
||||
MCPServerAssignment.entity_type == entity_type,
|
||||
MCPServerAssignment.entity_id == entity_id,
|
||||
MCPServer.mcp_server_url == url,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not assignment:
|
||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||
if not mcp_server:
|
||||
return (
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
f"You don't have an MCP server configured at `{url}`.\n"
|
||||
)
|
||||
|
||||
# Delete the assignment (server will cascade delete if no other assignments exist)
|
||||
session.delete(assignment)
|
||||
|
||||
# Check if server has other assignments
|
||||
mcp_server = assignment.mcp_server
|
||||
other_assignments = (
|
||||
session.query(MCPServerAssignment)
|
||||
.filter(
|
||||
MCPServerAssignment.mcp_server_id == mcp_server.id,
|
||||
MCPServerAssignment.id != assignment.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# If no other assignments, delete the server too
|
||||
if other_assignments == 0:
|
||||
session.delete(mcp_server)
|
||||
|
||||
session.delete(mcp_server)
|
||||
session.commit()
|
||||
|
||||
return f"🗑️ **Delete MCP Server**\n\nServer `{url}` has been removed."
|
||||
|
||||
|
||||
async def handle_mcp_connect(entity_type: str, entity_id: int, url: str) -> str:
|
||||
async def handle_mcp_connect(bot_user: discord.User, url: str) -> str:
|
||||
"""Reconnect to an existing MCP server (redo OAuth)."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, entity_type, entity_id, url)
|
||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
f"You don't have an MCP server configured at `{url}`.\n"
|
||||
f"Use `/memory_mcp_servers add {url}` to add it first."
|
||||
)
|
||||
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
f"**MCP Server Not Found**\n\n"
|
||||
@ -221,9 +184,7 @@ async def handle_mcp_connect(entity_type: str, entity_id: int, url: str) -> str:
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Regenerated OAuth challenge for {entity_type}={entity_id}, url={url}"
|
||||
)
|
||||
logger.info(f"Regenerated OAuth challenge for user={bot_user.id}, url={url}")
|
||||
|
||||
return (
|
||||
f"🔄 **Reconnect to MCP Server**\n\n"
|
||||
@ -234,10 +195,10 @@ async def handle_mcp_connect(entity_type: str, entity_id: int, url: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
|
||||
async def handle_mcp_tools(bot_user: discord.User, url: str) -> str:
|
||||
"""List tools available on an MCP server."""
|
||||
with make_session() as session:
|
||||
mcp_server = find_mcp_server(session, entity_type, entity_id, url)
|
||||
mcp_server = find_mcp_server(session, bot_user.id, url)
|
||||
|
||||
if not mcp_server:
|
||||
raise ValueError(
|
||||
@ -304,28 +265,37 @@ async def handle_mcp_tools(entity_type: str, entity_id: int, url: str) -> str:
|
||||
|
||||
|
||||
async def run_mcp_server_command(
|
||||
interaction: discord.Interaction,
|
||||
bot_user: discord.User | None,
|
||||
action: Literal["list", "add", "delete", "connect", "tools"],
|
||||
url: str | None,
|
||||
entity_type: str,
|
||||
entity_id: int,
|
||||
) -> None:
|
||||
"""Handle MCP server management commands."""
|
||||
if action not in ["list", "add", "delete", "connect", "tools"]:
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
await interaction.response.send_message("❌ Invalid action", ephemeral=True)
|
||||
return
|
||||
if action != "list" and not url:
|
||||
raise ValueError("URL is required for this action")
|
||||
await interaction.response.send_message(
|
||||
"❌ URL is required for this action", ephemeral=True
|
||||
)
|
||||
return
|
||||
if not bot_user:
|
||||
raise ValueError("Bot user is required")
|
||||
await interaction.response.send_message(
|
||||
"❌ Bot user is required", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
if action == "list" or not url:
|
||||
return await handle_mcp_list(entity_type, entity_id)
|
||||
elif action == "add":
|
||||
return await handle_mcp_add(entity_type, entity_id, bot_user, url)
|
||||
elif action == "delete":
|
||||
return await handle_mcp_delete(entity_type, entity_id, url)
|
||||
elif action == "connect":
|
||||
return await handle_mcp_connect(entity_type, entity_id, url)
|
||||
elif action == "tools":
|
||||
return await handle_mcp_tools(entity_type, entity_id, url)
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
try:
|
||||
if action == "list" or not url:
|
||||
result = await handle_mcp_list(interaction)
|
||||
elif action == "add":
|
||||
result = await handle_mcp_add(interaction, bot_user, url)
|
||||
elif action == "delete":
|
||||
result = await handle_mcp_delete(bot_user, url)
|
||||
elif action == "connect":
|
||||
result = await handle_mcp_connect(bot_user, url)
|
||||
elif action == "tools":
|
||||
result = await handle_mcp_tools(bot_user, url)
|
||||
except Exception as exc:
|
||||
result = f"❌ Error: {exc}"
|
||||
await interaction.response.send_message(result, ephemeral=True)
|
||||
|
||||
@ -12,9 +12,9 @@ from memory.common.db.models import (
|
||||
DiscordMessage,
|
||||
DiscordUser,
|
||||
ScheduledLLMCall,
|
||||
MCPServer as MCPServerModel,
|
||||
)
|
||||
from memory.common.llms.base import create_provider
|
||||
from memory.common.llms.tools import MCPServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -113,6 +113,8 @@ def upsert_scheduled_message(
|
||||
.first()
|
||||
)
|
||||
naive_scheduled_time = scheduled_time.replace(tzinfo=None)
|
||||
print(f"naive_scheduled_time: {naive_scheduled_time}")
|
||||
print(f"prev_call.scheduled_time: {prev_call and prev_call.scheduled_time}")
|
||||
if prev_call and cast(datetime, prev_call.scheduled_time) > naive_scheduled_time:
|
||||
prev_call.status = "cancelled" # type: ignore
|
||||
|
||||
@ -139,8 +141,10 @@ def previous_messages(
|
||||
) -> list[DiscordMessage]:
|
||||
messages = session.query(DiscordMessage)
|
||||
if user_id:
|
||||
print(f"user_id: {user_id}")
|
||||
messages = messages.filter(DiscordMessage.recipient_id == user_id)
|
||||
if channel_id:
|
||||
print(f"channel_id: {channel_id}")
|
||||
messages = messages.filter(DiscordMessage.channel_id == channel_id)
|
||||
return list(
|
||||
reversed(
|
||||
@ -187,7 +191,7 @@ def comm_channel_prompt(
|
||||
{users}
|
||||
</user_notes>
|
||||
""").format(
|
||||
users="\n".join({msg.from_user.xml_summary() for msg in messages}),
|
||||
users="\n".join({msg.from_user.as_xml() for msg in messages}),
|
||||
)
|
||||
|
||||
return textwrap.dedent("""
|
||||
@ -212,7 +216,7 @@ def call_llm(
|
||||
system_prompt: str = "",
|
||||
messages: list[str | dict[str, Any]] = [],
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
mcp_servers: list[MCPServerModel] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
num_previous_messages: int = 10,
|
||||
) -> str | None:
|
||||
"""
|
||||
@ -251,7 +255,6 @@ def call_llm(
|
||||
|
||||
from memory.common.llms.tools.discord import make_discord_tools
|
||||
from memory.common.llms.tools.base import WebSearchTool
|
||||
from memory.common.llms.tools import MCPServer
|
||||
|
||||
tools = make_discord_tools(bot_user.system_user, from_user, channel, model=model)
|
||||
tools |= {"web_search": WebSearchTool()}
|
||||
@ -267,16 +270,7 @@ def call_llm(
|
||||
messages=provider.as_messages(message_content),
|
||||
tools=tools,
|
||||
system_prompt=(bot_user.system_prompt or "") + "\n\n" + (system_prompt or ""),
|
||||
mcp_servers=[
|
||||
MCPServer(
|
||||
name=str(server.name),
|
||||
url=str(server.mcp_server_url),
|
||||
token=str(server.access_token),
|
||||
)
|
||||
for server in mcp_servers
|
||||
]
|
||||
if mcp_servers
|
||||
else None,
|
||||
mcp_servers=mcp_servers,
|
||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||
).response
|
||||
|
||||
@ -300,8 +294,10 @@ def send_discord_response(
|
||||
True if sent successfully
|
||||
"""
|
||||
if channel_id is not None:
|
||||
logger.info(f"Sending message to channel {channel_id}")
|
||||
return discord.send_to_channel(bot_id, channel_id, response)
|
||||
elif user_identifier is not None:
|
||||
logger.info(f"Sending DM to {user_identifier}")
|
||||
return discord.send_dm(bot_id, user_identifier, response)
|
||||
else:
|
||||
logger.error("Neither channel_id nor user_identifier provided")
|
||||
|
||||
@ -91,10 +91,6 @@ def should_process(message: DiscordMessage) -> bool:
|
||||
):
|
||||
return False
|
||||
|
||||
if f"<@{message.recipient_user.id}>" in message.content:
|
||||
logger.info("Direct mention of the bot, processing message")
|
||||
return True
|
||||
|
||||
if message.from_user == message.recipient_user:
|
||||
logger.info("Skipping message because from_user == recipient_user")
|
||||
return False
|
||||
@ -136,8 +132,6 @@ def should_process(message: DiscordMessage) -> bool:
|
||||
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||
return False
|
||||
try:
|
||||
logger.info(f"chattiness_threshold: {message.chattiness_threshold}")
|
||||
logger.info(f"haiku desire: {res.group(1)}")
|
||||
if int(res.group(1)) < 100 - message.chattiness_threshold:
|
||||
return False
|
||||
except ValueError:
|
||||
@ -196,11 +190,19 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
mcp_servers = discord_message.get_mcp_servers(session)
|
||||
system_prompt = discord_message.system_prompt or ""
|
||||
system_prompt += comm_channel_prompt(
|
||||
session, discord_message.recipient_user, discord_message.channel
|
||||
)
|
||||
mcp_servers = None
|
||||
if (
|
||||
discord_message.recipient_user
|
||||
and discord_message.recipient_user.mcp_servers
|
||||
):
|
||||
mcp_servers = [
|
||||
MCPServer(
|
||||
name=server.mcp_server_url,
|
||||
url=server.mcp_server_url,
|
||||
token=server.access_token,
|
||||
)
|
||||
for server in discord_message.recipient_user.mcp_servers
|
||||
]
|
||||
|
||||
try:
|
||||
response = call_llm(
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -21,31 +20,6 @@ from memory.common.qdrant import initialize_collections
|
||||
from tests.providers.email_provider import MockEmailProvider
|
||||
|
||||
|
||||
class MockRedis:
|
||||
"""In-memory mock of Redis for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
|
||||
def get(self, key: str):
|
||||
return self._data.get(key)
|
||||
|
||||
def set(self, key: str, value):
|
||||
self._data[key] = value
|
||||
|
||||
def scan_iter(self, match: str):
|
||||
import fnmatch
|
||||
|
||||
pattern = match.replace("*", "**")
|
||||
for key in self._data.keys():
|
||||
if fnmatch.fnmatch(key, pattern):
|
||||
yield key
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str):
|
||||
return cls()
|
||||
|
||||
|
||||
def get_test_db_name() -> str:
|
||||
return f"test_db_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@ -109,7 +83,7 @@ def run_alembic_migrations(db_name: str) -> None:
|
||||
alembic_ini = project_root / "db" / "migrations" / "alembic.ini"
|
||||
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||
["alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||
env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)},
|
||||
check=True,
|
||||
capture_output=True,
|
||||
@ -291,8 +265,7 @@ def mock_openai_client():
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@ -308,8 +281,7 @@ def mock_openai_client():
|
||||
delta=Mock(content="test", tool_calls=None),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
]
|
||||
),
|
||||
Mock(
|
||||
choices=[
|
||||
@ -317,8 +289,7 @@ def mock_openai_client():
|
||||
delta=Mock(content=" response", tool_calls=None),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=15),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
@ -332,8 +303,7 @@ def mock_openai_client():
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20),
|
||||
]
|
||||
)
|
||||
|
||||
client.chat.completions.create.side_effect = streaming_response
|
||||
@ -342,8 +312,6 @@ def mock_openai_client():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_anthropic_client():
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
|
||||
client = mock_client()
|
||||
client.messages = Mock()
|
||||
@ -377,57 +345,7 @@ def mock_anthropic_client():
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Mock async client
|
||||
async_client = Mock()
|
||||
async_client.messages = Mock()
|
||||
async_client.messages.create = AsyncMock(
|
||||
return_value=Mock(
|
||||
content=[
|
||||
Mock(
|
||||
type="text",
|
||||
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Mock async streaming
|
||||
def async_stream_ctx(*args, **kwargs):
|
||||
async def async_iter():
|
||||
yield Mock(
|
||||
type="content_block_delta",
|
||||
delta=Mock(
|
||||
type="text_delta",
|
||||
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||
),
|
||||
)
|
||||
|
||||
class AsyncStreamMock:
|
||||
async def __aenter__(self):
|
||||
return async_iter()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
return AsyncStreamMock()
|
||||
|
||||
async_client.messages.stream = Mock(side_effect=async_stream_ctx)
|
||||
|
||||
# Add async_client property to mock
|
||||
mock_client.return_value._async_client = None
|
||||
|
||||
with patch.object(anthropic, "AsyncAnthropic", return_value=async_client):
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_redis():
|
||||
"""Mock Redis client for all tests."""
|
||||
import redis
|
||||
|
||||
with patch.object(redis, "Redis", MockRedis):
|
||||
yield
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
@ -1,151 +0,0 @@
|
||||
"""Tests for authentication helpers and OAuth callback."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
|
||||
from memory.api import auth
|
||||
from memory.common import settings
|
||||
|
||||
|
||||
def make_request(query: str) -> Request:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/auth/callback/discord",
|
||||
"headers": [],
|
||||
"query_string": query.encode(),
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
|
||||
return Request(scope, receive)
|
||||
|
||||
|
||||
def test_get_bearer_token_parses_header():
|
||||
request = SimpleNamespace(headers={"Authorization": "Bearer token123"})
|
||||
|
||||
assert auth.get_bearer_token(request) == "token123"
|
||||
|
||||
|
||||
def test_get_bearer_token_handles_missing_header():
|
||||
request = SimpleNamespace(headers={})
|
||||
|
||||
assert auth.get_bearer_token(request) is None
|
||||
|
||||
|
||||
def test_get_token_prefers_header_over_cookie():
|
||||
request = SimpleNamespace(
|
||||
headers={"Authorization": "Bearer header-token"},
|
||||
cookies={"session": "cookie-token"},
|
||||
)
|
||||
|
||||
assert auth.get_token(request) == "header-token"
|
||||
|
||||
|
||||
def test_get_token_falls_back_to_cookie():
|
||||
request = SimpleNamespace(
|
||||
headers={},
|
||||
cookies={settings.SESSION_COOKIE_NAME: "cookie-token"},
|
||||
)
|
||||
|
||||
assert auth.get_token(request) == "cookie-token"
|
||||
|
||||
|
||||
@patch("memory.api.auth.get_user_session")
|
||||
def test_logout_removes_session(mock_get_user_session):
|
||||
db = MagicMock()
|
||||
session = MagicMock()
|
||||
mock_get_user_session.return_value = session
|
||||
request = SimpleNamespace()
|
||||
|
||||
result = auth.logout(request, db)
|
||||
|
||||
assert result == {"message": "Logged out successfully"}
|
||||
db.delete.assert_called_once_with(session)
|
||||
db.commit.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.api.auth.get_user_session", return_value=None)
|
||||
def test_logout_handles_missing_session(mock_get_user_session):
|
||||
db = MagicMock()
|
||||
request = SimpleNamespace()
|
||||
|
||||
result = auth.logout(request, db)
|
||||
|
||||
assert result == {"message": "Logged out successfully"}
|
||||
db.delete.assert_not_called()
|
||||
db.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.make_session")
|
||||
async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def session_cm():
|
||||
yield mock_session
|
||||
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
mcp_server = MagicMock()
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||
|
||||
mock_complete.return_value = (200, "Authorized")
|
||||
|
||||
request = make_request("code=abc123&state=state456")
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.body.decode()
|
||||
assert "Authorization Successful" in body
|
||||
assert "Authorized" in body
|
||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.make_session")
|
||||
async def test_oauth_callback_discord_handles_failures(
|
||||
mock_make_session, mock_complete
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def session_cm():
|
||||
yield mock_session
|
||||
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
mcp_server = MagicMock()
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||
|
||||
mock_complete.return_value = (500, "Failure")
|
||||
|
||||
request = make_request("code=abc123&state=state456")
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
|
||||
assert response.status_code == 500
|
||||
body = response.body.decode()
|
||||
assert "Authorization Failed" in body
|
||||
assert "Failure" in body
|
||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_discord_validates_query_params():
|
||||
request = make_request("code=&state=")
|
||||
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
|
||||
assert response.status_code == 400
|
||||
body = response.body.decode()
|
||||
assert "Missing authorization code" in body
|
||||
@ -1,7 +1,5 @@
|
||||
"""Tests for Discord database models."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from memory.common.db.models import DiscordServer, DiscordChannel, DiscordUser
|
||||
|
||||
@ -21,11 +19,12 @@ def test_create_discord_server(db_session):
|
||||
assert server.name == "Test Server"
|
||||
assert server.description == "A test Discord server"
|
||||
assert server.member_count == 100
|
||||
assert server.ignore_messages is False # default value
|
||||
assert server.track_messages is True # default value
|
||||
assert server.ignore_messages is False
|
||||
|
||||
|
||||
def test_discord_server_as_xml(db_session):
|
||||
"""Test DiscordServer.to_xml() method."""
|
||||
"""Test DiscordServer.as_xml() method."""
|
||||
server = DiscordServer(
|
||||
id=123456789,
|
||||
name="Test Server",
|
||||
@ -34,11 +33,11 @@ def test_discord_server_as_xml(db_session):
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
xml = server.to_xml("name", "summary")
|
||||
assert "<server>" in xml # tablename is discord_servers, strips to "server"
|
||||
assert "<name>" in xml and "Test Server" in xml
|
||||
assert "<summary>" in xml and "This is a test server for gaming" in xml
|
||||
assert "</server>" in xml
|
||||
xml = server.as_xml()
|
||||
assert "<servers>" in xml # tablename is discord_servers, strips to "servers"
|
||||
assert "<name>Test Server</name>" in xml
|
||||
assert "<summary>This is a test server for gaming</summary>" in xml
|
||||
assert "</servers>" in xml
|
||||
|
||||
|
||||
def test_discord_server_message_tracking(db_session):
|
||||
@ -46,11 +45,13 @@ def test_discord_server_message_tracking(db_session):
|
||||
server = DiscordServer(
|
||||
id=123456789,
|
||||
name="Test Server",
|
||||
track_messages=False,
|
||||
ignore_messages=True,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
assert server.track_messages is False
|
||||
assert server.ignore_messages is True
|
||||
|
||||
|
||||
@ -110,7 +111,7 @@ def test_discord_channel_without_server(db_session):
|
||||
|
||||
|
||||
def test_discord_channel_as_xml(db_session):
|
||||
"""Test DiscordChannel.to_xml() method."""
|
||||
"""Test DiscordChannel.as_xml() method."""
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
name="general",
|
||||
@ -120,28 +121,30 @@ def test_discord_channel_as_xml(db_session):
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
xml = channel.to_xml("name", "summary")
|
||||
assert "<channel>" in xml # tablename is discord_channels, strips to "channel"
|
||||
assert "<name>" in xml and "general" in xml
|
||||
assert "<summary>" in xml and "Main discussion channel" in xml
|
||||
assert "</channel>" in xml
|
||||
xml = channel.as_xml()
|
||||
assert "<channels>" in xml # tablename is discord_channels, strips to "channels"
|
||||
assert "<name>general</name>" in xml
|
||||
assert "<summary>Main discussion channel</summary>" in xml
|
||||
assert "</channels>" in xml
|
||||
|
||||
|
||||
def test_discord_channel_inherits_server_settings(db_session):
|
||||
"""Test that channels can have their own or inherit server settings."""
|
||||
server = DiscordServer(id=987654321, name="Server", ignore_messages=False)
|
||||
server = DiscordServer(
|
||||
id=987654321, name="Server", track_messages=True, ignore_messages=False
|
||||
)
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
server_id=server.id,
|
||||
name="announcements",
|
||||
channel_type="text",
|
||||
ignore_messages=True, # Override server setting
|
||||
track_messages=False, # Override server setting
|
||||
)
|
||||
db_session.add_all([server, channel])
|
||||
db_session.commit()
|
||||
|
||||
assert server.ignore_messages is False
|
||||
assert channel.ignore_messages is True
|
||||
assert server.track_messages is True
|
||||
assert channel.track_messages is False
|
||||
|
||||
|
||||
def test_create_discord_user(db_session):
|
||||
@ -183,7 +186,7 @@ def test_discord_user_with_system_user(db_session):
|
||||
|
||||
|
||||
def test_discord_user_as_xml(db_session):
|
||||
"""Test DiscordUser.to_xml() method."""
|
||||
"""Test DiscordUser.as_xml() method."""
|
||||
user = DiscordUser(
|
||||
id=555666777,
|
||||
username="testuser",
|
||||
@ -192,10 +195,11 @@ def test_discord_user_as_xml(db_session):
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
xml = user.to_xml("summary")
|
||||
assert "<user>" in xml # tablename is discord_users, strips to "user"
|
||||
assert "<summary>" in xml and "Friendly and helpful community member" in xml
|
||||
assert "</user>" in xml
|
||||
xml = user.as_xml()
|
||||
assert "<users>" in xml # tablename is discord_users, strips to "users"
|
||||
assert "<name>testuser</name>" in xml
|
||||
assert "<summary>Friendly and helpful community member</summary>" in xml
|
||||
assert "</users>" in xml
|
||||
|
||||
|
||||
def test_discord_user_message_preferences(db_session):
|
||||
@ -203,11 +207,13 @@ def test_discord_user_message_preferences(db_session):
|
||||
user = DiscordUser(
|
||||
id=555666777,
|
||||
username="testuser",
|
||||
track_messages=True,
|
||||
ignore_messages=False,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
assert user.track_messages is True
|
||||
assert user.ignore_messages is False
|
||||
|
||||
|
||||
@ -228,21 +234,6 @@ def test_discord_server_channel_relationship(db_session):
|
||||
assert channel2 in server.channels
|
||||
|
||||
|
||||
def test_discord_processor_xml_mcp_servers():
|
||||
"""Test xml_mcp_servers includes assigned MCP server XML."""
|
||||
server = DiscordServer(id=111, name="Server")
|
||||
mcp_stub = SimpleNamespace(
|
||||
as_xml=lambda: "<mcp_server><name>Example</name></mcp_server>"
|
||||
)
|
||||
|
||||
# Relationship is optional for test purposes; assign directly
|
||||
server.mcp_servers = [mcp_stub]
|
||||
|
||||
xml_output = server.xml_mcp_servers()
|
||||
assert "<mcp_server>" in xml_output
|
||||
assert "Example" in xml_output
|
||||
|
||||
|
||||
def test_discord_server_cascade_delete(db_session):
|
||||
"""Test that deleting a server cascades to channels."""
|
||||
server = DiscordServer(id=987654321, name="Test Server")
|
||||
|
||||
@ -1,155 +0,0 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from memory.common.db.models.mcp import MCPServer, MCPServerAssignment
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"available_tools,expected_tools",
|
||||
[
|
||||
(["search", "summarize"], ["• search", "• summarize"]),
|
||||
([], []),
|
||||
],
|
||||
)
|
||||
def test_mcp_server_as_xml_formats_available_tools(available_tools, expected_tools):
|
||||
server = MCPServer(
|
||||
name="Example Server",
|
||||
mcp_server_url="https://example.com/mcp",
|
||||
client_id="client-123",
|
||||
available_tools=available_tools,
|
||||
)
|
||||
|
||||
xml_output = server.as_xml()
|
||||
root = ET.fromstring(xml_output)
|
||||
|
||||
assert root.find("name").text.strip() == "Example Server"
|
||||
assert root.find("mcp_server_url").text.strip() == "https://example.com/mcp"
|
||||
assert root.find("client_id").text.strip() == "client-123"
|
||||
|
||||
tools_element = root.find("available_tools")
|
||||
assert tools_element is not None
|
||||
|
||||
tools_text = tools_element.text.strip() if tools_element.text else ""
|
||||
if expected_tools:
|
||||
assert tools_text.splitlines() == expected_tools
|
||||
else:
|
||||
assert tools_text == ""
|
||||
|
||||
|
||||
def test_mcp_server_crud_and_token_expiration(db_session):
|
||||
initial_expiry = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
server = MCPServer(
|
||||
name="Initial Server",
|
||||
mcp_server_url="https://initial.example.com/mcp",
|
||||
client_id="client-initial",
|
||||
available_tools=["search"],
|
||||
access_token="access-123",
|
||||
refresh_token="refresh-123",
|
||||
token_expires_at=initial_expiry,
|
||||
)
|
||||
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
fetched = db_session.get(MCPServer, server.id)
|
||||
assert fetched is not None
|
||||
assert fetched.access_token == "access-123"
|
||||
assert fetched.refresh_token == "refresh-123"
|
||||
assert fetched.token_expires_at == initial_expiry
|
||||
assert fetched.token_expires_at.tzinfo is not None
|
||||
|
||||
new_expiry = initial_expiry + timedelta(minutes=15)
|
||||
fetched.name = "Updated Server"
|
||||
fetched.available_tools = [*fetched.available_tools, "summarize"]
|
||||
fetched.access_token = "access-456"
|
||||
fetched.refresh_token = "refresh-456"
|
||||
fetched.token_expires_at = new_expiry
|
||||
db_session.commit()
|
||||
|
||||
updated = db_session.get(MCPServer, server.id)
|
||||
assert updated is not None
|
||||
assert updated.name == "Updated Server"
|
||||
assert updated.available_tools == ["search", "summarize"]
|
||||
assert updated.access_token == "access-456"
|
||||
assert updated.refresh_token == "refresh-456"
|
||||
assert updated.token_expires_at == new_expiry
|
||||
|
||||
db_session.delete(updated)
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.get(MCPServer, server.id) is None
|
||||
|
||||
|
||||
def test_mcp_server_assignments_relationship_and_cascade(db_session):
|
||||
server = MCPServer(
|
||||
name="Cascade Server",
|
||||
mcp_server_url="https://cascade.example.com/mcp",
|
||||
client_id="client-cascade",
|
||||
available_tools=["search"],
|
||||
)
|
||||
server.assignments.extend(
|
||||
[
|
||||
MCPServerAssignment(entity_type="DiscordUser", entity_id=101),
|
||||
MCPServerAssignment(entity_type="DiscordChannel", entity_id=202),
|
||||
]
|
||||
)
|
||||
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
persisted_server = db_session.get(MCPServer, server.id)
|
||||
assert persisted_server is not None
|
||||
assert len(persisted_server.assignments) == 2
|
||||
assert {assignment.entity_type for assignment in persisted_server.assignments} == {
|
||||
"DiscordUser",
|
||||
"DiscordChannel",
|
||||
}
|
||||
assert all(
|
||||
assignment.mcp_server_id == persisted_server.id
|
||||
for assignment in persisted_server.assignments
|
||||
)
|
||||
|
||||
db_session.delete(persisted_server)
|
||||
db_session.commit()
|
||||
|
||||
remaining_assignments = db_session.query(MCPServerAssignment).all()
|
||||
assert remaining_assignments == []
|
||||
|
||||
|
||||
def test_mcp_server_assignment_unique_constraint(db_session):
|
||||
server = MCPServer(
|
||||
name="Unique Server",
|
||||
mcp_server_url="https://unique.example.com/mcp",
|
||||
client_id="client-unique",
|
||||
available_tools=["search"],
|
||||
)
|
||||
assignment = MCPServerAssignment(
|
||||
entity_type="DiscordUser",
|
||||
entity_id=12345,
|
||||
)
|
||||
server.assignments.append(assignment)
|
||||
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
duplicate_assignment = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=12345,
|
||||
)
|
||||
db_session.add(duplicate_assignment)
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
db_session.commit()
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
assignments = (
|
||||
db_session.query(MCPServerAssignment)
|
||||
.filter(MCPServerAssignment.mcp_server_id == server.id)
|
||||
.all()
|
||||
)
|
||||
assert len(assignments) == 1
|
||||
@ -162,18 +162,8 @@ def test_create_bot_user_auto_api_key(db_session):
|
||||
|
||||
def test_create_discord_bot_user(db_session):
|
||||
"""Test creating a DiscordBotUser"""
|
||||
from memory.common.db.models import DiscordUser
|
||||
|
||||
# Create a Discord user for the bot
|
||||
discord_user = DiscordUser(
|
||||
id=123456789,
|
||||
username="botuser",
|
||||
)
|
||||
db_session.add(discord_user)
|
||||
db_session.commit()
|
||||
|
||||
user = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[discord_user],
|
||||
discord_users=[],
|
||||
name="Discord Bot",
|
||||
email="discordbot@example.com",
|
||||
api_key="discord_key_123",
|
||||
@ -186,7 +176,6 @@ def test_create_discord_bot_user(db_session):
|
||||
assert user.name == "Discord Bot"
|
||||
assert user.user_type == "discord_bot"
|
||||
assert user.api_key == "discord_key_123"
|
||||
assert len(user.discord_users) == 1
|
||||
|
||||
|
||||
def test_user_serialization_human(db_session):
|
||||
|
||||
@ -131,7 +131,7 @@ def test_build_request_kwargs_basic(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, None, settings)
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, settings)
|
||||
|
||||
assert kwargs["model"] == "claude-3-opus-20240229"
|
||||
assert kwargs["temperature"] == 0.5
|
||||
@ -143,9 +143,7 @@ def test_build_request_kwargs_with_system_prompt(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(
|
||||
messages, "system prompt", None, None, settings
|
||||
)
|
||||
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
|
||||
|
||||
assert kwargs["system"] == "system prompt"
|
||||
|
||||
@ -162,7 +160,7 @@ def test_build_request_kwargs_with_tools(provider):
|
||||
]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, None, tools, None, settings)
|
||||
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
|
||||
|
||||
assert "tools" in kwargs
|
||||
assert len(kwargs["tools"]) == 1
|
||||
@ -172,9 +170,7 @@ def test_build_request_kwargs_with_thinking(thinking_provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(max_tokens=5000)
|
||||
|
||||
kwargs = thinking_provider._build_request_kwargs(
|
||||
messages, None, None, None, settings
|
||||
)
|
||||
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
|
||||
|
||||
assert "thinking" in kwargs
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
@ -187,9 +183,7 @@ def test_build_request_kwargs_thinking_insufficient_tokens(thinking_provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(max_tokens=1000)
|
||||
|
||||
kwargs = thinking_provider._build_request_kwargs(
|
||||
messages, None, None, None, settings
|
||||
)
|
||||
kwargs = thinking_provider._build_request_kwargs(messages, None, None, settings)
|
||||
|
||||
# Shouldn't enable thinking if not enough tokens
|
||||
assert "thinking" not in kwargs
|
||||
@ -332,7 +326,7 @@ async def test_agenerate_basic(provider, mock_anthropic_client):
|
||||
|
||||
result = await provider.agenerate(messages)
|
||||
|
||||
assert "<summary>test summary</summary>" in result
|
||||
assert result == "test summary"
|
||||
provider.async_client.messages.create.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from unittest.mock import Mock
|
||||
from PIL import Image
|
||||
|
||||
from memory.common.llms.openai_provider import OpenAIProvider
|
||||
@ -192,7 +192,7 @@ def test_build_request_kwargs_basic(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(temperature=0.5, max_tokens=1000)
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, None, settings)
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, settings)
|
||||
|
||||
assert kwargs["model"] == "gpt-4o"
|
||||
assert kwargs["temperature"] == 0.5
|
||||
@ -204,9 +204,7 @@ def test_build_request_kwargs_with_system_prompt_standard_model(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(
|
||||
messages, "system prompt", None, None, settings
|
||||
)
|
||||
kwargs = provider._build_request_kwargs(messages, "system prompt", None, settings)
|
||||
|
||||
# For gpt-4o, system prompt becomes system message
|
||||
assert kwargs["messages"][0]["role"] == "system"
|
||||
@ -220,7 +218,7 @@ def test_build_request_kwargs_with_system_prompt_reasoning_model(
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = reasoning_provider._build_request_kwargs(
|
||||
messages, "system prompt", None, None, settings
|
||||
messages, "system prompt", None, settings
|
||||
)
|
||||
|
||||
# For o1 models, system prompt becomes developer message
|
||||
@ -234,9 +232,7 @@ def test_build_request_kwargs_reasoning_model_uses_max_completion_tokens(
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(max_tokens=2000)
|
||||
|
||||
kwargs = reasoning_provider._build_request_kwargs(
|
||||
messages, None, None, None, settings
|
||||
)
|
||||
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
|
||||
|
||||
# Reasoning models use max_completion_tokens
|
||||
assert "max_completion_tokens" in kwargs
|
||||
@ -248,9 +244,7 @@ def test_build_request_kwargs_reasoning_model_no_temperature(reasoning_provider)
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings(temperature=0.7)
|
||||
|
||||
kwargs = reasoning_provider._build_request_kwargs(
|
||||
messages, None, None, None, settings
|
||||
)
|
||||
kwargs = reasoning_provider._build_request_kwargs(messages, None, None, settings)
|
||||
|
||||
# Reasoning models don't support temperature
|
||||
assert "temperature" not in kwargs
|
||||
@ -269,7 +263,7 @@ def test_build_request_kwargs_with_tools(provider):
|
||||
]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(messages, None, tools, None, settings)
|
||||
kwargs = provider._build_request_kwargs(messages, None, tools, settings)
|
||||
|
||||
assert "tools" in kwargs
|
||||
assert len(kwargs["tools"]) == 1
|
||||
@ -280,9 +274,7 @@ def test_build_request_kwargs_with_stream(provider):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
settings = LLMSettings()
|
||||
|
||||
kwargs = provider._build_request_kwargs(
|
||||
messages, None, None, None, settings, stream=True
|
||||
)
|
||||
kwargs = provider._build_request_kwargs(messages, None, None, settings, stream=True)
|
||||
|
||||
assert kwargs["stream"] is True
|
||||
|
||||
@ -322,8 +314,7 @@ def test_handle_stream_chunk_text_content(provider):
|
||||
delta=Mock(content="hello", tool_calls=None),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
]
|
||||
)
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||
@ -351,9 +342,8 @@ def test_handle_stream_chunk_tool_call_start(provider):
|
||||
choice.delta = delta
|
||||
choice.finish_reason = None
|
||||
|
||||
chunk = Mock(spec=["choices", "usage"])
|
||||
chunk = Mock(spec=["choices"])
|
||||
chunk.choices = [choice]
|
||||
chunk.usage = Mock(prompt_tokens=10, completion_tokens=5)
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||
|
||||
@ -379,8 +369,7 @@ def test_handle_stream_chunk_tool_call_arguments(provider):
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
]
|
||||
)
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
||||
@ -397,8 +386,7 @@ def test_handle_stream_chunk_finish_with_tool_call(provider):
|
||||
delta=Mock(content=None, tool_calls=None),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
]
|
||||
)
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, current_tool)
|
||||
@ -411,7 +399,7 @@ def test_handle_stream_chunk_finish_with_tool_call(provider):
|
||||
|
||||
|
||||
def test_handle_stream_chunk_empty_choices(provider):
|
||||
chunk = Mock(choices=[], usage=Mock(prompt_tokens=10, completion_tokens=5))
|
||||
chunk = Mock(choices=[])
|
||||
|
||||
events, tool_call = provider._handle_stream_chunk(chunk, None)
|
||||
|
||||
@ -447,13 +435,8 @@ async def test_agenerate_basic(provider, mock_openai_client):
|
||||
messages = [Message(role=MessageRole.USER, content="test")]
|
||||
|
||||
# Mock the async client
|
||||
mock_response = Mock(
|
||||
choices=[Mock(message=Mock(content="async response"))],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
provider.async_client.chat.completions.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
mock_response = Mock(choices=[Mock(message=Mock(content="async response"))])
|
||||
provider.async_client.chat.completions.create = Mock(return_value=mock_response)
|
||||
|
||||
result = await provider.agenerate(messages)
|
||||
|
||||
@ -469,19 +452,15 @@ async def test_astream_basic(provider, mock_openai_client):
|
||||
yield Mock(
|
||||
choices=[
|
||||
Mock(delta=Mock(content="async", tool_calls=None), finish_reason=None)
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=5),
|
||||
]
|
||||
)
|
||||
yield Mock(
|
||||
choices=[
|
||||
Mock(delta=Mock(content=" test", tool_calls=None), finish_reason="stop")
|
||||
],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=10),
|
||||
]
|
||||
)
|
||||
|
||||
provider.async_client.chat.completions.create = AsyncMock(
|
||||
return_value=async_stream()
|
||||
)
|
||||
provider.async_client.chat.completions.create = Mock(return_value=async_stream())
|
||||
|
||||
events = []
|
||||
async for event in provider.astream(messages):
|
||||
|
||||
@ -18,10 +18,10 @@ except ModuleNotFoundError: # pragma: no cover - import guard for test envs
|
||||
|
||||
sys.modules.setdefault("redis", _RedisStub())
|
||||
|
||||
from memory.common.llms.usage import (
|
||||
from memory.common.llms.redis_usage_tracker import RedisUsageTracker
|
||||
from memory.common.llms.usage_tracker import (
|
||||
InMemoryUsageTracker,
|
||||
RateLimitConfig,
|
||||
RedisUsageTracker,
|
||||
UsageTracker,
|
||||
)
|
||||
|
||||
@ -84,9 +84,7 @@ def redis_tracker() -> RedisUsageTracker:
|
||||
(timedelta(seconds=0), {"max_total_tokens": 1}),
|
||||
],
|
||||
)
|
||||
def test_rate_limit_config_validation(
|
||||
window: timedelta, kwargs: dict[str, int]
|
||||
) -> None:
|
||||
def test_rate_limit_config_validation(window: timedelta, kwargs: dict[str, int]) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
RateLimitConfig(window=window, **kwargs)
|
||||
|
||||
@ -95,7 +93,9 @@ def test_allows_usage_within_limits(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||
|
||||
allowance = tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
|
||||
allowance = tracker.get_available_tokens(
|
||||
"anthropic/claude-3", timestamp=now
|
||||
)
|
||||
assert allowance is not None
|
||||
assert allowance.input_tokens == 900
|
||||
assert allowance.output_tokens == 1_800
|
||||
@ -114,7 +114,9 @@ def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
|
||||
tracker.record_usage("anthropic/claude-3", 800, 1_700, timestamp=now)
|
||||
|
||||
later = now + timedelta(minutes=2)
|
||||
allowance = tracker.get_available_tokens("anthropic/claude-3", timestamp=later)
|
||||
allowance = tracker.get_available_tokens(
|
||||
"anthropic/claude-3", timestamp=later
|
||||
)
|
||||
assert allowance is not None
|
||||
assert allowance.input_tokens == 1_000
|
||||
assert allowance.output_tokens == 2_000
|
||||
@ -124,7 +126,6 @@ def test_recovers_after_window(tracker: InMemoryUsageTracker) -> None:
|
||||
|
||||
def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
# Use the configured models from the fixture
|
||||
tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||
tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||
|
||||
@ -143,7 +144,6 @@ def test_usage_breakdown_and_provider_totals(tracker: InMemoryUsageTracker) -> N
|
||||
|
||||
def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
# Use configured models from the fixture
|
||||
tracker.record_usage("anthropic/claude-3", 10, 20, timestamp=now)
|
||||
tracker.record_usage("openai/gpt-4o", 5, 5, timestamp=now)
|
||||
|
||||
@ -156,19 +156,15 @@ def test_get_usage_breakdown_filters(tracker: InMemoryUsageTracker) -> None:
|
||||
assert set(filtered_model["openai"].keys()) == {"gpt-4o"}
|
||||
|
||||
|
||||
def test_missing_configuration_uses_default() -> None:
|
||||
# With no specific config, falls back to default config (from settings)
|
||||
def test_missing_configuration_records_lifetime_only() -> None:
|
||||
tracker = InMemoryUsageTracker(configs={})
|
||||
tracker.record_usage("openai/gpt-4o", 10, 20)
|
||||
|
||||
# Uses default config, so get_available_tokens returns allowance
|
||||
allowance = tracker.get_available_tokens("openai/gpt-4o")
|
||||
assert allowance is not None
|
||||
assert tracker.get_available_tokens("openai/gpt-4o") is None
|
||||
|
||||
# Lifetime stats are tracked
|
||||
breakdown = tracker.get_usage_breakdown()
|
||||
usage = breakdown["openai"]["gpt-4o"]
|
||||
assert usage.window_input_tokens == 10
|
||||
assert usage.window_input_tokens == 0
|
||||
assert usage.lifetime_input_tokens == 10
|
||||
|
||||
|
||||
@ -197,7 +193,6 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
|
||||
|
||||
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
# Use configured models from the fixture
|
||||
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||
|
||||
@ -206,8 +201,6 @@ def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) ->
|
||||
assert allowance.input_tokens == 900
|
||||
|
||||
breakdown = redis_tracker.get_usage_breakdown()
|
||||
assert "anthropic" in breakdown
|
||||
assert "claude-3" in breakdown["anthropic"]
|
||||
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
|
||||
|
||||
items = dict(redis_tracker.iter_state_items())
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
"""Tests for base web tool definitions."""
|
||||
|
||||
from memory.common.llms.tools.base import WebFetchTool, WebSearchTool
|
||||
|
||||
|
||||
def test_web_search_tool_provider_formats():
|
||||
tool = WebSearchTool()
|
||||
|
||||
assert tool.provider_format("openai") == {"type": "web_search"}
|
||||
assert tool.provider_format("anthropic") == {
|
||||
"type": "web_search_20250305",
|
||||
"name": "web_search",
|
||||
"max_uses": 10,
|
||||
}
|
||||
assert tool.provider_format("unknown") is None
|
||||
|
||||
|
||||
def test_web_fetch_tool_provider_formats():
|
||||
tool = WebFetchTool()
|
||||
|
||||
assert tool.provider_format("anthropic") == {
|
||||
"type": "web_fetch_20250910",
|
||||
"name": "web_fetch",
|
||||
"max_uses": 10,
|
||||
}
|
||||
assert tool.provider_format("openai") is None
|
||||
@ -497,14 +497,13 @@ def test_make_discord_tools_with_user_and_channel(
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# update_user_summary, update_server_summary, add_reaction
|
||||
assert len(tools) == 6
|
||||
# update_user_summary, update_server_summary
|
||||
assert len(tools) == 5
|
||||
assert "schedule_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_user_summary" in tools
|
||||
assert "update_server_summary" in tools
|
||||
assert "add_reaction" in tools
|
||||
|
||||
|
||||
def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user):
|
||||
@ -534,13 +533,12 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# update_server_summary, add_reaction (no user summary without author)
|
||||
assert len(tools) == 5
|
||||
# update_server_summary (no user summary without author)
|
||||
assert len(tools) == 4
|
||||
assert "schedule_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_server_summary" in tools
|
||||
assert "add_reaction" in tools
|
||||
assert "update_user_summary" not in tools
|
||||
|
||||
|
||||
|
||||
@ -1,539 +0,0 @@
|
||||
"""Tests for OAuth 2.0 flow handling."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import aiohttp
|
||||
|
||||
from memory.common.oauth import (
|
||||
OAuthEndpoints,
|
||||
generate_pkce_pair,
|
||||
discover_oauth_metadata,
|
||||
get_endpoints,
|
||||
register_oauth_client,
|
||||
issue_challenge,
|
||||
complete_oauth_flow,
|
||||
)
|
||||
from memory.common.db.models import MCPServer
|
||||
|
||||
|
||||
class TestGeneratePkcePair:
|
||||
"""Tests for generate_pkce_pair function."""
|
||||
|
||||
def test_generates_valid_verifier_and_challenge(self):
|
||||
"""Test that PKCE pair is generated correctly."""
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
|
||||
# Verifier should be base64url encoded (no padding)
|
||||
assert len(verifier) > 0
|
||||
assert "=" not in verifier
|
||||
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" for c in verifier)
|
||||
|
||||
# Challenge should be base64url encoded (no padding)
|
||||
assert len(challenge) > 0
|
||||
assert "=" not in challenge
|
||||
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" for c in challenge)
|
||||
|
||||
# They should be different
|
||||
assert verifier != challenge
|
||||
|
||||
def test_generates_unique_pairs(self):
|
||||
"""Test that each call generates a unique pair."""
|
||||
verifier1, challenge1 = generate_pkce_pair()
|
||||
verifier2, challenge2 = generate_pkce_pair()
|
||||
|
||||
assert verifier1 != verifier2
|
||||
assert challenge1 != challenge2
|
||||
|
||||
|
||||
class TestDiscoverOauthMetadata:
|
||||
"""Tests for discover_oauth_metadata function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_metadata_success(self):
|
||||
"""Test successful OAuth metadata discovery."""
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value=metadata)
|
||||
|
||||
mock_get = AsyncMock()
|
||||
mock_get.__aenter__.return_value = mock_response
|
||||
mock_get.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.get = Mock(return_value=mock_get)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await discover_oauth_metadata("https://example.com")
|
||||
|
||||
assert result == metadata
|
||||
assert result["authorization_endpoint"] == "https://example.com/auth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_metadata_not_found(self):
|
||||
"""Test OAuth metadata discovery when endpoint not found."""
|
||||
mock_response = Mock()
|
||||
mock_response.status = 404
|
||||
|
||||
mock_get = AsyncMock()
|
||||
mock_get.__aenter__.return_value = mock_response
|
||||
mock_get.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.get = Mock(return_value=mock_get)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await discover_oauth_metadata("https://example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_metadata_connection_error(self):
|
||||
"""Test OAuth metadata discovery with connection error."""
|
||||
mock_get = AsyncMock()
|
||||
mock_get.__aenter__.side_effect = aiohttp.ClientError("Connection failed")
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.get = Mock(return_value=mock_get)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await discover_oauth_metadata("https://example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetEndpoints:
|
||||
"""Tests for get_endpoints function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_success(self):
|
||||
"""Test successful endpoint retrieval."""
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||
result = await get_endpoints("https://example.com")
|
||||
|
||||
assert isinstance(result, OAuthEndpoints)
|
||||
assert result.authorization_endpoint == "https://example.com/auth"
|
||||
assert result.registration_endpoint == "https://example.com/register"
|
||||
assert result.token_endpoint == "https://example.com/token"
|
||||
assert "/auth/callback/discord" in result.redirect_uri
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_no_metadata(self):
|
||||
"""Test when OAuth metadata cannot be discovered."""
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=None):
|
||||
with pytest.raises(ValueError, match="Failed to connect to MCP server"):
|
||||
await get_endpoints("https://example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_missing_authorization(self):
|
||||
"""Test when authorization endpoint is missing."""
|
||||
metadata = {
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||
with pytest.raises(ValueError, match="authorization endpoint"):
|
||||
await get_endpoints("https://example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_missing_registration(self):
|
||||
"""Test when registration endpoint is missing."""
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||
with pytest.raises(ValueError, match="dynamic client registration"):
|
||||
await get_endpoints("https://example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_endpoints_missing_token(self):
|
||||
"""Test when token endpoint is missing."""
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
}
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata):
|
||||
with pytest.raises(ValueError, match="token endpoint"):
|
||||
await get_endpoints("https://example.com")
|
||||
|
||||
|
||||
class TestRegisterOauthClient:
|
||||
"""Tests for register_oauth_client function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client_success(self):
|
||||
"""Test successful OAuth client registration."""
|
||||
endpoints = OAuthEndpoints(
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
registration_endpoint="https://example.com/register",
|
||||
token_endpoint="https://example.com/token",
|
||||
redirect_uri="https://myapp.com/callback",
|
||||
)
|
||||
|
||||
client_info = {"client_id": "test-client-123"}
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.text = AsyncMock(return_value="Success")
|
||||
mock_response.json = AsyncMock(return_value=client_info)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
client_id = await register_oauth_client(
|
||||
endpoints,
|
||||
"https://example.com",
|
||||
"Test Client",
|
||||
)
|
||||
|
||||
assert client_id == "test-client-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client_http_error(self):
|
||||
"""Test OAuth client registration with HTTP error."""
|
||||
endpoints = OAuthEndpoints(
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
registration_endpoint="https://example.com/register",
|
||||
token_endpoint="https://example.com/token",
|
||||
redirect_uri="https://myapp.com/callback",
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status = Mock(side_effect=aiohttp.ClientResponseError(
|
||||
request_info=Mock(),
|
||||
history=(),
|
||||
status=400,
|
||||
message="Bad Request",
|
||||
))
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Failed to register OAuth client"):
|
||||
await register_oauth_client(
|
||||
endpoints,
|
||||
"https://example.com",
|
||||
"Test Client",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client_missing_client_id(self):
|
||||
"""Test OAuth client registration when response lacks client_id."""
|
||||
endpoints = OAuthEndpoints(
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
registration_endpoint="https://example.com/register",
|
||||
token_endpoint="https://example.com/token",
|
||||
redirect_uri="https://myapp.com/callback",
|
||||
)
|
||||
|
||||
client_info = {} # Missing client_id
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value=client_info)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Failed to register OAuth client"):
|
||||
await register_oauth_client(
|
||||
endpoints,
|
||||
"https://example.com",
|
||||
"Test Client",
|
||||
)
|
||||
|
||||
|
||||
class TestIssueChallenge:
|
||||
"""Tests for issue_challenge function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_issue_challenge_success(self, db_session):
|
||||
"""Test successful OAuth challenge issuance."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
endpoints = OAuthEndpoints(
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
registration_endpoint="https://example.com/register",
|
||||
token_endpoint="https://example.com/token",
|
||||
redirect_uri="https://myapp.com/callback",
|
||||
)
|
||||
|
||||
with patch("memory.common.oauth.generate_pkce_pair", return_value=("verifier123", "challenge123")):
|
||||
auth_url = await issue_challenge(mcp_server, endpoints)
|
||||
|
||||
# Verify the auth URL contains expected parameters
|
||||
assert "https://example.com/auth?" in auth_url
|
||||
assert "client_id=test-client-123" in auth_url
|
||||
# redirect_uri will be URL encoded
|
||||
assert "redirect_uri=" in auth_url
|
||||
assert "myapp.com" in auth_url
|
||||
assert "callback" in auth_url
|
||||
assert "response_type=code" in auth_url
|
||||
assert "code_challenge=challenge123" in auth_url
|
||||
assert "code_challenge_method=S256" in auth_url
|
||||
assert "state=" in auth_url
|
||||
|
||||
# Verify state and code_verifier were stored
|
||||
assert mcp_server.state is not None
|
||||
assert mcp_server.code_verifier == "verifier123"
|
||||
|
||||
|
||||
class TestCompleteOauthFlow:
|
||||
"""Tests for complete_oauth_flow function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_success(self, db_session):
|
||||
"""Test successful OAuth flow completion."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
state="test-state",
|
||||
code_verifier="test-verifier",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
token_response = {
|
||||
"access_token": "access-token-123",
|
||||
"refresh_token": "refresh-token-123",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
mock_token_response = Mock()
|
||||
mock_token_response.status = 200
|
||||
mock_token_response.json = AsyncMock(return_value=token_response)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_token_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
|
||||
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
|
||||
):
|
||||
status, message = await complete_oauth_flow(
|
||||
mcp_server,
|
||||
"auth-code-123",
|
||||
"test-state",
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert "successful" in message
|
||||
|
||||
# Verify tokens were stored
|
||||
assert mcp_server.access_token == "access-token-123"
|
||||
assert mcp_server.refresh_token == "refresh-token-123"
|
||||
assert mcp_server.token_expires_at is not None
|
||||
|
||||
# Verify temporary state was cleared
|
||||
assert mcp_server.state is None
|
||||
assert mcp_server.code_verifier is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_invalid_state(self):
|
||||
"""Test OAuth flow completion with invalid state."""
|
||||
status, message = await complete_oauth_flow(
|
||||
None,
|
||||
"auth-code-123",
|
||||
"invalid-state",
|
||||
)
|
||||
|
||||
assert status == 400
|
||||
assert "Invalid or expired" in message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_token_error(self, db_session):
|
||||
"""Test OAuth flow completion when token exchange fails."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
state="test-state",
|
||||
code_verifier="test-verifier",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
mock_token_response = Mock()
|
||||
mock_token_response.status = 400
|
||||
mock_token_response.text = AsyncMock(return_value="Invalid grant")
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_token_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
|
||||
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
|
||||
):
|
||||
status, message = await complete_oauth_flow(
|
||||
mcp_server,
|
||||
"invalid-code",
|
||||
"test-state",
|
||||
)
|
||||
|
||||
assert status == 500
|
||||
assert "Token exchange failed" in message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_missing_access_token(self, db_session):
|
||||
"""Test OAuth flow completion when access token is missing from response."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
state="test-state",
|
||||
code_verifier="test-verifier",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
metadata = {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"registration_endpoint": "https://example.com/register",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
}
|
||||
|
||||
token_response = {} # Missing access_token
|
||||
|
||||
mock_token_response = Mock()
|
||||
mock_token_response.status = 200
|
||||
mock_token_response.json = AsyncMock(return_value=token_response)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_token_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("memory.common.oauth.discover_oauth_metadata", return_value=metadata),
|
||||
patch("aiohttp.ClientSession", return_value=mock_session_ctx),
|
||||
):
|
||||
status, message = await complete_oauth_flow(
|
||||
mcp_server,
|
||||
"auth-code-123",
|
||||
"test-state",
|
||||
)
|
||||
|
||||
assert status == 500
|
||||
assert "did not include access_token" in message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_oauth_flow_get_endpoints_error(self, db_session):
|
||||
"""Test OAuth flow completion when getting endpoints fails."""
|
||||
mcp_server = MCPServer(
|
||||
name="Test Server",
|
||||
mcp_server_url="https://example.com",
|
||||
client_id="test-client-123",
|
||||
state="test-state",
|
||||
code_verifier="test-verifier",
|
||||
)
|
||||
db_session.add(mcp_server)
|
||||
db_session.commit()
|
||||
|
||||
with patch("memory.common.oauth.discover_oauth_metadata", return_value=None):
|
||||
status, message = await complete_oauth_flow(
|
||||
mcp_server,
|
||||
"auth-code-123",
|
||||
"test-state",
|
||||
)
|
||||
|
||||
assert status == 500
|
||||
assert "Failed to get OAuth endpoints" in message
|
||||
@ -1,253 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from memory.discord import api
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_app_bots():
|
||||
existing = getattr(api.app, "bots", None)
|
||||
api.app.bots = {}
|
||||
yield
|
||||
if existing is None:
|
||||
delattr(api.app, "bots")
|
||||
else:
|
||||
api.app.bots = existing
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def active_bot():
|
||||
collector = SimpleNamespace(
|
||||
send_dm=AsyncMock(return_value=True),
|
||||
trigger_typing_dm=AsyncMock(return_value=True),
|
||||
send_to_channel=AsyncMock(return_value=True),
|
||||
trigger_typing_channel=AsyncMock(return_value=True),
|
||||
add_reaction=AsyncMock(return_value=True),
|
||||
refresh_metadata=AsyncMock(return_value={"refreshed": True}),
|
||||
is_closed=Mock(return_value=False),
|
||||
user="CollectorUser#1234",
|
||||
guilds=[101, 202],
|
||||
)
|
||||
bot = SimpleNamespace(
|
||||
collector=collector,
|
||||
collector_task=None,
|
||||
bot_id=1,
|
||||
bot_token="token-123",
|
||||
bot_name="Test Bot",
|
||||
)
|
||||
api.app.bots[bot.bot_id] = bot
|
||||
return bot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_success(active_bot):
|
||||
request = api.SendDMRequest(bot_id=active_bot.bot_id, user="user123", message="Hello")
|
||||
|
||||
response = await api.send_dm_endpoint(request)
|
||||
|
||||
assert response == {
|
||||
"success": True,
|
||||
"message": "DM sent to user123",
|
||||
"user": "user123",
|
||||
}
|
||||
active_bot.collector.send_dm.assert_awaited_once_with("user123", "Hello")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint,payload",
|
||||
[
|
||||
(
|
||||
api.send_dm_endpoint,
|
||||
api.SendDMRequest(bot_id=99, user="ghost", message="hi"),
|
||||
),
|
||||
(
|
||||
api.trigger_dm_typing,
|
||||
api.TypingDMRequest(bot_id=99, user="ghost"),
|
||||
),
|
||||
(
|
||||
api.send_channel_endpoint,
|
||||
api.SendChannelRequest(bot_id=99, channel="general", message="hello"),
|
||||
),
|
||||
(
|
||||
api.trigger_channel_typing,
|
||||
api.TypingChannelRequest(bot_id=99, channel="general"),
|
||||
),
|
||||
(
|
||||
api.add_reaction_endpoint,
|
||||
api.AddReactionRequest(
|
||||
bot_id=99,
|
||||
channel="general",
|
||||
message_id=42,
|
||||
emoji=":thumbsup:",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_endpoint_returns_404_when_bot_missing(endpoint, payload):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await endpoint(payload)
|
||||
|
||||
assert exc.value.status_code == 404
|
||||
assert exc.value.detail == "Bot not found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint,request_cls,request_kwargs,attr_name,detail_template",
|
||||
[
|
||||
(
|
||||
api.send_dm_endpoint,
|
||||
api.SendDMRequest,
|
||||
{"bot_id": 1, "user": "user123", "message": "Hi"},
|
||||
"send_dm",
|
||||
"Failed to send DM to {user}",
|
||||
),
|
||||
(
|
||||
api.trigger_dm_typing,
|
||||
api.TypingDMRequest,
|
||||
{"bot_id": 1, "user": "user123"},
|
||||
"trigger_typing_dm",
|
||||
"Failed to trigger typing for user123",
|
||||
),
|
||||
(
|
||||
api.send_channel_endpoint,
|
||||
api.SendChannelRequest,
|
||||
{"bot_id": 1, "channel": "general", "message": "Hello"},
|
||||
"send_to_channel",
|
||||
"Failed to send message to channel general",
|
||||
),
|
||||
(
|
||||
api.trigger_channel_typing,
|
||||
api.TypingChannelRequest,
|
||||
{"bot_id": 1, "channel": "general"},
|
||||
"trigger_typing_channel",
|
||||
"Failed to trigger typing for channel general",
|
||||
),
|
||||
(
|
||||
api.add_reaction_endpoint,
|
||||
api.AddReactionRequest,
|
||||
{"bot_id": 1, "channel": "general", "message_id": 55, "emoji": ":fire:"},
|
||||
"add_reaction",
|
||||
"Failed to add reaction to message 55",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_endpoint_returns_400_on_collector_failure(
|
||||
active_bot, endpoint, request_cls, request_kwargs, attr_name, detail_template
|
||||
):
|
||||
request = request_cls(**request_kwargs)
|
||||
getattr(active_bot.collector, attr_name).return_value = False
|
||||
expected_detail = detail_template.format(**request_kwargs)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await endpoint(request)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert exc.value.detail == expected_detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint,request_cls,request_kwargs,attr_name",
|
||||
[
|
||||
(
|
||||
api.send_dm_endpoint,
|
||||
api.SendDMRequest,
|
||||
{"bot_id": 1, "user": "user123", "message": "Hi"},
|
||||
"send_dm",
|
||||
),
|
||||
(
|
||||
api.trigger_dm_typing,
|
||||
api.TypingDMRequest,
|
||||
{"bot_id": 1, "user": "user123"},
|
||||
"trigger_typing_dm",
|
||||
),
|
||||
(
|
||||
api.send_channel_endpoint,
|
||||
api.SendChannelRequest,
|
||||
{"bot_id": 1, "channel": "general", "message": "Hello"},
|
||||
"send_to_channel",
|
||||
),
|
||||
(
|
||||
api.trigger_channel_typing,
|
||||
api.TypingChannelRequest,
|
||||
{"bot_id": 1, "channel": "general"},
|
||||
"trigger_typing_channel",
|
||||
),
|
||||
(
|
||||
api.add_reaction_endpoint,
|
||||
api.AddReactionRequest,
|
||||
{"bot_id": 1, "channel": "general", "message_id": 55, "emoji": ":fire:"},
|
||||
"add_reaction",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_endpoint_returns_500_on_collector_exception(
|
||||
active_bot, endpoint, request_cls, request_kwargs, attr_name
|
||||
):
|
||||
request = request_cls(**request_kwargs)
|
||||
getattr(active_bot.collector, attr_name).side_effect = RuntimeError("boom")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await endpoint(request)
|
||||
|
||||
assert exc.value.status_code == 500
|
||||
assert "boom" in exc.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(active_bot):
|
||||
response = await api.health_check()
|
||||
|
||||
assert response["Test Bot"] == {
|
||||
"status": "healthy",
|
||||
"connected": True,
|
||||
"user": "CollectorUser#1234",
|
||||
"guilds": 2,
|
||||
}
|
||||
active_bot.collector.is_closed.assert_called_once_with()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_without_bots():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await api.health_check()
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
assert exc.value.detail == "Discord collector not running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_metadata_success(active_bot):
|
||||
active_bot.collector.refresh_metadata.return_value = {"channels": 3}
|
||||
|
||||
response = await api.refresh_metadata()
|
||||
|
||||
assert response["success"] is True
|
||||
assert response["message"] == "Metadata refreshed successfully for 1 bots"
|
||||
assert response["results"]["Test Bot"] == {"channels": 3}
|
||||
active_bot.collector.refresh_metadata.assert_awaited_once_with()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_metadata_without_bots():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await api.refresh_metadata()
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
assert exc.value.detail == "Discord collector not running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_metadata_failure(active_bot):
|
||||
active_bot.collector.refresh_metadata.side_effect = RuntimeError("sync failed")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await api.refresh_metadata()
|
||||
|
||||
assert exc.value.status_code == 500
|
||||
assert "sync failed" in exc.value.detail
|
||||
@ -79,7 +79,6 @@ def mock_message(mock_text_channel, mock_user):
|
||||
message.content = "Test message"
|
||||
message.created_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
message.reference = None
|
||||
message.attachments = []
|
||||
return message
|
||||
|
||||
|
||||
@ -352,7 +351,7 @@ def test_determine_message_metadata_thread():
|
||||
# Tests for should_track_message
|
||||
def test_should_track_message_server_disabled(db_session):
|
||||
"""Test when server has tracking disabled"""
|
||||
server = DiscordServer(id=1, name="Server", ignore_messages=True)
|
||||
server = DiscordServer(id=1, name="Server", track_messages=False)
|
||||
channel = DiscordChannel(id=2, name="Channel", channel_type="text")
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
@ -363,9 +362,9 @@ def test_should_track_message_server_disabled(db_session):
|
||||
|
||||
def test_should_track_message_channel_disabled(db_session):
|
||||
"""Test when channel has tracking disabled"""
|
||||
server = DiscordServer(id=1, name="Server", ignore_messages=False)
|
||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||
channel = DiscordChannel(
|
||||
id=2, name="Channel", channel_type="text", ignore_messages=True
|
||||
id=2, name="Channel", channel_type="text", track_messages=False
|
||||
)
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
@ -376,8 +375,8 @@ def test_should_track_message_channel_disabled(db_session):
|
||||
|
||||
def test_should_track_message_dm_allowed(db_session):
|
||||
"""Test DM tracking when user allows it"""
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", ignore_messages=False)
|
||||
user = DiscordUser(id=3, username="User", ignore_messages=False)
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||
user = DiscordUser(id=3, username="User", track_messages=True)
|
||||
|
||||
result = should_track_message(None, channel, user)
|
||||
|
||||
@ -386,8 +385,8 @@ def test_should_track_message_dm_allowed(db_session):
|
||||
|
||||
def test_should_track_message_dm_not_allowed(db_session):
|
||||
"""Test DM tracking when user doesn't allow it"""
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", ignore_messages=False)
|
||||
user = DiscordUser(id=3, username="User", ignore_messages=True)
|
||||
channel = DiscordChannel(id=2, name="DM", channel_type="dm", track_messages=True)
|
||||
user = DiscordUser(id=3, username="User", track_messages=False)
|
||||
|
||||
result = should_track_message(None, channel, user)
|
||||
|
||||
@ -396,9 +395,9 @@ def test_should_track_message_dm_not_allowed(db_session):
|
||||
|
||||
def test_should_track_message_default_true(db_session):
|
||||
"""Test default tracking behavior"""
|
||||
server = DiscordServer(id=1, name="Server", ignore_messages=False)
|
||||
server = DiscordServer(id=1, name="Server", track_messages=True)
|
||||
channel = DiscordChannel(
|
||||
id=2, name="Channel", channel_type="text", ignore_messages=False
|
||||
id=2, name="Channel", channel_type="text", track_messages=True
|
||||
)
|
||||
user = DiscordUser(id=3, username="User")
|
||||
|
||||
@ -466,7 +465,6 @@ def test_sync_guild_metadata(mock_make_session, mock_guild):
|
||||
voice_channel.guild = mock_guild
|
||||
|
||||
mock_guild.channels = [text_channel, voice_channel]
|
||||
mock_guild.threads = []
|
||||
|
||||
sync_guild_metadata(mock_guild)
|
||||
|
||||
@ -491,25 +489,16 @@ def test_message_collector_init():
|
||||
async def test_on_ready():
|
||||
"""Test on_ready event handler"""
|
||||
collector = MessageCollector()
|
||||
collector.user = Mock()
|
||||
collector.user.name = "TestBot"
|
||||
collector.guilds = [Mock(), Mock()]
|
||||
collector.sync_servers_and_channels = AsyncMock()
|
||||
collector.tree.sync = AsyncMock()
|
||||
|
||||
# Mock the properties
|
||||
mock_user = Mock()
|
||||
mock_user.name = "TestBot"
|
||||
with patch.object(
|
||||
type(collector), "user", new_callable=lambda: property(lambda self: mock_user)
|
||||
):
|
||||
with patch.object(
|
||||
type(collector),
|
||||
"guilds",
|
||||
new_callable=lambda: property(lambda self: [Mock(), Mock()]),
|
||||
):
|
||||
collector.sync_servers_and_channels = AsyncMock()
|
||||
collector.tree.sync = AsyncMock()
|
||||
await collector.on_ready()
|
||||
|
||||
await collector.on_ready()
|
||||
|
||||
collector.sync_servers_and_channels.assert_called_once()
|
||||
collector.tree.sync.assert_awaited()
|
||||
collector.sync_servers_and_channels.assert_called_once()
|
||||
collector.tree.sync.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -604,18 +593,14 @@ async def test_sync_servers_and_channels():
|
||||
guild2 = Mock()
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild1, guild2]
|
||||
|
||||
with patch.object(
|
||||
type(collector),
|
||||
"guilds",
|
||||
new_callable=lambda: property(lambda self: [guild1, guild2]),
|
||||
):
|
||||
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
|
||||
await collector.sync_servers_and_channels()
|
||||
with patch("memory.discord.collector.sync_guild_metadata") as mock_sync:
|
||||
await collector.sync_servers_and_channels()
|
||||
|
||||
assert mock_sync.call_count == 2
|
||||
mock_sync.assert_any_call(guild1)
|
||||
mock_sync.assert_any_call(guild2)
|
||||
assert mock_sync.call_count == 2
|
||||
mock_sync.assert_any_call(guild1)
|
||||
mock_sync.assert_any_call(guild2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -632,26 +617,17 @@ async def test_refresh_metadata(mock_make_session):
|
||||
guild.name = "Test"
|
||||
guild.channels = []
|
||||
guild.members = []
|
||||
guild.threads = []
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
collector.intents = Mock()
|
||||
collector.intents.members = False
|
||||
|
||||
mock_intents = Mock()
|
||||
mock_intents.members = False
|
||||
result = await collector.refresh_metadata()
|
||||
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||
):
|
||||
with patch.object(
|
||||
type(collector),
|
||||
"intents",
|
||||
new_callable=lambda: property(lambda self: mock_intents),
|
||||
):
|
||||
result = await collector.refresh_metadata()
|
||||
|
||||
assert result["servers_updated"] == 1
|
||||
assert result["channels_updated"] == 0
|
||||
assert result["users_updated"] == 0
|
||||
assert result["servers_updated"] == 1
|
||||
assert result["channels_updated"] == 0
|
||||
assert result["users_updated"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -661,7 +637,7 @@ async def test_get_user_by_id():
|
||||
user.id = 123
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.get_user = AsyncMock(return_value=user)
|
||||
collector.get_user = Mock(return_value=user)
|
||||
|
||||
result = await collector.get_user(123)
|
||||
|
||||
@ -680,32 +656,22 @@ async def test_get_user_by_username():
|
||||
guild.members = [member]
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||
):
|
||||
result = await collector.get_user("testuser")
|
||||
result = await collector.get_user("testuser")
|
||||
|
||||
assert result == member
|
||||
assert result == member
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_not_found():
|
||||
"""Test getting non-existent user"""
|
||||
collector = MessageCollector()
|
||||
collector.guilds = []
|
||||
|
||||
# Create proper mock response for discord.NotFound
|
||||
mock_response = Mock()
|
||||
mock_response.status = 404
|
||||
mock_response.text = ""
|
||||
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [])
|
||||
):
|
||||
with patch.object(collector, "get_user", return_value=None):
|
||||
with patch.object(
|
||||
collector,
|
||||
"fetch_user",
|
||||
AsyncMock(side_effect=discord.NotFound(mock_response, "User not found")),
|
||||
collector, "fetch_user", side_effect=discord.NotFound(Mock(), Mock())
|
||||
):
|
||||
result = await collector.get_user(999)
|
||||
assert result is None
|
||||
@ -721,13 +687,11 @@ async def test_get_channel_by_name():
|
||||
guild.channels = [channel]
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||
):
|
||||
result = await collector.get_channel_by_name("general")
|
||||
result = await collector.get_channel_by_name("general")
|
||||
|
||||
assert result == channel
|
||||
assert result == channel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -737,13 +701,11 @@ async def test_get_channel_by_name_not_found():
|
||||
guild.channels = []
|
||||
|
||||
collector = MessageCollector()
|
||||
collector.guilds = [guild]
|
||||
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [guild])
|
||||
):
|
||||
result = await collector.get_channel_by_name("nonexistent")
|
||||
result = await collector.get_channel_by_name("nonexistent")
|
||||
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -768,13 +730,11 @@ async def test_create_channel_no_guild():
|
||||
"""Test creating channel when no guild available"""
|
||||
collector = MessageCollector()
|
||||
collector.get_guild = Mock(return_value=None)
|
||||
collector.guilds = []
|
||||
|
||||
with patch.object(
|
||||
type(collector), "guilds", new_callable=lambda: property(lambda self: [])
|
||||
):
|
||||
result = await collector.create_channel("new-channel")
|
||||
result = await collector.create_channel("new-channel")
|
||||
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -856,19 +816,27 @@ async def test_send_to_channel_not_found():
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="run_collector function doesn't exist or uses different settings"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
async def test_run_collector():
|
||||
"""Test running the collector"""
|
||||
pass
|
||||
from memory.discord.collector import run_collector
|
||||
|
||||
with patch("memory.discord.collector.MessageCollector") as mock_collector_class:
|
||||
mock_collector = Mock()
|
||||
mock_collector.start = AsyncMock()
|
||||
mock_collector_class.return_value = mock_collector
|
||||
|
||||
await run_collector()
|
||||
|
||||
mock_collector.start.assert_called_once_with("test_token")
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="run_collector function doesn't exist or uses different settings"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
||||
async def test_run_collector_no_token():
|
||||
"""Test running collector without token"""
|
||||
pass
|
||||
from memory.discord.collector import run_collector
|
||||
|
||||
# Should return early without raising
|
||||
await run_collector()
|
||||
|
||||
@ -1,23 +1,17 @@
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import discord
|
||||
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
from memory.discord.commands import (
|
||||
CommandContext,
|
||||
CommandError,
|
||||
CommandResponse,
|
||||
run_command,
|
||||
handle_prompt,
|
||||
handle_chattiness,
|
||||
handle_ignore,
|
||||
handle_summary,
|
||||
respond,
|
||||
with_object_context,
|
||||
handle_mcp_servers,
|
||||
)
|
||||
|
||||
|
||||
@ -72,308 +66,107 @@ def interaction(guild, text_channel, discord_user) -> DummyInteraction:
|
||||
return DummyInteraction(guild=guild, channel=text_channel, user=discord_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_prompt_server(db_session, guild, interaction):
|
||||
def test_handle_command_prompt_server(db_session, guild, interaction):
|
||||
server = DiscordServer(id=guild.id, name="Test Guild", system_prompt="Be helpful")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Test Guild**",
|
||||
handler=handle_prompt,
|
||||
)
|
||||
|
||||
response = await handle_prompt(context)
|
||||
|
||||
assert isinstance(response, CommandResponse)
|
||||
assert "Be helpful" in response.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_prompt_channel_creates_channel(
|
||||
db_session, interaction, text_channel, guild
|
||||
):
|
||||
# Create the server first to satisfy FK constraint
|
||||
server = DiscordServer(id=guild.id, name="Test Guild")
|
||||
db_session.add(server)
|
||||
|
||||
channel_model = DiscordChannel(
|
||||
id=text_channel.id,
|
||||
name=text_channel.name,
|
||||
channel_type="text",
|
||||
server_id=guild.id,
|
||||
)
|
||||
db_session.add(channel_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
def test_handle_command_prompt_channel_creates_channel(db_session, interaction, text_channel):
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="channel",
|
||||
target=channel_model,
|
||||
display_name=f"channel **#{text_channel.name}**",
|
||||
handler=handle_prompt,
|
||||
)
|
||||
|
||||
response = await handle_prompt(context)
|
||||
|
||||
assert "No prompt" in response.content
|
||||
channel = db_session.get(DiscordChannel, text_channel.id)
|
||||
assert channel is not None
|
||||
assert channel.name == text_channel.name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_chattiness_show(db_session, interaction, guild):
|
||||
def test_handle_command_chattiness_show(db_session, interaction, guild):
|
||||
server = DiscordServer(id=guild.id, name="Guild", chattiness_threshold=73)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
handler=handle_chattiness,
|
||||
)
|
||||
|
||||
response = await handle_chattiness(context, value=None)
|
||||
|
||||
assert str(server.chattiness_threshold) in response.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_chattiness_update(db_session, interaction):
|
||||
user_model = DiscordUser(
|
||||
id=interaction.user.id, username="command-user", chattiness_threshold=15
|
||||
)
|
||||
def test_handle_command_chattiness_update(db_session, interaction):
|
||||
user_model = DiscordUser(id=interaction.user.id, username="command-user", chattiness_threshold=15)
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=user_model,
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
target=user_model,
|
||||
display_name="user **command-user**",
|
||||
handler=handle_chattiness,
|
||||
value=80,
|
||||
)
|
||||
|
||||
response = await handle_chattiness(context, value=80)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert user_model.chattiness_threshold == 80
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_chattiness_invalid_value(db_session, interaction):
|
||||
user_model = DiscordUser(id=interaction.user.id, username="command-user")
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=user_model,
|
||||
scope="user",
|
||||
target=user_model,
|
||||
display_name="user **command-user**",
|
||||
)
|
||||
|
||||
def test_handle_command_chattiness_invalid_value(db_session, interaction):
|
||||
with pytest.raises(CommandError):
|
||||
await handle_chattiness(context, value=150)
|
||||
run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
handler=handle_chattiness,
|
||||
value=150,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_ignore_toggle(db_session, interaction, guild):
|
||||
# Create the server first to satisfy FK constraint
|
||||
server = DiscordServer(id=guild.id, name="Test Guild")
|
||||
db_session.add(server)
|
||||
|
||||
channel = DiscordChannel(
|
||||
id=interaction.channel.id,
|
||||
name="general",
|
||||
channel_type="text",
|
||||
server_id=guild.id,
|
||||
)
|
||||
def test_handle_command_ignore_toggle(db_session, interaction, guild):
|
||||
channel = DiscordChannel(id=interaction.channel.id, name="general", channel_type="text", server_id=guild.id)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="channel",
|
||||
target=channel,
|
||||
display_name="channel **#general**",
|
||||
handler=handle_ignore,
|
||||
ignore_enabled=True,
|
||||
)
|
||||
|
||||
response = await handle_ignore(context, ignore_enabled=True)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
assert "no longer" not in response.content
|
||||
assert channel.ignore_messages is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_summary_missing(db_session, interaction):
|
||||
user_model = DiscordUser(id=interaction.user.id, username="command-user")
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=user_model,
|
||||
def test_handle_command_summary_missing(db_session, interaction):
|
||||
response = run_command(
|
||||
db_session,
|
||||
interaction,
|
||||
scope="user",
|
||||
target=user_model,
|
||||
display_name="user **command-user**",
|
||||
handler=handle_summary,
|
||||
)
|
||||
|
||||
response = await handle_summary(context)
|
||||
|
||||
assert "No summary" in response.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respond_sends_message_without_file():
|
||||
interaction = MagicMock(spec=discord.Interaction)
|
||||
interaction.response.send_message = AsyncMock()
|
||||
|
||||
await respond(interaction, "hello world", ephemeral=False)
|
||||
|
||||
interaction.response.send_message.assert_awaited_once_with(
|
||||
"hello world", ephemeral=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respond_sends_file_when_content_too_large():
|
||||
interaction = MagicMock(spec=discord.Interaction)
|
||||
interaction.response.send_message = AsyncMock()
|
||||
|
||||
oversized = "x" * 2000
|
||||
with patch("memory.discord.commands.discord.File") as mock_file:
|
||||
file_instance = MagicMock()
|
||||
mock_file.return_value = file_instance
|
||||
|
||||
await respond(interaction, oversized)
|
||||
|
||||
interaction.response.send_message.assert_awaited_once_with(
|
||||
"Response too large, sending as file:",
|
||||
file=file_instance,
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.discord.commands._ensure_channel")
|
||||
@patch("memory.discord.commands.ensure_server")
|
||||
@patch("memory.discord.commands.ensure_user")
|
||||
@patch("memory.discord.commands.make_session")
|
||||
def test_with_object_context_uses_ensured_objects(
|
||||
mock_make_session,
|
||||
mock_ensure_user,
|
||||
mock_ensure_server,
|
||||
mock_ensure_channel,
|
||||
interaction,
|
||||
guild,
|
||||
text_channel,
|
||||
discord_user,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
def session_cm():
|
||||
yield mock_session
|
||||
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
bot_model = MagicMock(name="bot_model")
|
||||
user_model = MagicMock(name="user_model")
|
||||
server_model = MagicMock(name="server_model")
|
||||
channel_model = MagicMock(name="channel_model")
|
||||
|
||||
mock_ensure_user.side_effect = [bot_model, user_model]
|
||||
mock_ensure_server.return_value = server_model
|
||||
mock_ensure_channel.return_value = channel_model
|
||||
|
||||
handler_objects = {}
|
||||
|
||||
def handler(objects):
|
||||
handler_objects["objects"] = objects
|
||||
return "done"
|
||||
|
||||
bot_client = SimpleNamespace(user=MagicMock())
|
||||
override_user = MagicMock(spec=discord.User)
|
||||
|
||||
result = with_object_context(bot_client, interaction, handler, override_user)
|
||||
|
||||
assert result == "done"
|
||||
objects = handler_objects["objects"]
|
||||
assert objects.bot is bot_model
|
||||
assert objects.server is server_model
|
||||
assert objects.channel is channel_model
|
||||
assert objects.user is user_model
|
||||
|
||||
mock_ensure_user.assert_any_call(mock_session, bot_client.user)
|
||||
mock_ensure_user.assert_any_call(mock_session, override_user)
|
||||
mock_ensure_server.assert_called_once_with(mock_session, guild)
|
||||
mock_ensure_channel.assert_called_once_with(
|
||||
mock_session, text_channel, guild.id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.commands.run_mcp_server_command", new_callable=AsyncMock)
|
||||
async def test_handle_mcp_servers_returns_response(mock_run_mcp, interaction):
|
||||
mock_run_mcp.return_value = "Listed servers"
|
||||
server_model = DiscordServer(id=interaction.guild.id, name="Guild")
|
||||
|
||||
context = CommandContext(
|
||||
session=MagicMock(),
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server_model,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
interaction.client = SimpleNamespace(user=MagicMock(spec=discord.User))
|
||||
|
||||
response = await handle_mcp_servers(
|
||||
context, action="list", url=None
|
||||
)
|
||||
|
||||
assert response.content == "Listed servers"
|
||||
mock_run_mcp.assert_awaited_once_with(
|
||||
interaction.client.user, "list", None, "DiscordServer", server_model.id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.discord.commands.run_mcp_server_command", new_callable=AsyncMock)
|
||||
async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
|
||||
mock_run_mcp.side_effect = RuntimeError("boom")
|
||||
server_model = DiscordServer(id=interaction.guild.id, name="Guild")
|
||||
|
||||
context = CommandContext(
|
||||
session=MagicMock(),
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server_model,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
interaction.client = SimpleNamespace(user=MagicMock(spec=discord.User))
|
||||
|
||||
with pytest.raises(CommandError) as exc:
|
||||
await handle_mcp_servers(context, action="list", url=None)
|
||||
|
||||
assert "Error: boom" in str(exc.value)
|
||||
|
||||
@ -1,590 +0,0 @@
|
||||
"""Tests for Discord MCP server management."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
import pytest
|
||||
|
||||
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||
from memory.discord.mcp import (
|
||||
call_mcp_server,
|
||||
find_mcp_server,
|
||||
handle_mcp_add,
|
||||
handle_mcp_connect,
|
||||
handle_mcp_delete,
|
||||
handle_mcp_list,
|
||||
handle_mcp_tools,
|
||||
run_mcp_server_command,
|
||||
)
|
||||
|
||||
|
||||
# Helper class for async iteration
|
||||
class AsyncIterator:
|
||||
"""Helper to create an async iterator for mocking aiohttp response content."""
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
self.index = 0
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.index >= len(self.items):
|
||||
raise StopAsyncIteration
|
||||
item = self.items[self.index]
|
||||
self.index += 1
|
||||
return item
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server(db_session) -> MCPServer:
|
||||
"""Create a test MCP server."""
|
||||
server = MCPServer(
|
||||
name="Test MCP Server",
|
||||
mcp_server_url="https://mcp.example.com",
|
||||
client_id="test_client_id",
|
||||
access_token="test_access_token",
|
||||
available_tools=["tool1", "tool2"],
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_assignment(db_session, mcp_server: MCPServer) -> MCPServerAssignment:
|
||||
"""Create a test MCP server assignment."""
|
||||
assignment = MCPServerAssignment(
|
||||
mcp_server_id=mcp_server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=123456,
|
||||
)
|
||||
db_session.add(assignment)
|
||||
db_session.commit()
|
||||
return assignment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_user() -> discord.User:
|
||||
"""Create a mock Discord bot user."""
|
||||
user = Mock(spec=discord.User)
|
||||
user.name = "TestBot"
|
||||
user.id = 999
|
||||
return user
|
||||
|
||||
|
||||
def test_find_mcp_server_exists(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test finding an existing MCP server."""
|
||||
result = find_mcp_server(
|
||||
db_session,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=123456,
|
||||
url="https://mcp.example.com",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == mcp_server.id
|
||||
assert result.mcp_server_url == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_find_mcp_server_not_found(db_session):
|
||||
"""Test finding a non-existent MCP server."""
|
||||
result = find_mcp_server(
|
||||
db_session,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=999999,
|
||||
url="https://nonexistent.com",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_find_mcp_server_wrong_entity(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test finding MCP server with wrong entity type."""
|
||||
result = find_mcp_server(
|
||||
db_session,
|
||||
entity_type="DiscordChannel", # Wrong entity type
|
||||
entity_id=123456,
|
||||
url="https://mcp.example.com",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_server_success():
|
||||
"""Test calling MCP server successfully."""
|
||||
mock_response_data = [
|
||||
b'data: {"result": {"tools": [{"name": "test"}]}}\n',
|
||||
b'data: {"status": "ok"}\n',
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.content = AsyncIterator(mock_response_data)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
results = []
|
||||
async for data in call_mcp_server(
|
||||
"https://mcp.example.com", "test_token", "tools/list", {}
|
||||
):
|
||||
results.append(data)
|
||||
|
||||
assert len(results) == 2
|
||||
assert "result" in results[0]
|
||||
assert results[0]["result"]["tools"][0]["name"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_server_error():
|
||||
"""Test calling MCP server with error response."""
|
||||
mock_response = Mock()
|
||||
mock_response.status = 500
|
||||
mock_response.text = AsyncMock(return_value="Internal Server Error")
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Failed to call MCP server"):
|
||||
async for _ in call_mcp_server(
|
||||
"https://mcp.example.com", "test_token", "tools/list"
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_server_invalid_json():
|
||||
"""Test calling MCP server with invalid JSON."""
|
||||
mock_response_data = [
|
||||
b"data: invalid json\n",
|
||||
b'data: {"valid": "json"}\n',
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.content = AsyncIterator(mock_response_data)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
results = []
|
||||
async for data in call_mcp_server(
|
||||
"https://mcp.example.com", "test_token", "tools/list"
|
||||
):
|
||||
results.append(data)
|
||||
|
||||
# Should skip invalid JSON and only return valid one
|
||||
assert len(results) == 1
|
||||
assert results[0] == {"valid": "json"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_list_empty(db_session):
|
||||
"""Test listing MCP servers when none exist."""
|
||||
result = await handle_mcp_list("DiscordUser", 123456)
|
||||
|
||||
assert "You don't have any MCP servers configured yet" in result
|
||||
assert "/memory_mcp_servers add" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_list_with_servers(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test listing MCP servers with existing servers."""
|
||||
result = await handle_mcp_list("DiscordUser", 123456)
|
||||
|
||||
assert "Your MCP Servers" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "test_client_id" in result
|
||||
assert "🟢" in result # Server has access token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_list_disconnected_server(db_session):
|
||||
"""Test listing MCP servers with disconnected server."""
|
||||
server = MCPServer(
|
||||
name="Disconnected Server",
|
||||
mcp_server_url="https://disconnected.example.com",
|
||||
client_id="client_123",
|
||||
access_token=None, # No access token
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.flush()
|
||||
|
||||
assignment = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=123456,
|
||||
)
|
||||
db_session.add(assignment)
|
||||
db_session.commit()
|
||||
|
||||
result = await handle_mcp_list("DiscordUser", 123456)
|
||||
|
||||
assert "🔴" in result # Server has no access token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_add_new_server(db_session, mock_bot_user):
|
||||
"""Test adding a new MCP server."""
|
||||
with (
|
||||
patch("memory.discord.mcp.get_endpoints") as mock_get_endpoints,
|
||||
patch("memory.discord.mcp.register_oauth_client") as mock_register,
|
||||
patch("memory.discord.mcp.issue_challenge") as mock_challenge,
|
||||
):
|
||||
mock_endpoints = Mock()
|
||||
mock_get_endpoints.return_value = mock_endpoints
|
||||
mock_register.return_value = "new_client_id"
|
||||
mock_challenge.return_value = "https://auth.example.com/authorize"
|
||||
|
||||
result = await handle_mcp_add(
|
||||
"DiscordUser", 123456, mock_bot_user, "https://new.example.com"
|
||||
)
|
||||
|
||||
assert "Add MCP Server" in result
|
||||
assert "https://new.example.com" in result
|
||||
assert "https://auth.example.com/authorize" in result
|
||||
|
||||
# Verify server was created
|
||||
server = (
|
||||
db_session.query(MCPServer)
|
||||
.filter(MCPServer.mcp_server_url == "https://new.example.com")
|
||||
.first()
|
||||
)
|
||||
assert server is not None
|
||||
assert server.client_id == "new_client_id"
|
||||
|
||||
# Verify assignment was created
|
||||
assignment = (
|
||||
db_session.query(MCPServerAssignment)
|
||||
.filter(
|
||||
MCPServerAssignment.mcp_server_id == server.id,
|
||||
MCPServerAssignment.entity_type == "DiscordUser",
|
||||
MCPServerAssignment.entity_id == 123456,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
assert assignment is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_add_existing_server(
|
||||
db_session,
|
||||
mcp_server: MCPServer,
|
||||
mcp_assignment: MCPServerAssignment,
|
||||
mock_bot_user,
|
||||
):
|
||||
"""Test adding an MCP server that already exists."""
|
||||
result = await handle_mcp_add(
|
||||
"DiscordUser", 123456, mock_bot_user, "https://mcp.example.com"
|
||||
)
|
||||
|
||||
assert "MCP Server Already Exists" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "/memory_mcp_servers connect" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_add_no_bot_user(db_session):
|
||||
"""Test adding MCP server without bot user."""
|
||||
with pytest.raises(ValueError, match="Bot user is required"):
|
||||
await handle_mcp_add("DiscordUser", 123456, None, "https://example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_delete_existing(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test deleting an existing MCP server assignment."""
|
||||
# Store IDs before deletion
|
||||
assignment_id = mcp_assignment.id
|
||||
server_id = mcp_server.id
|
||||
|
||||
result = await handle_mcp_delete("DiscordUser", 123456, "https://mcp.example.com")
|
||||
|
||||
assert "Delete MCP Server" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "has been removed" in result
|
||||
|
||||
# Verify assignment was deleted
|
||||
assignment = (
|
||||
db_session.query(MCPServerAssignment)
|
||||
.filter(MCPServerAssignment.id == assignment_id)
|
||||
.first()
|
||||
)
|
||||
assert assignment is None
|
||||
|
||||
# Verify server was also deleted (no other assignments)
|
||||
server = db_session.query(MCPServer).filter(MCPServer.id == server_id).first()
|
||||
assert server is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_delete_not_found(db_session):
|
||||
"""Test deleting a non-existent MCP server."""
|
||||
result = await handle_mcp_delete("DiscordUser", 123456, "https://nonexistent.com")
|
||||
|
||||
assert "MCP Server Not Found" in result
|
||||
assert "https://nonexistent.com" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_delete_with_other_assignments(db_session):
|
||||
"""Test deleting MCP server with multiple assignments."""
|
||||
server = MCPServer(
|
||||
name="Shared Server",
|
||||
mcp_server_url="https://shared.example.com",
|
||||
client_id="shared_client",
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.flush()
|
||||
|
||||
assignment1 = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=111,
|
||||
)
|
||||
assignment2 = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=222,
|
||||
)
|
||||
db_session.add_all([assignment1, assignment2])
|
||||
db_session.commit()
|
||||
|
||||
# Delete one assignment
|
||||
result = await handle_mcp_delete("DiscordUser", 111, "https://shared.example.com")
|
||||
|
||||
assert "has been removed" in result
|
||||
|
||||
# Verify only one assignment was deleted
|
||||
remaining = (
|
||||
db_session.query(MCPServerAssignment)
|
||||
.filter(MCPServerAssignment.mcp_server_id == server.id)
|
||||
.count()
|
||||
)
|
||||
assert remaining == 1
|
||||
|
||||
# Verify server still exists
|
||||
server_check = db_session.query(MCPServer).filter(MCPServer.id == server.id).first()
|
||||
assert server_check is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_connect_existing(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test reconnecting to an existing MCP server."""
|
||||
with (
|
||||
patch("memory.discord.mcp.get_endpoints") as mock_get_endpoints,
|
||||
patch("memory.discord.mcp.issue_challenge") as mock_challenge,
|
||||
):
|
||||
mock_endpoints = Mock()
|
||||
mock_get_endpoints.return_value = mock_endpoints
|
||||
mock_challenge.return_value = "https://auth.example.com/authorize?state=new"
|
||||
|
||||
result = await handle_mcp_connect(
|
||||
"DiscordUser", 123456, "https://mcp.example.com"
|
||||
)
|
||||
|
||||
assert "Reconnect to MCP Server" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "https://auth.example.com/authorize?state=new" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_connect_not_found(db_session):
|
||||
"""Test reconnecting to a non-existent MCP server."""
|
||||
with pytest.raises(ValueError, match="MCP Server Not Found"):
|
||||
await handle_mcp_connect("DiscordUser", 123456, "https://nonexistent.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_success(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test listing tools from an MCP server."""
|
||||
mock_response_data = [
|
||||
b'data: {"result": {"tools": [{"name": "search", "description": "Search tool"}]}}\n',
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.content = AsyncIterator(mock_response_data)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await handle_mcp_tools(
|
||||
"DiscordUser", 123456, "https://mcp.example.com"
|
||||
)
|
||||
|
||||
assert "MCP Server Tools" in result
|
||||
assert "https://mcp.example.com" in result
|
||||
assert "search" in result
|
||||
assert "Search tool" in result
|
||||
assert "Found 1 tool(s)" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_no_tools(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test listing tools when server has no tools."""
|
||||
mock_response_data = [
|
||||
b'data: {"result": {"tools": []}}\n',
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.content = AsyncIterator(mock_response_data)
|
||||
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
result = await handle_mcp_tools(
|
||||
"DiscordUser", 123456, "https://mcp.example.com"
|
||||
)
|
||||
|
||||
assert "No tools available" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_server_not_found(db_session):
|
||||
"""Test listing tools for a non-existent server."""
|
||||
with pytest.raises(ValueError, match="MCP Server Not Found"):
|
||||
await handle_mcp_tools("DiscordUser", 123456, "https://nonexistent.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_not_authorized(db_session):
|
||||
"""Test listing tools when not authorized."""
|
||||
server = MCPServer(
|
||||
name="Unauthorized Server",
|
||||
mcp_server_url="https://unauthorized.example.com",
|
||||
client_id="client_123",
|
||||
access_token=None, # No access token
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.flush()
|
||||
|
||||
assignment = MCPServerAssignment(
|
||||
mcp_server_id=server.id,
|
||||
entity_type="DiscordUser",
|
||||
entity_id=123456,
|
||||
)
|
||||
db_session.add(assignment)
|
||||
db_session.commit()
|
||||
|
||||
with pytest.raises(ValueError, match="Not Authorized"):
|
||||
await handle_mcp_tools(
|
||||
"DiscordUser", 123456, "https://unauthorized.example.com"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_mcp_tools_connection_error(
|
||||
db_session, mcp_server: MCPServer, mcp_assignment: MCPServerAssignment
|
||||
):
|
||||
"""Test listing tools with connection error."""
|
||||
mock_post = AsyncMock()
|
||||
mock_post.__aenter__.side_effect = aiohttp.ClientError("Connection failed")
|
||||
mock_post.__aexit__.return_value = None
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_post)
|
||||
|
||||
mock_session_ctx = AsyncMock()
|
||||
mock_session_ctx.__aenter__.return_value = mock_session
|
||||
mock_session_ctx.__aexit__.return_value = None
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Connection failed"):
|
||||
await handle_mcp_tools("DiscordUser", 123456, "https://mcp.example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_mcp_server_command_list(db_session, mock_bot_user):
|
||||
"""Test run_mcp_server_command with list action."""
|
||||
result = await run_mcp_server_command(
|
||||
mock_bot_user, "list", None, "DiscordUser", 123456
|
||||
)
|
||||
|
||||
assert "Your MCP Servers" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_mcp_server_command_invalid_action(mock_bot_user):
|
||||
"""Test run_mcp_server_command with invalid action."""
|
||||
with pytest.raises(ValueError, match="Invalid action"):
|
||||
await run_mcp_server_command(
|
||||
mock_bot_user, "invalid", None, "DiscordUser", 123456
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_mcp_server_command_missing_url(mock_bot_user):
|
||||
"""Test run_mcp_server_command with missing URL for non-list action."""
|
||||
with pytest.raises(ValueError, match="URL is required"):
|
||||
await run_mcp_server_command(mock_bot_user, "add", None, "DiscordUser", 123456)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_mcp_server_command_no_bot_user():
|
||||
"""Test run_mcp_server_command without bot user."""
|
||||
with pytest.raises(ValueError, match="Bot user is required"):
|
||||
await run_mcp_server_command(None, "list", None, "DiscordUser", 123456)
|
||||
@ -1,8 +1,5 @@
|
||||
"""Tests for Discord message helper functions."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from memory.discord.messages import (
|
||||
@ -12,7 +9,6 @@ from memory.discord.messages import (
|
||||
upsert_scheduled_message,
|
||||
previous_messages,
|
||||
comm_channel_prompt,
|
||||
call_llm,
|
||||
)
|
||||
from memory.common.db.models import (
|
||||
DiscordUser,
|
||||
@ -22,7 +18,6 @@ from memory.common.db.models import (
|
||||
HumanUser,
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
from memory.common.llms.tools import MCPServer as MCPServerDefinition
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -416,107 +411,3 @@ def test_comm_channel_prompt_includes_user_notes(
|
||||
|
||||
assert "user_notes" in result.lower()
|
||||
assert "testuser" in result # username should appear
|
||||
|
||||
|
||||
@patch("memory.discord.messages.create_provider")
|
||||
@patch("memory.discord.messages.previous_messages")
|
||||
@patch("memory.common.llms.tools.discord.make_discord_tools")
|
||||
@patch("memory.common.llms.tools.base.WebSearchTool")
|
||||
def test_call_llm_includes_web_search_and_mcp_servers(
|
||||
mock_web_search,
|
||||
mock_make_tools,
|
||||
mock_prev_messages,
|
||||
mock_create_provider,
|
||||
):
|
||||
provider = MagicMock()
|
||||
provider.usage_tracker.is_rate_limited.return_value = False
|
||||
provider.as_messages.return_value = ["converted"]
|
||||
provider.run_with_tools.return_value = SimpleNamespace(response="llm-output")
|
||||
mock_create_provider.return_value = provider
|
||||
|
||||
mock_prev_messages.return_value = [SimpleNamespace(as_content=lambda: "prev")]
|
||||
|
||||
existing_tool = MagicMock(name="existing_tool")
|
||||
mock_make_tools.return_value = {"existing": existing_tool}
|
||||
|
||||
web_tool_instance = MagicMock(name="web_tool")
|
||||
mock_web_search.return_value = web_tool_instance
|
||||
|
||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
|
||||
from_user = SimpleNamespace(id=123)
|
||||
mcp_model = SimpleNamespace(
|
||||
name="Server",
|
||||
mcp_server_url="https://mcp.example.com",
|
||||
access_token="token123",
|
||||
)
|
||||
|
||||
result = call_llm(
|
||||
session=MagicMock(),
|
||||
bot_user=bot_user,
|
||||
from_user=from_user,
|
||||
channel=None,
|
||||
model="gpt-test",
|
||||
messages=["hi"],
|
||||
mcp_servers=[mcp_model],
|
||||
)
|
||||
|
||||
assert result == "llm-output"
|
||||
|
||||
kwargs = provider.run_with_tools.call_args.kwargs
|
||||
tools = kwargs["tools"]
|
||||
assert tools["existing"] is existing_tool
|
||||
assert tools["web_search"] is web_tool_instance
|
||||
|
||||
mcp_servers = kwargs["mcp_servers"]
|
||||
assert mcp_servers == [
|
||||
MCPServerDefinition(
|
||||
name="Server", url="https://mcp.example.com", token="token123"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@patch("memory.discord.messages.create_provider")
|
||||
@patch("memory.discord.messages.previous_messages")
|
||||
@patch("memory.common.llms.tools.discord.make_discord_tools")
|
||||
@patch("memory.common.llms.tools.base.WebSearchTool")
|
||||
def test_call_llm_filters_disallowed_tools(
|
||||
mock_web_search,
|
||||
mock_make_tools,
|
||||
mock_prev_messages,
|
||||
mock_create_provider,
|
||||
):
|
||||
provider = MagicMock()
|
||||
provider.usage_tracker.is_rate_limited.return_value = False
|
||||
provider.as_messages.return_value = ["converted"]
|
||||
provider.run_with_tools.return_value = SimpleNamespace(response="filtered-output")
|
||||
mock_create_provider.return_value = provider
|
||||
|
||||
mock_prev_messages.return_value = []
|
||||
|
||||
allowed_tool = MagicMock(name="allowed")
|
||||
blocked_tool = MagicMock(name="blocked")
|
||||
mock_make_tools.return_value = {
|
||||
"allowed": allowed_tool,
|
||||
"blocked": blocked_tool,
|
||||
}
|
||||
|
||||
mock_web_search.return_value = MagicMock(name="web_tool")
|
||||
|
||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
|
||||
from_user = SimpleNamespace(id=1)
|
||||
|
||||
call_llm(
|
||||
session=MagicMock(),
|
||||
bot_user=bot_user,
|
||||
from_user=from_user,
|
||||
channel=None,
|
||||
model="gpt-test",
|
||||
messages=[],
|
||||
allowed_tools={"allowed"},
|
||||
mcp_servers=None,
|
||||
)
|
||||
|
||||
tools = provider.run_with_tools.call_args.kwargs["tools"]
|
||||
assert "allowed" in tools
|
||||
assert "blocked" not in tools
|
||||
assert "web_search" not in tools
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
"""Tests for Discord setup CLI utilities."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from tools.discord_setup import generate_bot_invite_url, make_invite
|
||||
|
||||
|
||||
def test_make_invite_generates_expected_url():
|
||||
result = make_invite(123456789)
|
||||
|
||||
assert (
|
||||
result
|
||||
== "https://discord.com/oauth2/authorize?client_id=123456789&scope=bot&permissions=3088"
|
||||
)
|
||||
|
||||
|
||||
@patch("tools.discord_setup.requests.get")
|
||||
def test_generate_bot_invite_url_outputs_link(mock_get):
|
||||
response = MagicMock()
|
||||
response.raise_for_status.return_value = None
|
||||
response.json.return_value = {"id": "987654321"}
|
||||
mock_get.return_value = response
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(generate_bot_invite_url, ["--bot-token", "abc.def"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Bot invite URL" in result.output
|
||||
assert "987654321" in result.output
|
||||
|
||||
|
||||
@patch("tools.discord_setup.requests.get", side_effect=Exception("api down"))
|
||||
def test_generate_bot_invite_url_handles_errors(mock_get):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(generate_bot_invite_url, ["--bot-token", "token"])
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert isinstance(result.exception, ValueError)
|
||||
assert "Could not get bot info" in str(result.exception)
|
||||
Loading…
x
Reference in New Issue
Block a user