mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-22 22:56:38 +02:00
fix admin
This commit is contained in:
parent
aaa0c2c3cd
commit
4fedd8fe04
@ -59,4 +59,4 @@ USER kb
|
||||
ENV PORT=8000
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uvicorn", "memory.api.app:app", "--host", "0.0.0.0", "--port", "8000", "--proxy-headers", "--forwarded-allow-ips", "*"]
|
@ -6,12 +6,10 @@ import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from memory.api.MCP.base import get_current_user
|
||||
from memory.api.MCP.base import get_current_user, mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import ScheduledLLMCall
|
||||
from memory.common.db.models.discord import DiscordChannel, DiscordUser
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.discord.schedule import schedule_discord_message
|
||||
from memory.discord.messages import schedule_discord_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -2,36 +2,29 @@
|
||||
SQLAdmin views for the knowledge base database models.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from sqladmin import Admin, ModelView
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
import logging
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
|
||||
from sqladmin import Admin, ModelView
|
||||
from memory.common.db.models import (
|
||||
Chunk,
|
||||
SourceItem,
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
Photo,
|
||||
Comic,
|
||||
AgentObservation,
|
||||
ArticleFeed,
|
||||
BlogPost,
|
||||
Book,
|
||||
BookSection,
|
||||
BlogPost,
|
||||
MiscDoc,
|
||||
ArticleFeed,
|
||||
ScheduledLLMCall,
|
||||
Chunk,
|
||||
Comic,
|
||||
EmailAccount,
|
||||
EmailAttachment,
|
||||
ForumPost,
|
||||
AgentObservation,
|
||||
MailMessage,
|
||||
MiscDoc,
|
||||
Note,
|
||||
Photo,
|
||||
SourceItem,
|
||||
User,
|
||||
UserSession,
|
||||
OAuthState,
|
||||
)
|
||||
from memory.common.db.models.discord import DiscordChannel, DiscordServer, DiscordUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -223,9 +216,79 @@ class NoteAdmin(ModelView, model=Note):
|
||||
class UserAdmin(ModelView, model=User):
|
||||
column_list = [
|
||||
"id",
|
||||
"user_type",
|
||||
"email",
|
||||
"api_key",
|
||||
"name",
|
||||
"created_at",
|
||||
"discord_users",
|
||||
]
|
||||
|
||||
|
||||
class DiscordUserAdmin(ModelView, model=DiscordUser):
|
||||
column_list = [
|
||||
"id",
|
||||
"username",
|
||||
"display_name",
|
||||
"track_messages",
|
||||
"ignore_messages",
|
||||
"allowed_tools",
|
||||
"disallowed_tools",
|
||||
"summary",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
|
||||
|
||||
class DiscordServerAdmin(ModelView, model=DiscordServer):
|
||||
column_list = [
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"member_count",
|
||||
"last_sync_at",
|
||||
"track_messages",
|
||||
"ignore_messages",
|
||||
"allowed_tools",
|
||||
"disallowed_tools",
|
||||
"summary",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
|
||||
|
||||
class DiscordChannelAdmin(ModelView, model=DiscordChannel):
|
||||
column_list = [
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"member_count",
|
||||
"last_sync_at",
|
||||
"track_messages",
|
||||
"ignore_messages",
|
||||
"allowed_tools",
|
||||
"disallowed_tools",
|
||||
"summary",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
|
||||
|
||||
class ScheduledLLMCallAdmin(ModelView, model=ScheduledLLMCall):
|
||||
column_list = [
|
||||
"id",
|
||||
"user",
|
||||
"topic",
|
||||
"scheduled_time",
|
||||
"model",
|
||||
"status",
|
||||
"error_message",
|
||||
"response",
|
||||
"discord_channel",
|
||||
"discord_user",
|
||||
"executed_at",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
|
||||
|
||||
@ -247,3 +310,7 @@ def setup_admin(admin: Admin):
|
||||
admin.add_view(ComicAdmin)
|
||||
admin.add_view(PhotoAdmin)
|
||||
admin.add_view(UserAdmin)
|
||||
admin.add_view(DiscordUserAdmin)
|
||||
admin.add_view(DiscordServerAdmin)
|
||||
admin.add_view(DiscordChannelAdmin)
|
||||
admin.add_view(ScheduledLLMCallAdmin)
|
||||
|
@ -25,6 +25,7 @@ WHITELIST = {
|
||||
"/oauth/",
|
||||
"/.well-known/",
|
||||
"/ui",
|
||||
"/admin/statics/", # SQLAdmin static resources
|
||||
}
|
||||
|
||||
|
||||
|
@ -3,7 +3,6 @@ import uuid
|
||||
from typing import Any, Dict, cast
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
|
@ -16,20 +16,39 @@ import uvicorn
|
||||
|
||||
from memory.common import settings
|
||||
from memory.discord.collector import MessageCollector
|
||||
from memory.common.db.models.users import BotUser
|
||||
from memory.common.db.connection import make_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SendDMRequest(BaseModel):
|
||||
bot_id: int
|
||||
user: str # Discord user ID or username
|
||||
message: str
|
||||
|
||||
|
||||
class SendChannelRequest(BaseModel):
|
||||
bot_id: int
|
||||
channel_name: str # Channel name (e.g., "memory-errors")
|
||||
message: str
|
||||
|
||||
|
||||
class Collector:
|
||||
collector: MessageCollector
|
||||
collector_task: asyncio.Task
|
||||
bot_id: int
|
||||
bot_token: str
|
||||
bot_name: str
|
||||
|
||||
def __init__(self, collector: MessageCollector, bot: BotUser):
|
||||
self.collector = collector
|
||||
self.collector_task = asyncio.create_task(collector.start(bot.api_key))
|
||||
self.bot_id = bot.id
|
||||
self.bot_token = bot.api_key
|
||||
self.bot_name = bot.name
|
||||
|
||||
|
||||
# Application state
|
||||
class AppState:
|
||||
def __init__(self):
|
||||
@ -47,27 +66,28 @@ async def lifespan(app: FastAPI):
|
||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||
return
|
||||
|
||||
# Create and start the collector
|
||||
app_state.collector = MessageCollector()
|
||||
app_state.collector_task = asyncio.create_task(
|
||||
app_state.collector.start(settings.DISCORD_BOT_TOKEN)
|
||||
)
|
||||
logger.info("Discord collector started")
|
||||
def make_collector(bot: BotUser):
|
||||
collector = MessageCollector()
|
||||
return Collector(collector=collector, bot=bot)
|
||||
|
||||
with make_session() as session:
|
||||
app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()}
|
||||
|
||||
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
if app_state.collector and not app_state.collector.is_closed():
|
||||
await app_state.collector.close()
|
||||
|
||||
if app_state.collector_task:
|
||||
app_state.collector_task.cancel()
|
||||
try:
|
||||
await app_state.collector_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Discord collector stopped")
|
||||
for bot in app.bots.values():
|
||||
if not bot.collector.is_closed():
|
||||
await bot.collector.close()
|
||||
if bot.collector_task:
|
||||
bot.collector_task.cancel()
|
||||
try:
|
||||
await bot.collector_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(f"Discord collectors stopped for {len(app.bots)} bots")
|
||||
|
||||
|
||||
# FastAPI app with lifespan management
|
||||
@ -77,11 +97,12 @@ app = FastAPI(title="Discord Collector API", version="1.0.0", lifespan=lifespan)
|
||||
@app.post("/send_dm")
|
||||
async def send_dm_endpoint(request: SendDMRequest):
|
||||
"""Send a DM via the collector's Discord client"""
|
||||
if not app_state.collector:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
collector = app.bots.get(request.bot_id)
|
||||
if not collector:
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
|
||||
try:
|
||||
success = await app_state.collector.send_dm(request.user, request.message)
|
||||
success = await collector.collector.send_dm(request.user, request.message)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"Failed to send DM: {e}")
|
||||
@ -102,11 +123,12 @@ async def send_dm_endpoint(request: SendDMRequest):
|
||||
@app.post("/send_channel")
|
||||
async def send_channel_endpoint(request: SendChannelRequest):
|
||||
"""Send a message to a channel via the collector's Discord client"""
|
||||
if not app_state.collector:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
collector = app.bots.get(request.bot_id)
|
||||
if not collector:
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
|
||||
try:
|
||||
success = await app_state.collector.send_to_channel(
|
||||
success = await collector.collector.send_to_channel(
|
||||
request.channel_name, request.message
|
||||
)
|
||||
|
||||
@ -130,27 +152,37 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Check if the Discord collector is running and healthy"""
|
||||
if not app_state.collector:
|
||||
if not app.bots:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
|
||||
collector = app_state.collector
|
||||
return {
|
||||
"status": "healthy",
|
||||
"connected": not collector.is_closed(),
|
||||
"user": str(collector.user) if collector.user else None,
|
||||
"guilds": len(collector.guilds) if collector.guilds else 0,
|
||||
collector.bot_name: {
|
||||
"status": "healthy",
|
||||
"connected": not bot.collector.is_closed(),
|
||||
"user": str(bot.collector.user) if bot.collector.user else None,
|
||||
"guilds": len(bot.collector.guilds) if bot.collector.guilds else 0,
|
||||
}
|
||||
for bot in app.bots.values()
|
||||
}
|
||||
|
||||
|
||||
@app.post("/refresh_metadata")
|
||||
async def refresh_metadata():
|
||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||
if not app_state.collector:
|
||||
if not app.bots:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
|
||||
try:
|
||||
result = await app_state.collector.refresh_metadata()
|
||||
return {"success": True, "message": "Metadata refreshed successfully", **result}
|
||||
result = {
|
||||
bot.bot_name: await bot.collector.refresh_metadata()
|
||||
for bot in app.bots.values()
|
||||
}
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Metadata refreshed successfully for {len(app.bots)} bots",
|
||||
"results": result,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh metadata: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
Loading…
x
Reference in New Issue
Block a user