mirror of
https://github.com/mruwnik/memory.git
synced 2025-11-13 08:14:05 +01:00
properly handle mcp redirects
This commit is contained in:
parent
0d9f8beec3
commit
2944a0bce1
@ -14,6 +14,7 @@ from memory.common.db.models import (
|
|||||||
BookSection,
|
BookSection,
|
||||||
Chunk,
|
Chunk,
|
||||||
Comic,
|
Comic,
|
||||||
|
DiscordMCPServer,
|
||||||
DiscordMessage,
|
DiscordMessage,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
@ -166,6 +167,37 @@ class DiscordMessageAdmin(ModelView, model=DiscordMessage):
|
|||||||
column_sortable_list = ["sent_at"]
|
column_sortable_list = ["sent_at"]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordMCPServerAdmin(ModelView, model=DiscordMCPServer):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"mcp_server_url",
|
||||||
|
"client_id",
|
||||||
|
"discord_bot_user_id",
|
||||||
|
"state",
|
||||||
|
"code_verifier",
|
||||||
|
"access_token",
|
||||||
|
"refresh_token",
|
||||||
|
"token_expires_at",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
column_searchable_list = [
|
||||||
|
"mcp_server_url",
|
||||||
|
"client_id",
|
||||||
|
"state",
|
||||||
|
"id",
|
||||||
|
"discord_bot_user_id",
|
||||||
|
]
|
||||||
|
column_sortable_list = [
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
"mcp_server_url",
|
||||||
|
"client_id",
|
||||||
|
"state",
|
||||||
|
"id",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
|
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
|
||||||
column_list = [
|
column_list = [
|
||||||
"id",
|
"id",
|
||||||
@ -328,4 +360,5 @@ def setup_admin(admin: Admin):
|
|||||||
admin.add_view(DiscordUserAdmin)
|
admin.add_view(DiscordUserAdmin)
|
||||||
admin.add_view(DiscordServerAdmin)
|
admin.add_view(DiscordServerAdmin)
|
||||||
admin.add_view(DiscordChannelAdmin)
|
admin.add_view(DiscordChannelAdmin)
|
||||||
|
admin.add_view(DiscordMCPServerAdmin)
|
||||||
admin.add_view(ScheduledLLMCallAdmin)
|
admin.add_view(ScheduledLLMCallAdmin)
|
||||||
|
|||||||
@ -1,13 +1,21 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from fastapi import HTTPException, Depends, Request, Response, APIRouter
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
|
from sqlalchemy.orm import Session as DBSession
|
||||||
|
from sqlalchemy.orm import scoped_session
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from memory.common import settings
|
|
||||||
from sqlalchemy.orm import Session as DBSession, scoped_session
|
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
from memory.common.db.connection import get_session, make_session
|
from memory.common.db.connection import get_session, make_session
|
||||||
from memory.common.db.models.users import User, HumanUser, BotUser, UserSession
|
from memory.common.db.models import (
|
||||||
|
BotUser,
|
||||||
|
DiscordMCPServer,
|
||||||
|
HumanUser,
|
||||||
|
User,
|
||||||
|
UserSession,
|
||||||
|
)
|
||||||
|
from memory.common.oauth import complete_oauth_flow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -136,6 +144,58 @@ def get_me(user: User = Depends(get_current_user)):
|
|||||||
return user.serialize()
|
return user.serialize()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/callback/discord")
|
||||||
|
async def oauth_callback_discord(request: Request):
|
||||||
|
"""Get current user info"""
|
||||||
|
code = request.query_params.get("code")
|
||||||
|
state = request.query_params.get("state")
|
||||||
|
error = request.query_params.get("error")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
message, title, close, status_code = "", "", "", 200
|
||||||
|
if error:
|
||||||
|
logger.error(f"OAuth error: {error}")
|
||||||
|
message = f"Error: {error}"
|
||||||
|
title = "❌ Authorization Failed"
|
||||||
|
status_code = 400
|
||||||
|
elif not code or not state:
|
||||||
|
message = "Missing authorization code or state parameter."
|
||||||
|
title = "❌ Invalid Request"
|
||||||
|
status_code = 400
|
||||||
|
else:
|
||||||
|
# Complete the OAuth flow (exchange code for token)
|
||||||
|
with make_session() as session:
|
||||||
|
mcp_server = (
|
||||||
|
session.query(DiscordMCPServer)
|
||||||
|
.filter(DiscordMCPServer.state == state)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
status_code, message = await complete_oauth_flow(mcp_server, code, state)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
if 200 <= status_code < 300:
|
||||||
|
title = "✅ Authorization Successful!"
|
||||||
|
close = "You can close this window and return to the MCP server."
|
||||||
|
else:
|
||||||
|
title = "❌ Authorization Failed"
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=f"""
|
||||||
|
<html>
|
||||||
|
<body>
|
||||||
|
<h1>{title}</h1>
|
||||||
|
<p>{message}</p>
|
||||||
|
<p>{close}</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""",
|
||||||
|
status_code=status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||||
"""Middleware to require authentication for all endpoints except whitelisted ones."""
|
"""Middleware to require authentication for all endpoints except whitelisted ones."""
|
||||||
|
|
||||||
|
|||||||
@ -1,22 +1,23 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
from typing import cast
|
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy.orm import Session
|
from typing import cast
|
||||||
from memory.common.db.models.base import Base
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
ARRAY,
|
||||||
|
Boolean,
|
||||||
|
CheckConstraint,
|
||||||
Column,
|
Column,
|
||||||
Integer,
|
|
||||||
String,
|
|
||||||
DateTime,
|
DateTime,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Boolean,
|
Integer,
|
||||||
ARRAY,
|
|
||||||
Numeric,
|
Numeric,
|
||||||
CheckConstraint,
|
String,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy.orm import Session, relationship
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
|
|||||||
@ -164,9 +164,6 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
kwargs["temperature"] = 1.0
|
kwargs["temperature"] = 1.0
|
||||||
kwargs.pop("top_p", None)
|
kwargs.pop("top_p", None)
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
print(f"{k}: {v}")
|
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _handle_stream_event(
|
def _handle_stream_event(
|
||||||
|
|||||||
@ -7,9 +7,7 @@ from datetime import datetime, timedelta
|
|||||||
from urllib.parse import urlencode, urljoin
|
from urllib.parse import urlencode, urljoin
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from sqlalchemy.orm import Session, scoped_session
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
|
||||||
from memory.common.db.models.discord import DiscordMCPServer
|
from memory.common.db.models.discord import DiscordMCPServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -99,7 +97,7 @@ async def get_endpoints(url: str) -> OAuthEndpoints:
|
|||||||
authorization_endpoint=authorization_endpoint,
|
authorization_endpoint=authorization_endpoint,
|
||||||
registration_endpoint=registration_endpoint,
|
registration_endpoint=registration_endpoint,
|
||||||
token_endpoint=token_endpoint,
|
token_endpoint=token_endpoint,
|
||||||
redirect_uri=f"{settings.SERVER_URL}/oauth/callback/discord",
|
redirect_uri=f"{settings.SERVER_URL}/auth/callback/discord",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -181,7 +179,7 @@ async def issue_challenge(
|
|||||||
|
|
||||||
|
|
||||||
async def complete_oauth_flow(
|
async def complete_oauth_flow(
|
||||||
session: Session | scoped_session, code: str, state: str
|
mcp_server: DiscordMCPServer, code: str, state: str
|
||||||
) -> tuple[int, str]:
|
) -> tuple[int, str]:
|
||||||
"""Complete OAuth flow by exchanging code for token.
|
"""Complete OAuth flow by exchanging code for token.
|
||||||
|
|
||||||
@ -193,12 +191,6 @@ async def complete_oauth_flow(
|
|||||||
Tuple of (status_code, html_message) for the callback response
|
Tuple of (status_code, html_message) for the callback response
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
mcp_server = (
|
|
||||||
session.query(DiscordMCPServer)
|
|
||||||
.filter(DiscordMCPServer.state == state)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
logger.error(f"Invalid or expired state: {state[:20]}...")
|
logger.error(f"Invalid or expired state: {state[:20]}...")
|
||||||
return 400, "Invalid or expired OAuth state"
|
return 400, "Invalid or expired OAuth state"
|
||||||
@ -254,8 +246,6 @@ async def complete_oauth_flow(
|
|||||||
mcp_server.state = None # type: ignore
|
mcp_server.state = None # type: ignore
|
||||||
mcp_server.code_verifier = None # type: ignore
|
mcp_server.code_verifier = None # type: ignore
|
||||||
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
|
f"Stored tokens for user {mcp_server.discord_bot_user_id}, "
|
||||||
f"server {mcp_server.mcp_server_url}"
|
f"server {mcp_server.mcp_server_url}"
|
||||||
@ -18,9 +18,9 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models.users import DiscordBotUser
|
from memory.common.db.models import DiscordMCPServer, DiscordBotUser
|
||||||
|
from memory.common.oauth import complete_oauth_flow
|
||||||
from memory.discord.collector import MessageCollector
|
from memory.discord.collector import MessageCollector
|
||||||
from memory.discord.oauth import complete_oauth_flow
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -279,51 +279,6 @@ async def refresh_metadata():
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@app.get("/oauth/callback/discord", response_class=HTMLResponse)
|
|
||||||
async def oauth_callback(request: Request):
|
|
||||||
"""Handle OAuth callback from MCP server after user authorization."""
|
|
||||||
code = request.query_params.get("code")
|
|
||||||
state = request.query_params.get("state")
|
|
||||||
error = request.query_params.get("error")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
message, title, close, status_code = "", "", "", 200
|
|
||||||
if error:
|
|
||||||
logger.error(f"OAuth error: {error}")
|
|
||||||
message = f"Error: {error}"
|
|
||||||
title = "❌ Authorization Failed"
|
|
||||||
status_code = 400
|
|
||||||
elif not code or not state:
|
|
||||||
message = "Missing authorization code or state parameter."
|
|
||||||
title = "❌ Invalid Request"
|
|
||||||
status_code = 400
|
|
||||||
else:
|
|
||||||
# Complete the OAuth flow (exchange code for token)
|
|
||||||
with make_session() as session:
|
|
||||||
status_code, message = await complete_oauth_flow(session, code, state)
|
|
||||||
if 200 <= status_code < 300:
|
|
||||||
title = "✅ Authorization Successful!"
|
|
||||||
close = "You can close this window and return to the MCP server."
|
|
||||||
else:
|
|
||||||
title = "❌ Authorization Failed"
|
|
||||||
|
|
||||||
return HTMLResponse(
|
|
||||||
content=f"""
|
|
||||||
<html>
|
|
||||||
<body>
|
|
||||||
<h1>{title}</h1>
|
|
||||||
<p>{message}</p>
|
|
||||||
<p>{close}</p>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
""",
|
|
||||||
status_code=status_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
|
def run_discord_api_server(host: str = "127.0.0.1", port: int = 8001):
|
||||||
"""Run the Discord API server"""
|
"""Run the Discord API server"""
|
||||||
uvicorn.run(app, host=host, port=port, log_level="debug")
|
uvicorn.run(app, host=host, port=port, log_level="debug")
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, scoped_session
|
|||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models.discord import DiscordMCPServer
|
from memory.common.db.models.discord import DiscordMCPServer
|
||||||
from memory.discord.oauth import get_endpoints, issue_challenge, register_oauth_client
|
from memory.common.oauth import get_endpoints, issue_challenge, register_oauth_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -13,8 +13,8 @@ from memory.common.db.models import (
|
|||||||
DiscordUser,
|
DiscordUser,
|
||||||
ScheduledLLMCall,
|
ScheduledLLMCall,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.users import BotUser
|
|
||||||
from memory.common.llms.base import create_provider
|
from memory.common.llms.base import create_provider
|
||||||
|
from memory.common.llms.tools import MCPServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -216,6 +216,7 @@ def call_llm(
|
|||||||
system_prompt: str = "",
|
system_prompt: str = "",
|
||||||
messages: list[str | dict[str, Any]] = [],
|
messages: list[str | dict[str, Any]] = [],
|
||||||
allowed_tools: Collection[str] | None = None,
|
allowed_tools: Collection[str] | None = None,
|
||||||
|
mcp_servers: list[MCPServer] | None = None,
|
||||||
num_previous_messages: int = 10,
|
num_previous_messages: int = 10,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""
|
"""
|
||||||
@ -269,6 +270,7 @@ def call_llm(
|
|||||||
messages=provider.as_messages(message_content),
|
messages=provider.as_messages(message_content),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
system_prompt=bot_user.system_prompt + "\n\n" + system_prompt,
|
system_prompt=bot_user.system_prompt + "\n\n" + system_prompt,
|
||||||
|
mcp_servers=mcp_servers,
|
||||||
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
max_iterations=settings.DISCORD_MAX_TOOL_CALLS,
|
||||||
).response
|
).response
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import textwrap
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from memory.common.llms.tools import MCPServer
|
||||||
import requests
|
import requests
|
||||||
from sqlalchemy import exc as sqlalchemy_exc
|
from sqlalchemy import exc as sqlalchemy_exc
|
||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
@ -189,6 +190,20 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mcp_servers = None
|
||||||
|
if (
|
||||||
|
discord_message.recipient_user
|
||||||
|
and discord_message.recipient_user.mcp_servers
|
||||||
|
):
|
||||||
|
mcp_servers = [
|
||||||
|
MCPServer(
|
||||||
|
name=server.mcp_server_url,
|
||||||
|
url=server.mcp_server_url,
|
||||||
|
token=server.access_token,
|
||||||
|
)
|
||||||
|
for server in discord_message.recipient_user.mcp_servers
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = call_llm(
|
response = call_llm(
|
||||||
session,
|
session,
|
||||||
@ -196,6 +211,7 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
|||||||
from_user=discord_message.from_user,
|
from_user=discord_message.from_user,
|
||||||
channel=discord_message.channel,
|
channel=discord_message.channel,
|
||||||
model=settings.DISCORD_MODEL,
|
model=settings.DISCORD_MODEL,
|
||||||
|
mcp_servers=mcp_servers,
|
||||||
system_prompt=discord_message.system_prompt,
|
system_prompt=discord_message.system_prompt,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user