proactive stuff

This commit is contained in:
mruwnik 2025-12-29 14:07:12 +00:00
parent 47180e1e71
commit f042f9aed8
37 changed files with 2004 additions and 725 deletions

View File

@ -0,0 +1,45 @@
"""Add proactive check-in fields to Discord entities
Revision ID: e1f2a3b4c5d6
Revises: d0e1f2a3b4c5
Create Date: 2025-12-24 16:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "e1f2a3b4c5d6"
down_revision: Union[str, None] = "d0e1f2a3b4c5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add proactive fields to all MessageProcessor tables
for table in ["discord_servers", "discord_channels", "discord_users"]:
op.add_column(
table,
sa.Column("proactive_cron", sa.Text(), nullable=True),
)
op.add_column(
table,
sa.Column("proactive_prompt", sa.Text(), nullable=True),
)
op.add_column(
table,
sa.Column(
"last_proactive_at", sa.DateTime(timezone=True), nullable=True
),
)
def downgrade() -> None:
for table in ["discord_servers", "discord_channels", "discord_users"]:
op.drop_column(table, "last_proactive_at")
op.drop_column(table, "proactive_prompt")
op.drop_column(table, "proactive_cron")

View File

@ -1,7 +1,8 @@
fastapi==0.112.2
fastapi>=0.115.12
uvicorn==0.29.0
python-jose==3.3.0
python-multipart==0.0.9
sqladmin==0.20.1
mcp==1.10.0
fastmcp>=2.10.0
slowapi==0.1.9

View File

@ -1,14 +1,15 @@
sqlalchemy==2.0.30
psycopg2-binary==2.9.9
pydantic==2.7.2
pydantic>=2.11.7
alembic==1.13.1
dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
anthropic==0.69.0
openai==2.3.0
# Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
httpx>=0.28.1
celery[redis,sqs]==5.3.6
croniter==2.0.1
cryptography==43.0.0
bcrypt==4.1.2

View File

@ -1,7 +1,9 @@
pytest==7.4.4
pytest-cov==4.1.0
pytest-asyncio==0.23.0
black==23.12.1
mypy==1.8.0
isort==5.13.2
isort==5.13.2
testcontainers[qdrant]==4.10.0
click==8.1.7
click==8.1.7
croniter==2.0.1

View File

@ -1,8 +1,16 @@
import memory.api.MCP.tools
import memory.api.MCP.memory
import memory.api.MCP.metadata
import memory.api.MCP.schedules
import memory.api.MCP.books
import memory.api.MCP.manifest
import memory.api.MCP.github
import memory.api.MCP.people
"""
MCP server with composed subservers.
Subservers are mounted with prefixes:
- core: search_knowledge_base, observe, search_observations, create_note, note_files, fetch_file
- github: list_github_issues, search_github_issues, github_issue_details, github_work_summary, github_repo_overview
- people: add_person, update_person_info, get_person, list_people, delete_person
- schedule: schedule_message, list_scheduled_llm_calls, cancel_scheduled_llm_call
- books: all_books, read_book
- meta: get_metadata_schemas, get_all_tags, get_all_subjects, get_all_observation_types, get_current_time, get_authenticated_user, get_forecasts
"""
# Import base to trigger subserver mounting
from memory.api.MCP.base import mcp, get_current_user
__all__ = ["mcp", "get_current_user"]

View File

@ -1,60 +1,28 @@
import logging
import pathlib
from typing import cast
from mcp.server.auth.handlers.authorize import AuthorizationRequest
from mcp.server.auth.handlers.token import (
AuthorizationCodeRequest,
RefreshTokenRequest,
TokenRequest,
)
from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
from mcp.server.fastmcp import FastMCP
from mcp.shared.auth import OAuthClientMetadata
from pydantic import AnyHttpUrl
from fastmcp import FastMCP
from fastmcp.server.dependencies import get_access_token
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.templating import Jinja2Templates
from memory.api.MCP.oauth_provider import (
ALLOWED_SCOPES,
BASE_SCOPES,
SimpleOAuthProvider,
)
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
from memory.api.MCP.servers.books import books_mcp
from memory.api.MCP.servers.core import core_mcp
from memory.api.MCP.servers.github import github_mcp
from memory.api.MCP.servers.meta import meta_mcp
from memory.api.MCP.servers.meta import set_auth_provider as set_meta_auth
from memory.api.MCP.servers.people import people_mcp
from memory.api.MCP.servers.schedule import schedule_mcp
from memory.api.MCP.servers.schedule import set_auth_provider as set_schedule_auth
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.connection import make_session, get_engine
from memory.common.db.models import OAuthState, UserSession
from memory.common.db.models.users import HumanUser
logger = logging.getLogger(__name__)
def validate_metadata(klass: type):
orig_validate = klass.model_validate
def validate(data: dict):
data = dict(data)
if "redirect_uris" in data:
data["redirect_uris"] = [
str(uri).replace("cursor://", "http://")
for uri in data["redirect_uris"]
]
if "redirect_uri" in data:
data["redirect_uri"] = str(data["redirect_uri"]).replace(
"cursor://", "http://"
)
return orig_validate(data)
klass.model_validate = validate
validate_metadata(OAuthClientMetadata)
validate_metadata(AuthorizationRequest)
validate_metadata(AuthorizationCodeRequest)
validate_metadata(RefreshTokenRequest)
validate_metadata(TokenRequest)
engine = get_engine()
# Setup templates
@ -63,22 +31,10 @@ templates = Jinja2Templates(directory=template_dir)
oauth_provider = SimpleOAuthProvider()
auth_settings = AuthSettings(
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore
client_registration_options=ClientRegistrationOptions(
enabled=True,
valid_scopes=ALLOWED_SCOPES,
default_scopes=BASE_SCOPES,
),
required_scopes=BASE_SCOPES,
)
mcp = FastMCP(
"memory",
stateless_http=True,
auth_server_provider=oauth_provider,
auth=auth_settings,
auth=oauth_provider,
)
@ -162,3 +118,52 @@ def get_current_user() -> dict:
"client_id": access_token.client_id,
"user": user_info,
}
@mcp.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
"""Health check endpoint that verifies all dependencies are accessible."""
from sqlalchemy import text
checks = {"mcp_oauth": "enabled"}
all_healthy = True
# Check database connection
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
checks["database"] = "healthy"
except Exception as e:
logger.error(f"Database health check failed: {e}")
checks["database"] = "unhealthy"
all_healthy = False
# Check Qdrant connection
try:
from memory.common.qdrant import get_qdrant_client
client = get_qdrant_client()
client.get_collections()
checks["qdrant"] = "healthy"
except Exception as e:
logger.error(f"Qdrant health check failed: {e}")
checks["qdrant"] = "unhealthy"
all_healthy = False
checks["status"] = "healthy" if all_healthy else "degraded"
status_code = 200 if all_healthy else 503
return JSONResponse(checks, status_code=status_code)
# Inject auth provider into subservers that need it
set_schedule_auth(get_current_user)
set_meta_auth(get_current_user)
# Mount all subservers onto the main MCP server
# Tools will be prefixed with their server name (e.g., core_search_knowledge_base)
mcp.mount(core_mcp, prefix="core")
mcp.mount(github_mcp, prefix="github")
mcp.mount(people_mcp, prefix="people")
mcp.mount(schedule_mcp, prefix="schedule")
mcp.mount(books_mcp, prefix="books")
mcp.mount(meta_mcp, prefix="meta")

View File

@ -1,119 +0,0 @@
import asyncio
import aiohttp
from datetime import datetime
from typing import TypedDict, NotRequired, Literal
from memory.api.MCP.tools import mcp
class BinaryProbs(TypedDict):
prob: float
class MultiProbs(TypedDict):
answerProbs: dict[str, float]
Probs = dict[str, BinaryProbs | MultiProbs]
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
class MarketAnswer(TypedDict):
id: str
text: str
resolutionProbability: float
class MarketDetails(TypedDict):
id: str
createdTime: int
question: str
outcomeType: OutcomeType
textDescription: str
groupSlugs: list[str]
volume: float
isResolved: bool
answers: list[MarketAnswer]
class Market(TypedDict):
id: str
url: str
question: str
volume: int
createdTime: int
outcomeType: OutcomeType
createdAt: NotRequired[str]
description: NotRequired[str]
answers: NotRequired[dict[str, float]]
probability: NotRequired[float]
details: NotRequired[MarketDetails]
async def get_details(session: aiohttp.ClientSession, market_id: str):
async with session.get(
f"https://api.manifold.markets/v0/market/{market_id}"
) as resp:
resp.raise_for_status()
return await resp.json()
async def format_market(session: aiohttp.ClientSession, market: Market):
if market.get("outcomeType") != "BINARY":
details = await get_details(session, market["id"])
market["answers"] = {
answer["text"]: round(
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
)
for answer in details["answers"]
}
if creationTime := market.get("createdTime"):
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
fields = [
"id",
"name",
"url",
"question",
"volume",
"createdAt",
"details",
"probability",
"answers",
]
return {k: v for k, v in market.items() if k in fields}
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.manifold.markets/v0/search-markets",
params={
"term": term,
"contractType": "BINARY" if binary else "ALL",
},
) as resp:
resp.raise_for_status()
markets = await resp.json()
return await asyncio.gather(
*[
format_market(session, market)
for market in markets
if market.get("volume", 0) >= min_volume
]
)
@mcp.tool()
async def get_forecasts(
term: str, min_volume: int = 1000, binary: bool = False
) -> list[dict]:
"""Get prediction market forecasts for a given term.
Args:
term: The term to search for.
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
binary: Whether to only return binary markets.
"""
return await search_markets(term, min_volume, binary)

View File

@ -1,119 +0,0 @@
import logging
from collections import defaultdict
from typing import Annotated, TypedDict, get_args, get_type_hints
from memory.common import qdrant
from sqlalchemy import func
from memory.api.MCP.tools import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import SourceItem
from memory.common.db.models.source_items import AgentObservation
logger = logging.getLogger(__name__)
class SchemaArg(TypedDict):
type: str | None
description: str | None
class CollectionMetadata(TypedDict):
schema: dict[str, SchemaArg]
size: int
def from_annotation(annotation: Annotated) -> SchemaArg | None:
try:
type_, description = get_args(annotation)
type_str = str(type_)
if type_str.startswith("typing."):
type_str = type_str[7:]
elif len((parts := type_str.split("'"))) > 1:
type_str = parts[1]
return SchemaArg(type=type_str, description=description)
except IndexError:
logger.error(f"Error from annotation: {annotation}")
return None
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
if not hasattr(klass, "as_payload"):
return {}
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
return {}
return {
name: schema
for name, arg in payload_type.__annotations__.items()
if (schema := from_annotation(arg))
}
@mcp.tool()
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
"""Get the metadata schema for each collection used in the knowledge base.
These schemas can be used to filter the knowledge base.
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
Example:
```
{
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
}
"""
client = qdrant.get_qdrant_client()
sizes = qdrant.get_collection_sizes(client)
schemas = defaultdict(dict)
for klass in SourceItem.__subclasses__():
for collection in klass.get_collections():
schemas[collection].update(get_schema(klass))
return {
collection: CollectionMetadata(schema=schema, size=size)
for collection, schema in schemas.items()
if (size := sizes.get(collection))
}
@mcp.tool()
async def get_all_tags() -> list[str]:
"""Get all unique tags used across the entire knowledge base.
Returns sorted list of tags from both observations and content.
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@mcp.tool()
async def get_all_subjects() -> list[str]:
"""Get all unique subjects from observations about the user.
Returns sorted list of subject identifiers used in observations.
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@mcp.tool()
async def get_all_observation_types() -> list[str]:
"""Get all observation types that have been used.
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)

View File

@ -5,11 +5,12 @@ from datetime import datetime, timezone
from typing import Any, Optional, cast
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from fastmcp.server.auth import OAuthProvider
from fastmcp.server.auth.auth import AccessToken as FastMCPAccessToken
from mcp.server.auth.provider import (
AccessToken,
AuthorizationCode,
AuthorizationParams,
OAuthAuthorizationServerProvider,
RefreshToken,
)
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
@ -133,8 +134,45 @@ def make_token(
)
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
class SimpleOAuthProvider(OAuthProvider):
"""OAuth provider that extends fastmcp's OAuthProvider with custom login flow."""
def __init__(self):
super().__init__(
base_url=settings.SERVER_URL,
issuer_url=settings.SERVER_URL,
client_registration_options=None, # We handle registration ourselves
required_scopes=BASE_SCOPES,
)
async def verify_token(self, token: str) -> FastMCPAccessToken | None:
"""Verify an access token and return token info if valid."""
with make_session() as session:
# Try as OAuth access token first
user_session = session.query(UserSession).get(token)
if user_session:
now = datetime.now(timezone.utc).replace(tzinfo=None)
if user_session.expires_at < now:
return None
return FastMCPAccessToken(
token=token,
client_id=user_session.oauth_state.client_id,
scopes=user_session.oauth_state.scopes,
)
# Try as bot API key
bot = session.query(User).filter(User.api_key == token).first()
if bot:
logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key")
return FastMCPAccessToken(
token=token,
client_id=cast(str, bot.name or bot.email),
scopes=["read", "write"],
)
return None
def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""Get OAuth client information."""
with make_session() as session:
client = session.get(OAuthClientInformation, client_id)

View File

@ -0,0 +1,17 @@
"""MCP subservers for composable tool organization."""
from memory.api.MCP.servers.core import core_mcp
from memory.api.MCP.servers.github import github_mcp
from memory.api.MCP.servers.people import people_mcp
from memory.api.MCP.servers.schedule import schedule_mcp
from memory.api.MCP.servers.books import books_mcp
from memory.api.MCP.servers.meta import meta_mcp
__all__ = [
"core_mcp",
"github_mcp",
"people_mcp",
"schedule_mcp",
"books_mcp",
"meta_mcp",
]

View File

@ -1,15 +1,19 @@
"""MCP subserver for ebook access."""
import logging
from fastmcp import FastMCP
from sqlalchemy.orm import joinedload
from memory.api.MCP.tools import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import Book, BookSection, BookSectionPayload
logger = logging.getLogger(__name__)
books_mcp = FastMCP("memory-books")
@mcp.tool()
@books_mcp.tool()
async def all_books(sections: bool = False) -> list[dict]:
"""
Get all books in the database.
@ -31,7 +35,7 @@ async def all_books(sections: bool = False) -> list[dict]:
return [book.as_payload(sections=sections) for book in books]
@mcp.tool()
@books_mcp.tool()
def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]:
"""
Read a book from the database.

View File

@ -1,53 +1,45 @@
"""
MCP tools for the epistemic sparring partner system.
Core MCP subserver for knowledge base search, observations, and notes.
"""
import base64
import logging
import pathlib
from datetime import datetime, timezone
from PIL import Image
from fastmcp import FastMCP
from PIL import Image
from pydantic import BaseModel
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.base import mcp
from memory.api.search.search import search
from memory.api.search.types import SearchFilters, SearchConfig
from memory.api.search.types import SearchConfig, SearchFilters
from memory.common import extract, settings
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
from memory.common.celery_app import app as celery_app
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
from memory.common.db.connection import make_session
from memory.common.db.models import SourceItem, AgentObservation
from memory.common.db.models import AgentObservation, SourceItem
from memory.common.formatters import observation
logger = logging.getLogger(__name__)
core_mcp = FastMCP("memory-core")
def validate_path_within_directory(
base_dir: pathlib.Path, requested_path: str
) -> pathlib.Path:
"""Validate that a requested path resolves within the base directory.
Prevents path traversal attacks using ../ or similar techniques.
Args:
base_dir: The allowed base directory
requested_path: The user-provided path
Returns:
The resolved absolute path if valid
Raises:
ValueError: If the path would escape the base directory
"""
"""Validate that a requested path resolves within the base directory."""
resolved = (base_dir / requested_path.lstrip("/")).resolve()
base_resolved = base_dir.resolve()
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
if (
not str(resolved).startswith(str(base_resolved) + "/")
and resolved != base_resolved
):
raise ValueError(f"Path escapes allowed directory: {requested_path}")
return resolved
@ -63,7 +55,6 @@ def filter_observation_source_ids(
items_query = session.query(AgentObservation.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
@ -89,7 +80,6 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
items_query = session.query(SourceItem.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
@ -102,7 +92,7 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
return source_ids
@mcp.tool()
@core_mcp.tool()
async def search_knowledge_base(
query: str,
filters: SearchFilters = {},
@ -141,7 +131,6 @@ async def search_knowledge_base(
if not modalities:
modalities = set(ALL_COLLECTIONS.keys())
# Filter to valid collections, excluding observation collections
modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS
search_filters = SearchFilters(**filters)
@ -167,7 +156,7 @@ class RawObservation(BaseModel):
tags: list[str] = []
@mcp.tool()
@core_mcp.tool()
async def observe(
observations: list[RawObservation],
session_id: str | None = None,
@ -212,23 +201,23 @@ async def observe(
logger.info("MCP: Observing")
tasks = [
(
observation,
obs,
celery_app.send_task(
SYNC_OBSERVATION,
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
kwargs={
"subject": observation.subject,
"content": observation.content,
"observation_type": observation.observation_type,
"confidences": observation.confidences,
"evidence": observation.evidence,
"tags": observation.tags,
"subject": obs.subject,
"content": obs.content,
"observation_type": obs.observation_type,
"confidences": obs.confidences,
"evidence": obs.evidence,
"tags": obs.tags,
"session_id": session_id,
"agent_model": agent_model,
},
),
)
for observation in observations
for obs in observations
]
def short_content(obs: RawObservation) -> str:
@ -242,7 +231,7 @@ async def observe(
}
@mcp.tool()
@core_mcp.tool()
async def search_observations(
query: str,
subject: str = "",
@ -308,7 +297,7 @@ async def search_observations(
]
@mcp.tool()
@core_mcp.tool()
async def create_note(
subject: str,
content: str,
@ -362,7 +351,7 @@ async def create_note(
}
@mcp.tool()
@core_mcp.tool()
async def note_files(path: str = "/"):
"""
List note files in the user's note storage.
@ -386,7 +375,7 @@ async def note_files(path: str = "/"):
]
@mcp.tool()
@core_mcp.tool()
def fetch_file(filename: str) -> dict:
"""
Read file content with automatic type detection.

View File

@ -1,14 +1,14 @@
"""MCP tools for GitHub issue tracking and management."""
"""MCP subserver for GitHub issue tracking and management."""
import logging
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import Text, case, desc, func, asc
from fastmcp import FastMCP
from sqlalchemy import Text, case, desc, func
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.base import mcp
from memory.api.search.search import search
from memory.api.search.types import SearchConfig, SearchFilters
from memory.common import extract
@ -17,6 +17,8 @@ from memory.common.db.models import GithubItem
logger = logging.getLogger(__name__)
github_mcp = FastMCP("memory-github")
def _build_github_url(repo_path: str, number: int | None, kind: str) -> str:
"""Build GitHub URL from repo path and issue/PR number."""
@ -53,7 +55,6 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
}
if include_content:
result["content"] = item.content
# Include PR-specific data if available
if item.kind == "pr" and item.pr_data:
result["pr_data"] = {
"additions": item.pr_data.additions,
@ -62,12 +63,12 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
"files": item.pr_data.files,
"reviews": item.pr_data.reviews,
"review_comments": item.pr_data.review_comments,
"diff": item.pr_data.diff, # Decompressed via property
"diff": item.pr_data.diff,
}
return result
@mcp.tool()
@github_mcp.tool()
async def list_github_issues(
repo: str | None = None,
assignee: str | None = None,
@ -109,64 +110,48 @@ async def list_github_issues(
with make_session() as session:
query = session.query(GithubItem)
# Apply filters
if repo:
query = query.filter(GithubItem.repo_path == repo)
if assignee:
query = query.filter(GithubItem.assignees.any(assignee))
if author:
query = query.filter(GithubItem.author == author)
if state:
query = query.filter(GithubItem.state == state)
if kind:
query = query.filter(GithubItem.kind == kind)
else:
# Exclude comments by default, only show issues and PRs
query = query.filter(GithubItem.kind.in_(["issue", "pr"]))
if labels:
# Match any label in the list using PostgreSQL array overlap
query = query.filter(
GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text)))
)
if project_status:
query = query.filter(GithubItem.project_status == project_status)
if project_field:
for key, value in project_field.items():
query = query.filter(
GithubItem.project_fields[key].astext == value
)
query = query.filter(GithubItem.project_fields[key].astext == value)
if updated_since:
since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00"))
query = query.filter(GithubItem.github_updated_at >= since_dt)
if updated_before:
before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00"))
query = query.filter(GithubItem.github_updated_at <= before_dt)
# Apply ordering
if order_by == "created":
query = query.order_by(desc(GithubItem.created_at))
elif order_by == "number":
query = query.order_by(desc(GithubItem.number))
else: # default: updated
else:
query = query.order_by(desc(GithubItem.github_updated_at))
query = query.limit(limit)
items = query.all()
return [_serialize_issue(item) for item in items]
@mcp.tool()
@github_mcp.tool()
async def search_github_issues(
query: str,
repo: str | None = None,
@ -191,7 +176,6 @@ async def search_github_issues(
limit = min(limit, 100)
# Pre-filter source_ids if repo/state/kind filters are specified
source_ids = None
if repo or state or kind:
with make_session() as session:
@ -206,7 +190,6 @@ async def search_github_issues(
q = q.filter(GithubItem.kind.in_(["issue", "pr"]))
source_ids = [item.id for item in q.all()]
# Use the existing search infrastructure
data = extract.extract_text(query, skip_summary=True)
config = SearchConfig(limit=limit, previews=True)
filters = SearchFilters()
@ -220,7 +203,6 @@ async def search_github_issues(
config=config,
)
# Fetch full issue details for the results
output = []
with make_session() as session:
for result in results:
@ -233,7 +215,7 @@ async def search_github_issues(
return output
@mcp.tool()
@github_mcp.tool()
async def github_issue_details(
repo: str,
number: int,
@ -267,7 +249,7 @@ async def github_issue_details(
return _serialize_issue(item, include_content=True)
@mcp.tool()
@github_mcp.tool()
async def github_work_summary(
since: str,
until: str | None = None,
@ -294,7 +276,6 @@ async def github_work_summary(
else:
until_dt = datetime.now(timezone.utc)
# Map group_by to SQL expression
group_mappings = {
"client": GithubItem.project_fields["EquiStamp.Client"].astext,
"status": GithubItem.project_status,
@ -311,7 +292,6 @@ async def github_work_summary(
group_col = group_mappings[group_by]
with make_session() as session:
# Build base query for the period
base_query = session.query(GithubItem).filter(
GithubItem.github_updated_at >= since_dt,
GithubItem.github_updated_at <= until_dt,
@ -321,7 +301,6 @@ async def github_work_summary(
if repo:
base_query = base_query.filter(GithubItem.repo_path == repo)
# Get aggregated counts by group
agg_query = (
session.query(
group_col.label("group_name"),
@ -346,7 +325,6 @@ async def github_work_summary(
groups = agg_query.all()
# Build summary with sample issues for each group
summary = []
total_issues = 0
total_prs = 0
@ -358,7 +336,6 @@ async def github_work_summary(
total_issues += issue_count
total_prs += pr_count
# Get sample issues for this group
sample_query = base_query.filter(group_col == group_name).limit(5)
samples = [
{
@ -394,7 +371,7 @@ async def github_work_summary(
}
@mcp.tool()
@github_mcp.tool()
async def github_repo_overview(
repo: str,
) -> dict:
@ -410,13 +387,6 @@ async def github_repo_overview(
logger.info(f"github_repo_overview called: repo={repo}")
with make_session() as session:
# Base query for this repo
base_query = session.query(GithubItem).filter(
GithubItem.repo_path == repo,
GithubItem.kind.in_(["issue", "pr"]),
)
# Get total counts
counts_query = session.query(
func.count(GithubItem.id).label("total"),
func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"),
@ -441,7 +411,6 @@ async def github_repo_overview(
counts = counts_query.first()
# Status breakdown (for project_status)
status_query = (
session.query(
GithubItem.project_status.label("status"),
@ -458,7 +427,6 @@ async def github_repo_overview(
status_breakdown = {row.status: row.count for row in status_query.all()}
# Top assignees (open issues only)
assignee_query = (
session.query(
func.unnest(GithubItem.assignees).label("assignee"),
@ -479,7 +447,6 @@ async def github_repo_overview(
for row in assignee_query.all()
]
# Label counts
label_query = (
session.query(
func.unnest(GithubItem.labels).label("label"),

View File

@ -0,0 +1,277 @@
"""MCP subserver for metadata, utilities, and forecasting."""
import asyncio
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import Annotated, Literal, NotRequired, TypedDict, get_args, get_type_hints
import aiohttp
from fastmcp import FastMCP
from sqlalchemy import func
from memory.common import qdrant
from memory.common.db.connection import make_session
from memory.common.db.models import SourceItem
from memory.common.db.models.source_items import AgentObservation
logger = logging.getLogger(__name__)
meta_mcp = FastMCP("memory-meta")
# Auth provider will be injected at mount time
_get_current_user = None
def set_auth_provider(get_current_user_func):
"""Set the authentication provider function."""
global _get_current_user
_get_current_user = get_current_user_func
def get_current_user() -> dict:
"""Get the current authenticated user."""
if _get_current_user is None:
return {"authenticated": False, "error": "Auth provider not configured"}
return _get_current_user()
# --- Metadata tools ---
class SchemaArg(TypedDict):
type: str | None
description: str | None
class CollectionMetadata(TypedDict):
schema: dict[str, SchemaArg]
size: int
def from_annotation(annotation: Annotated) -> SchemaArg | None:
try:
type_, description = get_args(annotation)
type_str = str(type_)
if type_str.startswith("typing."):
type_str = type_str[7:]
elif len((parts := type_str.split("'"))) > 1:
type_str = parts[1]
return SchemaArg(type=type_str, description=description)
except IndexError:
logger.error(f"Error from annotation: {annotation}")
return None
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
if not hasattr(klass, "as_payload"):
return {}
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
return {}
return {
name: schema
for name, arg in payload_type.__annotations__.items()
if (schema := from_annotation(arg))
}
@meta_mcp.tool()
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
"""Get the metadata schema for each collection used in the knowledge base.
These schemas can be used to filter the knowledge base.
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
Example:
```
{
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
}
"""
client = qdrant.get_qdrant_client()
sizes = qdrant.get_collection_sizes(client)
schemas = defaultdict(dict)
for klass in SourceItem.__subclasses__():
for collection in klass.get_collections():
schemas[collection].update(get_schema(klass))
return {
collection: CollectionMetadata(schema=schema, size=size)
for collection, schema in schemas.items()
if (size := sizes.get(collection))
}
@meta_mcp.tool()
async def get_all_tags() -> list[str]:
"""Get all unique tags used across the entire knowledge base.
Returns sorted list of tags from both observations and content.
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@meta_mcp.tool()
async def get_all_subjects() -> list[str]:
"""Get all unique subjects from observations about the user.
Returns sorted list of subject identifiers used in observations.
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@meta_mcp.tool()
async def get_all_observation_types() -> list[str]:
"""Get all observation types that have been used.
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)
# --- Utility tools ---
@meta_mcp.tool()
async def get_current_time() -> dict:
"""Get the current time in UTC."""
logger.info("get_current_time tool called")
return {"current_time": datetime.now(timezone.utc).isoformat()}
@meta_mcp.tool()
async def get_authenticated_user() -> dict:
"""Get information about the authenticated user."""
return get_current_user()
# --- Forecasting tools ---
class BinaryProbs(TypedDict):
prob: float
class MultiProbs(TypedDict):
answerProbs: dict[str, float]
Probs = dict[str, BinaryProbs | MultiProbs]
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
class MarketAnswer(TypedDict):
id: str
text: str
resolutionProbability: float
class MarketDetails(TypedDict):
id: str
createdTime: int
question: str
outcomeType: OutcomeType
textDescription: str
groupSlugs: list[str]
volume: float
isResolved: bool
answers: list[MarketAnswer]
class Market(TypedDict):
id: str
url: str
question: str
volume: int
createdTime: int
outcomeType: OutcomeType
createdAt: NotRequired[str]
description: NotRequired[str]
answers: NotRequired[dict[str, float]]
probability: NotRequired[float]
details: NotRequired[MarketDetails]
async def get_details(session: aiohttp.ClientSession, market_id: str):
async with session.get(
f"https://api.manifold.markets/v0/market/{market_id}"
) as resp:
resp.raise_for_status()
return await resp.json()
async def format_market(session: aiohttp.ClientSession, market: Market):
if market.get("outcomeType") != "BINARY":
details = await get_details(session, market["id"])
market["answers"] = {
answer["text"]: round(
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
)
for answer in details["answers"]
}
if creationTime := market.get("createdTime"):
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
fields = [
"id",
"name",
"url",
"question",
"volume",
"createdAt",
"details",
"probability",
"answers",
]
return {k: v for k, v in market.items() if k in fields}
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.manifold.markets/v0/search-markets",
params={
"term": term,
"contractType": "BINARY" if binary else "ALL",
},
) as resp:
resp.raise_for_status()
markets = await resp.json()
return await asyncio.gather(
*[
format_market(session, market)
for market in markets
if market.get("volume", 0) >= min_volume
]
)
@meta_mcp.tool()
async def get_forecasts(
term: str, min_volume: int = 1000, binary: bool = False
) -> list[dict]:
"""Get prediction market forecasts for a given term.
Args:
term: The term to search for.
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
binary: Whether to only return binary markets.
"""
return await search_markets(term, min_volume, binary)

View File

@ -1,23 +1,23 @@
"""
MCP tools for tracking people.
"""
"""MCP subserver for tracking people."""
import logging
from typing import Any
from fastmcp import FastMCP
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.base import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import Person
from memory.common import settings
from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON
from memory.common.celery_app import app as celery_app
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import Person
logger = logging.getLogger(__name__)
people_mcp = FastMCP("memory-people")
def _person_to_dict(person: Person) -> dict[str, Any]:
"""Convert a Person model to a dictionary for API responses."""
@ -32,7 +32,7 @@ def _person_to_dict(person: Person) -> dict[str, Any]:
}
@mcp.tool()
@people_mcp.tool()
async def add_person(
identifier: str,
display_name: str,
@ -67,7 +67,6 @@ async def add_person(
"""
logger.info(f"MCP: Adding person: {identifier}")
# Check if person already exists
with make_session() as session:
existing = session.query(Person).filter(Person.identifier == identifier).first()
if existing:
@ -93,7 +92,7 @@ async def add_person(
}
@mcp.tool()
@people_mcp.tool()
async def update_person_info(
identifier: str,
display_name: str | None = None,
@ -135,7 +134,6 @@ async def update_person_info(
"""
logger.info(f"MCP: Updating person: {identifier}")
# Verify person exists
with make_session() as session:
person = session.query(Person).filter(Person.identifier == identifier).first()
if not person:
@ -162,7 +160,7 @@ async def update_person_info(
}
@mcp.tool()
@people_mcp.tool()
async def get_person(identifier: str) -> dict | None:
"""
Get a person by their identifier.
@ -182,7 +180,7 @@ async def get_person(identifier: str) -> dict | None:
return _person_to_dict(person)
@mcp.tool()
@people_mcp.tool()
async def list_people(
tags: list[str] | None = None,
search: str | None = None,
@ -207,9 +205,7 @@ async def list_people(
query = session.query(Person)
if tags:
query = query.filter(
Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
)
query = query.filter(Person.tags.op("&&")(sql_cast(tags, ARRAY(Text))))
if search:
search_term = f"%{search.lower()}%"
@ -225,7 +221,7 @@ async def list_people(
return [_person_to_dict(p) for p in people]
@mcp.tool()
@people_mcp.tool()
async def delete_person(identifier: str) -> dict:
"""
Delete a person by their identifier.

View File

@ -1,20 +1,37 @@
"""
MCP tools for the epistemic sparring partner system.
"""
"""MCP subserver for scheduling messages and LLM calls."""
import logging
from datetime import datetime, timezone
from typing import Any, cast
from memory.api.MCP.base import get_current_user, mcp
from fastmcp import FastMCP
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
from memory.common.db.models import DiscordBotUser, ScheduledLLMCall
from memory.discord.messages import schedule_discord_message
logger = logging.getLogger(__name__)
schedule_mcp = FastMCP("memory-schedule")
@mcp.tool()
# We need access to get_current_user from base - this will be injected at mount time
_get_current_user = None
def set_auth_provider(get_current_user_func):
"""Set the authentication provider function."""
global _get_current_user
_get_current_user = get_current_user_func
def get_current_user() -> dict:
"""Get the current authenticated user."""
if _get_current_user is None:
return {"authenticated": False, "error": "Auth provider not configured"}
return _get_current_user()
@schedule_mcp.tool()
async def schedule_message(
scheduled_time: str,
message: str,
@ -61,10 +78,8 @@ async def schedule_message(
if not discord_user and not discord_channel:
raise ValueError("Either discord_user or discord_channel must be provided")
# Parse scheduled time
try:
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
# Ensure we store as naive datetime (remove timezone info for database storage)
if scheduled_dt.tzinfo is not None:
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
except ValueError:
@ -98,7 +113,7 @@ async def schedule_message(
}
@mcp.tool()
@schedule_mcp.tool()
async def list_scheduled_llm_calls(
status: str | None = None, limit: int | None = 50
) -> dict[str, Any]:
@ -143,7 +158,7 @@ async def list_scheduled_llm_calls(
}
@mcp.tool()
@schedule_mcp.tool()
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
"""
Cancel a scheduled LLM call.
@ -164,7 +179,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
return {"error": "User not found", "user": current_user}
with make_session() as session:
# Find the scheduled call
scheduled_call = (
session.query(ScheduledLLMCall)
.filter(
@ -180,7 +194,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
if not scheduled_call.can_be_cancelled():
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
# Update the status
scheduled_call.status = "cancelled"
session.commit()

View File

@ -1,74 +0,0 @@
"""
MCP tools for the epistemic sparring partner system.
"""
import logging
from datetime import datetime, timezone
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.common.db.connection import make_session
from memory.common.db.models import AgentObservation, SourceItem
from memory.api.MCP.base import mcp, get_current_user
logger = logging.getLogger(__name__)
def filter_observation_source_ids(
tags: list[str] | None = None, observation_types: list[str] | None = None
):
if not tags and not observation_types:
return None
with make_session() as session:
items_query = session.query(AgentObservation.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if observation_types:
items_query = items_query.filter(
AgentObservation.observation_type.in_(observation_types)
)
source_ids = [item.id for item in items_query.all()]
return source_ids
def filter_source_ids(
modalities: set[str],
tags: list[str] | None = None,
):
if not tags:
return None
with make_session() as session:
items_query = session.query(SourceItem.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if modalities:
items_query = items_query.filter(SourceItem.modality.in_(modalities))
source_ids = [item.id for item in items_query.all()]
return source_ids
@mcp.tool()
async def get_current_time() -> dict:
"""Get the current time in UTC."""
logger.info("get_current_time tool called")
return {"current_time": datetime.now(timezone.utc).isoformat()}
@mcp.tool()
async def get_authenticated_user() -> dict:
"""Get information about the authenticated user."""
return get_current_user()

View File

@ -2,7 +2,6 @@
FastAPI application for the knowledge base.
"""
import contextlib
import os
import logging
import mimetypes
@ -35,15 +34,10 @@ limiter = Limiter(
enabled=settings.API_RATE_LIMIT_ENABLED,
)
# Create the MCP http app to get its lifespan
mcp_http_app = mcp.http_app(stateless_http=True)
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
async with contextlib.AsyncExitStack() as stack:
await stack.enter_async_context(mcp.session_manager.run())
yield
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
app = FastAPI(title="Knowledge Base API", lifespan=mcp_http_app.lifespan)
app.state.limiter = limiter
# Rate limit exception handler
@ -158,46 +152,9 @@ app.include_router(auth_router)
# Add health check to MCP server instead of main app
@mcp.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
"""Health check endpoint that verifies all dependencies are accessible."""
from fastapi.responses import JSONResponse
from sqlalchemy import text
checks = {"mcp_oauth": "enabled"}
all_healthy = True
# Check database connection
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
checks["database"] = "healthy"
except Exception as e:
# Log error details but don't expose in response
logger.error(f"Database health check failed: {e}")
checks["database"] = "unhealthy"
all_healthy = False
# Check Qdrant connection
try:
from memory.common.qdrant import get_qdrant_client
client = get_qdrant_client()
client.get_collections()
checks["qdrant"] = "healthy"
except Exception as e:
# Log error details but don't expose in response
logger.error(f"Qdrant health check failed: {e}")
checks["qdrant"] = "unhealthy"
all_healthy = False
checks["status"] = "healthy" if all_healthy else "degraded"
status_code = 200 if all_healthy else 503
return JSONResponse(checks, status_code=status_code)
# Mount MCP server at root - OAuth endpoints need to be at root level
app.mount("/", mcp.streamable_http_app())
# Health check is defined in MCP/base.py
app.mount("/", mcp_http_app)
def main(reload: bool = False):

View File

@ -17,6 +17,7 @@ DISCORD_ROOT = "memory.workers.tasks.discord"
BACKUP_ROOT = "memory.workers.tasks.backup"
GITHUB_ROOT = "memory.workers.tasks.github"
PEOPLE_ROOT = "memory.workers.tasks.people"
PROACTIVE_ROOT = "memory.workers.tasks.proactive"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
@ -73,6 +74,10 @@ SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person"
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file"
# Proactive check-in tasks
EVALUATE_PROACTIVE_CHECKINS = f"{PROACTIVE_ROOT}.evaluate_proactive_checkins"
EXECUTE_PROACTIVE_CHECKIN = f"{PROACTIVE_ROOT}.execute_proactive_checkin"
def get_broker_url() -> str:
protocol = settings.CELERY_BROKER_TYPE
@ -130,12 +135,17 @@ app.conf.update(
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
f"{PROACTIVE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
},
beat_schedule={
"sync-github-repos-hourly": {
"task": SYNC_ALL_GITHUB_REPOS,
"schedule": crontab(minute=0), # Every hour at :00
},
"evaluate-proactive-checkins": {
"task": EVALUATE_PROACTIVE_CHECKINS,
"schedule": crontab(), # Every minute
},
},
)

View File

@ -63,13 +63,29 @@ class MessageProcessor:
doc=textwrap.dedent(
"""
A summary of this processor, made by and for AI systems.
The idea here is that AI systems can use this summary to keep notes on the given processor.
These should automatically be injected into the context of the messages that are processed by this processor.
These should automatically be injected into the context of the messages that are processed by this processor.
"""
),
)
proactive_cron = Column(
Text,
nullable=True,
doc="Cron schedule for proactive check-ins (e.g., '0 9 * * *' for 9am daily). None = disabled.",
)
proactive_prompt = Column(
Text,
nullable=True,
doc="Custom instructions for proactive check-ins.",
)
last_proactive_at = Column(
DateTime(timezone=True),
nullable=True,
doc="When the last proactive check-in was sent.",
)
@property
def entity_type(self) -> str:
return self.__class__.__tablename__[8:-1] # type: ignore

View File

@ -132,6 +132,7 @@ CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
PROACTIVE_CHECKIN_INTERVAL = int(os.getenv("PROACTIVE_CHECKIN_INTERVAL", 60))
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))

View File

@ -167,6 +167,25 @@ def _create_scope_group(
url=url and url.strip(),
)
# Proactive command
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
@discord.app_commands.describe(
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
prompt="Optional custom instructions for check-ins",
)
async def proactive_cmd(
interaction: discord.Interaction,
cron: str | None = None,
prompt: str | None = None,
):
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_proactive,
cron=cron and cron.strip(),
prompt=prompt,
)
return group
@ -265,6 +284,28 @@ def _create_user_scope_group(
url=url and url.strip(),
)
# Proactive command
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
@discord.app_commands.describe(
user="Target user",
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
prompt="Optional custom instructions for check-ins",
)
async def proactive_cmd(
interaction: discord.Interaction,
user: discord.User,
cron: str | None = None,
prompt: str | None = None,
):
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_proactive,
target_user=user,
cron=cron and cron.strip(),
prompt=prompt,
)
return group
@ -663,3 +704,68 @@ async def handle_mcp_servers(
except Exception as exc:
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
raise CommandError(f"Error: {exc}") from exc
async def handle_proactive(
context: CommandContext,
*,
cron: str | None = None,
prompt: str | None = None,
) -> CommandResponse:
"""Handle proactive check-in configuration."""
from croniter import croniter
model = context.target
# If no arguments, show current settings
if cron is None and prompt is None:
current_cron = getattr(model, "proactive_cron", None)
current_prompt = getattr(model, "proactive_prompt", None)
if not current_cron:
return CommandResponse(
content=f"Proactive check-ins are disabled for {context.display_name}."
)
lines = [f"Proactive check-ins for {context.display_name}:"]
lines.append(f" Schedule: `{current_cron}`")
if current_prompt:
lines.append(f" Prompt: {current_prompt}")
return CommandResponse(content="\n".join(lines))
# Handle cron setting
if cron is not None:
if cron.lower() == "off":
setattr(model, "proactive_cron", None)
return CommandResponse(
content=f"Proactive check-ins disabled for {context.display_name}."
)
# Validate cron expression
try:
croniter(cron)
except (ValueError, KeyError) as e:
raise CommandError(
f"Invalid cron expression: {cron}\n"
"Examples:\n"
" `0 9 * * *` - 9am daily\n"
" `0 9,17 * * 1-5` - 9am and 5pm weekdays\n"
" `0 */4 * * *` - every 4 hours"
) from e
setattr(model, "proactive_cron", cron)
# Handle prompt setting
if prompt is not None:
setattr(model, "proactive_prompt", prompt or None)
# Build response
current_cron = getattr(model, "proactive_cron", None)
current_prompt = getattr(model, "proactive_prompt", None)
lines = [f"Updated proactive settings for {context.display_name}:"]
lines.append(f" Schedule: `{current_cron}`")
if current_prompt:
lines.append(f" Prompt: {current_prompt}")
return CommandResponse(content="\n".join(lines))

View File

@ -11,6 +11,7 @@ from memory.common.celery_app import (
SYNC_LESSWRONG,
RUN_SCHEDULED_CALLS,
BACKUP_ALL,
EVALUATE_PROACTIVE_CHECKINS,
)
logger = logging.getLogger(__name__)
@ -53,4 +54,8 @@ app.conf.beat_schedule = {
"task": BACKUP_ALL,
"schedule": settings.S3_BACKUP_INTERVAL,
},
"evaluate-proactive-checkins": {
"task": EVALUATE_PROACTIVE_CHECKINS,
"schedule": settings.PROACTIVE_CHECKIN_INTERVAL,
},
}

View File

@ -14,20 +14,24 @@ from memory.workers.tasks import (
maintenance,
notes,
observations,
people,
proactive,
scheduled_calls,
) # noqa
__all__ = [
"backup",
"email",
"comic",
"blogs",
"ebook",
"comic",
"discord",
"ebook",
"email",
"forums",
"github",
"maintenance",
"notes",
"observations",
"people",
"proactive",
"scheduled_calls",
]

View File

@ -0,0 +1,341 @@
"""
Celery tasks for proactive Discord check-ins.
"""
import logging
import re
import textwrap
from datetime import datetime, timezone
from typing import Any, Literal, cast
from croniter import croniter
from sqlalchemy import or_
from memory.common import settings
from memory.common.celery_app import app
from memory.common.db.connection import make_session
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
from memory.workers.tasks.content_processing import safe_task_execution
logger = logging.getLogger(__name__)
EVALUATE_PROACTIVE_CHECKINS = "memory.workers.tasks.proactive.evaluate_proactive_checkins"
EXECUTE_PROACTIVE_CHECKIN = "memory.workers.tasks.proactive.execute_proactive_checkin"
EntityType = Literal["user", "channel", "server"]
def is_cron_due(cron_expr: str, last_run: datetime | None, now: datetime) -> bool:
"""Check if a cron expression is due to run now.
Uses croniter to determine if the current time falls within the cron's schedule
and enough time has passed since the last run.
"""
try:
cron = croniter(cron_expr, now)
# Get the previous scheduled time from now
prev_run = cron.get_prev(datetime)
# Get the one before that to determine the interval
cron.get_prev(datetime)
prev_prev_run = cron.get_current(datetime)
# If we haven't run since the last scheduled time, we should run
if last_run is None:
# Never run before - check if current time is within a minute of prev_run
time_since_scheduled = (now - prev_run).total_seconds()
return time_since_scheduled < 120 # Within 2 minutes of scheduled time
# Make sure last_run is timezone aware
if last_run.tzinfo is None:
last_run = last_run.replace(tzinfo=timezone.utc)
# We should run if last_run is before the previous scheduled time
return last_run < prev_run
except Exception as e:
logger.warning(f"Invalid cron expression '{cron_expr}': {e}")
return False
def get_bot_for_entity(
session, entity_type: EntityType, entity_id: int
) -> DiscordUser | None:
"""Get the bot user associated with an entity."""
from memory.common.db.models import DiscordBotUser, DiscordMessage
from sqlalchemy.orm import joinedload
# For servers, find a bot that has sent messages in that server
if entity_type == "server":
# Find bots that have interacted with this server
bot_users = (
session.query(DiscordUser)
.options(joinedload(DiscordUser.system_user))
.join(DiscordMessage, DiscordMessage.from_id == DiscordUser.id)
.filter(
DiscordMessage.server_id == entity_id,
DiscordUser.system_user_id.isnot(None),
)
.distinct()
.all()
)
# Find one that's actually a bot
for user in bot_users:
if user.system_user and user.system_user.user_type == "discord_bot":
return user
# For channels, check the server the channel belongs to
if entity_type == "channel":
channel = session.get(DiscordChannel, entity_id)
if channel and channel.server_id:
return get_bot_for_entity(session, "server", channel.server_id)
# Fallback: use first available bot
bot = (
session.query(DiscordBotUser)
.options(joinedload(DiscordBotUser.discord_users).joinedload(DiscordUser.system_user))
.first()
)
if bot and bot.discord_users:
return bot.discord_users[0]
return None
def get_target_user_for_entity(
session, entity_type: EntityType, entity_id: int
) -> DiscordUser | None:
"""Get the target user for sending a proactive message."""
if entity_type == "user":
return session.get(DiscordUser, entity_id)
# For channels and servers, we don't have a specific target user
return None
def get_channel_for_entity(
session, entity_type: EntityType, entity_id: int
) -> DiscordChannel | None:
"""Get the channel for sending a proactive message."""
if entity_type == "channel":
return session.get(DiscordChannel, entity_id)
if entity_type == "server":
# For servers, find the first text channel (prefer "general")
channels = (
session.query(DiscordChannel)
.filter(
DiscordChannel.server_id == entity_id,
DiscordChannel.channel_type == "text",
)
.all()
)
if not channels:
return None
# Prefer a channel named "general" if it exists
for channel in channels:
if channel.name and "general" in channel.name.lower():
return channel
return channels[0]
# For users, we use DMs (no channel)
return None
@app.task(name=EVALUATE_PROACTIVE_CHECKINS)
@safe_task_execution
def evaluate_proactive_checkins() -> dict[str, Any]:
"""
Evaluate which entities need proactive check-ins.
This task runs every minute and checks all entities with proactive_cron set
to see if they're due for a check-in.
"""
now = datetime.now(timezone.utc)
dispatched = []
with make_session() as session:
# Query all entities with proactive_cron set
for model, entity_type in [
(DiscordUser, "user"),
(DiscordChannel, "channel"),
(DiscordServer, "server"),
]:
entities = (
session.query(model)
.filter(model.proactive_cron.isnot(None))
.all()
)
for entity in entities:
cron_expr = cast(str, entity.proactive_cron)
last_run = entity.last_proactive_at
if is_cron_due(cron_expr, last_run, now):
logger.info(
f"Proactive check-in due for {entity_type} {entity.id}"
)
execute_proactive_checkin.delay(entity_type, entity.id)
dispatched.append({"type": entity_type, "id": entity.id})
return {
"evaluated_at": now.isoformat(),
"dispatched": dispatched,
"count": len(dispatched),
}
@app.task(name=EXECUTE_PROACTIVE_CHECKIN)
@safe_task_execution
def execute_proactive_checkin(entity_type: EntityType, entity_id: int) -> dict[str, Any]:
"""
Execute a proactive check-in for a specific entity.
This evaluates whether the bot should reach out and, if so, generates
and sends a check-in message.
"""
logger.info(f"Executing proactive check-in for {entity_type} {entity_id}")
with make_session() as session:
# Get the entity
model_class = {
"user": DiscordUser,
"channel": DiscordChannel,
"server": DiscordServer,
}[entity_type]
entity = session.get(model_class, entity_id)
if not entity:
return {"error": f"{entity_type} {entity_id} not found"}
# Get the bot user
bot_user = get_bot_for_entity(session, entity_type, entity_id)
if not bot_user:
return {"error": "No bot user found"}
# Get target user and channel
target_user = get_target_user_for_entity(session, entity_type, entity_id)
channel = get_channel_for_entity(session, entity_type, entity_id)
if not target_user and not channel:
return {"error": "No target user or channel for proactive check-in"}
# Get chattiness threshold
chattiness = entity.chattiness_threshold or 90
# Build the evaluation prompt
proactive_prompt = entity.proactive_prompt or ""
eval_prompt = textwrap.dedent("""
You are considering whether to proactively reach out to check in.
{proactive_prompt}
Based on your notes and the context of previous conversations:
1. Is there anything worth checking in about?
2. Has enough happened or enough time passed to warrant a check-in?
3. Would reaching out now be welcome or intrusive?
Please return a number between 0 and 100 indicating how strongly you want to check in
(0 = definitely not, 100 = definitely yes).
<response>
<number>50</number>
<reason>Your reasoning here</reason>
</response>
""").format(proactive_prompt=proactive_prompt)
# Build context
system_prompt = comm_channel_prompt(
session, bot_user, target_user, channel
)
# First, evaluate whether we should check in
eval_response = call_llm(
session,
bot_user=bot_user,
from_user=target_user,
channel=channel,
model=settings.SUMMARIZER_MODEL,
system_prompt=system_prompt,
messages=[eval_prompt],
allowed_tools=[
"update_channel_summary",
"update_user_summary",
"update_server_summary",
],
)
if not eval_response:
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {"status": "no_eval_response", "entity_type": entity_type, "entity_id": entity_id}
# Parse the interest score
match = re.search(r"<number>(\d+)</number>", eval_response)
if not match:
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {"status": "no_score_in_response", "entity_type": entity_type, "entity_id": entity_id}
interest_score = int(match.group(1))
threshold = 100 - chattiness
logger.info(
f"Proactive check-in eval: interest={interest_score}, threshold={threshold}, chattiness={chattiness}"
)
if interest_score < threshold:
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {
"status": "below_threshold",
"interest": interest_score,
"threshold": threshold,
"entity_type": entity_type,
"entity_id": entity_id,
}
# Generate the actual check-in message
checkin_prompt = textwrap.dedent("""
You've decided to proactively check in. Generate a natural, friendly check-in message.
{proactive_prompt}
Keep it brief and genuine. Don't be overly formal or robotic.
Reference specific things from your notes if relevant.
""").format(proactive_prompt=proactive_prompt)
response = call_llm(
session,
bot_user=bot_user,
from_user=target_user,
channel=channel,
model=settings.DISCORD_MODEL,
system_prompt=system_prompt,
messages=[checkin_prompt],
)
if not response:
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {"status": "no_message_generated", "entity_type": entity_type, "entity_id": entity_id}
# Send the message
bot_id = bot_user.system_user.id if bot_user.system_user else None
if not bot_id:
return {"error": "No system user for bot"}
success = send_discord_response(
bot_id=bot_id,
response=response,
channel_id=channel.id if channel else None,
user_identifier=target_user.username if target_user else None,
)
# Update last_proactive_at
entity.last_proactive_at = datetime.now(timezone.utc)
session.commit()
return {
"status": "sent" if success else "send_failed",
"interest": interest_score,
"entity_type": entity_type,
"entity_id": entity_id,
"response_preview": response[:100] + "..." if len(response) > 100 else response,
}

View File

@ -766,7 +766,7 @@ EXPECTED_OBSERVATION_RESULTS = {
),
(
0.409,
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
"Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
),
],
},
@ -835,11 +835,11 @@ EXPECTED_OBSERVATION_RESULTS = {
"semantic": [
(0.489, "I find backend logic more interesting than UI work"),
(0.462, "The user prefers working on backend systems over frontend UI"),
(0.455, "The user said pure functions are yucky"),
(
0.455,
"The user believes functional programming leads to better code quality",
),
(0.455, "The user said pure functions are yucky"),
],
"temporal": [
(

View File

@ -137,10 +137,9 @@ class TestBuildPrompt:
):
prompt = _build_prompt()
assert "lesswrong" in prompt.lower()
assert "comic" in prompt.lower()
assert "Remove" in prompt
assert "Remove meta-language" in prompt
assert "Return ONLY valid JSON" in prompt
assert "recalled_content" in prompt
class TestAnalyzeQuery:

View File

@ -83,9 +83,10 @@ def test_logout_handles_missing_session(mock_get_user_session):
@pytest.mark.asyncio
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session")
async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
async def test_oauth_callback_discord_success(mock_make_session, mock_complete, mock_mcp_tools):
mock_session = MagicMock()
@contextmanager
@ -95,9 +96,12 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
mock_make_session.return_value = session_cm()
mcp_server = MagicMock()
mcp_server.mcp_server_url = "https://example.com"
mcp_server.access_token = "token123"
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (200, "Authorized")
mock_mcp_tools.return_value = [{"name": "test_tool"}]
request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request)
@ -107,14 +111,15 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
assert "Authorization Successful" in body
assert "Authorized" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
mock_session.commit.assert_called_once()
assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
@pytest.mark.asyncio
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session")
async def test_oauth_callback_discord_handles_failures(
mock_make_session, mock_complete
mock_make_session, mock_complete, mock_mcp_tools
):
mock_session = MagicMock()
@ -125,9 +130,12 @@ async def test_oauth_callback_discord_handles_failures(
mock_make_session.return_value = session_cm()
mcp_server = MagicMock()
mcp_server.mcp_server_url = "https://example.com"
mcp_server.access_token = "token123"
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (500, "Failure")
mock_mcp_tools.return_value = []
request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request)
@ -137,7 +145,7 @@ async def test_oauth_callback_discord_handles_failures(
assert "Authorization Failed" in body
assert "Failure" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
mock_session.commit.assert_called_once()
assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
@pytest.mark.asyncio

View File

@ -18,6 +18,7 @@ from memory.common.db.models import (
DiscordUser,
DiscordMessage,
BotUser,
DiscordBotUser,
HumanUser,
ScheduledLLMCall,
)
@ -67,9 +68,10 @@ def sample_discord_user(db_session):
@pytest.fixture
def sample_bot_user(db_session):
def sample_bot_user(db_session, sample_discord_user):
"""Create a sample bot user for testing."""
bot = BotUser.create_with_api_key(
bot = DiscordBotUser.create_with_api_key(
discord_users=[sample_discord_user],
name="Test Bot",
email="testbot@example.com",
)
@ -209,9 +211,9 @@ def test_schedule_message_with_user(
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
result = schedule_message(
user_id=sample_human_user.id,
user=sample_discord_user.id,
channel=None,
bot_id=sample_human_user.id,
recipient_id=sample_discord_user.id,
channel_id=None,
model="test-model",
message="Test message",
date_time=future_time,
@ -240,9 +242,9 @@ def test_schedule_message_with_channel(
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
result = schedule_message(
user_id=sample_human_user.id,
user=None,
channel=sample_discord_channel.id,
bot_id=sample_human_user.id,
recipient_id=None,
channel_id=sample_discord_channel.id,
model="test-model",
message="Test message",
date_time=future_time,
@ -265,12 +267,12 @@ def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user):
"""Test creating a message scheduler tool for a user."""
tool = make_message_scheduler(
bot=sample_bot_user,
user=sample_discord_user.id,
channel=None,
user_id=sample_discord_user.id,
channel_id=None,
model="test-model",
)
assert tool.name == "schedule_message"
assert tool.name == "schedule_discord_message"
assert "from your chat with this user" in tool.description
assert tool.input_schema["type"] == "object"
assert "message" in tool.input_schema["properties"]
@ -282,12 +284,12 @@ def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_cha
"""Test creating a message scheduler tool for a channel."""
tool = make_message_scheduler(
bot=sample_bot_user,
user=None,
channel=sample_discord_channel.id,
user_id=None,
channel_id=sample_discord_channel.id,
model="test-model",
)
assert tool.name == "schedule_message"
assert tool.name == "schedule_discord_message"
assert "in this channel" in tool.description
assert callable(tool.function)
@ -297,8 +299,8 @@ def test_make_message_scheduler_without_user_or_channel(sample_bot_user):
with pytest.raises(ValueError, match="Either user or channel must be provided"):
make_message_scheduler(
bot=sample_bot_user,
user=None,
channel=None,
user_id=None,
channel_id=None,
model="test-model",
)
@ -310,8 +312,8 @@ def test_message_scheduler_handler_success(
"""Test message scheduler handler with valid input."""
tool = make_message_scheduler(
bot=sample_bot_user,
user=sample_discord_user.id,
channel=None,
user_id=sample_discord_user.id,
channel_id=None,
model="test-model",
)
@ -330,8 +332,8 @@ def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord
"""Test message scheduler handler with non-dict input."""
tool = make_message_scheduler(
bot=sample_bot_user,
user=sample_discord_user.id,
channel=None,
user_id=sample_discord_user.id,
channel_id=None,
model="test-model",
)
@ -345,8 +347,8 @@ def test_message_scheduler_handler_invalid_datetime(
"""Test message scheduler handler with invalid datetime."""
tool = make_message_scheduler(
bot=sample_bot_user,
user=sample_discord_user.id,
channel=None,
user_id=sample_discord_user.id,
channel_id=None,
model="test-model",
)
@ -365,8 +367,8 @@ def test_message_scheduler_handler_missing_datetime(
"""Test message scheduler handler with missing datetime."""
tool = make_message_scheduler(
bot=sample_bot_user,
user=sample_discord_user.id,
channel=None,
user_id=sample_discord_user.id,
channel_id=None,
model="test-model",
)
@ -375,9 +377,9 @@ def test_message_scheduler_handler_missing_datetime(
# Tests for make_prev_messages_tool
def test_make_prev_messages_tool_with_user(sample_discord_user):
def test_make_prev_messages_tool_with_user(sample_bot_user, sample_discord_user):
"""Test creating a previous messages tool for a user."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
assert tool.name == "previous_messages"
assert "from your chat with this user" in tool.description
@ -387,26 +389,26 @@ def test_make_prev_messages_tool_with_user(sample_discord_user):
assert callable(tool.function)
def test_make_prev_messages_tool_with_channel(sample_discord_channel):
def test_make_prev_messages_tool_with_channel(sample_bot_user, sample_discord_channel):
"""Test creating a previous messages tool for a channel."""
tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id)
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=sample_discord_channel.id)
assert tool.name == "previous_messages"
assert "in this channel" in tool.description
assert callable(tool.function)
def test_make_prev_messages_tool_without_user_or_channel():
def test_make_prev_messages_tool_without_user_or_channel(sample_bot_user):
"""Test that creating a tool without user or channel raises error."""
with pytest.raises(ValueError, match="Either user or channel must be provided"):
make_prev_messages_tool(user=None, channel=None)
make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=None)
def test_prev_messages_handler_success(
db_session, sample_discord_user, sample_discord_channel
db_session, sample_bot_user, sample_discord_user, sample_discord_channel
):
"""Test previous messages handler with valid input."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
# Create some actual messages in the database
msg1 = DiscordMessage(
@ -440,9 +442,9 @@ def test_prev_messages_handler_success(
assert "Message 1" in result or "Message 2" in result
def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
def test_prev_messages_handler_with_defaults(db_session, sample_bot_user, sample_discord_user):
"""Test previous messages handler with default values."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
result = tool.function({})
@ -450,35 +452,35 @@ def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
assert isinstance(result, str)
def test_prev_messages_handler_invalid_input(sample_discord_user):
def test_prev_messages_handler_invalid_input(sample_bot_user, sample_discord_user):
"""Test previous messages handler with non-dict input."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Input must be a dictionary"):
tool.function("not a dict")
def test_prev_messages_handler_invalid_max_messages(sample_discord_user):
def test_prev_messages_handler_invalid_max_messages(sample_bot_user, sample_discord_user):
"""Test previous messages handler with invalid max_messages (negative value)."""
# Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
# so we test with -1 which actually triggers the validation
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Max messages must be greater than 0"):
tool.function({"max_messages": -1})
def test_prev_messages_handler_invalid_offset(sample_discord_user):
def test_prev_messages_handler_invalid_offset(sample_bot_user, sample_discord_user):
"""Test previous messages handler with invalid offset."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
tool.function({"offset": -1})
def test_prev_messages_handler_non_integer_values(sample_discord_user):
def test_prev_messages_handler_non_integer_values(sample_bot_user, sample_discord_user):
"""Test previous messages handler with non-integer values."""
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Max messages and offset must be integers"):
tool.function({"max_messages": "not an int"})
@ -496,10 +498,10 @@ def test_make_discord_tools_with_user_and_channel(
model="test-model",
)
# Should have: schedule_message, previous_messages, update_channel_summary,
# Should have: schedule_discord_message, previous_messages, update_channel_summary,
# update_user_summary, update_server_summary, add_reaction
assert len(tools) == 6
assert "schedule_message" in tools
assert "schedule_discord_message" in tools
assert "previous_messages" in tools
assert "update_channel_summary" in tools
assert "update_user_summary" in tools
@ -516,10 +518,10 @@ def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user)
model="test-model",
)
# Should have: schedule_message, previous_messages, update_user_summary
# Should have: schedule_discord_message, previous_messages, update_user_summary
# Note: Without channel, there's no channel summary tool
assert len(tools) >= 2 # At least schedule and previous messages
assert "schedule_message" in tools
assert "schedule_discord_message" in tools
assert "previous_messages" in tools
assert "update_user_summary" in tools
@ -533,10 +535,10 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
model="test-model",
)
# Should have: schedule_message, previous_messages, update_channel_summary,
# Should have: schedule_discord_message, previous_messages, update_channel_summary,
# update_server_summary, add_reaction (no user summary without author)
assert len(tools) == 5
assert "schedule_message" in tools
assert "schedule_discord_message" in tools
assert "previous_messages" in tools
assert "update_channel_summary" in tools
assert "update_server_summary" in tools

View File

@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
"http://localhost:8000/send_channel",
json={
"bot_id": BOT_ID,
"channel_name": "general",
"channel": "general",
"message": "Announcement!",
},
timeout=10,

View File

@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
"http://localhost:8000/send_channel",
json={
"bot_id": BOT_ID,
"channel_name": "general",
"channel": "general",
"message": "Announcement!",
},
timeout=10,

View File

@ -15,6 +15,7 @@ from memory.discord.commands import (
handle_chattiness,
handle_ignore,
handle_summary,
handle_proactive,
respond,
with_object_context,
handle_mcp_servers,
@ -377,3 +378,207 @@ async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
await handle_mcp_servers(context, action="list", url=None)
assert "Error: boom" in str(exc.value)
# ============================================================================
# Tests for handle_proactive
# ============================================================================
@pytest.mark.asyncio
async def test_handle_proactive_show_disabled(db_session, interaction, guild):
"""Test showing proactive settings when disabled."""
server = DiscordServer(id=guild.id, name="Guild", proactive_cron=None)
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context)
assert "disabled" in response.content.lower()
@pytest.mark.asyncio
async def test_handle_proactive_show_enabled(db_session, interaction, guild):
"""Test showing proactive settings when enabled."""
server = DiscordServer(
id=guild.id,
name="Guild",
proactive_cron="0 9 * * *",
proactive_prompt="Check on projects",
)
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context)
assert "0 9 * * *" in response.content
assert "Check on projects" in response.content
@pytest.mark.asyncio
async def test_handle_proactive_set_cron(db_session, interaction, guild):
"""Test setting proactive cron schedule."""
server = DiscordServer(id=guild.id, name="Guild")
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context, cron="0 9 * * *")
assert "Updated" in response.content
assert "0 9 * * *" in response.content
assert server.proactive_cron == "0 9 * * *"
@pytest.mark.asyncio
async def test_handle_proactive_set_prompt(db_session, interaction, guild):
"""Test setting proactive prompt."""
server = DiscordServer(id=guild.id, name="Guild", proactive_cron="0 9 * * *")
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context, prompt="Focus on daily standups")
assert "Updated" in response.content
assert server.proactive_prompt == "Focus on daily standups"
@pytest.mark.asyncio
async def test_handle_proactive_disable(db_session, interaction, guild):
"""Test disabling proactive check-ins."""
server = DiscordServer(
id=guild.id,
name="Guild",
proactive_cron="0 9 * * *",
proactive_prompt="Some prompt",
)
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
response = await handle_proactive(context, cron="off")
assert "disabled" in response.content.lower()
assert server.proactive_cron is None
@pytest.mark.asyncio
async def test_handle_proactive_invalid_cron(db_session, interaction, guild):
"""Test error on invalid cron expression."""
server = DiscordServer(id=guild.id, name="Guild")
db_session.add(server)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="server",
target=server,
display_name="server **Guild**",
)
with pytest.raises(CommandError) as exc:
await handle_proactive(context, cron="not a valid cron")
assert "Invalid cron expression" in str(exc.value)
@pytest.mark.asyncio
async def test_handle_proactive_user_scope(db_session, interaction, discord_user):
"""Test proactive settings for user scope."""
user_model = DiscordUser(
id=discord_user.id, username="testuser", proactive_cron=None
)
db_session.add(user_model)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="me",
target=user_model,
display_name="you (**testuser**)",
)
response = await handle_proactive(context, cron="0 9,17 * * 1-5")
assert "Updated" in response.content
assert user_model.proactive_cron == "0 9,17 * * 1-5"
@pytest.mark.asyncio
async def test_handle_proactive_channel_scope(
db_session, interaction, guild, text_channel
):
"""Test proactive settings for channel scope."""
server = DiscordServer(id=guild.id, name="Guild")
db_session.add(server)
db_session.flush()
channel_model = DiscordChannel(
id=text_channel.id,
name="general",
channel_type="text",
server_id=guild.id,
)
db_session.add(channel_model)
db_session.commit()
context = CommandContext(
session=db_session,
interaction=interaction,
actor=MagicMock(spec=DiscordUser),
scope="channel",
target=channel_model,
display_name="channel **#general**",
)
response = await handle_proactive(context, cron="0 12 * * *")
assert "Updated" in response.content
assert channel_model.proactive_cron == "0 12 * * *"

View File

@ -1,6 +1,5 @@
"""Tests for Discord MCP server management."""
import json
from unittest.mock import AsyncMock, Mock, patch
import aiohttp
@ -8,8 +7,8 @@ import discord
import pytest
from memory.common.db.models import MCPServer, MCPServerAssignment
from memory.common.mcp import mcp_call
from memory.discord.mcp import (
call_mcp_server,
find_mcp_server,
handle_mcp_add,
handle_mcp_connect,
@ -142,7 +141,7 @@ async def test_call_mcp_server_success():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = []
async for data in call_mcp_server(
async for data in mcp_call(
"https://mcp.example.com", "test_token", "tools/list", {}
):
results.append(data)
@ -172,7 +171,7 @@ async def test_call_mcp_server_error():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
with pytest.raises(ValueError, match="Failed to call MCP server"):
async for _ in call_mcp_server(
async for _ in mcp_call(
"https://mcp.example.com", "test_token", "tools/list"
):
pass
@ -203,7 +202,7 @@ async def test_call_mcp_server_invalid_json():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = []
async for data in call_mcp_server(
async for data in mcp_call(
"https://mcp.example.com", "test_token", "tools/list"
):
results.append(data)

View File

@ -19,6 +19,7 @@ from memory.common.db.models import (
DiscordChannel,
DiscordServer,
DiscordMessage,
DiscordBotUser,
HumanUser,
ScheduledLLMCall,
)
@ -34,6 +35,19 @@ def sample_discord_user(db_session):
return user
@pytest.fixture
def sample_bot_user(db_session, sample_discord_user):
"""Create a sample Discord bot user."""
bot = DiscordBotUser.create_with_api_key(
discord_users=[sample_discord_user],
name="Test Bot",
email="testbot@example.com",
)
db_session.add(bot)
db_session.commit()
return bot
@pytest.fixture
def sample_discord_channel(db_session):
"""Create a sample Discord channel."""
@ -290,13 +304,13 @@ def test_upsert_scheduled_message_cancels_earlier_call(
# Test previous_messages
def test_previous_messages_empty(db_session):
def test_previous_messages_empty(db_session, sample_bot_user):
"""Test getting previous messages when none exist."""
result = previous_messages(db_session, user_id=123, channel_id=456)
result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=123, channel_id=456)
assert result == []
def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel):
def test_previous_messages_filters_by_user(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test filtering messages by recipient user."""
# Create some messages
msg1 = DiscordMessage(
@ -322,14 +336,14 @@ def test_previous_messages_filters_by_user(db_session, sample_discord_user, samp
db_session.add_all([msg1, msg2])
db_session.commit()
result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None)
result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None)
assert len(result) == 2
# Should be in chronological order (oldest first)
assert result[0].message_id == 1
assert result[1].message_id == 2
def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel):
def test_previous_messages_limits_results(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test limiting the number of previous messages."""
# Create 15 messages
for i in range(15):
@ -347,7 +361,7 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
db_session.commit()
result = previous_messages(
db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5
db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None, max_messages=5
)
assert len(result) == 5
@ -355,10 +369,10 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
# Test comm_channel_prompt
def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel):
def test_comm_channel_prompt_basic(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test generating a basic communication channel prompt."""
result = comm_channel_prompt(
db_session, user=sample_discord_user, channel=sample_discord_channel
db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
)
assert "You are a bot communicating on Discord" in result
@ -366,31 +380,31 @@ def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_disco
assert len(result) > 0
def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel):
def test_comm_channel_prompt_includes_server_context(db_session, sample_bot_user, sample_discord_channel):
"""Test that prompt includes server context when available."""
server = sample_discord_channel.server
server.summary = "Gaming community server"
db_session.commit()
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
assert "server_context" in result.lower()
assert "Gaming community server" in result
def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel):
def test_comm_channel_prompt_includes_channel_context(db_session, sample_bot_user, sample_discord_channel):
"""Test that prompt includes channel context."""
sample_discord_channel.summary = "General discussion channel"
db_session.commit()
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
assert "channel_context" in result.lower()
assert "General discussion channel" in result
def test_comm_channel_prompt_includes_user_notes(
db_session, sample_discord_user, sample_discord_channel
db_session, sample_bot_user, sample_discord_user, sample_discord_channel
):
"""Test that prompt includes user notes from previous messages."""
sample_discord_user.summary = "Helpful community member"
@ -411,7 +425,7 @@ def test_comm_channel_prompt_includes_user_notes(
db_session.commit()
result = comm_channel_prompt(
db_session, user=sample_discord_user, channel=sample_discord_channel
db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
)
assert "user_notes" in result.lower()
@ -442,12 +456,16 @@ def test_call_llm_includes_web_search_and_mcp_servers(
web_tool_instance = MagicMock(name="web_tool")
mock_web_search.return_value = web_tool_instance
bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
bot_user = SimpleNamespace(
system_user=SimpleNamespace(discord_id=999888777),
system_prompt="bot prompt"
)
from_user = SimpleNamespace(id=123)
mcp_model = SimpleNamespace(
name="Server",
mcp_server_url="https://mcp.example.com",
access_token="token123",
disabled_tools=[],
)
result = call_llm(
@ -502,7 +520,10 @@ def test_call_llm_filters_disallowed_tools(
mock_web_search.return_value = MagicMock(name="web_tool")
bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
bot_user = SimpleNamespace(
system_user=SimpleNamespace(discord_id=999888777),
system_prompt=None
)
from_user = SimpleNamespace(id=1)
call_llm(

View File

@ -14,12 +14,25 @@ from memory.workers.tasks import discord
@pytest.fixture
def discord_bot_user(db_session):
# Create a discord user for the bot first
bot_discord_user = DiscordUser(
id=999999999,
username="testbot",
)
db_session.add(bot_discord_user)
db_session.flush()
bot = DiscordBotUser.create_with_api_key(
discord_users=[],
discord_users=[bot_discord_user],
name="Test Bot",
email="bot@example.com",
)
db_session.add(bot)
db_session.flush()
# Link the discord user to the system user
bot_discord_user.system_user_id = bot.id
db_session.commit()
return bot
@ -176,26 +189,29 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel):
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.workers.tasks.discord.create_provider")
@patch("memory.workers.tasks.discord.call_llm")
@patch("memory.workers.tasks.discord.discord.trigger_typing_channel")
def test_should_process_normal_message(
mock_create_provider,
mock_trigger_typing,
mock_call_llm,
db_session,
mock_discord_user,
mock_discord_server,
mock_discord_channel,
discord_bot_user,
):
"""Test should_process returns True for normal messages."""
# Mock the LLM provider to return "yes"
mock_provider = Mock()
mock_provider.generate.return_value = "<response>yes</response>"
mock_provider.as_messages.return_value = []
mock_create_provider.return_value = mock_provider
# Create a separate recipient user (the bot)
bot_discord_user = discord_bot_user.discord_users[0]
# Mock call_llm to return a high number (100 = always process)
mock_call_llm.return_value = "<response><number>100</number><reason>Test</reason></response>"
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
from_id=mock_discord_user.id,
recipient_id=mock_discord_user.id,
recipient_id=bot_discord_user.id, # Bot is recipient, not the from_user
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
@ -207,6 +223,8 @@ def test_should_process_normal_message(
db_session.refresh(message)
assert discord.should_process(message) is True
mock_call_llm.assert_called_once()
mock_trigger_typing.assert_called_once()
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
@ -344,6 +362,7 @@ def test_add_discord_message_success(db_session, sample_message_data, qdrant):
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
"""Test adding a Discord message that is a reply."""
sample_message_data["message_reference_id"] = 111222333
sample_message_data["message_type"] = "reply" # Explicitly set message_type
discord.add_discord_message(**sample_message_data)
@ -523,8 +542,17 @@ def test_edit_discord_message_updates_context(
assert result["status"] == "processed"
def test_process_discord_message_success(db_session, sample_message_data, qdrant):
@patch("memory.workers.tasks.discord.send_discord_response")
@patch("memory.workers.tasks.discord.call_llm")
def test_process_discord_message_success(
mock_call_llm, mock_send_response, db_session, sample_message_data, qdrant
):
"""Test processing a Discord message."""
# Mock LLM to return a response
mock_call_llm.return_value = "Test response from bot"
# Mock Discord API to succeed
mock_send_response.return_value = True
# Add a message first
add_result = discord.add_discord_message(**sample_message_data)
message_id = add_result["discordmessage_id"]
@ -534,6 +562,8 @@ def test_process_discord_message_success(db_session, sample_message_data, qdrant
assert result["status"] == "processed"
assert result["message_id"] == message_id
mock_call_llm.assert_called_once()
mock_send_response.assert_called_once()
def test_process_discord_message_not_found(db_session):

View File

@ -0,0 +1,536 @@
"""Tests for proactive check-in tasks."""
import pytest
from datetime import datetime, timezone, timedelta
from unittest.mock import Mock, patch, MagicMock
from memory.common.db.models import (
DiscordBotUser,
DiscordUser,
DiscordChannel,
DiscordServer,
)
from memory.workers.tasks import proactive
from memory.workers.tasks.proactive import is_cron_due
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def bot_user(db_session):
"""Create a bot user for testing."""
bot_discord_user = DiscordUser(
id=999999999,
username="testbot",
)
db_session.add(bot_discord_user)
db_session.flush()
user = DiscordBotUser.create_with_api_key(
discord_users=[bot_discord_user],
name="testbot",
email="bot@example.com",
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def target_user(db_session):
"""Create a target Discord user for testing."""
discord_user = DiscordUser(
id=123456789,
username="targetuser",
proactive_cron="0 9 * * *", # 9am daily
chattiness_threshold=50,
)
db_session.add(discord_user)
db_session.commit()
return discord_user
@pytest.fixture
def target_user_no_cron(db_session):
"""Create a target Discord user without proactive cron."""
discord_user = DiscordUser(
id=123456790,
username="nocronuser",
proactive_cron=None,
)
db_session.add(discord_user)
db_session.commit()
return discord_user
@pytest.fixture
def target_server(db_session):
"""Create a target Discord server for testing."""
server = DiscordServer(
id=987654321,
name="Test Server",
proactive_cron="0 */4 * * *", # Every 4 hours
chattiness_threshold=30,
)
db_session.add(server)
db_session.commit()
return server
@pytest.fixture
def target_channel(db_session, target_server):
"""Create a target Discord channel for testing."""
channel = DiscordChannel(
id=111222333,
name="test-channel",
channel_type="text",
server_id=target_server.id,
proactive_cron="0 12 * * 1-5", # Noon on weekdays
chattiness_threshold=70,
)
db_session.add(channel)
db_session.commit()
return channel
# ============================================================================
# Tests for is_cron_due helper
# ============================================================================
@pytest.mark.parametrize(
"cron_expr,now,last_run,expected",
[
# Cron is due when never run before and time matches
(
"0 9 * * *",
datetime(2025, 12, 24, 9, 0, 30, tzinfo=timezone.utc),
None,
True,
),
# Cron is due when last run was before the scheduled time
(
"0 9 * * *",
datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc),
datetime(2025, 12, 23, 9, 0, 0, tzinfo=timezone.utc),
True,
),
# Cron is NOT due when already run this period
(
"0 9 * * *",
datetime(2025, 12, 24, 9, 30, 0, tzinfo=timezone.utc),
datetime(2025, 12, 24, 9, 5, 0, tzinfo=timezone.utc),
False,
),
# Cron is NOT due when current time is before scheduled time
(
"0 9 * * *",
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
None,
False,
),
# Hourly cron schedule
(
"0 * * * *",
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
datetime(2025, 12, 24, 11, 0, 0, tzinfo=timezone.utc),
True,
),
# Every 4 hours cron schedule
(
"0 */4 * * *",
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
True,
),
],
ids=[
"due_never_run",
"due_last_run_before_schedule",
"not_due_already_run",
"not_due_too_early",
"due_hourly",
"due_every_4_hours",
],
)
def test_is_cron_due(cron_expr, now, last_run, expected):
"""Test is_cron_due with various scenarios."""
assert is_cron_due(cron_expr, last_run, now) is expected
def test_is_cron_due_invalid_expression():
"""Test invalid cron expression returns False."""
now = datetime(2025, 12, 24, 9, 0, 0, tzinfo=timezone.utc)
assert is_cron_due("invalid cron", None, now) is False
def test_is_cron_due_with_naive_last_run():
"""Test cron handles naive datetime for last_run."""
now = datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc)
cron_expr = "0 9 * * *"
last_run = datetime(2025, 12, 23, 9, 0, 0) # Naive datetime
assert is_cron_due(cron_expr, last_run, now) is True
# ============================================================================
# Tests for evaluate_proactive_checkins task
# ============================================================================
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
@patch("memory.workers.tasks.proactive.is_cron_due")
@patch("memory.workers.tasks.proactive.make_session")
def test_evaluate_proactive_checkins_dispatches_due(
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
):
"""Test that due check-ins are dispatched."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
mock_is_cron_due.return_value = True
result = proactive.evaluate_proactive_checkins()
assert result["count"] >= 1
mock_execute.delay.assert_called()
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
@patch("memory.workers.tasks.proactive.is_cron_due")
@patch("memory.workers.tasks.proactive.make_session")
def test_evaluate_proactive_checkins_skips_not_due(
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
):
"""Test that not-due check-ins are not dispatched."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
mock_is_cron_due.return_value = False
result = proactive.evaluate_proactive_checkins()
assert result["count"] == 0
mock_execute.delay.assert_not_called()
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
@patch("memory.workers.tasks.proactive.make_session")
def test_evaluate_proactive_checkins_skips_no_cron(
mock_make_session, mock_execute, db_session, target_user_no_cron
):
"""Test that entities without proactive_cron are skipped."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
result = proactive.evaluate_proactive_checkins()
for call in mock_execute.delay.call_args_list:
entity_type, entity_id = call[0]
assert entity_id != target_user_no_cron.id
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
@patch("memory.workers.tasks.proactive.is_cron_due")
@patch("memory.workers.tasks.proactive.make_session")
def test_evaluate_proactive_checkins_multiple_entity_types(
mock_make_session,
mock_is_cron_due,
mock_execute,
db_session,
target_user,
target_server,
target_channel,
):
"""Test that check-ins are dispatched for users, channels, and servers."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
mock_is_cron_due.return_value = True
result = proactive.evaluate_proactive_checkins()
assert result["count"] == 3
dispatched_types = {d["type"] for d in result["dispatched"]}
assert "user" in dispatched_types
assert "channel" in dispatched_types
assert "server" in dispatched_types
# ============================================================================
# Tests for execute_proactive_checkin task
# ============================================================================
@patch("memory.workers.tasks.proactive.send_discord_response")
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_sends_when_above_threshold(
mock_make_session,
mock_get_bot,
mock_call_llm,
mock_send,
db_session,
target_user,
bot_user,
):
"""Test check-in is sent when interest exceeds threshold."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.side_effect = [
"<response><number>80</number><reason>Should check in</reason></response>",
"Hey! Just checking in - how are things going?",
]
mock_send.return_value = True
result = proactive.execute_proactive_checkin("user", target_user.id)
assert result["status"] == "sent"
assert result["interest"] == 80
mock_send.assert_called_once()
db_session.refresh(target_user)
assert target_user.last_proactive_at is not None
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_skips_below_threshold(
mock_make_session,
mock_get_bot,
mock_call_llm,
db_session,
target_user,
bot_user,
):
"""Test check-in is skipped when interest is below threshold."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.return_value = (
"<response><number>30</number><reason>Not much to say</reason></response>"
)
result = proactive.execute_proactive_checkin("user", target_user.id)
assert result["status"] == "below_threshold"
assert result["interest"] == 30
assert result["threshold"] == 50
@pytest.mark.parametrize(
"llm_response,expected_status",
[
(None, "no_eval_response"),
("I'm not sure what to say.", "no_score_in_response"),
],
ids=["no_response", "malformed_response"],
)
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_handles_bad_llm_response(
mock_make_session,
mock_get_bot,
mock_call_llm,
llm_response,
expected_status,
db_session,
target_user,
bot_user,
):
"""Test handling of missing or malformed LLM responses."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.return_value = llm_response
result = proactive.execute_proactive_checkin("user", target_user.id)
assert result["status"] == expected_status
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_nonexistent_entity(mock_make_session, db_session):
"""Test handling when entity doesn't exist."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
result = proactive.execute_proactive_checkin("user", 999999)
assert "error" in result
assert "not found" in result["error"]
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_no_bot_user(
mock_make_session, mock_get_bot, db_session, target_user
):
"""Test handling when no bot user is found."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
mock_get_bot.return_value = None
result = proactive.execute_proactive_checkin("user", target_user.id)
assert "error" in result
assert "No bot user" in result["error"]
@patch("memory.workers.tasks.proactive.send_discord_response")
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_uses_proactive_prompt(
mock_make_session,
mock_get_bot,
mock_call_llm,
mock_send,
db_session,
bot_user,
):
"""Test that proactive_prompt is included in the evaluation."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
user_with_prompt = DiscordUser(
id=555666777,
username="promptuser",
proactive_cron="0 9 * * *",
proactive_prompt="Focus on their coding projects",
chattiness_threshold=50,
)
db_session.add(user_with_prompt)
db_session.commit()
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.side_effect = [
"<response><number>80</number><reason>Check on projects</reason></response>",
"How are your coding projects coming along?",
]
mock_send.return_value = True
result = proactive.execute_proactive_checkin("user", user_with_prompt.id)
assert result["status"] == "sent"
call_args = mock_call_llm.call_args_list[0]
messages_arg = call_args.kwargs.get("messages") or call_args[1].get("messages")
assert any("Focus on their coding projects" in str(m) for m in messages_arg)
@patch("memory.workers.tasks.proactive.send_discord_response")
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_channel(
mock_make_session,
mock_get_bot,
mock_call_llm,
mock_send,
db_session,
target_channel,
bot_user,
):
"""Test check-in to a channel."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.side_effect = [
"<response><number>50</number><reason>Check channel</reason></response>",
"Good morning everyone!",
]
mock_send.return_value = True
result = proactive.execute_proactive_checkin("channel", target_channel.id)
assert result["status"] == "sent"
assert result["entity_type"] == "channel"
send_call = mock_send.call_args
assert send_call.kwargs.get("channel_id") == target_channel.id
@patch("memory.workers.tasks.proactive.send_discord_response")
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_updates_last_proactive_at(
mock_make_session,
mock_get_bot,
mock_call_llm,
mock_send,
db_session,
target_user,
bot_user,
):
"""Test that last_proactive_at is updated after successful check-in."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.side_effect = [
"<response><number>80</number><reason>Check in</reason></response>",
"Hey there!",
]
mock_send.return_value = True
before_time = datetime.now(timezone.utc)
proactive.execute_proactive_checkin("user", target_user.id)
after_time = datetime.now(timezone.utc)
db_session.refresh(target_user)
assert target_user.last_proactive_at is not None
assert before_time <= target_user.last_proactive_at <= after_time
@patch("memory.workers.tasks.proactive.call_llm")
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
@patch("memory.workers.tasks.proactive.make_session")
def test_execute_proactive_checkin_updates_last_proactive_at_on_skip(
mock_make_session,
mock_get_bot,
mock_call_llm,
db_session,
target_user,
bot_user,
):
"""Test that last_proactive_at is updated even when check-in is skipped."""
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
mock_make_session.return_value.__exit__ = Mock(return_value=False)
bot_discord_user = bot_user.discord_users[0]
bot_discord_user.system_user = bot_user
mock_get_bot.return_value = bot_discord_user
mock_call_llm.return_value = (
"<response><number>10</number><reason>Nothing to say</reason></response>"
)
proactive.execute_proactive_checkin("user", target_user.id)
db_session.refresh(target_user)
assert target_user.last_proactive_at is not None

View File

@ -16,8 +16,16 @@ from memory.workers.tasks import scheduled_calls
@pytest.fixture
def sample_user(db_session):
"""Create a sample user for testing."""
# Create a discord user for the bot
bot_discord_user = DiscordUser(
id=999999999,
username="testbot",
)
db_session.add(bot_discord_user)
db_session.flush()
user = DiscordBotUser.create_with_api_key(
discord_users=[],
discord_users=[bot_discord_user],
name="testbot",
email="bot@example.com",
)
@ -122,65 +130,64 @@ def future_scheduled_call(db_session, sample_user, sample_discord_user):
return call
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
@patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
"""Test sending to Discord user."""
response = "This is a test response."
scheduled_calls.send_to_discord(pending_scheduled_call, response)
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, response)
mock_send_dm.assert_called_once_with(
pending_scheduled_call.user_id,
999999999, # bot_id
"testuser", # username, not ID
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
response,
)
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
@patch("memory.discord.messages.discord.send_to_channel")
def test_send_to_discord_channel(mock_send_to_channel, completed_scheduled_call):
"""Test sending to Discord channel."""
response = "This is a channel response."
scheduled_calls.send_to_discord(completed_scheduled_call, response)
scheduled_calls.send_to_discord(999999999, completed_scheduled_call, response)
mock_broadcast.assert_called_once_with(
completed_scheduled_call.user_id,
"test-channel", # channel name, not ID
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
mock_send_to_channel.assert_called_once_with(
999999999, # bot_id
completed_scheduled_call.discord_channel.id, # channel ID, not name
response,
)
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
@patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
"""Test message truncation for long responses."""
long_response = "A" * 2500 # Very long response
scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, long_response)
# Verify the message was truncated
# With the new implementation, send_discord_response sends the full response
# No truncation happens in _send_to_discord
args, kwargs = mock_send_dm.call_args
assert args[0] == pending_scheduled_call.user_id
assert args[0] == 999999999 # bot_id
message = args[2]
assert len(message) <= 1950 # Should be truncated
assert message.endswith("... (response truncated)")
assert message == long_response
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
@patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
"""Test that normal length messages are not truncated."""
normal_response = "This is a normal length response."
scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args
assert args[0] == pending_scheduled_call.user_id
assert args[0] == 999999999 # bot_id
message = args[2]
assert not message.endswith("... (response truncated)")
assert "This is a normal length response." in message
assert message == normal_response
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_success(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@ -189,12 +196,8 @@ def test_execute_scheduled_call_success(
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
# Verify LLM was called with correct parameters
mock_llm_call.assert_called_once_with(
prompt="What is the weather like today?",
model="anthropic/claude-3-5-sonnet-20241022",
system_prompt="You are a helpful assistant.",
)
# Verify LLM was called
mock_llm_call.assert_called_once()
# Verify result
assert result["success"] is True
@ -218,7 +221,7 @@ def test_execute_scheduled_call_not_found(db_session):
assert result == {"error": "Scheduled call not found"}
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_not_pending(
mock_llm_call, completed_scheduled_call, db_session
):
@ -229,8 +232,8 @@ def test_execute_scheduled_call_not_pending(
mock_llm_call.assert_not_called()
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_with_default_system_prompt(
mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
):
@ -254,16 +257,12 @@ def test_execute_scheduled_call_with_default_system_prompt(
scheduled_calls.execute_scheduled_call(call.id)
# Verify default system prompt was used
mock_llm_call.assert_called_once_with(
prompt="Test prompt",
model="anthropic/claude-3-5-sonnet-20241022",
system_prompt=None, # The code uses system_prompt as-is, not a default
)
# Verify LLM was called
mock_llm_call.assert_called_once()
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_discord_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@ -286,26 +285,27 @@ def test_execute_scheduled_call_discord_error(
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_llm_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
"""Test execution when LLM call fails."""
mock_llm_call.side_effect = Exception("LLM API error")
# The safe_task_execution decorator should catch this
# The execute_scheduled_call function catches the exception and returns an error response
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
assert result["status"] == "error"
assert "LLM API error" in result["error"]
assert result["success"] is False
assert "error" in result
assert "LLM call failed" in result["error"]
# Discord should not be called
mock_send_discord.assert_not_called()
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_long_response_truncation(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@ -477,8 +477,8 @@ def test_run_scheduled_calls_timezone_handling(
mock_execute_delay.delay.assert_called_once_with(due_call.id)
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_status_transition_pending_to_executing_to_completed(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@ -502,14 +502,14 @@ def test_status_transition_pending_to_executing_to_completed(
"has_discord_user,has_discord_channel,expected_method",
[
(True, False, "send_dm"),
(False, True, "broadcast_message"),
(True, True, "send_dm"), # User takes precedence
(False, True, "send_to_channel"),
(True, True, "send_to_channel"), # Channel takes precedence in the implementation
],
)
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
@patch("memory.discord.messages.discord.send_dm")
@patch("memory.discord.messages.discord.send_to_channel")
def test_discord_destination_priority(
mock_broadcast,
mock_send_to_channel,
mock_send_dm,
has_discord_user,
has_discord_channel,
@ -535,50 +535,39 @@ def test_discord_destination_priority(
db_session.commit()
response = "Test response"
scheduled_calls.send_to_discord(call, response)
scheduled_calls.send_to_discord(999999999, call, response)
if expected_method == "send_dm":
mock_send_dm.assert_called_once()
mock_broadcast.assert_not_called()
mock_send_to_channel.assert_not_called()
else:
mock_broadcast.assert_called_once()
mock_send_to_channel.assert_called_once()
mock_send_dm.assert_not_called()
@pytest.mark.parametrize(
"topic,model,response,expected_in_message",
"topic,model,response",
[
(
"Weather Check",
"anthropic/claude-3-5-sonnet-20241022",
"It's sunny!",
[
"**Topic:** Weather Check",
"**Model:** anthropic/claude-3-5-sonnet-20241022",
"**Response:** It's sunny!",
],
),
(
"Test Topic",
"gpt-4",
"Hello world",
["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
),
(
"Long Topic Name Here",
"claude-2",
"Short",
[
"**Topic:** Long Topic Name Here",
"**Model:** claude-2",
"**Response:** Short",
],
),
],
)
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
"""Test the Discord message formatting with different inputs."""
@patch("memory.discord.messages.discord.send_dm")
def test_message_formatting(mock_send_dm, topic, model, response):
"""Test that _send_to_discord sends the response as-is."""
# Create a mock scheduled call with a mock Discord user
mock_discord_user = Mock()
mock_discord_user.username = "testuser"
@ -590,16 +579,15 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
mock_call.discord_user = mock_discord_user
mock_call.discord_channel = None
scheduled_calls.send_to_discord(mock_call, response)
scheduled_calls.send_to_discord(999999999, mock_call, response)
# Get the actual message that was sent
args, kwargs = mock_send_dm.call_args
assert args[0] == mock_call.user_id
assert args[0] == 999999999 # bot_id
actual_message = args[2]
# Verify all expected parts are in the message
for expected_part in expected_in_message:
assert expected_part in actual_message
# The new implementation sends the response as-is, without formatting
assert actual_message == response
@pytest.mark.parametrize(
@ -612,7 +600,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
("cancelled", False),
],
)
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_status_check(
mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
):