mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-23 07:06:36 +02:00
fix admin
This commit is contained in:
parent
aaa0c2c3cd
commit
4fedd8fe04
@ -59,4 +59,4 @@ USER kb
|
|||||||
ENV PORT=8000
|
ENV PORT=8000
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
CMD ["uvicorn", "memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
CMD ["uvicorn", "memory.api.app:app", "--host", "0.0.0.0", "--port", "8000", "--proxy-headers", "--forwarded-allow-ips", "*"]
|
@ -6,12 +6,10 @@ import logging
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from memory.api.MCP.base import get_current_user
|
from memory.api.MCP.base import get_current_user, mcp
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import ScheduledLLMCall
|
from memory.common.db.models import ScheduledLLMCall
|
||||||
from memory.common.db.models.discord import DiscordChannel, DiscordUser
|
from memory.discord.messages import schedule_discord_message
|
||||||
from memory.api.MCP.base import mcp
|
|
||||||
from memory.discord.schedule import schedule_discord_message
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -2,36 +2,29 @@
|
|||||||
SQLAdmin views for the knowledge base database models.
|
SQLAdmin views for the knowledge base database models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
|
||||||
from sqladmin import Admin, ModelView
|
|
||||||
from fastapi import Request
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import RedirectResponse
|
|
||||||
import logging
|
import logging
|
||||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
|
||||||
from memory.api.MCP.oauth_provider import create_expiration, ACCESS_TOKEN_LIFETIME
|
from sqladmin import Admin, ModelView
|
||||||
from memory.common import settings
|
|
||||||
from memory.common.db.connection import make_session
|
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
Chunk,
|
AgentObservation,
|
||||||
SourceItem,
|
ArticleFeed,
|
||||||
MailMessage,
|
BlogPost,
|
||||||
EmailAttachment,
|
|
||||||
Photo,
|
|
||||||
Comic,
|
|
||||||
Book,
|
Book,
|
||||||
BookSection,
|
BookSection,
|
||||||
BlogPost,
|
ScheduledLLMCall,
|
||||||
MiscDoc,
|
Chunk,
|
||||||
ArticleFeed,
|
Comic,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
|
EmailAttachment,
|
||||||
ForumPost,
|
ForumPost,
|
||||||
AgentObservation,
|
MailMessage,
|
||||||
|
MiscDoc,
|
||||||
Note,
|
Note,
|
||||||
|
Photo,
|
||||||
|
SourceItem,
|
||||||
User,
|
User,
|
||||||
UserSession,
|
|
||||||
OAuthState,
|
|
||||||
)
|
)
|
||||||
|
from memory.common.db.models.discord import DiscordChannel, DiscordServer, DiscordUser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -223,9 +216,79 @@ class NoteAdmin(ModelView, model=Note):
|
|||||||
class UserAdmin(ModelView, model=User):
|
class UserAdmin(ModelView, model=User):
|
||||||
column_list = [
|
column_list = [
|
||||||
"id",
|
"id",
|
||||||
|
"user_type",
|
||||||
"email",
|
"email",
|
||||||
|
"api_key",
|
||||||
"name",
|
"name",
|
||||||
"created_at",
|
"created_at",
|
||||||
|
"discord_users",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordUserAdmin(ModelView, model=DiscordUser):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"username",
|
||||||
|
"display_name",
|
||||||
|
"track_messages",
|
||||||
|
"ignore_messages",
|
||||||
|
"allowed_tools",
|
||||||
|
"disallowed_tools",
|
||||||
|
"summary",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordServerAdmin(ModelView, model=DiscordServer):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
|
"member_count",
|
||||||
|
"last_sync_at",
|
||||||
|
"track_messages",
|
||||||
|
"ignore_messages",
|
||||||
|
"allowed_tools",
|
||||||
|
"disallowed_tools",
|
||||||
|
"summary",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordChannelAdmin(ModelView, model=DiscordChannel):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
|
"member_count",
|
||||||
|
"last_sync_at",
|
||||||
|
"track_messages",
|
||||||
|
"ignore_messages",
|
||||||
|
"allowed_tools",
|
||||||
|
"disallowed_tools",
|
||||||
|
"summary",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledLLMCallAdmin(ModelView, model=ScheduledLLMCall):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"user",
|
||||||
|
"topic",
|
||||||
|
"scheduled_time",
|
||||||
|
"model",
|
||||||
|
"status",
|
||||||
|
"error_message",
|
||||||
|
"response",
|
||||||
|
"discord_channel",
|
||||||
|
"discord_user",
|
||||||
|
"executed_at",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -247,3 +310,7 @@ def setup_admin(admin: Admin):
|
|||||||
admin.add_view(ComicAdmin)
|
admin.add_view(ComicAdmin)
|
||||||
admin.add_view(PhotoAdmin)
|
admin.add_view(PhotoAdmin)
|
||||||
admin.add_view(UserAdmin)
|
admin.add_view(UserAdmin)
|
||||||
|
admin.add_view(DiscordUserAdmin)
|
||||||
|
admin.add_view(DiscordServerAdmin)
|
||||||
|
admin.add_view(DiscordChannelAdmin)
|
||||||
|
admin.add_view(ScheduledLLMCallAdmin)
|
||||||
|
@ -25,6 +25,7 @@ WHITELIST = {
|
|||||||
"/oauth/",
|
"/oauth/",
|
||||||
"/.well-known/",
|
"/.well-known/",
|
||||||
"/ui",
|
"/ui",
|
||||||
|
"/admin/statics/", # SQLAdmin static resources
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ import uuid
|
|||||||
from typing import Any, Dict, cast
|
from typing import Any, Dict, cast
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column,
|
Column,
|
||||||
Integer,
|
|
||||||
String,
|
String,
|
||||||
DateTime,
|
DateTime,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
|
@ -16,20 +16,39 @@ import uvicorn
|
|||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.discord.collector import MessageCollector
|
from memory.discord.collector import MessageCollector
|
||||||
|
from memory.common.db.models.users import BotUser
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SendDMRequest(BaseModel):
|
class SendDMRequest(BaseModel):
|
||||||
|
bot_id: int
|
||||||
user: str # Discord user ID or username
|
user: str # Discord user ID or username
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
class SendChannelRequest(BaseModel):
|
class SendChannelRequest(BaseModel):
|
||||||
|
bot_id: int
|
||||||
channel_name: str # Channel name (e.g., "memory-errors")
|
channel_name: str # Channel name (e.g., "memory-errors")
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class Collector:
|
||||||
|
collector: MessageCollector
|
||||||
|
collector_task: asyncio.Task
|
||||||
|
bot_id: int
|
||||||
|
bot_token: str
|
||||||
|
bot_name: str
|
||||||
|
|
||||||
|
def __init__(self, collector: MessageCollector, bot: BotUser):
|
||||||
|
self.collector = collector
|
||||||
|
self.collector_task = asyncio.create_task(collector.start(bot.api_key))
|
||||||
|
self.bot_id = bot.id
|
||||||
|
self.bot_token = bot.api_key
|
||||||
|
self.bot_name = bot.name
|
||||||
|
|
||||||
|
|
||||||
# Application state
|
# Application state
|
||||||
class AppState:
|
class AppState:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -47,27 +66,28 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create and start the collector
|
def make_collector(bot: BotUser):
|
||||||
app_state.collector = MessageCollector()
|
collector = MessageCollector()
|
||||||
app_state.collector_task = asyncio.create_task(
|
return Collector(collector=collector, bot=bot)
|
||||||
app_state.collector.start(settings.DISCORD_BOT_TOKEN)
|
|
||||||
)
|
with make_session() as session:
|
||||||
logger.info("Discord collector started")
|
app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()}
|
||||||
|
|
||||||
|
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
if app_state.collector and not app_state.collector.is_closed():
|
for bot in app.bots.values():
|
||||||
await app_state.collector.close()
|
if not bot.collector.is_closed():
|
||||||
|
await bot.collector.close()
|
||||||
if app_state.collector_task:
|
if bot.collector_task:
|
||||||
app_state.collector_task.cancel()
|
bot.collector_task.cancel()
|
||||||
try:
|
try:
|
||||||
await app_state.collector_task
|
await bot.collector_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
logger.info(f"Discord collectors stopped for {len(app.bots)} bots")
|
||||||
logger.info("Discord collector stopped")
|
|
||||||
|
|
||||||
|
|
||||||
# FastAPI app with lifespan management
|
# FastAPI app with lifespan management
|
||||||
@ -77,11 +97,12 @@ app = FastAPI(title="Discord Collector API", version="1.0.0", lifespan=lifespan)
|
|||||||
@app.post("/send_dm")
|
@app.post("/send_dm")
|
||||||
async def send_dm_endpoint(request: SendDMRequest):
|
async def send_dm_endpoint(request: SendDMRequest):
|
||||||
"""Send a DM via the collector's Discord client"""
|
"""Send a DM via the collector's Discord client"""
|
||||||
if not app_state.collector:
|
collector = app.bots.get(request.bot_id)
|
||||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
if not collector:
|
||||||
|
raise HTTPException(status_code=404, detail="Bot not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success = await app_state.collector.send_dm(request.user, request.message)
|
success = await collector.collector.send_dm(request.user, request.message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(f"Failed to send DM: {e}")
|
logger.error(f"Failed to send DM: {e}")
|
||||||
@ -102,11 +123,12 @@ async def send_dm_endpoint(request: SendDMRequest):
|
|||||||
@app.post("/send_channel")
|
@app.post("/send_channel")
|
||||||
async def send_channel_endpoint(request: SendChannelRequest):
|
async def send_channel_endpoint(request: SendChannelRequest):
|
||||||
"""Send a message to a channel via the collector's Discord client"""
|
"""Send a message to a channel via the collector's Discord client"""
|
||||||
if not app_state.collector:
|
collector = app.bots.get(request.bot_id)
|
||||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
if not collector:
|
||||||
|
raise HTTPException(status_code=404, detail="Bot not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success = await app_state.collector.send_to_channel(
|
success = await collector.collector.send_to_channel(
|
||||||
request.channel_name, request.message
|
request.channel_name, request.message
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -130,27 +152,37 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
|||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Check if the Discord collector is running and healthy"""
|
"""Check if the Discord collector is running and healthy"""
|
||||||
if not app_state.collector:
|
if not app.bots:
|
||||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
collector = app_state.collector
|
collector = app_state.collector
|
||||||
return {
|
return {
|
||||||
|
collector.bot_name: {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"connected": not collector.is_closed(),
|
"connected": not bot.collector.is_closed(),
|
||||||
"user": str(collector.user) if collector.user else None,
|
"user": str(bot.collector.user) if bot.collector.user else None,
|
||||||
"guilds": len(collector.guilds) if collector.guilds else 0,
|
"guilds": len(bot.collector.guilds) if bot.collector.guilds else 0,
|
||||||
|
}
|
||||||
|
for bot in app.bots.values()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/refresh_metadata")
|
@app.post("/refresh_metadata")
|
||||||
async def refresh_metadata():
|
async def refresh_metadata():
|
||||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||||
if not app_state.collector:
|
if not app.bots:
|
||||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await app_state.collector.refresh_metadata()
|
result = {
|
||||||
return {"success": True, "message": "Metadata refreshed successfully", **result}
|
bot.bot_name: await bot.collector.refresh_metadata()
|
||||||
|
for bot in app.bots.values()
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Metadata refreshed successfully for {len(app.bots)} bots",
|
||||||
|
"results": result,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to refresh metadata: {e}")
|
logger.error(f"Failed to refresh metadata: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user