mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 00:04:05 +01:00
properly handle mcp redirects
This commit is contained in:
parent
0d9f8beec3
commit
2944a0bce1
@ -14,6 +14,7 @@ from memory.common.db.models import (
|
||||
BookSection,
|
||||
Chunk,
|
||||
Comic,
|
||||
DiscordMCPServer,
|
||||
DiscordMessage,
|
||||
EmailAccount,
|
||||
EmailAttachment,
|
||||
@ -166,6 +167,37 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage):
|
||||
column_sortable_list = ["sent_at"]
|
||||
|
||||
|
||||
class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
|
||||
column_list = [
|
||||
"id",
|
||||
"mcp_server_url",
|
||||
"client_id",
|
||||
"discord_bot_user_id",
|
||||
"state",
|
||||
"code_verifier",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"token_expires_at",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
column_searchable_list = [
|
||||
"mcp_server_url",
|
||||
"client_id",
|
||||
"state",
|
||||
"id",
|
||||
"discord_bot_user_id",
|
||||
]
|
||||
column_sortable_list = [
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"mcp_server_url",
|
||||
"client_id",
|
||||
"state",
|
||||
"id",
|
||||
]
|
||||
|
||||
|
||||
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
|
||||
column_list = [
|
||||
"id",
|
||||
@ -328,4 +360,5 @@ def setup_admin(admin: Admin):
|
||||
admin.add_view(DiscordUserAdmin)
|
||||
admin.add_view(DiscordServerAdmin)
|
||||
admin.add_view(DiscordChannelAdmin)
|
||||
admin.add_view(DiscordMCPServerAdmin)
|
||||
admin.add_view(ScheduledLLMCallAdmin)
|
||||
|
||||
@ -1,13 +1,21 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import HTTPException, Depends, Request, Response, APIRouter
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from sqlalchemy.orm import Session as DBSession
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from memory.common import settings
|
||||
from sqlalchemy.orm import Session as DBSession, scoped_session
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import get_session, make_session
|
||||
from memory.common.db.models.users import User, HumanUser, BotUser, UserSession
|
||||
from memory.common.db.models import (
|
||||
BotUser,
|
||||
DiscordMCPServer,
|
||||
HumanUser,
|
||||
User,
|
||||
UserSession,
|
||||
)
|
||||
from memory.common.oauth import complete_oauth_flow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -136,6 +144,58 @@ def get_me(user: User = Depends(get_current_user)):
|
||||
return user.serialize()
|
||||
|
||||
|
||||
@router.get("/callback/discord")
|
||||
async def oauth_callback_discord(request: Request):
|
||||
"""Get current user info"""
|
||||
code = request.query_params.get("code")
|
||||
state = request.query_params.get("state")
|
||||
error = request.query_params.get("error")
|
||||
|
||||
logger.info(
|
||||
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
|
||||
)
|
||||
|
||||
message, title, close, status_code = "", "", "", 200
|
||||
if error:
|
||||
logger.error(f"OAuth error: {error}")
|
||||
message = f"Error: {error}"
|
||||
title = "❌ Authorization Failed"
|
||||
status_code = 400
|
||||
elif not code or not state:
|
||||
message = "Missing authorization code or state parameter."
|
||||
title = "❌ Invalid Request"
|
||||
status_code = 400
|
||||
else:
|
||||
# Complete the OAuth flow (exchange code for token)
|
||||
with make_session() as session:
|
||||
mcp_server = (
|
||||
session.query(DiscordMCPServer)
|
||||
.filter(DiscordMCPServer.state == state)
|
||||
.first()
|
||||
)
|
||||
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
||||
session.commit()
|
||||
|
||||
if 200 <= status_code < 300:
|
||||
title = "✅ Authorization Successful!"
|
||||
close = "You can close this window and return to the MCP server."
|
||||
else:
|
||||
title = "❌ Authorization Failed"
|
||||
|
||||
return Response(
|
||||
content=f"""
|
||||
<html>
|
||||
<body>
|
||||
<h1>{title}</h1>
|
||||
<p>{message}</p>
|
||||
<p>{close}</p>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to require authentication for all endpoints except whitelisted ones."""
|
||||
|
||||
|
||||
@ -1,22 +1,23 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import cast
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from memory.common.db.models.base import Base
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Boolean,
|
||||
ARRAY,
|
||||
Integer,
|
||||
Numeric,
|
||||
CheckConstraint,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
|
||||
@ -164,9 +164,6 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
kwargs["temperature"] = 1.0
|
||||
kwargs.pop("top_p", None)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
print(f"{k}: {v}")
|
||||
|
||||
return kwargs
|
||||
|
||||
def _handle_stream_event(
|
||||
|
||||
@ -7,9 +7,7 @@ from datetime import datetime, timedelta
|
||||
from urllib.parse import urlencode, urljoin
|
||||
|
||||
import aiohttp
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.discord import DiscordMCPServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -99,7 +97,7 @@ async def get_endpoints(url: str) -> OAuthEndpoints:
|
||||
authorization_endpoint=authorization_endpoint,
|
||||
registration_endpoint=registration_endpoint,
|
||||
token_endpoint=token_endpoint,
|
||||
redirect_uri=f"{settings.SERVER_URL}/oauth/callback/discord",
|
||||
redirect_uri=f"{settings.SERVER_URL}/auth/callback/discord",
|
||||
)
|
||||
|
||||
|
||||
@ -181,7 +179,7 @@ async def issue_challenge(
|
||||
|
||||
|
||||
async def complete_oauth_flow(
|
||||
session: Session | scoped_session, code: str, state: str
|
||||
mcp_server: DiscordMCPServer, code: str, state: str
|
||||
) -> tuple[int, str]:
|
||||
"""Complete OAuth flow by exchanging code for token.
|
||||
|
||||
@ -193,12 +191,6 @@ async def complete_oauth_flow(
|
||||
Tuple of (status_code, html_message) for the callback response
|
||||
"""
|
||||
try:
|
||||
mcp_server = (
|
||||
session.query(DiscordMCPServer)
|
||||
.filter(DiscordMCPServer.state == state)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not mcp_server:
|
||||
logger.error(f"Invalid or expired state: {state[:20]}...")
|
||||
return 400, "Invalid or expired OAuth state"
|
||||
@ -254,8 +246,6 @@ async def complete_oauth_flow(
|
||||
mcp_server.state = None # type: ignore
|
||||
mcp_server.code_verifier = None # type: ignore
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
|
||||
f"server {mcp_server.mcp_server_url}"
|
||||
@ -18,9 +18,9 @@ from pydantic import BaseModel
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.users import DiscordBotUser
|
||||
from memory.common.db.models import DiscordMCPServer, DiscordBotUser
|
||||
from memory.common.oauth import complete_oauth_flow
|
||||
from memory.discord.collector import MessageCollector
|
||||
from memory.discord.oauth import complete_oauth_flow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -279,51 +279,6 @@ async def refresh_metadata():
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/oauth/callback/discord", response_class=HTMLResponse)
|
||||
async def oauth_callback(request: Request):
|
||||
"""Handle OAuth callback from MCP server after user authorization."""
|
||||
code = request.query_params.get("code")
|
||||
state = request.query_params.get("state")
|
||||
error = request.query_params.get("error")
|
||||
|
||||
logger.info(
|
||||
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
|
||||
)
|
||||
|
||||
message, title, close, status_code = "", "", "", 200
|
||||
if error:
|
||||
logger.error(f"OAuth error: {error}")
|
||||
message = f"Error: {error}"
|
||||
title = "❌ Authorization Failed"
|
||||
status_code = 400
|
||||
elif not code or not state:
|
||||
message = "Missing authorization code or state parameter."
|
||||
title = "❌ Invalid Request"
|
||||
status_code = 400
|
||||
else:
|
||||
# Complete the OAuth flow (exchange code for token)
|
||||
with make_session() as session:
|
||||
status_code, message = await complete_oauth_flow(session, code, state)
|
||||
if 200 <= status_code < 300:
|
||||
title = "✅ Authorization Successful!"
|
||||
close = "You can close this window and return to the MCP server."
|
||||
else:
|
||||
title = "❌ Authorization Failed"
|
||||
|
||||
return HTMLResponse(
|
||||
content=f"""
|
||||
<html>
|
||||
<body>
|
||||
<h1>{title}</h1>
|
||||
<p>{message}</p>
|
||||
<p>{close}</p>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
|
||||
"""Run the Discord API server"""
|
||||
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||
|
||||
@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.discord import DiscordMCPServer
|
||||
from memory.discord.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -13,8 +13,8 @@ from memory.common.db.models import (
|
||||
DiscordUser,
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
from memory.common.db.models.users import BotUser
|
||||
from memory.common.llms.base import create_provider
|
||||
from memory.common.llms.tools import MCPServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -216,6 +216,7 @@ def call_llm(
|
||||
system_prompt: str = "",
|
||||
messages: list[str | dict[str, Any]] = [],
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
mcp_servers: list[MCPServer] | None = None,
|
||||
num_previous_messages: int = 10,
|
||||
) -> str | None:
|
||||
"""
|
||||
@ -269,6 +270,7 @@ def call_llm(
|
||||
messages=provider.as_messages(message_content),
|
||||
tools=tools,
|
||||
system_prompt=bot_user.system_prompt + "\n\n" + system_prompt,
|
||||
mcp_servers=mcp_servers,
|
||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||
).response
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from memory.common.llms.tools import MCPServer
|
||||
import requests
|
||||
from sqlalchemy import exc as sqlalchemy_exc
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
@ -189,6 +190,20 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
mcp_servers = None
|
||||
if (
|
||||
discord_message.recipient_user
|
||||
and discord_message.recipient_user.mcp_servers
|
||||
):
|
||||
mcp_servers = [
|
||||
MCPServer(
|
||||
name=server.mcp_server_url,
|
||||
url=server.mcp_server_url,
|
||||
token=server.access_token,
|
||||
)
|
||||
for server in discord_message.recipient_user.mcp_servers
|
||||
]
|
||||
|
||||
try:
|
||||
response = call_llm(
|
||||
session,
|
||||
@ -196,6 +211,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
from_user=discord_message.from_user,
|
||||
channel=discord_message.channel,
|
||||
model=settings.DISCORD_MODEL,
|
||||
mcp_servers=mcp_servers,
|
||||
system_prompt=discord_message.system_prompt,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user