mirror of
https://github.com/mruwnik/memory.git
synced 2025-10-23 15:16:35 +02:00
add scheduled calls
This commit is contained in:
parent
a2d107fad7
commit
a3544222e7
55
db/migrations/versions/20250812_234327_discord_schedules.py
Normal file
55
db/migrations/versions/20250812_234327_discord_schedules.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""discord schedules
|
||||||
|
|
||||||
|
Revision ID: 2fb3223dc71b
|
||||||
|
Revises: 1d6bc8015ea9
|
||||||
|
Create Date: 2025-08-12 23:43:27.671182
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "2fb3223dc71b"
|
||||||
|
down_revision: Union[str, None] = "1d6bc8015ea9"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"scheduled_llm_calls",
|
||||||
|
sa.Column("id", sa.String(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("topic", sa.Text(), nullable=True),
|
||||||
|
sa.Column("scheduled_time", sa.DateTime(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("executed_at", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column("model", sa.String(), nullable=True),
|
||||||
|
sa.Column("prompt", sa.Text(), nullable=False),
|
||||||
|
sa.Column("system_prompt", sa.Text(), nullable=True),
|
||||||
|
sa.Column("allowed_tools", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("discord_channel", sa.String(), nullable=True),
|
||||||
|
sa.Column("discord_user", sa.String(), nullable=True),
|
||||||
|
sa.Column("status", sa.String(), nullable=False),
|
||||||
|
sa.Column("response", sa.Text(), nullable=True),
|
||||||
|
sa.Column("error_message", sa.Text(), nullable=True),
|
||||||
|
sa.Column("data", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("celery_task_id", sa.String(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.add_column("users", sa.Column("discord_user_id", sa.String(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("users", "discord_user_id")
|
||||||
|
op.drop_table("scheduled_llm_calls")
|
@ -174,7 +174,7 @@ services:
|
|||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes"
|
QUEUES: "email,ebooks,comic,blogs,forums,maintenance,notes,scheduler"
|
||||||
|
|
||||||
ingest-hub:
|
ingest-hub:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
import memory.api.MCP.manifest
|
import memory.api.MCP.tools
|
||||||
import memory.api.MCP.memory
|
import memory.api.MCP.memory
|
||||||
import memory.api.MCP.metadata
|
import memory.api.MCP.metadata
|
||||||
|
import memory.api.MCP.schedules
|
||||||
|
import memory.api.MCP.books
|
||||||
|
import memory.api.MCP.manifest
|
||||||
|
@ -8,23 +8,24 @@ from mcp.server.auth.handlers.token import (
|
|||||||
RefreshTokenRequest,
|
RefreshTokenRequest,
|
||||||
TokenRequest,
|
TokenRequest,
|
||||||
)
|
)
|
||||||
|
from mcp.server.auth.middleware.auth_context import get_access_token
|
||||||
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
from mcp.shared.auth import OAuthClientMetadata
|
from mcp.shared.auth import OAuthClientMetadata
|
||||||
from memory.common.db.models.users import User
|
|
||||||
from pydantic import AnyHttpUrl
|
from pydantic import AnyHttpUrl
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse, RedirectResponse
|
from starlette.responses import JSONResponse, RedirectResponse
|
||||||
from starlette.templating import Jinja2Templates
|
from starlette.templating import Jinja2Templates
|
||||||
|
|
||||||
from memory.api.MCP.oauth_provider import (
|
from memory.api.MCP.oauth_provider import (
|
||||||
SimpleOAuthProvider,
|
|
||||||
ALLOWED_SCOPES,
|
ALLOWED_SCOPES,
|
||||||
BASE_SCOPES,
|
BASE_SCOPES,
|
||||||
|
SimpleOAuthProvider,
|
||||||
)
|
)
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import OAuthState
|
from memory.common.db.models import OAuthState, UserSession
|
||||||
|
from memory.common.db.models.users import User
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -134,3 +135,30 @@ async def handle_login(request: Request):
|
|||||||
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
||||||
redirect_url = redirect_url.replace("http://", "cursor://")
|
redirect_url = redirect_url.replace("http://", "cursor://")
|
||||||
return RedirectResponse(url=redirect_url, status_code=302)
|
return RedirectResponse(url=redirect_url, status_code=302)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user() -> dict:
|
||||||
|
access_token = get_access_token()
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
return {"authenticated": False}
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
user_session = (
|
||||||
|
session.query(UserSession)
|
||||||
|
.filter(UserSession.id == access_token.token)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_session and user_session.user:
|
||||||
|
user_info = user_session.user.serialize()
|
||||||
|
else:
|
||||||
|
user_info = {"error": "User not found"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"authenticated": True,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scopes": access_token.scopes,
|
||||||
|
"client_id": access_token.client_id,
|
||||||
|
"user": user_info,
|
||||||
|
}
|
||||||
|
@ -13,7 +13,7 @@ from sqlalchemy import Text
|
|||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
from memory.api.MCP.tools import mcp
|
from memory.api.MCP.base import mcp
|
||||||
from memory.api.search.search import search
|
from memory.api.search.search import search
|
||||||
from memory.api.search.types import SearchFilters, SearchConfig
|
from memory.api.search.types import SearchFilters, SearchConfig
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
|
186
src/memory/api/MCP/schedules.py
Normal file
186
src/memory/api/MCP/schedules.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
"""
|
||||||
|
MCP tools for the epistemic sparring partner system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from memory.api.MCP.base import get_current_user
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import ScheduledLLMCall
|
||||||
|
from memory.api.MCP.base import mcp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def schedule_llm_call(
|
||||||
|
scheduled_time: str,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
topic: str | None = None,
|
||||||
|
discord_channel: str | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Schedule an LLM call to be executed at a specific time with response sent to Discord.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduled_time: ISO format datetime string (e.g., "2024-12-20T15:30:00Z")
|
||||||
|
model: Model to use (e.g., "anthropic/claude-3-5-sonnet-20241022"). If not provided, the message will be sent to the user directly.
|
||||||
|
prompt: The prompt to send to the LLM
|
||||||
|
topic: The topic of the scheduled call. If not provided, the topic will be inferred from the prompt.
|
||||||
|
discord_channel: Discord channel name where the response should be sent. If not provided, the message will be sent to the user directly.
|
||||||
|
system_prompt: Optional system prompt
|
||||||
|
metadata: Optional metadata dict for tracking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with scheduled call ID and status
|
||||||
|
"""
|
||||||
|
logger.info("schedule_llm_call tool called")
|
||||||
|
|
||||||
|
current_user = get_current_user()
|
||||||
|
if not current_user["authenticated"]:
|
||||||
|
raise ValueError("Not authenticated")
|
||||||
|
user_id = current_user.get("user", {}).get("user_id")
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("User not found")
|
||||||
|
|
||||||
|
discord_user = current_user.get("user", {}).get("discord_user_id")
|
||||||
|
if not discord_user and not discord_channel:
|
||||||
|
raise ValueError("Either discord_user or discord_channel must be provided")
|
||||||
|
|
||||||
|
# Parse scheduled time
|
||||||
|
try:
|
||||||
|
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
|
||||||
|
# Ensure we store as naive datetime (remove timezone info for database storage)
|
||||||
|
if scheduled_dt.tzinfo is not None:
|
||||||
|
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Invalid datetime format")
|
||||||
|
|
||||||
|
# Validate that the scheduled time is in the future
|
||||||
|
# Compare with naive datetime since we store naive in the database
|
||||||
|
current_time_naive = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
if scheduled_dt <= current_time_naive:
|
||||||
|
raise ValueError("Scheduled time must be in the future")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Create the scheduled call
|
||||||
|
scheduled_call = ScheduledLLMCall(
|
||||||
|
user_id=user_id,
|
||||||
|
scheduled_time=scheduled_dt,
|
||||||
|
topic=topic,
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
discord_channel=discord_channel,
|
||||||
|
discord_user=discord_user,
|
||||||
|
data=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(scheduled_call)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"scheduled_call_id": scheduled_call.id,
|
||||||
|
"scheduled_time": scheduled_dt.isoformat(),
|
||||||
|
"message": f"LLM call scheduled for {scheduled_dt.isoformat()}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def list_scheduled_llm_calls(
|
||||||
|
status: str | None = None, limit: int | None = 50
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
List scheduled LLM calls for the authenticated user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Optional status filter ("pending", "executing", "completed", "failed", "cancelled")
|
||||||
|
limit: Maximum number of calls to return (default: 50)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with list of scheduled calls
|
||||||
|
"""
|
||||||
|
logger.info("list_scheduled_llm_calls tool called")
|
||||||
|
|
||||||
|
current_user = get_current_user()
|
||||||
|
if not current_user["authenticated"]:
|
||||||
|
return {"error": "Not authenticated", "user": current_user}
|
||||||
|
user_id = current_user.get("user", {}).get("user_id")
|
||||||
|
if not user_id:
|
||||||
|
return {"error": "User not found", "user": current_user}
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
query = (
|
||||||
|
session.query(ScheduledLLMCall)
|
||||||
|
.filter(ScheduledLLMCall.user_id == user_id)
|
||||||
|
.order_by(ScheduledLLMCall.scheduled_time.desc())
|
||||||
|
)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
query = query.filter(ScheduledLLMCall.status == status)
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
|
calls = query.all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"scheduled_calls": [call.serialize() for call in calls],
|
||||||
|
"count": len(calls),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Cancel a scheduled LLM call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduled_call_id: ID of the scheduled call to cancel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with cancellation status
|
||||||
|
"""
|
||||||
|
logger.info(f"cancel_scheduled_llm_call tool called for ID: {scheduled_call_id}")
|
||||||
|
|
||||||
|
current_user = get_current_user()
|
||||||
|
if not current_user["authenticated"]:
|
||||||
|
return {"error": "Not authenticated", "user": current_user}
|
||||||
|
user_id = current_user.get("user", {}).get("user_id")
|
||||||
|
if not user_id:
|
||||||
|
return {"error": "User not found", "user": current_user}
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Find the scheduled call
|
||||||
|
scheduled_call = (
|
||||||
|
session.query(ScheduledLLMCall)
|
||||||
|
.filter(
|
||||||
|
ScheduledLLMCall.id == scheduled_call_id,
|
||||||
|
ScheduledLLMCall.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not scheduled_call:
|
||||||
|
return {"error": "Scheduled call not found"}
|
||||||
|
|
||||||
|
if not scheduled_call.can_be_cancelled():
|
||||||
|
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
|
||||||
|
|
||||||
|
# Update the status
|
||||||
|
scheduled_call.status = "cancelled"
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Scheduled LLM call {scheduled_call_id} cancelled")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Scheduled call {scheduled_call_id} has been cancelled",
|
||||||
|
}
|
@ -5,18 +5,13 @@ MCP tools for the epistemic sparring partner system.
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from mcp.server.auth.middleware.auth_context import get_access_token
|
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import AgentObservation, SourceItem
|
||||||
AgentObservation,
|
from memory.api.MCP.base import mcp, get_current_user
|
||||||
SourceItem,
|
|
||||||
UserSession,
|
|
||||||
)
|
|
||||||
from memory.api.MCP.base import mcp
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -76,35 +71,4 @@ async def get_current_time() -> dict:
|
|||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def get_authenticated_user() -> dict:
|
async def get_authenticated_user() -> dict:
|
||||||
"""Get information about the authenticated user."""
|
"""Get information about the authenticated user."""
|
||||||
logger.info("🔧 get_authenticated_user tool called")
|
return get_current_user()
|
||||||
access_token = get_access_token()
|
|
||||||
logger.info(f"🔧 Access token from MCP context: {access_token}")
|
|
||||||
|
|
||||||
if not access_token:
|
|
||||||
logger.warning("❌ No access token found in MCP context!")
|
|
||||||
return {"error": "Not authenticated"}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"🔧 MCP context token details - scopes: {access_token.scopes}, client_id: {access_token.client_id}, token: {access_token.token[:20]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Look up the actual user from the session token
|
|
||||||
with make_session() as session:
|
|
||||||
user_session = (
|
|
||||||
session.query(UserSession)
|
|
||||||
.filter(UserSession.id == access_token.token)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_session and user_session.user:
|
|
||||||
user_info = user_session.user.serialize()
|
|
||||||
else:
|
|
||||||
user_info = {"error": "User not found"}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"authenticated": True,
|
|
||||||
"token_type": "Bearer",
|
|
||||||
"scopes": access_token.scopes,
|
|
||||||
"client_id": access_token.client_id,
|
|
||||||
"user": user_info,
|
|
||||||
}
|
|
||||||
|
@ -11,6 +11,7 @@ EBOOK_ROOT = "memory.workers.tasks.ebook"
|
|||||||
MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
|
MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
|
||||||
NOTES_ROOT = "memory.workers.tasks.notes"
|
NOTES_ROOT = "memory.workers.tasks.notes"
|
||||||
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
||||||
|
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
||||||
|
|
||||||
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
|
||||||
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
|
||||||
@ -44,6 +45,10 @@ SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
|
|||||||
ADD_ARTICLE_FEED = f"{BLOGS_ROOT}.add_article_feed"
|
ADD_ARTICLE_FEED = f"{BLOGS_ROOT}.add_article_feed"
|
||||||
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
||||||
|
|
||||||
|
# Scheduled calls tasks
|
||||||
|
EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call"
|
||||||
|
RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
|
||||||
|
|
||||||
|
|
||||||
def get_broker_url() -> str:
|
def get_broker_url() -> str:
|
||||||
protocol = settings.CELERY_BROKER_TYPE
|
protocol = settings.CELERY_BROKER_TYPE
|
||||||
@ -78,6 +83,9 @@ app.conf.update(
|
|||||||
},
|
},
|
||||||
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
f"{NOTES_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||||
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
f"{OBSERVATIONS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-notes"},
|
||||||
|
f"{SCHEDULED_CALLS_ROOT}.*": {
|
||||||
|
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -48,6 +48,9 @@ from memory.common.db.models.users import (
|
|||||||
OAuthState,
|
OAuthState,
|
||||||
OAuthRefreshToken,
|
OAuthRefreshToken,
|
||||||
)
|
)
|
||||||
|
from memory.common.db.models.scheduled_calls import (
|
||||||
|
ScheduledLLMCall,
|
||||||
|
)
|
||||||
|
|
||||||
Payload = (
|
Payload = (
|
||||||
SourceItemPayload
|
SourceItemPayload
|
||||||
@ -96,6 +99,8 @@ __all__ = [
|
|||||||
"OAuthClientInformation",
|
"OAuthClientInformation",
|
||||||
"OAuthState",
|
"OAuthState",
|
||||||
"OAuthRefreshToken",
|
"OAuthRefreshToken",
|
||||||
|
# Scheduled Calls
|
||||||
|
"ScheduledLLMCall",
|
||||||
# Payloads
|
# Payloads
|
||||||
"Payload",
|
"Payload",
|
||||||
]
|
]
|
||||||
|
92
src/memory/common/db/models/scheduled_calls.py
Normal file
92
src/memory/common/db/models/scheduled_calls.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, cast
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
JSON,
|
||||||
|
Text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledLLMCall(Base):
|
||||||
|
__tablename__ = "scheduled_llm_calls"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
|
topic = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Scheduling info
|
||||||
|
scheduled_time = Column(DateTime, nullable=False)
|
||||||
|
created_at = Column(DateTime, server_default=func.now())
|
||||||
|
executed_at = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
# LLM call configuration
|
||||||
|
model = Column(
|
||||||
|
String, nullable=True
|
||||||
|
) # e.g., "anthropic/claude-3-5-sonnet-20241022"
|
||||||
|
prompt = Column(Text, nullable=False)
|
||||||
|
system_prompt = Column(Text, nullable=True)
|
||||||
|
allowed_tools = Column(JSON, nullable=True) # List of allowed tool names
|
||||||
|
|
||||||
|
# Discord configuration
|
||||||
|
discord_channel = Column(String, nullable=True)
|
||||||
|
discord_user = Column(String, nullable=True)
|
||||||
|
|
||||||
|
# Execution status and results
|
||||||
|
status = Column(
|
||||||
|
String, nullable=False, default="pending"
|
||||||
|
) # pending, executing, completed, failed, cancelled
|
||||||
|
response = Column(Text, nullable=True) # LLM response content
|
||||||
|
error_message = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Additional metadata
|
||||||
|
data = Column(JSON, nullable=True) # For extensibility
|
||||||
|
|
||||||
|
# Celery task tracking
|
||||||
|
celery_task_id = Column(String, nullable=True) # Track the Celery Beat task
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
user = relationship("User")
|
||||||
|
|
||||||
|
def serialize(self) -> Dict[str, Any]:
|
||||||
|
def print_datetime(dt: datetime | None) -> str | None:
|
||||||
|
if dt:
|
||||||
|
return dt.isoformat()
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"topic": self.topic,
|
||||||
|
"scheduled_time": print_datetime(cast(datetime, self.scheduled_time)),
|
||||||
|
"created_at": print_datetime(cast(datetime, self.created_at)),
|
||||||
|
"executed_at": print_datetime(cast(datetime, self.executed_at)),
|
||||||
|
"model": self.model,
|
||||||
|
"prompt": self.prompt,
|
||||||
|
"system_prompt": self.system_prompt,
|
||||||
|
"allowed_tools": self.allowed_tools,
|
||||||
|
"discord_channel": self.discord_channel,
|
||||||
|
"discord_user": self.discord_user,
|
||||||
|
"status": self.status,
|
||||||
|
"response": self.response,
|
||||||
|
"error_message": self.error_message,
|
||||||
|
"metadata": self.data,
|
||||||
|
"celery_task_id": self.celery_task_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def is_pending(self) -> bool:
|
||||||
|
return cast(str, self.status) == "pending"
|
||||||
|
|
||||||
|
def is_completed(self) -> bool:
|
||||||
|
return cast(str, self.status) in ("completed", "failed", "cancelled")
|
||||||
|
|
||||||
|
def can_be_cancelled(self) -> bool:
|
||||||
|
return cast(str, self.status) in ("pending",)
|
@ -2,6 +2,8 @@ import hashlib
|
|||||||
import secrets
|
import secrets
|
||||||
from typing import cast
|
from typing import cast
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from memory.common.db.models.base import Base
|
from memory.common.db.models.base import Base
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column,
|
Column,
|
||||||
@ -39,6 +41,7 @@ class User(Base):
|
|||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
email = Column(String, nullable=False, unique=True)
|
email = Column(String, nullable=False, unique=True)
|
||||||
password_hash = Column(String, nullable=False)
|
password_hash = Column(String, nullable=False)
|
||||||
|
discord_user_id = Column(String, nullable=True)
|
||||||
|
|
||||||
# Relationship to sessions
|
# Relationship to sessions
|
||||||
sessions = relationship(
|
sessions = relationship(
|
||||||
@ -53,6 +56,7 @@ class User(Base):
|
|||||||
"user_id": self.id,
|
"user_id": self.id,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"email": self.email,
|
"email": self.email,
|
||||||
|
"discord_user_id": self.discord_user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
def is_valid_password(self, password: str) -> bool:
|
def is_valid_password(self, password: str) -> bool:
|
||||||
@ -193,3 +197,15 @@ class OAuthRefreshToken(Base, OAuthToken):
|
|||||||
"expires_at": self.expires_at.timestamp(),
|
"expires_at": self.expires_at.timestamp(),
|
||||||
"revoked": self.revoked,
|
"revoked": self.revoked,
|
||||||
} | super().serialize()
|
} | super().serialize()
|
||||||
|
|
||||||
|
|
||||||
|
def purge_oauth(session: Session):
|
||||||
|
for token in session.query(OAuthRefreshToken).all():
|
||||||
|
session.delete(token)
|
||||||
|
for user_session in session.query(UserSession).all():
|
||||||
|
session.delete(user_session)
|
||||||
|
|
||||||
|
for oauth_state in session.query(OAuthState).all():
|
||||||
|
session.delete(oauth_state)
|
||||||
|
for oauth_client in session.query(OAuthClientInformation).all():
|
||||||
|
session.delete(oauth_client)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
from typing import Any, Dict, List
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ class DiscordServer(requests.Session):
|
|||||||
self.channels = {}
|
self.channels = {}
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.setup_channels()
|
self.setup_channels()
|
||||||
|
self.members = self.fetch_all_members()
|
||||||
|
|
||||||
def setup_channels(self):
|
def setup_channels(self):
|
||||||
resp = self.get(self.channels_url)
|
resp = self.get(self.channels_url)
|
||||||
@ -63,9 +65,20 @@ class DiscordServer(requests.Session):
|
|||||||
return channel_id
|
return channel_id
|
||||||
|
|
||||||
def send_message(self, channel_id: str, content: str):
|
def send_message(self, channel_id: str, content: str):
|
||||||
self.post(
|
payload: dict[str, Any] = {"content": content}
|
||||||
|
mentions = re.findall(r"@(\S*)", content)
|
||||||
|
users = {u: i for u, i in self.members.items() if u in mentions}
|
||||||
|
if users:
|
||||||
|
for u, i in users.items():
|
||||||
|
payload["content"] = payload["content"].replace(f"@{u}", f"<@{i}>")
|
||||||
|
payload["allowed_mentions"] = {
|
||||||
|
"parse": [],
|
||||||
|
"users": list(users.values()),
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.post(
|
||||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
||||||
json={"content": content},
|
json=payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
|
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
|
||||||
@ -91,8 +104,64 @@ class DiscordServer(requests.Session):
|
|||||||
def channels_url(self) -> str:
|
def channels_url(self) -> str:
|
||||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
|
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def members_url(self) -> str:
|
||||||
|
return f"https://discord.com/api/v10/guilds/{self.server_id}/members"
|
||||||
|
|
||||||
def get_bot_servers() -> List[Dict]:
|
@property
|
||||||
|
def dm_create_url(self) -> str:
|
||||||
|
return "https://discord.com/api/v10/users/@me/channels"
|
||||||
|
|
||||||
|
def list_members(
|
||||||
|
self, limit: int = 1000, after: str | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""List up to `limit` members in this guild, starting after a user ID.
|
||||||
|
|
||||||
|
Requires the bot to have the Server Members Intent enabled in the Discord developer portal.
|
||||||
|
"""
|
||||||
|
params: dict[str, Any] = {"limit": limit}
|
||||||
|
if after:
|
||||||
|
params["after"] = after
|
||||||
|
resp = self.get(self.members_url, params=params)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
def fetch_all_members(self, page_size: int = 1000) -> dict[str, str]:
|
||||||
|
"""Retrieve all members in the guild by paginating the members list.
|
||||||
|
|
||||||
|
Note: Large guilds may take multiple requests. Rate limits are respected by requests.Session automatically.
|
||||||
|
"""
|
||||||
|
members: dict[str, str] = {}
|
||||||
|
after: str | None = None
|
||||||
|
while batch := self.list_members(limit=page_size, after=after):
|
||||||
|
for member in batch:
|
||||||
|
user = member.get("user", {})
|
||||||
|
members[user.get("global_name") or user.get("username", "")] = user.get(
|
||||||
|
"id", ""
|
||||||
|
)
|
||||||
|
after = user.get("id", "")
|
||||||
|
return members
|
||||||
|
|
||||||
|
def create_dm_channel(self, user_id: str) -> str:
|
||||||
|
"""Create (or retrieve) a DM channel with the given user and return the channel ID.
|
||||||
|
|
||||||
|
The bot must share a guild with the user, and the user's privacy settings must allow DMs from server members.
|
||||||
|
"""
|
||||||
|
resp = self.post(self.dm_create_url, json={"recipient_id": user_id})
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
return data["id"]
|
||||||
|
|
||||||
|
def send_dm(self, user_id: str, content: str):
|
||||||
|
"""Send a direct message to a specific user by ID."""
|
||||||
|
channel_id = self.create_dm_channel(self.members.get(user_id) or user_id)
|
||||||
|
return self.post(
|
||||||
|
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
||||||
|
json={"content": content},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bot_servers() -> list[dict[str, Any]]:
|
||||||
"""Get list of servers the bot is in."""
|
"""Get list of servers the bot is in."""
|
||||||
if not settings.DISCORD_BOT_TOKEN:
|
if not settings.DISCORD_BOT_TOKEN:
|
||||||
return []
|
return []
|
||||||
@ -141,6 +210,14 @@ def send_chat_message(message: str):
|
|||||||
broadcast_message(CHAT_CHANNEL, message)
|
broadcast_message(CHAT_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
|
def send_dm(user_id: str, message: str):
|
||||||
|
for server in servers.values():
|
||||||
|
if not server.members.get(user_id) and user_id not in server.members.values():
|
||||||
|
continue
|
||||||
|
|
||||||
|
server.send_dm(user_id, message)
|
||||||
|
|
||||||
|
|
||||||
def notify_task_failure(
|
def notify_task_failure(
|
||||||
task_name: str,
|
task_name: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
|
@ -107,6 +107,7 @@ CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 24 * 60 *
|
|||||||
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
|
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
|
||||||
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
|
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
|
||||||
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
|
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
|
||||||
|
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
|
||||||
|
|
||||||
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
||||||
|
|
||||||
@ -168,6 +169,6 @@ DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
|||||||
|
|
||||||
|
|
||||||
# Enable Discord notifications if bot token is set
|
# Enable Discord notifications if bot token is set
|
||||||
DISCORD_NOTIFICATIONS_ENABLED = (
|
DISCORD_NOTIFICATIONS_ENABLED = bool(
|
||||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
@ -9,6 +9,7 @@ from memory.common.celery_app import (
|
|||||||
SYNC_ALL_ARTICLE_FEEDS,
|
SYNC_ALL_ARTICLE_FEEDS,
|
||||||
TRACK_GIT_CHANGES,
|
TRACK_GIT_CHANGES,
|
||||||
SYNC_LESSWRONG,
|
SYNC_LESSWRONG,
|
||||||
|
RUN_SCHEDULED_CALLS,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -43,4 +44,8 @@ app.conf.beat_schedule = {
|
|||||||
"task": SYNC_LESSWRONG,
|
"task": SYNC_LESSWRONG,
|
||||||
"schedule": settings.LESSWRONG_SYNC_INTERVAL,
|
"schedule": settings.LESSWRONG_SYNC_INTERVAL,
|
||||||
},
|
},
|
||||||
|
"run-scheduled-calls": {
|
||||||
|
"task": RUN_SCHEDULED_CALLS,
|
||||||
|
"schedule": settings.SCHEDULED_CALL_RUN_INTERVAL,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ from memory.workers.tasks import (
|
|||||||
maintenance,
|
maintenance,
|
||||||
notes,
|
notes,
|
||||||
observations,
|
observations,
|
||||||
|
scheduled_calls,
|
||||||
) # noqa
|
) # noqa
|
||||||
|
|
||||||
|
|
||||||
@ -23,4 +24,5 @@ __all__ = [
|
|||||||
"maintenance",
|
"maintenance",
|
||||||
"notes",
|
"notes",
|
||||||
"observations",
|
"observations",
|
||||||
|
"scheduled_calls",
|
||||||
]
|
]
|
||||||
|
141
src/memory/workers/tasks/scheduled_calls.py
Normal file
141
src/memory/workers/tasks/scheduled_calls.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import textwrap
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import ScheduledLLMCall
|
||||||
|
from memory.common.celery_app import (
|
||||||
|
app,
|
||||||
|
EXECUTE_SCHEDULED_CALL,
|
||||||
|
RUN_SCHEDULED_CALLS,
|
||||||
|
)
|
||||||
|
from memory.common import llms, discord
|
||||||
|
from memory.workers.tasks.content_processing import safe_task_execution
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
||||||
|
"""
|
||||||
|
Send the LLM response to the specified Discord user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduled_call: The scheduled call object
|
||||||
|
response: The LLM response to send
|
||||||
|
"""
|
||||||
|
message = response
|
||||||
|
if cast(str, scheduled_call.topic):
|
||||||
|
message = f"**{scheduled_call.topic}**\n\n{message}"
|
||||||
|
|
||||||
|
# Discord has a 2000 character limit, so we may need to split the message
|
||||||
|
if len(message) > 1900: # Leave some buffer
|
||||||
|
message = message[:1900] + "\n\n... (response truncated)"
|
||||||
|
|
||||||
|
if discord_user := cast(str, scheduled_call.discord_user):
|
||||||
|
logger.info(f"Sending DM to {discord_user}: {message}")
|
||||||
|
discord.send_dm(discord_user, message)
|
||||||
|
elif discord_channel := cast(str, scheduled_call.discord_channel):
|
||||||
|
logger.info(f"Broadcasting message to {discord_channel}: {message}")
|
||||||
|
discord.broadcast_message(discord_channel, message)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True, name=EXECUTE_SCHEDULED_CALL)
|
||||||
|
@safe_task_execution
|
||||||
|
def execute_scheduled_call(self, scheduled_call_id: str):
|
||||||
|
"""
|
||||||
|
Execute a scheduled LLM call and send the response to Discord.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduled_call_id: The ID of the scheduled call to execute
|
||||||
|
"""
|
||||||
|
logger.info(f"Executing scheduled LLM call: {scheduled_call_id}")
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
# Fetch the scheduled call
|
||||||
|
scheduled_call = (
|
||||||
|
session.query(ScheduledLLMCall)
|
||||||
|
.filter(ScheduledLLMCall.id == scheduled_call_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not scheduled_call:
|
||||||
|
logger.error(f"Scheduled call {scheduled_call_id} not found")
|
||||||
|
return {"error": "Scheduled call not found"}
|
||||||
|
|
||||||
|
# Check if the call is still pending
|
||||||
|
if not scheduled_call.is_pending():
|
||||||
|
logger.warning(
|
||||||
|
f"Scheduled call {scheduled_call_id} is not pending (status: {scheduled_call.status})"
|
||||||
|
)
|
||||||
|
return {"error": f"Call is not pending (status: {scheduled_call.status})"}
|
||||||
|
|
||||||
|
# Update status to executing
|
||||||
|
scheduled_call.status = "executing"
|
||||||
|
scheduled_call.executed_at = datetime.now(timezone.utc)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Calling LLM with model {scheduled_call.model}")
|
||||||
|
|
||||||
|
# Make the LLM call
|
||||||
|
if scheduled_call.model:
|
||||||
|
response = llms.call(
|
||||||
|
prompt=cast(str, scheduled_call.prompt),
|
||||||
|
model=cast(str, scheduled_call.model),
|
||||||
|
system_prompt=cast(str, scheduled_call.system_prompt)
|
||||||
|
or llms.SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = cast(str, scheduled_call.prompt)
|
||||||
|
|
||||||
|
# Store the response
|
||||||
|
scheduled_call.response = response
|
||||||
|
scheduled_call.status = "completed"
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.info(f"LLM call completed for {scheduled_call_id}")
|
||||||
|
|
||||||
|
# Send to Discord
|
||||||
|
try:
|
||||||
|
_send_to_discord(scheduled_call, response)
|
||||||
|
logger.info(f"Response sent to Discord for {scheduled_call_id}")
|
||||||
|
except Exception as discord_error:
|
||||||
|
logger.error(f"Failed to send to Discord: {discord_error}")
|
||||||
|
# Don't mark as failed since the LLM call succeeded
|
||||||
|
scheduled_call.data = scheduled_call.data or {}
|
||||||
|
scheduled_call.data["discord_error"] = str(discord_error)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"scheduled_call_id": scheduled_call_id,
|
||||||
|
"response": response[:100] + "..." if len(response) > 100 else response,
|
||||||
|
"discord_sent": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=RUN_SCHEDULED_CALLS)
|
||||||
|
@safe_task_execution
|
||||||
|
def run_scheduled_calls():
|
||||||
|
"""Run scheduled calls that are due."""
|
||||||
|
with make_session() as session:
|
||||||
|
calls = (
|
||||||
|
session.query(ScheduledLLMCall)
|
||||||
|
.filter(
|
||||||
|
ScheduledLLMCall.status.in_(["pending"]),
|
||||||
|
ScheduledLLMCall.scheduled_time
|
||||||
|
< datetime.now(timezone.utc).replace(tzinfo=None),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
for call in calls:
|
||||||
|
execute_scheduled_call.delay(call.id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"calls": [call.id for call in calls],
|
||||||
|
"count": len(calls),
|
||||||
|
}
|
@ -1,8 +1,6 @@
|
|||||||
import logging
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
import requests
|
import requests
|
||||||
import json
|
|
||||||
|
|
||||||
from memory.common import discord, settings
|
from memory.common import discord, settings
|
||||||
|
|
||||||
|
585
tests/memory/workers/tasks/test_scheduled_calls.py
Normal file
585
tests/memory/workers/tasks/test_scheduled_calls.py
Normal file
@ -0,0 +1,585 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from memory.common.db.models import ScheduledLLMCall, User
|
||||||
|
from memory.workers.tasks import scheduled_calls
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_user(db_session):
|
||||||
|
"""Create a sample user for testing."""
|
||||||
|
user = User(
|
||||||
|
name="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
discord_user_id="123456789",
|
||||||
|
password_hash="password",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pending_scheduled_call(db_session, sample_user):
|
||||||
|
"""Create a pending scheduled call for testing."""
|
||||||
|
call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
topic="Test Topic",
|
||||||
|
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||||
|
model="anthropic/claude-3-5-sonnet-20241022",
|
||||||
|
prompt="What is the weather like today?",
|
||||||
|
system_prompt="You are a helpful assistant.",
|
||||||
|
discord_user="123456789",
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
db_session.add(call)
|
||||||
|
db_session.commit()
|
||||||
|
return call
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def completed_scheduled_call(db_session, sample_user):
|
||||||
|
"""Create a completed scheduled call for testing."""
|
||||||
|
call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
topic="Completed Topic",
|
||||||
|
scheduled_time=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||||
|
executed_at=datetime.now(timezone.utc) - timedelta(minutes=30),
|
||||||
|
model="anthropic/claude-3-5-sonnet-20241022",
|
||||||
|
prompt="Tell me a joke.",
|
||||||
|
system_prompt="You are a funny assistant.",
|
||||||
|
discord_channel="987654321",
|
||||||
|
status="completed",
|
||||||
|
response="Why did the chicken cross the road? To get to the other side!",
|
||||||
|
)
|
||||||
|
db_session.add(call)
|
||||||
|
db_session.commit()
|
||||||
|
return call
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def future_scheduled_call(db_session, sample_user):
|
||||||
|
"""Create a future scheduled call for testing."""
|
||||||
|
call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
topic="Future Topic",
|
||||||
|
scheduled_time=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||||
|
model="anthropic/claude-3-5-sonnet-20241022",
|
||||||
|
prompt="What will happen tomorrow?",
|
||||||
|
discord_user="123456789",
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
db_session.add(call)
|
||||||
|
db_session.commit()
|
||||||
|
return call
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||||
|
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
||||||
|
"""Test sending to Discord user."""
|
||||||
|
response = "This is a test response."
|
||||||
|
|
||||||
|
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
||||||
|
|
||||||
|
mock_send_dm.assert_called_once_with(
|
||||||
|
"123456789",
|
||||||
|
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
||||||
|
def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
||||||
|
"""Test sending to Discord channel."""
|
||||||
|
response = "This is a channel response."
|
||||||
|
|
||||||
|
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
||||||
|
|
||||||
|
mock_broadcast.assert_called_once_with(
|
||||||
|
"987654321",
|
||||||
|
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||||
|
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
|
||||||
|
"""Test message truncation for long responses."""
|
||||||
|
long_response = "A" * 2500 # Very long response
|
||||||
|
|
||||||
|
scheduled_calls._send_to_discord(pending_scheduled_call, long_response)
|
||||||
|
|
||||||
|
# Verify the message was truncated
|
||||||
|
args, kwargs = mock_send_dm.call_args
|
||||||
|
message = args[1]
|
||||||
|
assert len(message) <= 1950 # Should be truncated
|
||||||
|
assert message.endswith("... (response truncated)")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||||
|
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
|
||||||
|
"""Test that normal length messages are not truncated."""
|
||||||
|
normal_response = "This is a normal length response."
|
||||||
|
|
||||||
|
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
||||||
|
|
||||||
|
args, kwargs = mock_send_dm.call_args
|
||||||
|
message = args[1]
|
||||||
|
assert not message.endswith("... (response truncated)")
|
||||||
|
assert "This is a normal length response." in message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||||
|
def test_execute_scheduled_call_success(
|
||||||
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
|
):
|
||||||
|
"""Test successful execution of a scheduled LLM call."""
|
||||||
|
mock_llm_call.return_value = "The weather is sunny today!"
|
||||||
|
|
||||||
|
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||||
|
|
||||||
|
# Verify LLM was called with correct parameters
|
||||||
|
mock_llm_call.assert_called_once_with(
|
||||||
|
prompt="What is the weather like today?",
|
||||||
|
model="anthropic/claude-3-5-sonnet-20241022",
|
||||||
|
system_prompt="You are a helpful assistant.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["scheduled_call_id"] == pending_scheduled_call.id
|
||||||
|
assert result["response"] == "The weather is sunny today!"
|
||||||
|
assert result["discord_sent"] is True
|
||||||
|
|
||||||
|
# Verify database was updated
|
||||||
|
db_session.refresh(pending_scheduled_call)
|
||||||
|
assert pending_scheduled_call.status == "completed"
|
||||||
|
assert pending_scheduled_call.response == "The weather is sunny today!"
|
||||||
|
assert pending_scheduled_call.executed_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_scheduled_call_not_found(db_session):
|
||||||
|
"""Test execution with non-existent call ID."""
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
result = scheduled_calls.execute_scheduled_call(fake_id)
|
||||||
|
|
||||||
|
assert result == {"error": "Scheduled call not found"}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||||
|
def test_execute_scheduled_call_not_pending(
|
||||||
|
mock_llm_call, completed_scheduled_call, db_session
|
||||||
|
):
|
||||||
|
"""Test execution of a call that is not pending."""
|
||||||
|
result = scheduled_calls.execute_scheduled_call(completed_scheduled_call.id)
|
||||||
|
|
||||||
|
assert result == {"error": "Call is not pending (status: completed)"}
|
||||||
|
mock_llm_call.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||||
|
def test_execute_scheduled_call_with_default_system_prompt(
|
||||||
|
mock_llm_call, mock_send_discord, db_session, sample_user
|
||||||
|
):
|
||||||
|
"""Test execution when system_prompt is None, should use default."""
|
||||||
|
# Create call without system prompt
|
||||||
|
call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
topic="No System Prompt",
|
||||||
|
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||||
|
model="anthropic/claude-3-5-sonnet-20241022",
|
||||||
|
prompt="Test prompt",
|
||||||
|
system_prompt=None,
|
||||||
|
discord_user="123456789",
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
db_session.add(call)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
mock_llm_call.return_value = "Response"
|
||||||
|
|
||||||
|
scheduled_calls.execute_scheduled_call(call.id)
|
||||||
|
|
||||||
|
# Verify default system prompt was used
|
||||||
|
mock_llm_call.assert_called_once_with(
|
||||||
|
prompt="Test prompt",
|
||||||
|
model="anthropic/claude-3-5-sonnet-20241022",
|
||||||
|
system_prompt=scheduled_calls.llms.SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||||
|
def test_execute_scheduled_call_discord_error(
|
||||||
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
|
):
|
||||||
|
"""Test execution when Discord sending fails."""
|
||||||
|
mock_llm_call.return_value = "LLM response"
|
||||||
|
mock_send_discord.side_effect = Exception("Discord API error")
|
||||||
|
|
||||||
|
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||||
|
|
||||||
|
# Should still return success since LLM call succeeded
|
||||||
|
assert result["success"] is True
|
||||||
|
assert (
|
||||||
|
result["discord_sent"] is True
|
||||||
|
) # This is always True in current implementation
|
||||||
|
|
||||||
|
# Verify the call was marked as completed despite Discord error
|
||||||
|
db_session.refresh(pending_scheduled_call)
|
||||||
|
assert pending_scheduled_call.status == "completed"
|
||||||
|
assert pending_scheduled_call.response == "LLM response"
|
||||||
|
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||||
|
def test_execute_scheduled_call_llm_error(
|
||||||
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
|
):
|
||||||
|
"""Test execution when LLM call fails."""
|
||||||
|
mock_llm_call.side_effect = Exception("LLM API error")
|
||||||
|
|
||||||
|
# The safe_task_execution decorator should catch this
|
||||||
|
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "LLM API error" in result["error"]
|
||||||
|
|
||||||
|
# Discord should not be called
|
||||||
|
mock_send_discord.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||||
|
def test_execute_scheduled_call_long_response_truncation(
|
||||||
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
|
):
|
||||||
|
"""Test that long responses are truncated in the result."""
|
||||||
|
long_response = "A" * 500 # Long response
|
||||||
|
mock_llm_call.return_value = long_response
|
||||||
|
|
||||||
|
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||||
|
|
||||||
|
# Response in result should be truncated
|
||||||
|
assert len(result["response"]) <= 103 # 100 chars + "..."
|
||||||
|
assert result["response"].endswith("...")
|
||||||
|
|
||||||
|
# But full response should be stored in database
|
||||||
|
db_session.refresh(pending_scheduled_call)
|
||||||
|
assert pending_scheduled_call.response == long_response
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||||
|
def test_run_scheduled_calls_with_due_calls(
|
||||||
|
mock_execute_delay, db_session, sample_user
|
||||||
|
):
|
||||||
|
"""Test running scheduled calls with due calls."""
|
||||||
|
# Create multiple due calls
|
||||||
|
due_call1 = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=10),
|
||||||
|
model="test-model",
|
||||||
|
prompt="Test 1",
|
||||||
|
discord_user="123",
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
due_call2 = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||||
|
model="test-model",
|
||||||
|
prompt="Test 2",
|
||||||
|
discord_user="123",
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all([due_call1, due_call2])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.id = "task-123"
|
||||||
|
mock_execute_delay.delay.return_value = mock_task
|
||||||
|
|
||||||
|
result = scheduled_calls.run_scheduled_calls()
|
||||||
|
|
||||||
|
assert result["count"] == 2
|
||||||
|
assert due_call1.id in result["calls"]
|
||||||
|
assert due_call2.id in result["calls"]
|
||||||
|
|
||||||
|
# Verify execute_scheduled_call.delay was called for both
|
||||||
|
assert mock_execute_delay.delay.call_count == 2
|
||||||
|
mock_execute_delay.delay.assert_any_call(due_call1.id)
|
||||||
|
mock_execute_delay.delay.assert_any_call(due_call2.id)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||||
|
def test_run_scheduled_calls_no_due_calls(
|
||||||
|
mock_execute_delay, future_scheduled_call, db_session
|
||||||
|
):
|
||||||
|
"""Test running scheduled calls when no calls are due."""
|
||||||
|
result = scheduled_calls.run_scheduled_calls()
|
||||||
|
|
||||||
|
assert result["count"] == 0
|
||||||
|
assert result["calls"] == []
|
||||||
|
|
||||||
|
# No tasks should be scheduled
|
||||||
|
mock_execute_delay.delay.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||||
|
def test_run_scheduled_calls_mixed_statuses(
|
||||||
|
mock_execute_delay, db_session, sample_user
|
||||||
|
):
|
||||||
|
"""Test that only pending calls are processed."""
|
||||||
|
# Create calls with different statuses
|
||||||
|
pending_call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||||
|
model="test-model",
|
||||||
|
prompt="Pending",
|
||||||
|
discord_user="123",
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
executing_call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||||
|
model="test-model",
|
||||||
|
prompt="Executing",
|
||||||
|
discord_user="123",
|
||||||
|
status="executing",
|
||||||
|
)
|
||||||
|
completed_call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||||
|
model="test-model",
|
||||||
|
prompt="Completed",
|
||||||
|
discord_user="123",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all([pending_call, executing_call, completed_call])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.id = "task-123"
|
||||||
|
mock_execute_delay.delay.return_value = mock_task
|
||||||
|
|
||||||
|
result = scheduled_calls.run_scheduled_calls()
|
||||||
|
|
||||||
|
# Only the pending call should be processed
|
||||||
|
assert result["count"] == 1
|
||||||
|
assert result["calls"] == [pending_call.id]
|
||||||
|
|
||||||
|
mock_execute_delay.delay.assert_called_once_with(pending_call.id)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.execute_scheduled_call")
|
||||||
|
def test_run_scheduled_calls_timezone_handling(
|
||||||
|
mock_execute_delay, db_session, sample_user
|
||||||
|
):
|
||||||
|
"""Test that timezone handling works correctly."""
|
||||||
|
# Create a call that's due (scheduled time in the past)
|
||||||
|
past_time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||||
|
due_call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
scheduled_time=past_time.replace(tzinfo=None), # Store as naive datetime
|
||||||
|
model="test-model",
|
||||||
|
prompt="Due call",
|
||||||
|
discord_user="123",
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a call that's not due yet
|
||||||
|
future_time = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||||
|
future_call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
scheduled_time=future_time.replace(tzinfo=None), # Store as naive datetime
|
||||||
|
model="test-model",
|
||||||
|
prompt="Future call",
|
||||||
|
discord_user="123",
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all([due_call, future_call])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.id = "task-123"
|
||||||
|
mock_execute_delay.delay.return_value = mock_task
|
||||||
|
|
||||||
|
result = scheduled_calls.run_scheduled_calls()
|
||||||
|
|
||||||
|
# Only the due call should be processed
|
||||||
|
assert result["count"] == 1
|
||||||
|
assert result["calls"] == [due_call.id]
|
||||||
|
|
||||||
|
mock_execute_delay.delay.assert_called_once_with(due_call.id)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||||
|
def test_status_transition_pending_to_executing_to_completed(
|
||||||
|
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||||
|
):
|
||||||
|
"""Test that status transitions correctly during execution."""
|
||||||
|
mock_llm_call.return_value = "Response"
|
||||||
|
|
||||||
|
# Initial status should be pending
|
||||||
|
assert pending_scheduled_call.status == "pending"
|
||||||
|
assert pending_scheduled_call.executed_at is None
|
||||||
|
|
||||||
|
scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||||
|
|
||||||
|
# Final status should be completed
|
||||||
|
db_session.refresh(pending_scheduled_call)
|
||||||
|
assert pending_scheduled_call.status == "completed"
|
||||||
|
assert pending_scheduled_call.executed_at is not None
|
||||||
|
assert pending_scheduled_call.response == "Response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"discord_user,discord_channel,expected_method",
|
||||||
|
[
|
||||||
|
("123456789", None, "send_dm"),
|
||||||
|
(None, "987654321", "broadcast_message"),
|
||||||
|
("123456789", "987654321", "send_dm"), # User takes precedence
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
||||||
|
def test_discord_destination_priority(
|
||||||
|
mock_broadcast,
|
||||||
|
mock_send_dm,
|
||||||
|
discord_user,
|
||||||
|
discord_channel,
|
||||||
|
expected_method,
|
||||||
|
db_session,
|
||||||
|
sample_user,
|
||||||
|
):
|
||||||
|
"""Test that Discord user takes precedence over channel."""
|
||||||
|
call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
topic="Priority Test",
|
||||||
|
scheduled_time=datetime.now(timezone.utc),
|
||||||
|
model="test-model",
|
||||||
|
prompt="Test",
|
||||||
|
discord_user=discord_user,
|
||||||
|
discord_channel=discord_channel,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
db_session.add(call)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = "Test response"
|
||||||
|
scheduled_calls._send_to_discord(call, response)
|
||||||
|
|
||||||
|
if expected_method == "send_dm":
|
||||||
|
mock_send_dm.assert_called_once()
|
||||||
|
mock_broadcast.assert_not_called()
|
||||||
|
else:
|
||||||
|
mock_broadcast.assert_called_once()
|
||||||
|
mock_send_dm.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"topic,model,response,expected_in_message",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"Weather Check",
|
||||||
|
"anthropic/claude-3-5-sonnet-20241022",
|
||||||
|
"It's sunny!",
|
||||||
|
[
|
||||||
|
"**Topic:** Weather Check",
|
||||||
|
"**Model:** anthropic/claude-3-5-sonnet-20241022",
|
||||||
|
"**Response:** It's sunny!",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Test Topic",
|
||||||
|
"gpt-4",
|
||||||
|
"Hello world",
|
||||||
|
["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Long Topic Name Here",
|
||||||
|
"claude-2",
|
||||||
|
"Short",
|
||||||
|
[
|
||||||
|
"**Topic:** Long Topic Name Here",
|
||||||
|
"**Model:** claude-2",
|
||||||
|
"**Response:** Short",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||||
|
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
|
||||||
|
"""Test the Discord message formatting with different inputs."""
|
||||||
|
# Create a mock scheduled call
|
||||||
|
mock_call = Mock()
|
||||||
|
mock_call.topic = topic
|
||||||
|
mock_call.model = model
|
||||||
|
mock_call.discord_user = "123456789"
|
||||||
|
|
||||||
|
scheduled_calls._send_to_discord(mock_call, response)
|
||||||
|
|
||||||
|
# Get the actual message that was sent
|
||||||
|
args, kwargs = mock_send_dm.call_args
|
||||||
|
actual_message = args[1]
|
||||||
|
|
||||||
|
# Verify all expected parts are in the message
|
||||||
|
for expected_part in expected_in_message:
|
||||||
|
assert expected_part in actual_message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status,should_execute",
|
||||||
|
[
|
||||||
|
("pending", True),
|
||||||
|
("executing", False),
|
||||||
|
("completed", False),
|
||||||
|
("failed", False),
|
||||||
|
("cancelled", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.workers.tasks.scheduled_calls.llms.call")
|
||||||
|
def test_execute_scheduled_call_status_check(
|
||||||
|
mock_llm_call, status, should_execute, db_session, sample_user
|
||||||
|
):
|
||||||
|
"""Test that only pending calls are executed."""
|
||||||
|
call = ScheduledLLMCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=sample_user.id,
|
||||||
|
topic="Status Test",
|
||||||
|
scheduled_time=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||||
|
model="test-model",
|
||||||
|
prompt="Test",
|
||||||
|
discord_user="123",
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
db_session.add(call)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
result = scheduled_calls.execute_scheduled_call(call.id)
|
||||||
|
|
||||||
|
if should_execute:
|
||||||
|
mock_llm_call.assert_called_once()
|
||||||
|
# We don't check the full result here since it depends on mocking more functions
|
||||||
|
else:
|
||||||
|
assert result == {"error": f"Call is not pending (status: {status})"}
|
||||||
|
mock_llm_call.assert_not_called()
|
Loading…
x
Reference in New Issue
Block a user