mirror of
https://github.com/mruwnik/memory.git
synced 2025-09-06 16:52:53 +02:00
add scheduled calls
This commit is contained in:
parent
a2d107fad7
commit
a3544222e7
55
db/migrations/versions/20250812_234327_discord_schedules.py
Normal file
55
db/migrations/versions/20250812_234327_discord_schedules.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""discord schedules
|
||||
|
||||
Revision ID: 2fb3223dc71b
|
||||
Revises: 1d6bc8015ea9
|
||||
Create Date: 2025-08-12 23:43:27.671182
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2fb3223dc71b"
|
||||
down_revision: Union[str, None] = "1d6bc8015ea9"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scheduled_llm_calls",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.Column("topic", sa.Text(), nullable=True),
|
||||
sa.Column("scheduled_time", sa.DateTime(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||
),
|
||||
sa.Column("executed_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("model", sa.String(), nullable=True),
|
||||
sa.Column("prompt", sa.Text(), nullable=False),
|
||||
sa.Column("system_prompt", sa.Text(), nullable=True),
|
||||
sa.Column("allowed_tools", sa.JSON(), nullable=True),
|
||||
sa.Column("discord_channel", sa.String(), nullable=True),
|
||||
sa.Column("discord_user", sa.String(), nullable=True),
|
||||
sa.Column("status", sa.String(), nullable=False),
|
||||
sa.Column("response", sa.Text(), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("celery_task_id", sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.add_column("users", sa.Column("discord_user_id", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "discord_user_id")
|
||||
op.drop_table("scheduled_llm_calls")
|
@ -174,7 +174,7 @@ services:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes"
|
||||
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes,scheduler"
|
||||
|
||||
ingest-hub:
|
||||
<<: *worker-base
|
||||
|
@ -1,3 +1,6 @@
|
||||
import memory.api.MCP.manifest
|
||||
import memory.api.MCP.tools
|
||||
import memory.api.MCP.memory
|
||||
import memory.api.MCP.metadata
|
||||
import memory.api.MCP.schedules
|
||||
import memory.api.MCP.books
|
||||
import memory.api.MCP.manifest
|
||||
|
@ -8,23 +8,24 @@ from mcp.server.auth.handlers.token import (
|
||||
RefreshTokenRequest,
|
||||
TokenRequest,
|
||||
)
|
||||
from mcp.server.auth.middleware.auth_context import get_access_token
|
||||
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
from memory.common.db.models.users import User
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, RedirectResponse
|
||||
from starlette.templating import Jinja2Templates
|
||||
|
||||
from memory.api.MCP.oauth_provider import (
|
||||
SimpleOAuthProvider,
|
||||
ALLOWED_SCOPES,
|
||||
BASE_SCOPES,
|
||||
SimpleOAuthProvider,
|
||||
)
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import OAuthState
|
||||
from memory.common.db.models import OAuthState, UserSession
|
||||
from memory.common.db.models.users import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -134,3 +135,30 @@ async def handle_login(request: Request):
|
||||
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
||||
redirect_url = redirect_url.replace("http://", "cursor://")
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
||||
|
||||
def get_current_user() -> dict:
|
||||
access_token = get_access_token()
|
||||
|
||||
if not access_token:
|
||||
return {"authenticated": False}
|
||||
|
||||
with make_session() as session:
|
||||
user_session = (
|
||||
session.query(UserSession)
|
||||
.filter(UserSession.id == access_token.token)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_session and user_session.user:
|
||||
user_info = user_session.user.serialize()
|
||||
else:
|
||||
user_info = {"error": "User not found"}
|
||||
|
||||
return {
|
||||
"authenticated": True,
|
||||
"token_type": "Bearer",
|
||||
"scopes": access_token.scopes,
|
||||
"client_id": access_token.client_id,
|
||||
"user": user_info,
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.api.search.search import search
|
||||
from memory.api.search.types import SearchFilters, SearchConfig
|
||||
from memory.common import extract, settings
|
||||
|
186
src/memory/api/MCP/schedules.py
Normal file
186
src/memory/api/MCP/schedules.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from memory.api.MCP.base import get_current_user
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import ScheduledLLMCall
|
||||
from memory.api.MCP.base import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def schedule_llm_call(
|
||||
scheduled_time: str,
|
||||
model: str,
|
||||
prompt: str,
|
||||
topic: str | None = None,
|
||||
discord_channel: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Schedule an LLM call to be executed at a specific time with response sent to Discord.
|
||||
|
||||
Args:
|
||||
scheduled_time: ISO format datetime string (e.g., "2024-12-20T15:30:00Z")
|
||||
model: Model to use (e.g., "anthropic/claude-3-5-sonnet-20241022"). If not provided, the message will be sent to the user directly.
|
||||
prompt: The prompt to send to the LLM
|
||||
topic: The topic of the scheduled call. If not provided, the topic will be inferred from the prompt.
|
||||
discord_channel: Discord channel name where the response should be sent. If not provided, the message will be sent to the user directly.
|
||||
system_prompt: Optional system prompt
|
||||
metadata: Optional metadata dict for tracking
|
||||
|
||||
Returns:
|
||||
Dict with scheduled call ID and status
|
||||
"""
|
||||
logger.info("schedule_llm_call tool called")
|
||||
|
||||
current_user = get_current_user()
|
||||
if not current_user["authenticated"]:
|
||||
raise ValueError("Not authenticated")
|
||||
user_id = current_user.get("user", {}).get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("User not found")
|
||||
|
||||
discord_user = current_user.get("user", {}).get("discord_user_id")
|
||||
if not discord_user and not discord_channel:
|
||||
raise ValueError("Either discord_user or discord_channel must be provided")
|
||||
|
||||
# Parse scheduled time
|
||||
try:
|
||||
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
|
||||
# Ensure we store as naive datetime (remove timezone info for database storage)
|
||||
if scheduled_dt.tzinfo is not None:
|
||||
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid datetime format")
|
||||
|
||||
# Validate that the scheduled time is in the future
|
||||
# Compare with naive datetime since we store naive in the database
|
||||
current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
if scheduled_dt <= current_time_naive:
|
||||
raise ValueError("Scheduled time must be in the future")
|
||||
|
||||
with make_session() as session:
|
||||
# Create the scheduled call
|
||||
scheduled_call = ScheduledLLMCall(
|
||||
user_id=user_id,
|
||||
scheduled_time=scheduled_dt,
|
||||
topic=topic,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
discord_channel=discord_channel,
|
||||
discord_user=discord_user,
|
||||
data=metadata or {},
|
||||
)
|
||||
|
||||
session.add(scheduled_call)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"scheduled_call_id": scheduled_call.id,
|
||||
"scheduled_time": scheduled_dt.isoformat(),
|
||||
"message": f"LLM call scheduled for {scheduled_dt.isoformat()}",
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_scheduled_llm_calls(
|
||||
status: str | None = None, limit: int | None = 50
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List scheduled LLM calls for the authenticated user.
|
||||
|
||||
Args:
|
||||
status: Optional status filter ("pending", "executing", "completed", "failed", "cancelled")
|
||||
limit: Maximum number of calls to return (default: 50)
|
||||
|
||||
Returns:
|
||||
Dict with list of scheduled calls
|
||||
"""
|
||||
logger.info("list_scheduled_llm_calls tool called")
|
||||
|
||||
current_user = get_current_user()
|
||||
if not current_user["authenticated"]:
|
||||
return {"error": "Not authenticated", "user": current_user}
|
||||
user_id = current_user.get("user", {}).get("user_id")
|
||||
if not user_id:
|
||||
return {"error": "User not found", "user": current_user}
|
||||
|
||||
with make_session() as session:
|
||||
query = (
|
||||
session.query(ScheduledLLMCall)
|
||||
.filter(ScheduledLLMCall.user_id == user_id)
|
||||
.order_by(ScheduledLLMCall.scheduled_time.desc())
|
||||
)
|
||||
|
||||
if status:
|
||||
query = query.filter(ScheduledLLMCall.status == status)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
calls = query.all()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"scheduled_calls": [call.serialize() for call in calls],
|
||||
"count": len(calls),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Cancel a scheduled LLM call.
|
||||
|
||||
Args:
|
||||
scheduled_call_id: ID of the scheduled call to cancel
|
||||
|
||||
Returns:
|
||||
Dict with cancellation status
|
||||
"""
|
||||
logger.info(f"cancel_scheduled_llm_call tool called for ID: {scheduled_call_id}")
|
||||
|
||||
current_user = get_current_user()
|
||||
if not current_user["authenticated"]:
|
||||
return {"error": "Not authenticated", "user": current_user}
|
||||
user_id = current_user.get("user", {}).get("user_id")
|
||||
if not user_id:
|
||||
return {"error": "User not found", "user": current_user}
|
||||
|
||||
with make_session() as session:
|
||||
# Find the scheduled call
|
||||
scheduled_call = (
|
||||
session.query(ScheduledLLMCall)
|
||||
.filter(
|
||||
ScheduledLLMCall.id == scheduled_call_id,
|
||||
ScheduledLLMCall.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not scheduled_call:
|
||||
return {"error": "Scheduled call not found"}
|
||||
|
||||
if not scheduled_call.can_be_cancelled():
|
||||
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
|
||||
|
||||
# Update the status
|
||||
scheduled_call.status = "cancelled"
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Scheduled LLM call {scheduled_call_id} cancelled")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Scheduled call {scheduled_call_id} has been cancelled",
|
||||
}
|
@ -5,18 +5,13 @@ MCP tools for the epistemic sparring partner system.
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from mcp.server.auth.middleware.auth_context import get_access_token
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import (
|
||||
AgentObservation,
|
||||
SourceItem,
|
||||
UserSession,
|
||||
)
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.api.MCP.base import mcp, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -76,35 +71,4 @@ async def get_current_time() -> dict:
|
||||
@mcp.tool()
|
||||
async def get_authenticated_user() -> dict:
|
||||
"""Get information about the authenticated user."""
|
||||
logger.info("🔧 get_authenticated_user tool called")
|
||||
access_token = get_access_token()
|
||||
logger.info(f"🔧 Access token from MCP context: {access_token}")
|
||||
|
||||
if not access_token:
|
||||
logger.warning("❌ No access token found in MCP context!")
|
||||
return {"error": "Not authenticated"}
|
||||
|
||||
logger.info(
|
||||
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
|
||||
)
|
||||
|
||||
# Look up the actual user from the session token
|
||||
with make_session() as session:
|
||||
user_session = (
|
||||
session.query(UserSession)
|
||||
.filter(UserSession.id == access_token.token)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_session and user_session.user:
|
||||
user_info = user_session.user.serialize()
|
||||
else:
|
||||
user_info = {"error": "User not found"}
|
||||
|
||||
return {
|
||||
"authenticated": True,
|
||||
"token_type": "Bearer",
|
||||
"scopes": access_token.scopes,
|
||||
"client_id": access_token.client_id,
|
||||
"user": user_info,
|
||||
}
|
||||
return get_current_user()
|
||||
|
@ -11,6 +11,7 @@ EBOOK_ROOT = "memory.workers.tasks.ebook"
|
||||
MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
|
||||
NOTES_ROOT = "memory.workers.tasks.notes"
|
||||
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
||||
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
||||
|
||||
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
||||
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
||||
@ -44,6 +45,10 @@ SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
|
||||
ADD_ARTICLE_FEED = f"{BLOGS_ROOT}.add_article_feed"
|
||||
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
||||
|
||||
# Scheduled calls tasks
|
||||
EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call"
|
||||
RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
|
||||
|
||||
|
||||
def get_broker_url() -> str:
|
||||
protocol = settings.CELERY_BROKER_TYPE
|
||||
@ -78,6 +83,9 @@ app.conf.update(
|
||||
},
|
||||
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||
f"{SCHEDULED_CALLS_ROOT}.*": {
|
||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -48,6 +48,9 @@ from memory.common.db.models.users import (
|
||||
OAuthState,
|
||||
OAuthRefreshToken,
|
||||
)
|
||||
from memory.common.db.models.scheduled_calls import (
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
|
||||
Payload = (
|
||||
SourceItemPayload
|
||||
@ -96,6 +99,8 @@ __all__ = [
|
||||
"OAuthClientInformation",
|
||||
"OAuthState",
|
||||
"OAuthRefreshToken",
|
||||
# Scheduled Calls
|
||||
"ScheduledLLMCall",
|
||||
# Payloads
|
||||
"Payload",
|
||||
]
|
||||
|
92
src/memory/common/db/models/scheduled_calls.py
Normal file
92
src/memory/common/db/models/scheduled_calls.py
Normal file
@ -0,0 +1,92 @@
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Any, Dict, cast
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class ScheduledLLMCall(Base):
|
||||
__tablename__ = "scheduled_llm_calls"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
topic = Column(Text, nullable=True)
|
||||
|
||||
# Scheduling info
|
||||
scheduled_time = Column(DateTime, nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
executed_at = Column(DateTime, nullable=True)
|
||||
|
||||
# LLM call configuration
|
||||
model = Column(
|
||||
String, nullable=True
|
||||
) # e.g., "anthropic/claude-3-5-sonnet-20241022"
|
||||
prompt = Column(Text, nullable=False)
|
||||
system_prompt = Column(Text, nullable=True)
|
||||
allowed_tools = Column(JSON, nullable=True) # List of allowed tool names
|
||||
|
||||
# Discord configuration
|
||||
discord_channel = Column(String, nullable=True)
|
||||
discord_user = Column(String, nullable=True)
|
||||
|
||||
# Execution status and results
|
||||
status = Column(
|
||||
String, nullable=False, default="pending"
|
||||
) # pending, executing, completed, failed, cancelled
|
||||
response = Column(Text, nullable=True) # LLM response content
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Additional metadata
|
||||
data = Column(JSON, nullable=True) # For extensibility
|
||||
|
||||
# Celery task tracking
|
||||
celery_task_id = Column(String, nullable=True) # Track the Celery Beat task
|
||||
|
||||
# Relationships
|
||||
user = relationship("User")
|
||||
|
||||
def serialize(self) -> Dict[str, Any]:
|
||||
def print_datetime(dt: datetime | None) -> str | None:
|
||||
if dt:
|
||||
return dt.isoformat()
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"user_id": self.user_id,
|
||||
"topic": self.topic,
|
||||
"scheduled_time": print_datetime(cast(datetime, self.scheduled_time)),
|
||||
"created_at": print_datetime(cast(datetime, self.created_at)),
|
||||
"executed_at": print_datetime(cast(datetime, self.executed_at)),
|
||||
"model": self.model,
|
||||
"prompt": self.prompt,
|
||||
"system_prompt": self.system_prompt,
|
||||
"allowed_tools": self.allowed_tools,
|
||||
"discord_channel": self.discord_channel,
|
||||
"discord_user": self.discord_user,
|
||||
"status": self.status,
|
||||
"response": self.response,
|
||||
"error_message": self.error_message,
|
||||
"metadata": self.data,
|
||||
"celery_task_id": self.celery_task_id,
|
||||
}
|
||||
|
||||
def is_pending(self) -> bool:
|
||||
return cast(str, self.status) == "pending"
|
||||
|
||||
def is_completed(self) -> bool:
|
||||
return cast(str, self.status) in ("completed", "failed", "cancelled")
|
||||
|
||||
def can_be_cancelled(self) -> bool:
|
||||
return cast(str, self.status) in ("pending",)
|
@ -2,6 +2,8 @@ import hashlib
|
||||
import secrets
|
||||
from typing import cast
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
from memory.common.db.models.base import Base
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
@ -39,6 +41,7 @@ class User(Base):
|
||||
name = Column(String, nullable=False)
|
||||
email = Column(String, nullable=False, unique=True)
|
||||
password_hash = Column(String, nullable=False)
|
||||
discord_user_id = Column(String, nullable=True)
|
||||
|
||||
# Relationship to sessions
|
||||
sessions = relationship(
|
||||
@ -53,6 +56,7 @@ class User(Base):
|
||||
"user_id": self.id,
|
||||
"name": self.name,
|
||||
"email": self.email,
|
||||
"discord_user_id": self.discord_user_id,
|
||||
}
|
||||
|
||||
def is_valid_password(self, password: str) -> bool:
|
||||
@ -193,3 +197,15 @@ class OAuthRefreshToken(Base, OAuthToken):
|
||||
"expires_at": self.expires_at.timestamp(),
|
||||
"revoked": self.revoked,
|
||||
} | super().serialize()
|
||||
|
||||
|
||||
def purge_oauth(session: Session):
|
||||
for token in session.query(OAuthRefreshToken).all():
|
||||
session.delete(token)
|
||||
for user_session in session.query(UserSession).all():
|
||||
session.delete(user_session)
|
||||
|
||||
for oauth_state in session.query(OAuthState).all():
|
||||
session.delete(oauth_state)
|
||||
for oauth_client in session.query(OAuthClientInformation).all():
|
||||
session.delete(oauth_client)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import requests
|
||||
from typing import Any, Dict, List
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
@ -19,6 +20,7 @@ class DiscordServer(requests.Session):
|
||||
self.channels = {}
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_channels()
|
||||
self.members = self.fetch_all_members()
|
||||
|
||||
def setup_channels(self):
|
||||
resp = self.get(self.channels_url)
|
||||
@ -63,9 +65,20 @@ class DiscordServer(requests.Session):
|
||||
return channel_id
|
||||
|
||||
def send_message(self, channel_id: str, content: str):
|
||||
self.post(
|
||||
payload: dict[str, Any] = {"content": content}
|
||||
mentions = re.findall(r"@(\S*)", content)
|
||||
users = {u: i for u, i in self.members.items() if u in mentions}
|
||||
if users:
|
||||
for u, i in users.items():
|
||||
payload["content"] = payload["content"].replace(f"@{u}", f"<@{i}>")
|
||||
payload["allowed_mentions"] = {
|
||||
"parse": [],
|
||||
"users": list(users.values()),
|
||||
}
|
||||
|
||||
return self.post(
|
||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
||||
json={"content": content},
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
|
||||
@ -91,8 +104,64 @@ class DiscordServer(requests.Session):
|
||||
def channels_url(self) -> str:
|
||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
|
||||
|
||||
@property
|
||||
def members_url(self) -> str:
|
||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/members"
|
||||
|
||||
def get_bot_servers() -> List[Dict]:
|
||||
@property
|
||||
def dm_create_url(self) -> str:
|
||||
return "https://discord.com/api/v10/users/@me/channels"
|
||||
|
||||
def list_members(
|
||||
self, limit: int = 1000, after: str | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List up to `limit` members in this guild, starting after a user ID.
|
||||
|
||||
Requires the bot to have the Server Members Intent enabled in the Discord developer portal.
|
||||
"""
|
||||
params: dict[str, Any] = {"limit": limit}
|
||||
if after:
|
||||
params["after"] = after
|
||||
resp = self.get(self.members_url, params=params)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
def fetch_all_members(self, page_size: int = 1000) -> dict[str, str]:
|
||||
"""Retrieve all members in the guild by paginating the members list.
|
||||
|
||||
Note: Large guilds may take multiple requests. Rate limits are respected by requests.Session automatically.
|
||||
"""
|
||||
members: dict[str, str] = {}
|
||||
after: str | None = None
|
||||
while batch := self.list_members(limit=page_size, after=after):
|
||||
for member in batch:
|
||||
user = member.get("user", {})
|
||||
members[user.get("global_name") or user.get("username", "")] = user.get(
|
||||
"id", ""
|
||||
)
|
||||
after = user.get("id", "")
|
||||
return members
|
||||
|
||||
def create_dm_channel(self, user_id: str) -> str:
|
||||
"""Create (or retrieve) a DM channel with the given user and return the channel ID.
|
||||
|
||||
The bot must share a guild with the user, and the user's privacy settings must allow DMs from server members.
|
||||
"""
|
||||
resp = self.post(self.dm_create_url, json={"recipient_id": user_id})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data["id"]
|
||||
|
||||
def send_dm(self, user_id: str, content: str):
|
||||
"""Send a direct message to a specific user by ID."""
|
||||
channel_id = self.create_dm_channel(self.members.get(user_id) or user_id)
|
||||
return self.post(
|
||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
||||
json={"content": content},
|
||||
)
|
||||
|
||||
|
||||
def get_bot_servers() -> list[dict[str, Any]]:
|
||||
"""Get list of servers the bot is in."""
|
||||
if not settings.DISCORD_BOT_TOKEN:
|
||||
return []
|
||||
@ -141,6 +210,14 @@ def send_chat_message(message: str):
|
||||
broadcast_message(CHAT_CHANNEL, message)
|
||||
|
||||
|
||||
def send_dm(user_id: str, message: str):
|
||||
for server in servers.values():
|
||||
if not server.members.get(user_id) and user_id not in server.members.values():
|
||||
continue
|
||||
|
||||
server.send_dm(user_id, message)
|
||||
|
||||
|
||||
def notify_task_failure(
|
||||
task_name: str,
|
||||
error_message: str,
|
||||
|
@ -107,6 +107,7 @@ CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 24 * 60 *
|
||||
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
|
||||
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
|
||||
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
|
||||
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
|
||||
|
||||
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
||||
|
||||
@ -168,6 +169,6 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||
|
||||
|
||||
# Enable Discord notifications if bot token is set
|
||||
DISCORD_NOTIFICATIONS_ENABLED = (
|
||||
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||
)
|
||||
|
@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
@ -9,6 +9,7 @@ from memory.common.celery_app import (
|
||||
SYNC_ALL_ARTICLE_FEEDS,
|
||||
TRACK_GIT_CHANGES,
|
||||
SYNC_LESSWRONG,
|
||||
RUN_SCHEDULED_CALLS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -43,4 +44,8 @@ app.conf.beat_schedule = {
|
||||
"task": SYNC_LESSWRONG,
|
||||
"schedule": settings.LESSWRONG_SYNC_INTERVAL,
|
||||
},
|
||||
"run-scheduled-calls": {
|
||||
"task": RUN_SCHEDULED_CALLS,
|
||||
"schedule": settings.SCHEDULED_CALL_RUN_INTERVAL,
|
||||
},
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ from memory.workers.tasks import (
|
||||
maintenance,
|
||||
notes,
|
||||
observations,
|
||||
scheduled_calls,
|
||||
) # noqa
|
||||
|
||||
|
||||
@ -23,4 +24,5 @@ __all__ = [
|
||||
"maintenance",
|
||||
"notes",
|
||||
"observations",
|
||||
"scheduled_calls",
|
||||
]
|
||||
|
141
src/memory/workers/tasks/scheduled_calls.py
Normal file
141
src/memory/workers/tasks/scheduled_calls.py
Normal file
@ -0,0 +1,141 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
import textwrap
|
||||
from typing import cast
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import ScheduledLLMCall
|
||||
from memory.common.celery_app import (
|
||||
app,
|
||||
EXECUTE_SCHEDULED_CALL,
|
||||
RUN_SCHEDULED_CALLS,
|
||||
)
|
||||
from memory.common import llms, discord
|
||||
from memory.workers.tasks.content_processing import safe_task_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
||||
"""
|
||||
Send the LLM response to the specified Discord user.
|
||||
|
||||
Args:
|
||||
scheduled_call: The scheduled call object
|
||||
response: The LLM response to send
|
||||
"""
|
||||
message = response
|
||||
if cast(str, scheduled_call.topic):
|
||||
message = f"**{scheduled_call.topic}**\n\n{message}"
|
||||
|
||||
# Discord has a 2000 character limit, so we may need to split the message
|
||||
if len(message) > 1900: # Leave some buffer
|
||||
message = message[:1900] + "\n\n... (response truncated)"
|
||||
|
||||
if discord_user := cast(str, scheduled_call.discord_user):
|
||||
logger.info(f"Sending DM to {discord_user}: {message}")
|
||||
discord.send_dm(discord_user, message)
|
||||
elif discord_channel := cast(str, scheduled_call.discord_channel):
|
||||
logger.info(f"Broadcasting message to {discord_channel}: {message}")
|
||||
discord.broadcast_message(discord_channel, message)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
||||
)
|
||||
|
||||
|
||||
@app.task(bind=True, name=EXECUTE_SCHEDULED_CALL)
|
||||
@safe_task_execution
|
||||
def execute_scheduled_call(self, scheduled_call_id: str):
|
||||
"""
|
||||
Execute a scheduled LLM call and send the response to Discord.
|
||||
|
||||
Args:
|
||||
scheduled_call_id: The ID of the scheduled call to execute
|
||||
"""
|
||||
logger.info(f"Executing scheduled LLM call: {scheduled_call_id}")
|
||||
|
||||
with make_session() as session:
|
||||
# Fetch the scheduled call
|
||||
scheduled_call = (
|
||||
session.query(ScheduledLLMCall)
|
||||
.filter(ScheduledLLMCall.id == scheduled_call_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not scheduled_call:
|
||||
logger.error(f"Scheduled call {scheduled_call_id} not found")
|
||||
return {"error": "Scheduled call not found"}
|
||||
|
||||
# Check if the call is still pending
|
||||
if not scheduled_call.is_pending():
|
||||
logger.warning(
|
||||
f"Scheduled call {scheduled_call_id} is not pending (status: {scheduled_call.status})"
|
||||
)
|
||||
return {"error": f"Call is not pending (status: {scheduled_call.status})"}
|
||||
|
||||
# Update status to executing
|
||||
scheduled_call.status = "executing"
|
||||
scheduled_call.executed_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Calling LLM with model {scheduled_call.model}")
|
||||
|
||||
# Make the LLM call
|
||||
if scheduled_call.model:
|
||||
response = llms.call(
|
||||
prompt=cast(str, scheduled_call.prompt),
|
||||
model=cast(str, scheduled_call.model),
|
||||
system_prompt=cast(str, scheduled_call.system_prompt)
|
||||
or llms.SYSTEM_PROMPT,
|
||||
)
|
||||
else:
|
||||
response = cast(str, scheduled_call.prompt)
|
||||
|
||||
# Store the response
|
||||
scheduled_call.response = response
|
||||
scheduled_call.status = "completed"
|
||||
session.commit()
|
||||
|
||||
logger.info(f"LLM call completed for {scheduled_call_id}")
|
||||
|
||||
# Send to Discord
|
||||
try:
|
||||
_send_to_discord(scheduled_call, response)
|
||||
logger.info(f"Response sent to Discord for {scheduled_call_id}")
|
||||
except Exception as discord_error:
|
||||
logger.error(f"Failed to send to Discord: {discord_error}")
|
||||
# Don't mark as failed since the LLM call succeeded
|
||||
scheduled_call.data = scheduled_call.data or {}
|
||||
scheduled_call.data["discord_error"] = str(discord_error)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"scheduled_call_id": scheduled_call_id,
|
||||
"response": response[:100] + "..." if len(response) > 100 else response,
|
||||
"discord_sent": True,
|
||||
}
|
||||
|
||||
|
||||
@app.task(name=RUN_SCHEDULED_CALLS)
|
||||
@safe_task_execution
|
||||
def run_scheduled_calls():
|
||||
"""Run scheduled calls that are due."""
|
||||
with make_session() as session:
|
||||
calls = (
|
||||
session.query(ScheduledLLMCall)
|
||||
.filter(
|
||||
ScheduledLLMCall.status.in_(["pending"]),
|
||||
ScheduledLLMCall.scheduled_time
|
||||
< datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for call in calls:
|
||||
execute_scheduled_call.delay(call.id)
|
||||
|
||||
return {
|
||||
"calls": [call.id for call in calls],
|
||||
"count": len(calls),
|
||||
}
|
@ -1,8 +1,6 @@
|
||||
import logging
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import requests
|
||||
import json
|
||||
|
||||
from memory.common import discord, settings
|
||||
|
||||
|
585
tests/memory/workers/tasks/test_scheduled_calls.py
Normal file
585
tests/memory/workers/tasks/test_scheduled_calls.py
Normal file
@ -0,0 +1,585 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
import uuid
|
||||
|
||||
from memory.common.db.models import ScheduledLLMCall, User
|
||||
from memory.workers.tasks import scheduled_calls
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user(db_session):
|
||||
"""Create a sample user for testing."""
|
||||
user = User(
|
||||
name="testuser",
|
||||
email="test@example.com",
|
||||
discord_user_id="123456789",
|
||||
password_hash="password",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pending_scheduled_call(db_session, sample_user):
|
||||
"""Create a pending scheduled call for testing."""
|
||||
call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
topic="Test Topic",
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
prompt="What is the weather like today?",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
discord_user="123456789",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(call)
|
||||
db_session.commit()
|
||||
return call
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completed_scheduled_call(db_session, sample_user):
|
||||
"""Create a completed scheduled call for testing."""
|
||||
call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
topic="Completed Topic",
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
executed_at=datetime.now(timezone.utc) - timedelta(minutes=30),
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
prompt="Tell me a joke.",
|
||||
system_prompt="You are a funny assistant.",
|
||||
discord_channel="987654321",
|
||||
status="completed",
|
||||
response="Why did the chicken cross the road? To get to the other side!",
|
||||
)
|
||||
db_session.add(call)
|
||||
db_session.commit()
|
||||
return call
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def future_scheduled_call(db_session, sample_user):
|
||||
"""Create a future scheduled call for testing."""
|
||||
call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
topic="Future Topic",
|
||||
scheduled_time=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
prompt="What will happen tomorrow?",
|
||||
discord_user="123456789",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(call)
|
||||
db_session.commit()
|
||||
return call
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
||||
"""Test sending to Discord user."""
|
||||
response = "This is a test response."
|
||||
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
||||
|
||||
mock_send_dm.assert_called_once_with(
|
||||
"123456789",
|
||||
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
||||
def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
||||
"""Test sending to Discord channel."""
|
||||
response = "This is a channel response."
|
||||
|
||||
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
||||
|
||||
mock_broadcast.assert_called_once_with(
|
||||
"987654321",
|
||||
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
|
||||
"""Test message truncation for long responses."""
|
||||
long_response = "A" * 2500 # Very long response
|
||||
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, long_response)
|
||||
|
||||
# Verify the message was truncated
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
message = args[1]
|
||||
assert len(message) <= 1950 # Should be truncated
|
||||
assert message.endswith("... (response truncated)")
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
|
||||
"""Test that normal length messages are not truncated."""
|
||||
normal_response = "This is a normal length response."
|
||||
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
||||
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
message = args[1]
|
||||
assert not message.endswith("... (response truncated)")
|
||||
assert "This is a normal length response." in message
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
def test_execute_scheduled_call_success(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
"""Test successful execution of a scheduled LLM call."""
|
||||
mock_llm_call.return_value = "The weather is sunny today!"
|
||||
|
||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||
|
||||
# Verify LLM was called with correct parameters
|
||||
mock_llm_call.assert_called_once_with(
|
||||
prompt="What is the weather like today?",
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result["success"] is True
|
||||
assert result["scheduled_call_id"] == pending_scheduled_call.id
|
||||
assert result["response"] == "The weather is sunny today!"
|
||||
assert result["discord_sent"] is True
|
||||
|
||||
# Verify database was updated
|
||||
db_session.refresh(pending_scheduled_call)
|
||||
assert pending_scheduled_call.status == "completed"
|
||||
assert pending_scheduled_call.response == "The weather is sunny today!"
|
||||
assert pending_scheduled_call.executed_at is not None
|
||||
|
||||
|
||||
def test_execute_scheduled_call_not_found(db_session):
|
||||
"""Test execution with non-existent call ID."""
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
result = scheduled_calls.execute_scheduled_call(fake_id)
|
||||
|
||||
assert result == {"error": "Scheduled call not found"}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
def test_execute_scheduled_call_not_pending(
|
||||
mock_llm_call, completed_scheduled_call, db_session
|
||||
):
|
||||
"""Test execution of a call that is not pending."""
|
||||
result = scheduled_calls.execute_scheduled_call(completed_scheduled_call.id)
|
||||
|
||||
assert result == {"error": "Call is not pending (status: completed)"}
|
||||
mock_llm_call.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
def test_execute_scheduled_call_with_default_system_prompt(
|
||||
mock_llm_call, mock_send_discord, db_session, sample_user
|
||||
):
|
||||
"""Test execution when system_prompt is None, should use default."""
|
||||
# Create call without system prompt
|
||||
call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
topic="No System Prompt",
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
prompt="Test prompt",
|
||||
system_prompt=None,
|
||||
discord_user="123456789",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(call)
|
||||
db_session.commit()
|
||||
|
||||
mock_llm_call.return_value = "Response"
|
||||
|
||||
scheduled_calls.execute_scheduled_call(call.id)
|
||||
|
||||
# Verify default system prompt was used
|
||||
mock_llm_call.assert_called_once_with(
|
||||
prompt="Test prompt",
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
system_prompt=scheduled_calls.llms.SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
def test_execute_scheduled_call_discord_error(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
"""Test execution when Discord sending fails."""
|
||||
mock_llm_call.return_value = "LLM response"
|
||||
mock_send_discord.side_effect = Exception("Discord API error")
|
||||
|
||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||
|
||||
# Should still return success since LLM call succeeded
|
||||
assert result["success"] is True
|
||||
assert (
|
||||
result["discord_sent"] is True
|
||||
) # This is always True in current implementation
|
||||
|
||||
# Verify the call was marked as completed despite Discord error
|
||||
db_session.refresh(pending_scheduled_call)
|
||||
assert pending_scheduled_call.status == "completed"
|
||||
assert pending_scheduled_call.response == "LLM response"
|
||||
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
def test_execute_scheduled_call_llm_error(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
"""Test execution when LLM call fails."""
|
||||
mock_llm_call.side_effect = Exception("LLM API error")
|
||||
|
||||
# The safe_task_execution decorator should catch this
|
||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "LLM API error" in result["error"]
|
||||
|
||||
# Discord should not be called
|
||||
mock_send_discord.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
def test_execute_scheduled_call_long_response_truncation(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
"""Test that long responses are truncated in the result."""
|
||||
long_response = "A" * 500 # Long response
|
||||
mock_llm_call.return_value = long_response
|
||||
|
||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||
|
||||
# Response in result should be truncated
|
||||
assert len(result["response"]) <= 103 # 100 chars + "..."
|
||||
assert result["response"].endswith("...")
|
||||
|
||||
# But full response should be stored in database
|
||||
db_session.refresh(pending_scheduled_call)
|
||||
assert pending_scheduled_call.response == long_response
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||
def test_run_scheduled_calls_with_due_calls(
|
||||
mock_execute_delay, db_session, sample_user
|
||||
):
|
||||
"""Test running scheduled calls with due calls."""
|
||||
# Create multiple due calls
|
||||
due_call1 = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=10),
|
||||
model="test-model",
|
||||
prompt="Test 1",
|
||||
discord_user="123",
|
||||
status="pending",
|
||||
)
|
||||
due_call2 = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
prompt="Test 2",
|
||||
discord_user="123",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
db_session.add_all([due_call1, due_call2])
|
||||
db_session.commit()
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task.id = "task-123"
|
||||
mock_execute_delay.delay.return_value = mock_task
|
||||
|
||||
result = scheduled_calls.run_scheduled_calls()
|
||||
|
||||
assert result["count"] == 2
|
||||
assert due_call1.id in result["calls"]
|
||||
assert due_call2.id in result["calls"]
|
||||
|
||||
# Verify execute_scheduled_call.delay was called for both
|
||||
assert mock_execute_delay.delay.call_count == 2
|
||||
mock_execute_delay.delay.assert_any_call(due_call1.id)
|
||||
mock_execute_delay.delay.assert_any_call(due_call2.id)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||
def test_run_scheduled_calls_no_due_calls(
|
||||
mock_execute_delay, future_scheduled_call, db_session
|
||||
):
|
||||
"""Test running scheduled calls when no calls are due."""
|
||||
result = scheduled_calls.run_scheduled_calls()
|
||||
|
||||
assert result["count"] == 0
|
||||
assert result["calls"] == []
|
||||
|
||||
# No tasks should be scheduled
|
||||
mock_execute_delay.delay.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||
def test_run_scheduled_calls_mixed_statuses(
|
||||
mock_execute_delay, db_session, sample_user
|
||||
):
|
||||
"""Test that only pending calls are processed."""
|
||||
# Create calls with different statuses
|
||||
pending_call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
prompt="Pending",
|
||||
discord_user="123",
|
||||
status="pending",
|
||||
)
|
||||
executing_call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
prompt="Executing",
|
||||
discord_user="123",
|
||||
status="executing",
|
||||
)
|
||||
completed_call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
prompt="Completed",
|
||||
discord_user="123",
|
||||
status="completed",
|
||||
)
|
||||
|
||||
db_session.add_all([pending_call, executing_call, completed_call])
|
||||
db_session.commit()
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task.id = "task-123"
|
||||
mock_execute_delay.delay.return_value = mock_task
|
||||
|
||||
result = scheduled_calls.run_scheduled_calls()
|
||||
|
||||
# Only the pending call should be processed
|
||||
assert result["count"] == 1
|
||||
assert result["calls"] == [pending_call.id]
|
||||
|
||||
mock_execute_delay.delay.assert_called_once_with(pending_call.id)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||
def test_run_scheduled_calls_timezone_handling(
|
||||
mock_execute_delay, db_session, sample_user
|
||||
):
|
||||
"""Test that timezone handling works correctly."""
|
||||
# Create a call that's due (scheduled time in the past)
|
||||
past_time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
due_call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
scheduled_time=past_time.replace(tzinfo=None), # Store as naive datetime
|
||||
model="test-model",
|
||||
prompt="Due call",
|
||||
discord_user="123",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
# Create a call that's not due yet
|
||||
future_time = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||
future_call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
scheduled_time=future_time.replace(tzinfo=None), # Store as naive datetime
|
||||
model="test-model",
|
||||
prompt="Future call",
|
||||
discord_user="123",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
db_session.add_all([due_call, future_call])
|
||||
db_session.commit()
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task.id = "task-123"
|
||||
mock_execute_delay.delay.return_value = mock_task
|
||||
|
||||
result = scheduled_calls.run_scheduled_calls()
|
||||
|
||||
# Only the due call should be processed
|
||||
assert result["count"] == 1
|
||||
assert result["calls"] == [due_call.id]
|
||||
|
||||
mock_execute_delay.delay.assert_called_once_with(due_call.id)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
def test_status_transition_pending_to_executing_to_completed(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
"""Test that status transitions correctly during execution."""
|
||||
mock_llm_call.return_value = "Response"
|
||||
|
||||
# Initial status should be pending
|
||||
assert pending_scheduled_call.status == "pending"
|
||||
assert pending_scheduled_call.executed_at is None
|
||||
|
||||
scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||
|
||||
# Final status should be completed
|
||||
db_session.refresh(pending_scheduled_call)
|
||||
assert pending_scheduled_call.status == "completed"
|
||||
assert pending_scheduled_call.executed_at is not None
|
||||
assert pending_scheduled_call.response == "Response"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"discord_user,discord_channel,expected_method",
|
||||
[
|
||||
("123456789", None, "send_dm"),
|
||||
(None, "987654321", "broadcast_message"),
|
||||
("123456789", "987654321", "send_dm"), # User takes precedence
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
||||
def test_discord_destination_priority(
|
||||
mock_broadcast,
|
||||
mock_send_dm,
|
||||
discord_user,
|
||||
discord_channel,
|
||||
expected_method,
|
||||
db_session,
|
||||
sample_user,
|
||||
):
|
||||
"""Test that Discord user takes precedence over channel."""
|
||||
call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
topic="Priority Test",
|
||||
scheduled_time=datetime.now(timezone.utc),
|
||||
model="test-model",
|
||||
prompt="Test",
|
||||
discord_user=discord_user,
|
||||
discord_channel=discord_channel,
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(call)
|
||||
db_session.commit()
|
||||
|
||||
response = "Test response"
|
||||
scheduled_calls._send_to_discord(call, response)
|
||||
|
||||
if expected_method == "send_dm":
|
||||
mock_send_dm.assert_called_once()
|
||||
mock_broadcast.assert_not_called()
|
||||
else:
|
||||
mock_broadcast.assert_called_once()
|
||||
mock_send_dm.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topic,model,response,expected_in_message",
|
||||
[
|
||||
(
|
||||
"Weather Check",
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
"It's sunny!",
|
||||
[
|
||||
"**Topic:** Weather Check",
|
||||
"**Model:** anthropic/claude-3-5-sonnet-20241022",
|
||||
"**Response:** It's sunny!",
|
||||
],
|
||||
),
|
||||
(
|
||||
"Test Topic",
|
||||
"gpt-4",
|
||||
"Hello world",
|
||||
["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
|
||||
),
|
||||
(
|
||||
"Long Topic Name Here",
|
||||
"claude-2",
|
||||
"Short",
|
||||
[
|
||||
"**Topic:** Long Topic Name Here",
|
||||
"**Model:** claude-2",
|
||||
"**Response:** Short",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
|
||||
"""Test the Discord message formatting with different inputs."""
|
||||
# Create a mock scheduled call
|
||||
mock_call = Mock()
|
||||
mock_call.topic = topic
|
||||
mock_call.model = model
|
||||
mock_call.discord_user = "123456789"
|
||||
|
||||
scheduled_calls._send_to_discord(mock_call, response)
|
||||
|
||||
# Get the actual message that was sent
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
actual_message = args[1]
|
||||
|
||||
# Verify all expected parts are in the message
|
||||
for expected_part in expected_in_message:
|
||||
assert expected_part in actual_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status,should_execute",
|
||||
[
|
||||
("pending", True),
|
||||
("executing", False),
|
||||
("completed", False),
|
||||
("failed", False),
|
||||
("cancelled", False),
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||
def test_execute_scheduled_call_status_check(
|
||||
mock_llm_call, status, should_execute, db_session, sample_user
|
||||
):
|
||||
"""Test that only pending calls are executed."""
|
||||
call = ScheduledLLMCall(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=sample_user.id,
|
||||
topic="Status Test",
|
||||
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
model="test-model",
|
||||
prompt="Test",
|
||||
discord_user="123",
|
||||
status=status,
|
||||
)
|
||||
db_session.add(call)
|
||||
db_session.commit()
|
||||
|
||||
result = scheduled_calls.execute_scheduled_call(call.id)
|
||||
|
||||
if should_execute:
|
||||
mock_llm_call.assert_called_once()
|
||||
# We don't check the full result here since it depends on mocking more functions
|
||||
else:
|
||||
assert result == {"error": f"Call is not pending (status: {status})"}
|
||||
mock_llm_call.assert_not_called()
|
Loading…
x
Reference in New Issue
Block a user