diff --git a/db/migrations/versions/20251224_160000_add_proactive_checkins.py b/db/migrations/versions/20251224_160000_add_proactive_checkins.py
new file mode 100644
index 0000000..e8a8a5d
--- /dev/null
+++ b/db/migrations/versions/20251224_160000_add_proactive_checkins.py
@@ -0,0 +1,45 @@
+"""Add proactive check-in fields to Discord entities
+
+Revision ID: e1f2a3b4c5d6
+Revises: d0e1f2a3b4c5
+Create Date: 2025-12-24 16:00:00.000000
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "e1f2a3b4c5d6"
+down_revision: Union[str, None] = "d0e1f2a3b4c5"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # Add proactive fields to all MessageProcessor tables
+ for table in ["discord_servers", "discord_channels", "discord_users"]:
+ op.add_column(
+ table,
+ sa.Column("proactive_cron", sa.Text(), nullable=True),
+ )
+ op.add_column(
+ table,
+ sa.Column("proactive_prompt", sa.Text(), nullable=True),
+ )
+ op.add_column(
+ table,
+ sa.Column(
+ "last_proactive_at", sa.DateTime(timezone=True), nullable=True
+ ),
+ )
+
+
+def downgrade() -> None:
+ for table in ["discord_servers", "discord_channels", "discord_users"]:
+ op.drop_column(table, "last_proactive_at")
+ op.drop_column(table, "proactive_prompt")
+ op.drop_column(table, "proactive_cron")
diff --git a/requirements/requirements-api.txt b/requirements/requirements-api.txt
index 3e1a402..0fc683f 100644
--- a/requirements/requirements-api.txt
+++ b/requirements/requirements-api.txt
@@ -1,7 +1,8 @@
-fastapi==0.112.2
+fastapi>=0.115.12
uvicorn==0.29.0
python-jose==3.3.0
python-multipart==0.0.9
sqladmin==0.20.1
mcp==1.10.0
+fastmcp>=2.10.0
slowapi==0.1.9
\ No newline at end of file
diff --git a/requirements/requirements-common.txt b/requirements/requirements-common.txt
index 09674e1..2cbcd4e 100644
--- a/requirements/requirements-common.txt
+++ b/requirements/requirements-common.txt
@@ -1,14 +1,15 @@
sqlalchemy==2.0.30
psycopg2-binary==2.9.9
-pydantic==2.7.2
+pydantic>=2.11.7
alembic==1.13.1
dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
anthropic==0.69.0
openai==2.3.0
-# Pin the httpx version, as newer versions break the anthropic client
-httpx==0.27.0
+# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
+httpx>=0.28.1
celery[redis,sqs]==5.3.6
+croniter==2.0.1
cryptography==43.0.0
bcrypt==4.1.2
\ No newline at end of file
diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt
index 445b8ba..2a89bfb 100644
--- a/requirements/requirements-dev.txt
+++ b/requirements/requirements-dev.txt
@@ -1,7 +1,9 @@
pytest==7.4.4
pytest-cov==4.1.0
+pytest-asyncio==0.23.0
black==23.12.1
mypy==1.8.0
-isort==5.13.2
+isort==5.13.2
testcontainers[qdrant]==4.10.0
-click==8.1.7
\ No newline at end of file
+click==8.1.7
+croniter==2.0.1
\ No newline at end of file
diff --git a/src/memory/api/MCP/__init__.py b/src/memory/api/MCP/__init__.py
index 07f0efa..540920d 100644
--- a/src/memory/api/MCP/__init__.py
+++ b/src/memory/api/MCP/__init__.py
@@ -1,8 +1,16 @@
-import memory.api.MCP.tools
-import memory.api.MCP.memory
-import memory.api.MCP.metadata
-import memory.api.MCP.schedules
-import memory.api.MCP.books
-import memory.api.MCP.manifest
-import memory.api.MCP.github
-import memory.api.MCP.people
+"""
+MCP server with composed subservers.
+
+Subservers are mounted with prefixes:
+- core: search_knowledge_base, observe, search_observations, create_note, note_files, fetch_file
+- github: list_github_issues, search_github_issues, github_issue_details, github_work_summary, github_repo_overview
+- people: add_person, update_person_info, get_person, list_people, delete_person
+- schedule: schedule_message, list_scheduled_llm_calls, cancel_scheduled_llm_call
+- books: all_books, read_book
+- meta: get_metadata_schemas, get_all_tags, get_all_subjects, get_all_observation_types, get_current_time, get_authenticated_user, get_forecasts
+"""
+
+# Import base to trigger subserver mounting
+from memory.api.MCP.base import mcp, get_current_user
+
+__all__ = ["mcp", "get_current_user"]
diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py
index 92c4ad0..719c4e0 100644
--- a/src/memory/api/MCP/base.py
+++ b/src/memory/api/MCP/base.py
@@ -1,60 +1,28 @@
import logging
import pathlib
-from typing import cast
-from mcp.server.auth.handlers.authorize import AuthorizationRequest
-from mcp.server.auth.handlers.token import (
- AuthorizationCodeRequest,
- RefreshTokenRequest,
- TokenRequest,
-)
-from mcp.server.auth.middleware.auth_context import get_access_token
-from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
-from mcp.server.fastmcp import FastMCP
-from mcp.shared.auth import OAuthClientMetadata
-from pydantic import AnyHttpUrl
+from fastmcp import FastMCP
+from fastmcp.server.dependencies import get_access_token
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.templating import Jinja2Templates
-from memory.api.MCP.oauth_provider import (
- ALLOWED_SCOPES,
- BASE_SCOPES,
- SimpleOAuthProvider,
-)
+from memory.api.MCP.oauth_provider import SimpleOAuthProvider
+from memory.api.MCP.servers.books import books_mcp
+from memory.api.MCP.servers.core import core_mcp
+from memory.api.MCP.servers.github import github_mcp
+from memory.api.MCP.servers.meta import meta_mcp
+from memory.api.MCP.servers.meta import set_auth_provider as set_meta_auth
+from memory.api.MCP.servers.people import people_mcp
+from memory.api.MCP.servers.schedule import schedule_mcp
+from memory.api.MCP.servers.schedule import set_auth_provider as set_schedule_auth
from memory.common import settings
-from memory.common.db.connection import make_session
+from memory.common.db.connection import make_session, get_engine
from memory.common.db.models import OAuthState, UserSession
from memory.common.db.models.users import HumanUser
logger = logging.getLogger(__name__)
-
-
-def validate_metadata(klass: type):
- orig_validate = klass.model_validate
-
- def validate(data: dict):
- data = dict(data)
- if "redirect_uris" in data:
- data["redirect_uris"] = [
- str(uri).replace("cursor://", "http://")
- for uri in data["redirect_uris"]
- ]
- if "redirect_uri" in data:
- data["redirect_uri"] = str(data["redirect_uri"]).replace(
- "cursor://", "http://"
- )
-
- return orig_validate(data)
-
- klass.model_validate = validate
-
-
-validate_metadata(OAuthClientMetadata)
-validate_metadata(AuthorizationRequest)
-validate_metadata(AuthorizationCodeRequest)
-validate_metadata(RefreshTokenRequest)
-validate_metadata(TokenRequest)
+engine = get_engine()
# Setup templates
@@ -63,22 +31,10 @@ templates = Jinja2Templates(directory=template_dir)
oauth_provider = SimpleOAuthProvider()
-auth_settings = AuthSettings(
- issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
- resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore
- client_registration_options=ClientRegistrationOptions(
- enabled=True,
- valid_scopes=ALLOWED_SCOPES,
- default_scopes=BASE_SCOPES,
- ),
- required_scopes=BASE_SCOPES,
-)
mcp = FastMCP(
"memory",
- stateless_http=True,
- auth_server_provider=oauth_provider,
- auth=auth_settings,
+ auth=oauth_provider,
)
@@ -162,3 +118,52 @@ def get_current_user() -> dict:
"client_id": access_token.client_id,
"user": user_info,
}
+
+
+@mcp.custom_route("/health", methods=["GET"])
+async def health_check(request: Request):
+ """Health check endpoint that verifies all dependencies are accessible."""
+ from sqlalchemy import text
+
+ checks = {"mcp_oauth": "enabled"}
+ all_healthy = True
+
+ # Check database connection
+ try:
+ with engine.connect() as conn:
+ conn.execute(text("SELECT 1"))
+ checks["database"] = "healthy"
+ except Exception as e:
+ logger.error(f"Database health check failed: {e}")
+ checks["database"] = "unhealthy"
+ all_healthy = False
+
+ # Check Qdrant connection
+ try:
+ from memory.common.qdrant import get_qdrant_client
+
+ client = get_qdrant_client()
+ client.get_collections()
+ checks["qdrant"] = "healthy"
+ except Exception as e:
+ logger.error(f"Qdrant health check failed: {e}")
+ checks["qdrant"] = "unhealthy"
+ all_healthy = False
+
+ checks["status"] = "healthy" if all_healthy else "degraded"
+ status_code = 200 if all_healthy else 503
+ return JSONResponse(checks, status_code=status_code)
+
+
+# Inject auth provider into subservers that need it
+set_schedule_auth(get_current_user)
+set_meta_auth(get_current_user)
+
+# Mount all subservers onto the main MCP server
+# Tools will be prefixed with their server name (e.g., core_search_knowledge_base)
+mcp.mount(core_mcp, prefix="core")
+mcp.mount(github_mcp, prefix="github")
+mcp.mount(people_mcp, prefix="people")
+mcp.mount(schedule_mcp, prefix="schedule")
+mcp.mount(books_mcp, prefix="books")
+mcp.mount(meta_mcp, prefix="meta")
diff --git a/src/memory/api/MCP/manifest.py b/src/memory/api/MCP/manifest.py
deleted file mode 100644
index 6dcd76a..0000000
--- a/src/memory/api/MCP/manifest.py
+++ /dev/null
@@ -1,119 +0,0 @@
-import asyncio
-import aiohttp
-from datetime import datetime
-
-from typing import TypedDict, NotRequired, Literal
-from memory.api.MCP.tools import mcp
-
-
-class BinaryProbs(TypedDict):
- prob: float
-
-
-class MultiProbs(TypedDict):
- answerProbs: dict[str, float]
-
-
-Probs = dict[str, BinaryProbs | MultiProbs]
-OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
-
-
-class MarketAnswer(TypedDict):
- id: str
- text: str
- resolutionProbability: float
-
-
-class MarketDetails(TypedDict):
- id: str
- createdTime: int
- question: str
- outcomeType: OutcomeType
- textDescription: str
- groupSlugs: list[str]
- volume: float
- isResolved: bool
- answers: list[MarketAnswer]
-
-
-class Market(TypedDict):
- id: str
- url: str
- question: str
- volume: int
- createdTime: int
- outcomeType: OutcomeType
- createdAt: NotRequired[str]
- description: NotRequired[str]
- answers: NotRequired[dict[str, float]]
- probability: NotRequired[float]
- details: NotRequired[MarketDetails]
-
-
-async def get_details(session: aiohttp.ClientSession, market_id: str):
- async with session.get(
- f"https://api.manifold.markets/v0/market/{market_id}"
- ) as resp:
- resp.raise_for_status()
- return await resp.json()
-
-
-async def format_market(session: aiohttp.ClientSession, market: Market):
- if market.get("outcomeType") != "BINARY":
- details = await get_details(session, market["id"])
- market["answers"] = {
- answer["text"]: round(
- answer.get("resolutionProbability") or answer.get("probability") or 0, 3
- )
- for answer in details["answers"]
- }
- if creationTime := market.get("createdTime"):
- market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
-
- fields = [
- "id",
- "name",
- "url",
- "question",
- "volume",
- "createdAt",
- "details",
- "probability",
- "answers",
- ]
- return {k: v for k, v in market.items() if k in fields}
-
-
-async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
- async with aiohttp.ClientSession() as session:
- async with session.get(
- "https://api.manifold.markets/v0/search-markets",
- params={
- "term": term,
- "contractType": "BINARY" if binary else "ALL",
- },
- ) as resp:
- resp.raise_for_status()
- markets = await resp.json()
-
- return await asyncio.gather(
- *[
- format_market(session, market)
- for market in markets
- if market.get("volume", 0) >= min_volume
- ]
- )
-
-
-@mcp.tool()
-async def get_forecasts(
- term: str, min_volume: int = 1000, binary: bool = False
-) -> list[dict]:
- """Get prediction market forecasts for a given term.
-
- Args:
- term: The term to search for.
- min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
- binary: Whether to only return binary markets.
- """
- return await search_markets(term, min_volume, binary)
diff --git a/src/memory/api/MCP/metadata.py b/src/memory/api/MCP/metadata.py
deleted file mode 100644
index ce54bec..0000000
--- a/src/memory/api/MCP/metadata.py
+++ /dev/null
@@ -1,119 +0,0 @@
-import logging
-from collections import defaultdict
-from typing import Annotated, TypedDict, get_args, get_type_hints
-
-from memory.common import qdrant
-from sqlalchemy import func
-
-from memory.api.MCP.tools import mcp
-from memory.common.db.connection import make_session
-from memory.common.db.models import SourceItem
-from memory.common.db.models.source_items import AgentObservation
-
-logger = logging.getLogger(__name__)
-
-
-class SchemaArg(TypedDict):
- type: str | None
- description: str | None
-
-
-class CollectionMetadata(TypedDict):
- schema: dict[str, SchemaArg]
- size: int
-
-
-def from_annotation(annotation: Annotated) -> SchemaArg | None:
- try:
- type_, description = get_args(annotation)
- type_str = str(type_)
- if type_str.startswith("typing."):
- type_str = type_str[7:]
- elif len((parts := type_str.split("'"))) > 1:
- type_str = parts[1]
- return SchemaArg(type=type_str, description=description)
- except IndexError:
- logger.error(f"Error from annotation: {annotation}")
- return None
-
-
-def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
- if not hasattr(klass, "as_payload"):
- return {}
-
- if not (payload_type := get_type_hints(klass.as_payload).get("return")):
- return {}
-
- return {
- name: schema
- for name, arg in payload_type.__annotations__.items()
- if (schema := from_annotation(arg))
- }
-
-
-@mcp.tool()
-async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
- """Get the metadata schema for each collection used in the knowledge base.
-
- These schemas can be used to filter the knowledge base.
-
- Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
-
- Example:
- ```
- {
- "mail": {"subject": {"type": "str", "description": "The subject of the email."}},
- "chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
- }
- """
- client = qdrant.get_qdrant_client()
- sizes = qdrant.get_collection_sizes(client)
- schemas = defaultdict(dict)
- for klass in SourceItem.__subclasses__():
- for collection in klass.get_collections():
- schemas[collection].update(get_schema(klass))
-
- return {
- collection: CollectionMetadata(schema=schema, size=size)
- for collection, schema in schemas.items()
- if (size := sizes.get(collection))
- }
-
-
-@mcp.tool()
-async def get_all_tags() -> list[str]:
- """Get all unique tags used across the entire knowledge base.
-
- Returns sorted list of tags from both observations and content.
- """
- with make_session() as session:
- tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
- return sorted({row[0] for row in tags_query if row[0] is not None})
-
-
-@mcp.tool()
-async def get_all_subjects() -> list[str]:
- """Get all unique subjects from observations about the user.
-
- Returns sorted list of subject identifiers used in observations.
- """
- with make_session() as session:
- return sorted(
- r.subject for r in session.query(AgentObservation.subject).distinct()
- )
-
-
-@mcp.tool()
-async def get_all_observation_types() -> list[str]:
- """Get all observation types that have been used.
-
- Standard types are belief, preference, behavior, contradiction, general, but there can be more.
- """
- with make_session() as session:
- return sorted(
- {
- r.observation_type
- for r in session.query(AgentObservation.observation_type).distinct()
- if r.observation_type is not None
- }
- )
diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py
index 3f9ac9f..db5b6f4 100644
--- a/src/memory/api/MCP/oauth_provider.py
+++ b/src/memory/api/MCP/oauth_provider.py
@@ -5,11 +5,12 @@ from datetime import datetime, timezone
from typing import Any, Optional, cast
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
+from fastmcp.server.auth import OAuthProvider
+from fastmcp.server.auth.auth import AccessToken as FastMCPAccessToken
from mcp.server.auth.provider import (
AccessToken,
AuthorizationCode,
AuthorizationParams,
- OAuthAuthorizationServerProvider,
RefreshToken,
)
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
@@ -133,8 +134,45 @@ def make_token(
)
-class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
- async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
+class SimpleOAuthProvider(OAuthProvider):
+ """OAuth provider that extends fastmcp's OAuthProvider with custom login flow."""
+
+ def __init__(self):
+ super().__init__(
+ base_url=settings.SERVER_URL,
+ issuer_url=settings.SERVER_URL,
+ client_registration_options=None, # We handle registration ourselves
+ required_scopes=BASE_SCOPES,
+ )
+
+ async def verify_token(self, token: str) -> FastMCPAccessToken | None:
+ """Verify an access token and return token info if valid."""
+ with make_session() as session:
+ # Try as OAuth access token first
+ user_session = session.query(UserSession).get(token)
+ if user_session:
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
+ if user_session.expires_at < now:
+ return None
+ return FastMCPAccessToken(
+ token=token,
+ client_id=user_session.oauth_state.client_id,
+ scopes=user_session.oauth_state.scopes,
+ )
+
+ # Try as bot API key
+ bot = session.query(User).filter(User.api_key == token).first()
+ if bot:
+ logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key")
+ return FastMCPAccessToken(
+ token=token,
+ client_id=cast(str, bot.name or bot.email),
+ scopes=["read", "write"],
+ )
+
+ return None
+
+ def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""Get OAuth client information."""
with make_session() as session:
client = session.get(OAuthClientInformation, client_id)
diff --git a/src/memory/api/MCP/servers/__init__.py b/src/memory/api/MCP/servers/__init__.py
new file mode 100644
index 0000000..68122c0
--- /dev/null
+++ b/src/memory/api/MCP/servers/__init__.py
@@ -0,0 +1,17 @@
+"""MCP subservers for composable tool organization."""
+
+from memory.api.MCP.servers.core import core_mcp
+from memory.api.MCP.servers.github import github_mcp
+from memory.api.MCP.servers.people import people_mcp
+from memory.api.MCP.servers.schedule import schedule_mcp
+from memory.api.MCP.servers.books import books_mcp
+from memory.api.MCP.servers.meta import meta_mcp
+
+__all__ = [
+ "core_mcp",
+ "github_mcp",
+ "people_mcp",
+ "schedule_mcp",
+ "books_mcp",
+ "meta_mcp",
+]
diff --git a/src/memory/api/MCP/books.py b/src/memory/api/MCP/servers/books.py
similarity index 92%
rename from src/memory/api/MCP/books.py
rename to src/memory/api/MCP/servers/books.py
index ad102ef..d5989ff 100644
--- a/src/memory/api/MCP/books.py
+++ b/src/memory/api/MCP/servers/books.py
@@ -1,15 +1,19 @@
+"""MCP subserver for ebook access."""
+
import logging
+from fastmcp import FastMCP
from sqlalchemy.orm import joinedload
-from memory.api.MCP.tools import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import Book, BookSection, BookSectionPayload
logger = logging.getLogger(__name__)
+books_mcp = FastMCP("memory-books")
-@mcp.tool()
+
+@books_mcp.tool()
async def all_books(sections: bool = False) -> list[dict]:
"""
Get all books in the database.
@@ -31,7 +35,7 @@ async def all_books(sections: bool = False) -> list[dict]:
return [book.as_payload(sections=sections) for book in books]
-@mcp.tool()
+@books_mcp.tool()
def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]:
"""
Read a book from the database.
diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/servers/core.py
similarity index 91%
rename from src/memory/api/MCP/memory.py
rename to src/memory/api/MCP/servers/core.py
index 8ea3fd0..f73f251 100644
--- a/src/memory/api/MCP/memory.py
+++ b/src/memory/api/MCP/servers/core.py
@@ -1,53 +1,45 @@
"""
-MCP tools for the epistemic sparring partner system.
+Core MCP subserver for knowledge base search, observations, and notes.
"""
import base64
import logging
import pathlib
from datetime import datetime, timezone
-from PIL import Image
+from fastmcp import FastMCP
+from PIL import Image
from pydantic import BaseModel
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
-from memory.api.MCP.base import mcp
from memory.api.search.search import search
-from memory.api.search.types import SearchFilters, SearchConfig
+from memory.api.search.types import SearchConfig, SearchFilters
from memory.common import extract, settings
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
from memory.common.celery_app import app as celery_app
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
from memory.common.db.connection import make_session
-from memory.common.db.models import SourceItem, AgentObservation
+from memory.common.db.models import AgentObservation, SourceItem
from memory.common.formatters import observation
logger = logging.getLogger(__name__)
+core_mcp = FastMCP("memory-core")
+
def validate_path_within_directory(
base_dir: pathlib.Path, requested_path: str
) -> pathlib.Path:
- """Validate that a requested path resolves within the base directory.
-
- Prevents path traversal attacks using ../ or similar techniques.
-
- Args:
- base_dir: The allowed base directory
- requested_path: The user-provided path
-
- Returns:
- The resolved absolute path if valid
-
- Raises:
- ValueError: If the path would escape the base directory
- """
+ """Validate that a requested path resolves within the base directory."""
resolved = (base_dir / requested_path.lstrip("/")).resolve()
base_resolved = base_dir.resolve()
- if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
+ if (
+ not str(resolved).startswith(str(base_resolved) + "/")
+ and resolved != base_resolved
+ ):
raise ValueError(f"Path escapes allowed directory: {requested_path}")
return resolved
@@ -63,7 +55,6 @@ def filter_observation_source_ids(
items_query = session.query(AgentObservation.id)
if tags:
- # Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
@@ -89,7 +80,6 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
items_query = session.query(SourceItem.id)
if tags:
- # Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
@@ -102,7 +92,7 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
return source_ids
-@mcp.tool()
+@core_mcp.tool()
async def search_knowledge_base(
query: str,
filters: SearchFilters = {},
@@ -141,7 +131,6 @@ async def search_knowledge_base(
if not modalities:
modalities = set(ALL_COLLECTIONS.keys())
- # Filter to valid collections, excluding observation collections
modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS
search_filters = SearchFilters(**filters)
@@ -167,7 +156,7 @@ class RawObservation(BaseModel):
tags: list[str] = []
-@mcp.tool()
+@core_mcp.tool()
async def observe(
observations: list[RawObservation],
session_id: str | None = None,
@@ -212,23 +201,23 @@ async def observe(
logger.info("MCP: Observing")
tasks = [
(
- observation,
+ obs,
celery_app.send_task(
SYNC_OBSERVATION,
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
kwargs={
- "subject": observation.subject,
- "content": observation.content,
- "observation_type": observation.observation_type,
- "confidences": observation.confidences,
- "evidence": observation.evidence,
- "tags": observation.tags,
+ "subject": obs.subject,
+ "content": obs.content,
+ "observation_type": obs.observation_type,
+ "confidences": obs.confidences,
+ "evidence": obs.evidence,
+ "tags": obs.tags,
"session_id": session_id,
"agent_model": agent_model,
},
),
)
- for observation in observations
+ for obs in observations
]
def short_content(obs: RawObservation) -> str:
@@ -242,7 +231,7 @@ async def observe(
}
-@mcp.tool()
+@core_mcp.tool()
async def search_observations(
query: str,
subject: str = "",
@@ -308,7 +297,7 @@ async def search_observations(
]
-@mcp.tool()
+@core_mcp.tool()
async def create_note(
subject: str,
content: str,
@@ -362,7 +351,7 @@ async def create_note(
}
-@mcp.tool()
+@core_mcp.tool()
async def note_files(path: str = "/"):
"""
List note files in the user's note storage.
@@ -386,7 +375,7 @@ async def note_files(path: str = "/"):
]
-@mcp.tool()
+@core_mcp.tool()
def fetch_file(filename: str) -> dict:
"""
Read file content with automatic type detection.
diff --git a/src/memory/api/MCP/github.py b/src/memory/api/MCP/servers/github.py
similarity index 92%
rename from src/memory/api/MCP/github.py
rename to src/memory/api/MCP/servers/github.py
index 859c224..624c49f 100644
--- a/src/memory/api/MCP/github.py
+++ b/src/memory/api/MCP/servers/github.py
@@ -1,14 +1,14 @@
-"""MCP tools for GitHub issue tracking and management."""
+"""MCP subserver for GitHub issue tracking and management."""
import logging
from datetime import datetime, timezone
from typing import Any
-from sqlalchemy import Text, case, desc, func, asc
+from fastmcp import FastMCP
+from sqlalchemy import Text, case, desc, func
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
-from memory.api.MCP.base import mcp
from memory.api.search.search import search
from memory.api.search.types import SearchConfig, SearchFilters
from memory.common import extract
@@ -17,6 +17,8 @@ from memory.common.db.models import GithubItem
logger = logging.getLogger(__name__)
+github_mcp = FastMCP("memory-github")
+
def _build_github_url(repo_path: str, number: int | None, kind: str) -> str:
"""Build GitHub URL from repo path and issue/PR number."""
@@ -53,7 +55,6 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
}
if include_content:
result["content"] = item.content
- # Include PR-specific data if available
if item.kind == "pr" and item.pr_data:
result["pr_data"] = {
"additions": item.pr_data.additions,
@@ -62,12 +63,12 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
"files": item.pr_data.files,
"reviews": item.pr_data.reviews,
"review_comments": item.pr_data.review_comments,
- "diff": item.pr_data.diff, # Decompressed via property
+ "diff": item.pr_data.diff,
}
return result
-@mcp.tool()
+@github_mcp.tool()
async def list_github_issues(
repo: str | None = None,
assignee: str | None = None,
@@ -109,64 +110,48 @@ async def list_github_issues(
with make_session() as session:
query = session.query(GithubItem)
- # Apply filters
if repo:
query = query.filter(GithubItem.repo_path == repo)
-
if assignee:
query = query.filter(GithubItem.assignees.any(assignee))
-
if author:
query = query.filter(GithubItem.author == author)
-
if state:
query = query.filter(GithubItem.state == state)
-
if kind:
query = query.filter(GithubItem.kind == kind)
else:
- # Exclude comments by default, only show issues and PRs
query = query.filter(GithubItem.kind.in_(["issue", "pr"]))
-
if labels:
- # Match any label in the list using PostgreSQL array overlap
query = query.filter(
GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text)))
)
-
if project_status:
query = query.filter(GithubItem.project_status == project_status)
-
if project_field:
for key, value in project_field.items():
- query = query.filter(
- GithubItem.project_fields[key].astext == value
- )
-
+ query = query.filter(GithubItem.project_fields[key].astext == value)
if updated_since:
since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00"))
query = query.filter(GithubItem.github_updated_at >= since_dt)
-
if updated_before:
before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00"))
query = query.filter(GithubItem.github_updated_at <= before_dt)
- # Apply ordering
if order_by == "created":
query = query.order_by(desc(GithubItem.created_at))
elif order_by == "number":
query = query.order_by(desc(GithubItem.number))
- else: # default: updated
+ else:
query = query.order_by(desc(GithubItem.github_updated_at))
query = query.limit(limit)
-
items = query.all()
return [_serialize_issue(item) for item in items]
-@mcp.tool()
+@github_mcp.tool()
async def search_github_issues(
query: str,
repo: str | None = None,
@@ -191,7 +176,6 @@ async def search_github_issues(
limit = min(limit, 100)
- # Pre-filter source_ids if repo/state/kind filters are specified
source_ids = None
if repo or state or kind:
with make_session() as session:
@@ -206,7 +190,6 @@ async def search_github_issues(
q = q.filter(GithubItem.kind.in_(["issue", "pr"]))
source_ids = [item.id for item in q.all()]
- # Use the existing search infrastructure
data = extract.extract_text(query, skip_summary=True)
config = SearchConfig(limit=limit, previews=True)
filters = SearchFilters()
@@ -220,7 +203,6 @@ async def search_github_issues(
config=config,
)
- # Fetch full issue details for the results
output = []
with make_session() as session:
for result in results:
@@ -233,7 +215,7 @@ async def search_github_issues(
return output
-@mcp.tool()
+@github_mcp.tool()
async def github_issue_details(
repo: str,
number: int,
@@ -267,7 +249,7 @@ async def github_issue_details(
return _serialize_issue(item, include_content=True)
-@mcp.tool()
+@github_mcp.tool()
async def github_work_summary(
since: str,
until: str | None = None,
@@ -294,7 +276,6 @@ async def github_work_summary(
else:
until_dt = datetime.now(timezone.utc)
- # Map group_by to SQL expression
group_mappings = {
"client": GithubItem.project_fields["EquiStamp.Client"].astext,
"status": GithubItem.project_status,
@@ -311,7 +292,6 @@ async def github_work_summary(
group_col = group_mappings[group_by]
with make_session() as session:
- # Build base query for the period
base_query = session.query(GithubItem).filter(
GithubItem.github_updated_at >= since_dt,
GithubItem.github_updated_at <= until_dt,
@@ -321,7 +301,6 @@ async def github_work_summary(
if repo:
base_query = base_query.filter(GithubItem.repo_path == repo)
- # Get aggregated counts by group
agg_query = (
session.query(
group_col.label("group_name"),
@@ -346,7 +325,6 @@ async def github_work_summary(
groups = agg_query.all()
- # Build summary with sample issues for each group
summary = []
total_issues = 0
total_prs = 0
@@ -358,7 +336,6 @@ async def github_work_summary(
total_issues += issue_count
total_prs += pr_count
- # Get sample issues for this group
sample_query = base_query.filter(group_col == group_name).limit(5)
samples = [
{
@@ -394,7 +371,7 @@ async def github_work_summary(
}
-@mcp.tool()
+@github_mcp.tool()
async def github_repo_overview(
repo: str,
) -> dict:
@@ -410,13 +387,6 @@ async def github_repo_overview(
logger.info(f"github_repo_overview called: repo={repo}")
with make_session() as session:
- # Base query for this repo
- base_query = session.query(GithubItem).filter(
- GithubItem.repo_path == repo,
- GithubItem.kind.in_(["issue", "pr"]),
- )
-
- # Get total counts
counts_query = session.query(
func.count(GithubItem.id).label("total"),
func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"),
@@ -441,7 +411,6 @@ async def github_repo_overview(
counts = counts_query.first()
- # Status breakdown (for project_status)
status_query = (
session.query(
GithubItem.project_status.label("status"),
@@ -458,7 +427,6 @@ async def github_repo_overview(
status_breakdown = {row.status: row.count for row in status_query.all()}
- # Top assignees (open issues only)
assignee_query = (
session.query(
func.unnest(GithubItem.assignees).label("assignee"),
@@ -479,7 +447,6 @@ async def github_repo_overview(
for row in assignee_query.all()
]
- # Label counts
label_query = (
session.query(
func.unnest(GithubItem.labels).label("label"),
diff --git a/src/memory/api/MCP/servers/meta.py b/src/memory/api/MCP/servers/meta.py
new file mode 100644
index 0000000..c9f361d
--- /dev/null
+++ b/src/memory/api/MCP/servers/meta.py
@@ -0,0 +1,277 @@
+"""MCP subserver for metadata, utilities, and forecasting."""
+
+import asyncio
+import logging
+from collections import defaultdict
+from datetime import datetime, timezone
+from typing import Annotated, Literal, NotRequired, TypedDict, get_args, get_type_hints
+
+import aiohttp
+from fastmcp import FastMCP
+from sqlalchemy import func
+
+from memory.common import qdrant
+from memory.common.db.connection import make_session
+from memory.common.db.models import SourceItem
+from memory.common.db.models.source_items import AgentObservation
+
+logger = logging.getLogger(__name__)
+
+meta_mcp = FastMCP("memory-meta")
+
+# Auth provider will be injected at mount time
+_get_current_user = None
+
+
+def set_auth_provider(get_current_user_func):
+ """Set the authentication provider function."""
+ global _get_current_user
+ _get_current_user = get_current_user_func
+
+
+def get_current_user() -> dict:
+ """Get the current authenticated user."""
+ if _get_current_user is None:
+ return {"authenticated": False, "error": "Auth provider not configured"}
+ return _get_current_user()
+
+
+# --- Metadata tools ---
+
+
+class SchemaArg(TypedDict):
+ type: str | None
+ description: str | None
+
+
+class CollectionMetadata(TypedDict):
+ schema: dict[str, SchemaArg]
+ size: int
+
+
+def from_annotation(annotation: Annotated) -> SchemaArg | None:
+ try:
+ type_, description = get_args(annotation)
+ type_str = str(type_)
+ if type_str.startswith("typing."):
+ type_str = type_str[7:]
+ elif len((parts := type_str.split("'"))) > 1:
+ type_str = parts[1]
+ return SchemaArg(type=type_str, description=description)
+ except IndexError:
+ logger.error(f"Error from annotation: {annotation}")
+ return None
+
+
+def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
+ if not hasattr(klass, "as_payload"):
+ return {}
+
+ if not (payload_type := get_type_hints(klass.as_payload).get("return")):
+ return {}
+
+ return {
+ name: schema
+ for name, arg in payload_type.__annotations__.items()
+ if (schema := from_annotation(arg))
+ }
+
+
+@meta_mcp.tool()
+async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
+ """Get the metadata schema for each collection used in the knowledge base.
+
+ These schemas can be used to filter the knowledge base.
+
+ Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
+
+ Example:
+ ```
+ {
+ "mail": {"subject": {"type": "str", "description": "The subject of the email."}},
+ "chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
+ }
+ """
+ client = qdrant.get_qdrant_client()
+ sizes = qdrant.get_collection_sizes(client)
+ schemas = defaultdict(dict)
+ for klass in SourceItem.__subclasses__():
+ for collection in klass.get_collections():
+ schemas[collection].update(get_schema(klass))
+
+ return {
+ collection: CollectionMetadata(schema=schema, size=size)
+ for collection, schema in schemas.items()
+ if (size := sizes.get(collection))
+ }
+
+
+@meta_mcp.tool()
+async def get_all_tags() -> list[str]:
+ """Get all unique tags used across the entire knowledge base.
+
+ Returns sorted list of tags from both observations and content.
+ """
+ with make_session() as session:
+ tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
+ return sorted({row[0] for row in tags_query if row[0] is not None})
+
+
+@meta_mcp.tool()
+async def get_all_subjects() -> list[str]:
+ """Get all unique subjects from observations about the user.
+
+ Returns sorted list of subject identifiers used in observations.
+ """
+ with make_session() as session:
+ return sorted(
+ r.subject for r in session.query(AgentObservation.subject).distinct()
+ )
+
+
+@meta_mcp.tool()
+async def get_all_observation_types() -> list[str]:
+ """Get all observation types that have been used.
+
+ Standard types are belief, preference, behavior, contradiction, general, but there can be more.
+ """
+ with make_session() as session:
+ return sorted(
+ {
+ r.observation_type
+ for r in session.query(AgentObservation.observation_type).distinct()
+ if r.observation_type is not None
+ }
+ )
+
+
+# --- Utility tools ---
+
+
+@meta_mcp.tool()
+async def get_current_time() -> dict:
+ """Get the current time in UTC."""
+ logger.info("get_current_time tool called")
+ return {"current_time": datetime.now(timezone.utc).isoformat()}
+
+
+@meta_mcp.tool()
+async def get_authenticated_user() -> dict:
+ """Get information about the authenticated user."""
+ return get_current_user()
+
+
+# --- Forecasting tools ---
+
+
+class BinaryProbs(TypedDict):
+ prob: float
+
+
+class MultiProbs(TypedDict):
+ answerProbs: dict[str, float]
+
+
+Probs = dict[str, BinaryProbs | MultiProbs]
+OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
+
+
+class MarketAnswer(TypedDict):
+ id: str
+ text: str
+ resolutionProbability: float
+
+
+class MarketDetails(TypedDict):
+ id: str
+ createdTime: int
+ question: str
+ outcomeType: OutcomeType
+ textDescription: str
+ groupSlugs: list[str]
+ volume: float
+ isResolved: bool
+ answers: list[MarketAnswer]
+
+
+class Market(TypedDict):
+ id: str
+ url: str
+ question: str
+ volume: int
+ createdTime: int
+ outcomeType: OutcomeType
+ createdAt: NotRequired[str]
+ description: NotRequired[str]
+ answers: NotRequired[dict[str, float]]
+ probability: NotRequired[float]
+ details: NotRequired[MarketDetails]
+
+
+async def get_details(session: aiohttp.ClientSession, market_id: str):
+ async with session.get(
+ f"https://api.manifold.markets/v0/market/{market_id}"
+ ) as resp:
+ resp.raise_for_status()
+ return await resp.json()
+
+
+async def format_market(session: aiohttp.ClientSession, market: Market):
+ if market.get("outcomeType") != "BINARY":
+ details = await get_details(session, market["id"])
+ market["answers"] = {
+ answer["text"]: round(
+ answer.get("resolutionProbability") or answer.get("probability") or 0, 3
+ )
+ for answer in details["answers"]
+ }
+ if creationTime := market.get("createdTime"):
+ market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
+
+ fields = [
+ "id",
+ "name",
+ "url",
+ "question",
+ "volume",
+ "createdAt",
+ "details",
+ "probability",
+ "answers",
+ ]
+ return {k: v for k, v in market.items() if k in fields}
+
+
+async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ "https://api.manifold.markets/v0/search-markets",
+ params={
+ "term": term,
+ "contractType": "BINARY" if binary else "ALL",
+ },
+ ) as resp:
+ resp.raise_for_status()
+ markets = await resp.json()
+
+ return await asyncio.gather(
+ *[
+ format_market(session, market)
+ for market in markets
+ if market.get("volume", 0) >= min_volume
+ ]
+ )
+
+
+@meta_mcp.tool()
+async def get_forecasts(
+ term: str, min_volume: int = 1000, binary: bool = False
+) -> list[dict]:
+ """Get prediction market forecasts for a given term.
+
+ Args:
+ term: The term to search for.
+ min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
+ binary: Whether to only return binary markets.
+ """
+ return await search_markets(term, min_volume, binary)
diff --git a/src/memory/api/MCP/people.py b/src/memory/api/MCP/servers/people.py
similarity index 95%
rename from src/memory/api/MCP/people.py
rename to src/memory/api/MCP/servers/people.py
index 9982317..b92eea8 100644
--- a/src/memory/api/MCP/people.py
+++ b/src/memory/api/MCP/servers/people.py
@@ -1,23 +1,23 @@
-"""
-MCP tools for tracking people.
-"""
+"""MCP subserver for tracking people."""
import logging
from typing import Any
+from fastmcp import FastMCP
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
-from memory.api.MCP.base import mcp
-from memory.common.db.connection import make_session
-from memory.common.db.models import Person
+from memory.common import settings
from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON
from memory.common.celery_app import app as celery_app
-from memory.common import settings
+from memory.common.db.connection import make_session
+from memory.common.db.models import Person
logger = logging.getLogger(__name__)
+people_mcp = FastMCP("memory-people")
+
def _person_to_dict(person: Person) -> dict[str, Any]:
"""Convert a Person model to a dictionary for API responses."""
@@ -32,7 +32,7 @@ def _person_to_dict(person: Person) -> dict[str, Any]:
}
-@mcp.tool()
+@people_mcp.tool()
async def add_person(
identifier: str,
display_name: str,
@@ -67,7 +67,6 @@ async def add_person(
"""
logger.info(f"MCP: Adding person: {identifier}")
- # Check if person already exists
with make_session() as session:
existing = session.query(Person).filter(Person.identifier == identifier).first()
if existing:
@@ -93,7 +92,7 @@ async def add_person(
}
-@mcp.tool()
+@people_mcp.tool()
async def update_person_info(
identifier: str,
display_name: str | None = None,
@@ -135,7 +134,6 @@ async def update_person_info(
"""
logger.info(f"MCP: Updating person: {identifier}")
- # Verify person exists
with make_session() as session:
person = session.query(Person).filter(Person.identifier == identifier).first()
if not person:
@@ -162,7 +160,7 @@ async def update_person_info(
}
-@mcp.tool()
+@people_mcp.tool()
async def get_person(identifier: str) -> dict | None:
"""
Get a person by their identifier.
@@ -182,7 +180,7 @@ async def get_person(identifier: str) -> dict | None:
return _person_to_dict(person)
-@mcp.tool()
+@people_mcp.tool()
async def list_people(
tags: list[str] | None = None,
search: str | None = None,
@@ -207,9 +205,7 @@ async def list_people(
query = session.query(Person)
if tags:
- query = query.filter(
- Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
- )
+ query = query.filter(Person.tags.op("&&")(sql_cast(tags, ARRAY(Text))))
if search:
search_term = f"%{search.lower()}%"
@@ -225,7 +221,7 @@ async def list_people(
return [_person_to_dict(p) for p in people]
-@mcp.tool()
+@people_mcp.tool()
async def delete_person(identifier: str) -> dict:
"""
Delete a person by their identifier.
diff --git a/src/memory/api/MCP/schedules.py b/src/memory/api/MCP/servers/schedule.py
similarity index 89%
rename from src/memory/api/MCP/schedules.py
rename to src/memory/api/MCP/servers/schedule.py
index 26b50ad..a44e9f1 100644
--- a/src/memory/api/MCP/schedules.py
+++ b/src/memory/api/MCP/servers/schedule.py
@@ -1,20 +1,37 @@
-"""
-MCP tools for the epistemic sparring partner system.
-"""
+"""MCP subserver for scheduling messages and LLM calls."""
import logging
from datetime import datetime, timezone
from typing import Any, cast
-from memory.api.MCP.base import get_current_user, mcp
+from fastmcp import FastMCP
+
from memory.common.db.connection import make_session
-from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
+from memory.common.db.models import DiscordBotUser, ScheduledLLMCall
from memory.discord.messages import schedule_discord_message
logger = logging.getLogger(__name__)
+schedule_mcp = FastMCP("memory-schedule")
-@mcp.tool()
+# We need access to get_current_user from base - this will be injected at mount time
+_get_current_user = None
+
+
+def set_auth_provider(get_current_user_func):
+ """Set the authentication provider function."""
+ global _get_current_user
+ _get_current_user = get_current_user_func
+
+
+def get_current_user() -> dict:
+ """Get the current authenticated user."""
+ if _get_current_user is None:
+ return {"authenticated": False, "error": "Auth provider not configured"}
+ return _get_current_user()
+
+
+@schedule_mcp.tool()
async def schedule_message(
scheduled_time: str,
message: str,
@@ -61,10 +78,8 @@ async def schedule_message(
if not discord_user and not discord_channel:
raise ValueError("Either discord_user or discord_channel must be provided")
- # Parse scheduled time
try:
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
- # Ensure we store as naive datetime (remove timezone info for database storage)
if scheduled_dt.tzinfo is not None:
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
except ValueError:
@@ -98,7 +113,7 @@ async def schedule_message(
}
-@mcp.tool()
+@schedule_mcp.tool()
async def list_scheduled_llm_calls(
status: str | None = None, limit: int | None = 50
) -> dict[str, Any]:
@@ -143,7 +158,7 @@ async def list_scheduled_llm_calls(
}
-@mcp.tool()
+@schedule_mcp.tool()
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
"""
Cancel a scheduled LLM call.
@@ -164,7 +179,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
return {"error": "User not found", "user": current_user}
with make_session() as session:
- # Find the scheduled call
scheduled_call = (
session.query(ScheduledLLMCall)
.filter(
@@ -180,7 +194,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
if not scheduled_call.can_be_cancelled():
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
- # Update the status
scheduled_call.status = "cancelled"
session.commit()
diff --git a/src/memory/api/MCP/tools.py b/src/memory/api/MCP/tools.py
deleted file mode 100644
index b3efaf7..0000000
--- a/src/memory/api/MCP/tools.py
+++ /dev/null
@@ -1,74 +0,0 @@
-"""
-MCP tools for the epistemic sparring partner system.
-"""
-
-import logging
-from datetime import datetime, timezone
-
-from sqlalchemy import Text
-from sqlalchemy import cast as sql_cast
-from sqlalchemy.dialects.postgresql import ARRAY
-
-from memory.common.db.connection import make_session
-from memory.common.db.models import AgentObservation, SourceItem
-from memory.api.MCP.base import mcp, get_current_user
-
-logger = logging.getLogger(__name__)
-
-
-def filter_observation_source_ids(
- tags: list[str] | None = None, observation_types: list[str] | None = None
-):
- if not tags and not observation_types:
- return None
-
- with make_session() as session:
- items_query = session.query(AgentObservation.id)
-
- if tags:
- # Use PostgreSQL array overlap operator with proper array casting
- items_query = items_query.filter(
- AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
- )
- if observation_types:
- items_query = items_query.filter(
- AgentObservation.observation_type.in_(observation_types)
- )
- source_ids = [item.id for item in items_query.all()]
-
- return source_ids
-
-
-def filter_source_ids(
- modalities: set[str],
- tags: list[str] | None = None,
-):
- if not tags:
- return None
-
- with make_session() as session:
- items_query = session.query(SourceItem.id)
-
- if tags:
- # Use PostgreSQL array overlap operator with proper array casting
- items_query = items_query.filter(
- SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
- )
- if modalities:
- items_query = items_query.filter(SourceItem.modality.in_(modalities))
- source_ids = [item.id for item in items_query.all()]
-
- return source_ids
-
-
-@mcp.tool()
-async def get_current_time() -> dict:
- """Get the current time in UTC."""
- logger.info("get_current_time tool called")
- return {"current_time": datetime.now(timezone.utc).isoformat()}
-
-
-@mcp.tool()
-async def get_authenticated_user() -> dict:
- """Get information about the authenticated user."""
- return get_current_user()
diff --git a/src/memory/api/app.py b/src/memory/api/app.py
index 959f276..44911a3 100644
--- a/src/memory/api/app.py
+++ b/src/memory/api/app.py
@@ -2,7 +2,6 @@
FastAPI application for the knowledge base.
"""
-import contextlib
import os
import logging
import mimetypes
@@ -35,15 +34,10 @@ limiter = Limiter(
enabled=settings.API_RATE_LIMIT_ENABLED,
)
+# Create the MCP http app to get its lifespan
+mcp_http_app = mcp.http_app(stateless_http=True)
-@contextlib.asynccontextmanager
-async def lifespan(app: FastAPI):
- async with contextlib.AsyncExitStack() as stack:
- await stack.enter_async_context(mcp.session_manager.run())
- yield
-
-
-app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
+app = FastAPI(title="Knowledge Base API", lifespan=mcp_http_app.lifespan)
app.state.limiter = limiter
# Rate limit exception handler
@@ -158,46 +152,9 @@ app.include_router(auth_router)
# Add health check to MCP server instead of main app
-@mcp.custom_route("/health", methods=["GET"])
-async def health_check(request: Request):
- """Health check endpoint that verifies all dependencies are accessible."""
- from fastapi.responses import JSONResponse
- from sqlalchemy import text
-
- checks = {"mcp_oauth": "enabled"}
- all_healthy = True
-
- # Check database connection
- try:
- with engine.connect() as conn:
- conn.execute(text("SELECT 1"))
- checks["database"] = "healthy"
- except Exception as e:
- # Log error details but don't expose in response
- logger.error(f"Database health check failed: {e}")
- checks["database"] = "unhealthy"
- all_healthy = False
-
- # Check Qdrant connection
- try:
- from memory.common.qdrant import get_qdrant_client
-
- client = get_qdrant_client()
- client.get_collections()
- checks["qdrant"] = "healthy"
- except Exception as e:
- # Log error details but don't expose in response
- logger.error(f"Qdrant health check failed: {e}")
- checks["qdrant"] = "unhealthy"
- all_healthy = False
-
- checks["status"] = "healthy" if all_healthy else "degraded"
- status_code = 200 if all_healthy else 503
- return JSONResponse(checks, status_code=status_code)
-
-
# Mount MCP server at root - OAuth endpoints need to be at root level
-app.mount("/", mcp.streamable_http_app())
+# Health check is defined in MCP/base.py
+app.mount("/", mcp_http_app)
def main(reload: bool = False):
diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py
index 4ec849e..46d9622 100644
--- a/src/memory/common/celery_app.py
+++ b/src/memory/common/celery_app.py
@@ -17,6 +17,7 @@ DISCORD_ROOT = "memory.workers.tasks.discord"
BACKUP_ROOT = "memory.workers.tasks.backup"
GITHUB_ROOT = "memory.workers.tasks.github"
PEOPLE_ROOT = "memory.workers.tasks.people"
+PROACTIVE_ROOT = "memory.workers.tasks.proactive"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
@@ -73,6 +74,10 @@ SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person"
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file"
+# Proactive check-in tasks
+EVALUATE_PROACTIVE_CHECKINS = f"{PROACTIVE_ROOT}.evaluate_proactive_checkins"
+EXECUTE_PROACTIVE_CHECKIN = f"{PROACTIVE_ROOT}.execute_proactive_checkin"
+
def get_broker_url() -> str:
protocol = settings.CELERY_BROKER_TYPE
@@ -130,12 +135,17 @@ app.conf.update(
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
+ f"{PROACTIVE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
},
beat_schedule={
"sync-github-repos-hourly": {
"task": SYNC_ALL_GITHUB_REPOS,
"schedule": crontab(minute=0), # Every hour at :00
},
+ "evaluate-proactive-checkins": {
+ "task": EVALUATE_PROACTIVE_CHECKINS,
+ "schedule": crontab(), # Every minute
+ },
},
)
diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py
index 0b5c568..da8d70a 100644
--- a/src/memory/common/db/models/discord.py
+++ b/src/memory/common/db/models/discord.py
@@ -63,13 +63,29 @@ class MessageProcessor:
doc=textwrap.dedent(
"""
A summary of this processor, made by and for AI systems.
-
+
The idea here is that AI systems can use this summary to keep notes on the given processor.
- These should automatically be injected into the context of the messages that are processed by this processor.
+ These should automatically be injected into the context of the messages that are processed by this processor.
"""
),
)
+ proactive_cron = Column(
+ Text,
+ nullable=True,
+ doc="Cron schedule for proactive check-ins (e.g., '0 9 * * *' for 9am daily). None = disabled.",
+ )
+ proactive_prompt = Column(
+ Text,
+ nullable=True,
+ doc="Custom instructions for proactive check-ins.",
+ )
+ last_proactive_at = Column(
+ DateTime(timezone=True),
+ nullable=True,
+ doc="When the last proactive check-in was sent.",
+ )
+
@property
def entity_type(self) -> str:
return self.__class__.__tablename__[8:-1] # type: ignore
diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py
index 07a4306..d157cf0 100644
--- a/src/memory/common/settings.py
+++ b/src/memory/common/settings.py
@@ -132,6 +132,7 @@ CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
+PROACTIVE_CHECKIN_INTERVAL = int(os.getenv("PROACTIVE_CHECKIN_INTERVAL", 60))
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py
index 0c432bf..668ad81 100644
--- a/src/memory/discord/commands.py
+++ b/src/memory/discord/commands.py
@@ -167,6 +167,25 @@ def _create_scope_group(
url=url and url.strip(),
)
+ # Proactive command
+ @group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
+ @discord.app_commands.describe(
+ cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
+ prompt="Optional custom instructions for check-ins",
+ )
+ async def proactive_cmd(
+ interaction: discord.Interaction,
+ cron: str | None = None,
+ prompt: str | None = None,
+ ):
+ await _run_interaction_command(
+ interaction,
+ scope=scope,
+ handler=handle_proactive,
+ cron=cron and cron.strip(),
+ prompt=prompt,
+ )
+
return group
@@ -265,6 +284,28 @@ def _create_user_scope_group(
url=url and url.strip(),
)
+ # Proactive command
+ @group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
+ @discord.app_commands.describe(
+ user="Target user",
+ cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
+ prompt="Optional custom instructions for check-ins",
+ )
+ async def proactive_cmd(
+ interaction: discord.Interaction,
+ user: discord.User,
+ cron: str | None = None,
+ prompt: str | None = None,
+ ):
+ await _run_interaction_command(
+ interaction,
+ scope=scope,
+ handler=handle_proactive,
+ target_user=user,
+ cron=cron and cron.strip(),
+ prompt=prompt,
+ )
+
return group
@@ -663,3 +704,68 @@ async def handle_mcp_servers(
except Exception as exc:
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
raise CommandError(f"Error: {exc}") from exc
+
+
+async def handle_proactive(
+ context: CommandContext,
+ *,
+ cron: str | None = None,
+ prompt: str | None = None,
+) -> CommandResponse:
+ """Handle proactive check-in configuration."""
+ from croniter import croniter
+
+ model = context.target
+
+ # If no arguments, show current settings
+ if cron is None and prompt is None:
+ current_cron = getattr(model, "proactive_cron", None)
+ current_prompt = getattr(model, "proactive_prompt", None)
+
+ if not current_cron:
+ return CommandResponse(
+ content=f"Proactive check-ins are disabled for {context.display_name}."
+ )
+
+ lines = [f"Proactive check-ins for {context.display_name}:"]
+ lines.append(f" Schedule: `{current_cron}`")
+ if current_prompt:
+ lines.append(f" Prompt: {current_prompt}")
+ return CommandResponse(content="\n".join(lines))
+
+ # Handle cron setting
+ if cron is not None:
+ if cron.lower() == "off":
+ setattr(model, "proactive_cron", None)
+ return CommandResponse(
+ content=f"Proactive check-ins disabled for {context.display_name}."
+ )
+
+ # Validate cron expression
+ try:
+ croniter(cron)
+ except (ValueError, KeyError) as e:
+ raise CommandError(
+ f"Invalid cron expression: {cron}\n"
+ "Examples:\n"
+ " `0 9 * * *` - 9am daily\n"
+ " `0 9,17 * * 1-5` - 9am and 5pm weekdays\n"
+ " `0 */4 * * *` - every 4 hours"
+ ) from e
+
+ setattr(model, "proactive_cron", cron)
+
+ # Handle prompt setting
+ if prompt is not None:
+ setattr(model, "proactive_prompt", prompt or None)
+
+ # Build response
+ current_cron = getattr(model, "proactive_cron", None)
+ current_prompt = getattr(model, "proactive_prompt", None)
+
+ lines = [f"Updated proactive settings for {context.display_name}:"]
+ lines.append(f" Schedule: `{current_cron}`")
+ if current_prompt:
+ lines.append(f" Prompt: {current_prompt}")
+
+ return CommandResponse(content="\n".join(lines))
diff --git a/src/memory/workers/ingest.py b/src/memory/workers/ingest.py
index 6b0b4e8..b18810c 100644
--- a/src/memory/workers/ingest.py
+++ b/src/memory/workers/ingest.py
@@ -11,6 +11,7 @@ from memory.common.celery_app import (
SYNC_LESSWRONG,
RUN_SCHEDULED_CALLS,
BACKUP_ALL,
+ EVALUATE_PROACTIVE_CHECKINS,
)
logger = logging.getLogger(__name__)
@@ -53,4 +54,8 @@ app.conf.beat_schedule = {
"task": BACKUP_ALL,
"schedule": settings.S3_BACKUP_INTERVAL,
},
+ "evaluate-proactive-checkins": {
+ "task": EVALUATE_PROACTIVE_CHECKINS,
+ "schedule": settings.PROACTIVE_CHECKIN_INTERVAL,
+ },
}
diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py
index 33f22aa..f7c0a97 100644
--- a/src/memory/workers/tasks/__init__.py
+++ b/src/memory/workers/tasks/__init__.py
@@ -14,20 +14,24 @@ from memory.workers.tasks import (
maintenance,
notes,
observations,
+ people,
+ proactive,
scheduled_calls,
) # noqa
__all__ = [
"backup",
- "email",
- "comic",
"blogs",
- "ebook",
+ "comic",
"discord",
+ "ebook",
+ "email",
"forums",
"github",
"maintenance",
"notes",
"observations",
+ "people",
+ "proactive",
"scheduled_calls",
]
diff --git a/src/memory/workers/tasks/proactive.py b/src/memory/workers/tasks/proactive.py
new file mode 100644
index 0000000..689cfc9
--- /dev/null
+++ b/src/memory/workers/tasks/proactive.py
@@ -0,0 +1,341 @@
+"""
+Celery tasks for proactive Discord check-ins.
+"""
+
+import logging
+import re
+import textwrap
+from datetime import datetime, timezone
+from typing import Any, Literal, cast
+
+from croniter import croniter
+from sqlalchemy import or_
+
+from memory.common import settings
+from memory.common.celery_app import app
+from memory.common.db.connection import make_session
+from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
+from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
+from memory.workers.tasks.content_processing import safe_task_execution
+
+logger = logging.getLogger(__name__)
+
+EVALUATE_PROACTIVE_CHECKINS = "memory.workers.tasks.proactive.evaluate_proactive_checkins"
+EXECUTE_PROACTIVE_CHECKIN = "memory.workers.tasks.proactive.execute_proactive_checkin"
+
+EntityType = Literal["user", "channel", "server"]
+
+
+def is_cron_due(cron_expr: str, last_run: datetime | None, now: datetime) -> bool:
+ """Check if a cron expression is due to run now.
+
+ Uses croniter to determine if the current time falls within the cron's schedule
+ and enough time has passed since the last run.
+ """
+ try:
+ cron = croniter(cron_expr, now)
+ # Get the previous scheduled time from now
+ prev_run = cron.get_prev(datetime)
+ # Get the one before that to determine the interval
+ cron.get_prev(datetime)
+ prev_prev_run = cron.get_current(datetime)
+
+ # If we haven't run since the last scheduled time, we should run
+ if last_run is None:
+ # Never run before - check if current time is within a minute of prev_run
+ time_since_scheduled = (now - prev_run).total_seconds()
+ return time_since_scheduled < 120 # Within 2 minutes of scheduled time
+
+ # Make sure last_run is timezone aware
+ if last_run.tzinfo is None:
+ last_run = last_run.replace(tzinfo=timezone.utc)
+
+ # We should run if last_run is before the previous scheduled time
+ return last_run < prev_run
+ except Exception as e:
+ logger.warning(f"Invalid cron expression '{cron_expr}': {e}")
+ return False
+
+
+def get_bot_for_entity(
+ session, entity_type: EntityType, entity_id: int
+) -> DiscordUser | None:
+ """Get the bot user associated with an entity."""
+ from memory.common.db.models import DiscordBotUser, DiscordMessage
+
+ from sqlalchemy.orm import joinedload
+
+ # For servers, find a bot that has sent messages in that server
+ if entity_type == "server":
+ # Find bots that have interacted with this server
+ bot_users = (
+ session.query(DiscordUser)
+ .options(joinedload(DiscordUser.system_user))
+ .join(DiscordMessage, DiscordMessage.from_id == DiscordUser.id)
+ .filter(
+ DiscordMessage.server_id == entity_id,
+ DiscordUser.system_user_id.isnot(None),
+ )
+ .distinct()
+ .all()
+ )
+ # Find one that's actually a bot
+ for user in bot_users:
+ if user.system_user and user.system_user.user_type == "discord_bot":
+ return user
+
+ # For channels, check the server the channel belongs to
+ if entity_type == "channel":
+ channel = session.get(DiscordChannel, entity_id)
+ if channel and channel.server_id:
+ return get_bot_for_entity(session, "server", channel.server_id)
+
+ # Fallback: use first available bot
+ bot = (
+ session.query(DiscordBotUser)
+ .options(joinedload(DiscordBotUser.discord_users).joinedload(DiscordUser.system_user))
+ .first()
+ )
+ if bot and bot.discord_users:
+ return bot.discord_users[0]
+ return None
+
+
+def get_target_user_for_entity(
+ session, entity_type: EntityType, entity_id: int
+) -> DiscordUser | None:
+ """Get the target user for sending a proactive message."""
+ if entity_type == "user":
+ return session.get(DiscordUser, entity_id)
+ # For channels and servers, we don't have a specific target user
+ return None
+
+
+def get_channel_for_entity(
+ session, entity_type: EntityType, entity_id: int
+) -> DiscordChannel | None:
+ """Get the channel for sending a proactive message."""
+ if entity_type == "channel":
+ return session.get(DiscordChannel, entity_id)
+ if entity_type == "server":
+ # For servers, find the first text channel (prefer "general")
+ channels = (
+ session.query(DiscordChannel)
+ .filter(
+ DiscordChannel.server_id == entity_id,
+ DiscordChannel.channel_type == "text",
+ )
+ .all()
+ )
+ if not channels:
+ return None
+ # Prefer a channel named "general" if it exists
+ for channel in channels:
+ if channel.name and "general" in channel.name.lower():
+ return channel
+ return channels[0]
+ # For users, we use DMs (no channel)
+ return None
+
+
+@app.task(name=EVALUATE_PROACTIVE_CHECKINS)
+@safe_task_execution
+def evaluate_proactive_checkins() -> dict[str, Any]:
+ """
+ Evaluate which entities need proactive check-ins.
+
+ This task runs every minute and checks all entities with proactive_cron set
+ to see if they're due for a check-in.
+ """
+ now = datetime.now(timezone.utc)
+ dispatched = []
+
+ with make_session() as session:
+ # Query all entities with proactive_cron set
+ for model, entity_type in [
+ (DiscordUser, "user"),
+ (DiscordChannel, "channel"),
+ (DiscordServer, "server"),
+ ]:
+ entities = (
+ session.query(model)
+ .filter(model.proactive_cron.isnot(None))
+ .all()
+ )
+
+ for entity in entities:
+ cron_expr = cast(str, entity.proactive_cron)
+ last_run = entity.last_proactive_at
+
+ if is_cron_due(cron_expr, last_run, now):
+ logger.info(
+ f"Proactive check-in due for {entity_type} {entity.id}"
+ )
+ execute_proactive_checkin.delay(entity_type, entity.id)
+ dispatched.append({"type": entity_type, "id": entity.id})
+
+ return {
+ "evaluated_at": now.isoformat(),
+ "dispatched": dispatched,
+ "count": len(dispatched),
+ }
+
+
+@app.task(name=EXECUTE_PROACTIVE_CHECKIN)
+@safe_task_execution
+def execute_proactive_checkin(entity_type: EntityType, entity_id: int) -> dict[str, Any]:
+ """
+ Execute a proactive check-in for a specific entity.
+
+ This evaluates whether the bot should reach out and, if so, generates
+ and sends a check-in message.
+ """
+ logger.info(f"Executing proactive check-in for {entity_type} {entity_id}")
+
+ with make_session() as session:
+ # Get the entity
+ model_class = {
+ "user": DiscordUser,
+ "channel": DiscordChannel,
+ "server": DiscordServer,
+ }[entity_type]
+
+ entity = session.get(model_class, entity_id)
+ if not entity:
+ return {"error": f"{entity_type} {entity_id} not found"}
+
+ # Get the bot user
+ bot_user = get_bot_for_entity(session, entity_type, entity_id)
+ if not bot_user:
+ return {"error": "No bot user found"}
+
+ # Get target user and channel
+ target_user = get_target_user_for_entity(session, entity_type, entity_id)
+ channel = get_channel_for_entity(session, entity_type, entity_id)
+
+ if not target_user and not channel:
+ return {"error": "No target user or channel for proactive check-in"}
+
+ # Get chattiness threshold
+ chattiness = entity.chattiness_threshold or 90
+
+ # Build the evaluation prompt
+ proactive_prompt = entity.proactive_prompt or ""
+ eval_prompt = textwrap.dedent("""
+ You are considering whether to proactively reach out to check in.
+
+ {proactive_prompt}
+
+ Based on your notes and the context of previous conversations:
+ 1. Is there anything worth checking in about?
+ 2. Has enough happened or enough time passed to warrant a check-in?
+ 3. Would reaching out now be welcome or intrusive?
+
+ Please return a number between 0 and 100 indicating how strongly you want to check in
+ (0 = definitely not, 100 = definitely yes).
+
+
+ 50
+ Your reasoning here
+
+ """).format(proactive_prompt=proactive_prompt)
+
+ # Build context
+ system_prompt = comm_channel_prompt(
+ session, bot_user, target_user, channel
+ )
+
+ # First, evaluate whether we should check in
+ eval_response = call_llm(
+ session,
+ bot_user=bot_user,
+ from_user=target_user,
+ channel=channel,
+ model=settings.SUMMARIZER_MODEL,
+ system_prompt=system_prompt,
+ messages=[eval_prompt],
+ allowed_tools=[
+ "update_channel_summary",
+ "update_user_summary",
+ "update_server_summary",
+ ],
+ )
+
+ if not eval_response:
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+ return {"status": "no_eval_response", "entity_type": entity_type, "entity_id": entity_id}
+
+ # Parse the interest score
+ match = re.search(r"(\d+)", eval_response)
+ if not match:
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+ return {"status": "no_score_in_response", "entity_type": entity_type, "entity_id": entity_id}
+
+ interest_score = int(match.group(1))
+ threshold = 100 - chattiness
+
+ logger.info(
+ f"Proactive check-in eval: interest={interest_score}, threshold={threshold}, chattiness={chattiness}"
+ )
+
+ if interest_score < threshold:
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+ return {
+ "status": "below_threshold",
+ "interest": interest_score,
+ "threshold": threshold,
+ "entity_type": entity_type,
+ "entity_id": entity_id,
+ }
+
+ # Generate the actual check-in message
+ checkin_prompt = textwrap.dedent("""
+ You've decided to proactively check in. Generate a natural, friendly check-in message.
+
+ {proactive_prompt}
+
+ Keep it brief and genuine. Don't be overly formal or robotic.
+ Reference specific things from your notes if relevant.
+ """).format(proactive_prompt=proactive_prompt)
+
+ response = call_llm(
+ session,
+ bot_user=bot_user,
+ from_user=target_user,
+ channel=channel,
+ model=settings.DISCORD_MODEL,
+ system_prompt=system_prompt,
+ messages=[checkin_prompt],
+ )
+
+ if not response:
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+ return {"status": "no_message_generated", "entity_type": entity_type, "entity_id": entity_id}
+
+ # Send the message
+ bot_id = bot_user.system_user.id if bot_user.system_user else None
+ if not bot_id:
+ return {"error": "No system user for bot"}
+
+ success = send_discord_response(
+ bot_id=bot_id,
+ response=response,
+ channel_id=channel.id if channel else None,
+ user_identifier=target_user.username if target_user else None,
+ )
+
+ # Update last_proactive_at
+ entity.last_proactive_at = datetime.now(timezone.utc)
+ session.commit()
+
+ return {
+ "status": "sent" if success else "send_failed",
+ "interest": interest_score,
+ "entity_type": entity_type,
+ "entity_id": entity_id,
+ "response_preview": response[:100] + "..." if len(response) > 100 else response,
+ }
diff --git a/tests/integration/test_real_queries.py b/tests/integration/test_real_queries.py
index df46fbf..0da4a43 100644
--- a/tests/integration/test_real_queries.py
+++ b/tests/integration/test_real_queries.py
@@ -766,7 +766,7 @@ EXPECTED_OBSERVATION_RESULTS = {
),
(
0.409,
- "Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
+ "Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
),
],
},
@@ -835,11 +835,11 @@ EXPECTED_OBSERVATION_RESULTS = {
"semantic": [
(0.489, "I find backend logic more interesting than UI work"),
(0.462, "The user prefers working on backend systems over frontend UI"),
+ (0.455, "The user said pure functions are yucky"),
(
0.455,
"The user believes functional programming leads to better code quality",
),
- (0.455, "The user said pure functions are yucky"),
],
"temporal": [
(
diff --git a/tests/memory/api/search/test_query_analysis.py b/tests/memory/api/search/test_query_analysis.py
index 85ef343..2290d96 100644
--- a/tests/memory/api/search/test_query_analysis.py
+++ b/tests/memory/api/search/test_query_analysis.py
@@ -137,10 +137,9 @@ class TestBuildPrompt:
):
prompt = _build_prompt()
- assert "lesswrong" in prompt.lower()
- assert "comic" in prompt.lower()
- assert "Remove" in prompt
+ assert "Remove meta-language" in prompt
assert "Return ONLY valid JSON" in prompt
+ assert "recalled_content" in prompt
class TestAnalyzeQuery:
diff --git a/tests/memory/api/test_auth.py b/tests/memory/api/test_auth.py
index 708fb05..dddd814 100644
--- a/tests/memory/api/test_auth.py
+++ b/tests/memory/api/test_auth.py
@@ -83,9 +83,10 @@ def test_logout_handles_missing_session(mock_get_user_session):
@pytest.mark.asyncio
+@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session")
-async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
+async def test_oauth_callback_discord_success(mock_make_session, mock_complete, mock_mcp_tools):
mock_session = MagicMock()
@contextmanager
@@ -95,9 +96,12 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
mock_make_session.return_value = session_cm()
mcp_server = MagicMock()
+ mcp_server.mcp_server_url = "https://example.com"
+ mcp_server.access_token = "token123"
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (200, "Authorized")
+ mock_mcp_tools.return_value = [{"name": "test_tool"}]
request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request)
@@ -107,14 +111,15 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
assert "Authorization Successful" in body
assert "Authorized" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
- mock_session.commit.assert_called_once()
+ assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
@pytest.mark.asyncio
+@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
@patch("memory.api.auth.make_session")
async def test_oauth_callback_discord_handles_failures(
- mock_make_session, mock_complete
+ mock_make_session, mock_complete, mock_mcp_tools
):
mock_session = MagicMock()
@@ -125,9 +130,12 @@ async def test_oauth_callback_discord_handles_failures(
mock_make_session.return_value = session_cm()
mcp_server = MagicMock()
+ mcp_server.mcp_server_url = "https://example.com"
+ mcp_server.access_token = "token123"
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
mock_complete.return_value = (500, "Failure")
+ mock_mcp_tools.return_value = []
request = make_request("code=abc123&state=state456")
response = await auth.oauth_callback_discord(request)
@@ -137,7 +145,7 @@ async def test_oauth_callback_discord_handles_failures(
assert "Authorization Failed" in body
assert "Failure" in body
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
- mock_session.commit.assert_called_once()
+ assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
@pytest.mark.asyncio
diff --git a/tests/memory/common/llms/tools/test_discord_tools.py b/tests/memory/common/llms/tools/test_discord_tools.py
index 29ce374..81a2d19 100644
--- a/tests/memory/common/llms/tools/test_discord_tools.py
+++ b/tests/memory/common/llms/tools/test_discord_tools.py
@@ -18,6 +18,7 @@ from memory.common.db.models import (
DiscordUser,
DiscordMessage,
BotUser,
+ DiscordBotUser,
HumanUser,
ScheduledLLMCall,
)
@@ -67,9 +68,10 @@ def sample_discord_user(db_session):
@pytest.fixture
-def sample_bot_user(db_session):
+def sample_bot_user(db_session, sample_discord_user):
"""Create a sample bot user for testing."""
- bot = BotUser.create_with_api_key(
+ bot = DiscordBotUser.create_with_api_key(
+ discord_users=[sample_discord_user],
name="Test Bot",
email="testbot@example.com",
)
@@ -209,9 +211,9 @@ def test_schedule_message_with_user(
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
result = schedule_message(
- user_id=sample_human_user.id,
- user=sample_discord_user.id,
- channel=None,
+ bot_id=sample_human_user.id,
+ recipient_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
message="Test message",
date_time=future_time,
@@ -240,9 +242,9 @@ def test_schedule_message_with_channel(
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
result = schedule_message(
- user_id=sample_human_user.id,
- user=None,
- channel=sample_discord_channel.id,
+ bot_id=sample_human_user.id,
+ recipient_id=None,
+ channel_id=sample_discord_channel.id,
model="test-model",
message="Test message",
date_time=future_time,
@@ -265,12 +267,12 @@ def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user):
"""Test creating a message scheduler tool for a user."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
- assert tool.name == "schedule_message"
+ assert tool.name == "schedule_discord_message"
assert "from your chat with this user" in tool.description
assert tool.input_schema["type"] == "object"
assert "message" in tool.input_schema["properties"]
@@ -282,12 +284,12 @@ def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_cha
"""Test creating a message scheduler tool for a channel."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=None,
- channel=sample_discord_channel.id,
+ user_id=None,
+ channel_id=sample_discord_channel.id,
model="test-model",
)
- assert tool.name == "schedule_message"
+ assert tool.name == "schedule_discord_message"
assert "in this channel" in tool.description
assert callable(tool.function)
@@ -297,8 +299,8 @@ def test_make_message_scheduler_without_user_or_channel(sample_bot_user):
with pytest.raises(ValueError, match="Either user or channel must be provided"):
make_message_scheduler(
bot=sample_bot_user,
- user=None,
- channel=None,
+ user_id=None,
+ channel_id=None,
model="test-model",
)
@@ -310,8 +312,8 @@ def test_message_scheduler_handler_success(
"""Test message scheduler handler with valid input."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
@@ -330,8 +332,8 @@ def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord
"""Test message scheduler handler with non-dict input."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
@@ -345,8 +347,8 @@ def test_message_scheduler_handler_invalid_datetime(
"""Test message scheduler handler with invalid datetime."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
@@ -365,8 +367,8 @@ def test_message_scheduler_handler_missing_datetime(
"""Test message scheduler handler with missing datetime."""
tool = make_message_scheduler(
bot=sample_bot_user,
- user=sample_discord_user.id,
- channel=None,
+ user_id=sample_discord_user.id,
+ channel_id=None,
model="test-model",
)
@@ -375,9 +377,9 @@ def test_message_scheduler_handler_missing_datetime(
# Tests for make_prev_messages_tool
-def test_make_prev_messages_tool_with_user(sample_discord_user):
+def test_make_prev_messages_tool_with_user(sample_bot_user, sample_discord_user):
"""Test creating a previous messages tool for a user."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
assert tool.name == "previous_messages"
assert "from your chat with this user" in tool.description
@@ -387,26 +389,26 @@ def test_make_prev_messages_tool_with_user(sample_discord_user):
assert callable(tool.function)
-def test_make_prev_messages_tool_with_channel(sample_discord_channel):
+def test_make_prev_messages_tool_with_channel(sample_bot_user, sample_discord_channel):
"""Test creating a previous messages tool for a channel."""
- tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=sample_discord_channel.id)
assert tool.name == "previous_messages"
assert "in this channel" in tool.description
assert callable(tool.function)
-def test_make_prev_messages_tool_without_user_or_channel():
+def test_make_prev_messages_tool_without_user_or_channel(sample_bot_user):
"""Test that creating a tool without user or channel raises error."""
with pytest.raises(ValueError, match="Either user or channel must be provided"):
- make_prev_messages_tool(user=None, channel=None)
+ make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=None)
def test_prev_messages_handler_success(
- db_session, sample_discord_user, sample_discord_channel
+ db_session, sample_bot_user, sample_discord_user, sample_discord_channel
):
"""Test previous messages handler with valid input."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
# Create some actual messages in the database
msg1 = DiscordMessage(
@@ -440,9 +442,9 @@ def test_prev_messages_handler_success(
assert "Message 1" in result or "Message 2" in result
-def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
+def test_prev_messages_handler_with_defaults(db_session, sample_bot_user, sample_discord_user):
"""Test previous messages handler with default values."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
result = tool.function({})
@@ -450,35 +452,35 @@ def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
assert isinstance(result, str)
-def test_prev_messages_handler_invalid_input(sample_discord_user):
+def test_prev_messages_handler_invalid_input(sample_bot_user, sample_discord_user):
"""Test previous messages handler with non-dict input."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Input must be a dictionary"):
tool.function("not a dict")
-def test_prev_messages_handler_invalid_max_messages(sample_discord_user):
+def test_prev_messages_handler_invalid_max_messages(sample_bot_user, sample_discord_user):
"""Test previous messages handler with invalid max_messages (negative value)."""
# Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
# so we test with -1 which actually triggers the validation
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Max messages must be greater than 0"):
tool.function({"max_messages": -1})
-def test_prev_messages_handler_invalid_offset(sample_discord_user):
+def test_prev_messages_handler_invalid_offset(sample_bot_user, sample_discord_user):
"""Test previous messages handler with invalid offset."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
tool.function({"offset": -1})
-def test_prev_messages_handler_non_integer_values(sample_discord_user):
+def test_prev_messages_handler_non_integer_values(sample_bot_user, sample_discord_user):
"""Test previous messages handler with non-integer values."""
- tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
+ tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
with pytest.raises(ValueError, match="Max messages and offset must be integers"):
tool.function({"max_messages": "not an int"})
@@ -496,10 +498,10 @@ def test_make_discord_tools_with_user_and_channel(
model="test-model",
)
- # Should have: schedule_message, previous_messages, update_channel_summary,
+ # Should have: schedule_discord_message, previous_messages, update_channel_summary,
# update_user_summary, update_server_summary, add_reaction
assert len(tools) == 6
- assert "schedule_message" in tools
+ assert "schedule_discord_message" in tools
assert "previous_messages" in tools
assert "update_channel_summary" in tools
assert "update_user_summary" in tools
@@ -516,10 +518,10 @@ def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user)
model="test-model",
)
- # Should have: schedule_message, previous_messages, update_user_summary
+ # Should have: schedule_discord_message, previous_messages, update_user_summary
# Note: Without channel, there's no channel summary tool
assert len(tools) >= 2 # At least schedule and previous messages
- assert "schedule_message" in tools
+ assert "schedule_discord_message" in tools
assert "previous_messages" in tools
assert "update_user_summary" in tools
@@ -533,10 +535,10 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
model="test-model",
)
- # Should have: schedule_message, previous_messages, update_channel_summary,
+ # Should have: schedule_discord_message, previous_messages, update_channel_summary,
# update_server_summary, add_reaction (no user summary without author)
assert len(tools) == 5
- assert "schedule_message" in tools
+ assert "schedule_discord_message" in tools
assert "previous_messages" in tools
assert "update_channel_summary" in tools
assert "update_server_summary" in tools
diff --git a/tests/memory/common/test_discord.py b/tests/memory/common/test_discord.py
index ce4cf3f..55db891 100644
--- a/tests/memory/common/test_discord.py
+++ b/tests/memory/common/test_discord.py
@@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
"http://localhost:8000/send_channel",
json={
"bot_id": BOT_ID,
- "channel_name": "general",
+ "channel": "general",
"message": "Announcement!",
},
timeout=10,
diff --git a/tests/memory/common/test_discord_integration.py b/tests/memory/common/test_discord_integration.py
index ce4cf3f..55db891 100644
--- a/tests/memory/common/test_discord_integration.py
+++ b/tests/memory/common/test_discord_integration.py
@@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
"http://localhost:8000/send_channel",
json={
"bot_id": BOT_ID,
- "channel_name": "general",
+ "channel": "general",
"message": "Announcement!",
},
timeout=10,
diff --git a/tests/memory/discord_tests/test_commands.py b/tests/memory/discord_tests/test_commands.py
index fd2bf9d..9d05195 100644
--- a/tests/memory/discord_tests/test_commands.py
+++ b/tests/memory/discord_tests/test_commands.py
@@ -15,6 +15,7 @@ from memory.discord.commands import (
handle_chattiness,
handle_ignore,
handle_summary,
+ handle_proactive,
respond,
with_object_context,
handle_mcp_servers,
@@ -377,3 +378,207 @@ async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
await handle_mcp_servers(context, action="list", url=None)
assert "Error: boom" in str(exc.value)
+
+
+# ============================================================================
+# Tests for handle_proactive
+# ============================================================================
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_show_disabled(db_session, interaction, guild):
+ """Test showing proactive settings when disabled."""
+ server = DiscordServer(id=guild.id, name="Guild", proactive_cron=None)
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context)
+
+ assert "disabled" in response.content.lower()
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_show_enabled(db_session, interaction, guild):
+ """Test showing proactive settings when enabled."""
+ server = DiscordServer(
+ id=guild.id,
+ name="Guild",
+ proactive_cron="0 9 * * *",
+ proactive_prompt="Check on projects",
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context)
+
+ assert "0 9 * * *" in response.content
+ assert "Check on projects" in response.content
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_set_cron(db_session, interaction, guild):
+ """Test setting proactive cron schedule."""
+ server = DiscordServer(id=guild.id, name="Guild")
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context, cron="0 9 * * *")
+
+ assert "Updated" in response.content
+ assert "0 9 * * *" in response.content
+ assert server.proactive_cron == "0 9 * * *"
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_set_prompt(db_session, interaction, guild):
+ """Test setting proactive prompt."""
+ server = DiscordServer(id=guild.id, name="Guild", proactive_cron="0 9 * * *")
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context, prompt="Focus on daily standups")
+
+ assert "Updated" in response.content
+ assert server.proactive_prompt == "Focus on daily standups"
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_disable(db_session, interaction, guild):
+ """Test disabling proactive check-ins."""
+ server = DiscordServer(
+ id=guild.id,
+ name="Guild",
+ proactive_cron="0 9 * * *",
+ proactive_prompt="Some prompt",
+ )
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ response = await handle_proactive(context, cron="off")
+
+ assert "disabled" in response.content.lower()
+ assert server.proactive_cron is None
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_invalid_cron(db_session, interaction, guild):
+ """Test error on invalid cron expression."""
+ server = DiscordServer(id=guild.id, name="Guild")
+ db_session.add(server)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="server",
+ target=server,
+ display_name="server **Guild**",
+ )
+
+ with pytest.raises(CommandError) as exc:
+ await handle_proactive(context, cron="not a valid cron")
+
+ assert "Invalid cron expression" in str(exc.value)
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_user_scope(db_session, interaction, discord_user):
+ """Test proactive settings for user scope."""
+ user_model = DiscordUser(
+ id=discord_user.id, username="testuser", proactive_cron=None
+ )
+ db_session.add(user_model)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="me",
+ target=user_model,
+ display_name="you (**testuser**)",
+ )
+
+ response = await handle_proactive(context, cron="0 9,17 * * 1-5")
+
+ assert "Updated" in response.content
+ assert user_model.proactive_cron == "0 9,17 * * 1-5"
+
+
+@pytest.mark.asyncio
+async def test_handle_proactive_channel_scope(
+ db_session, interaction, guild, text_channel
+):
+ """Test proactive settings for channel scope."""
+ server = DiscordServer(id=guild.id, name="Guild")
+ db_session.add(server)
+ db_session.flush()
+
+ channel_model = DiscordChannel(
+ id=text_channel.id,
+ name="general",
+ channel_type="text",
+ server_id=guild.id,
+ )
+ db_session.add(channel_model)
+ db_session.commit()
+
+ context = CommandContext(
+ session=db_session,
+ interaction=interaction,
+ actor=MagicMock(spec=DiscordUser),
+ scope="channel",
+ target=channel_model,
+ display_name="channel **#general**",
+ )
+
+ response = await handle_proactive(context, cron="0 12 * * *")
+
+ assert "Updated" in response.content
+ assert channel_model.proactive_cron == "0 12 * * *"
diff --git a/tests/memory/discord_tests/test_mcp.py b/tests/memory/discord_tests/test_mcp.py
index 37a7649..0e59f89 100644
--- a/tests/memory/discord_tests/test_mcp.py
+++ b/tests/memory/discord_tests/test_mcp.py
@@ -1,6 +1,5 @@
"""Tests for Discord MCP server management."""
-import json
from unittest.mock import AsyncMock, Mock, patch
import aiohttp
@@ -8,8 +7,8 @@ import discord
import pytest
from memory.common.db.models import MCPServer, MCPServerAssignment
+from memory.common.mcp import mcp_call
from memory.discord.mcp import (
- call_mcp_server,
find_mcp_server,
handle_mcp_add,
handle_mcp_connect,
@@ -142,7 +141,7 @@ async def test_call_mcp_server_success():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = []
- async for data in call_mcp_server(
+ async for data in mcp_call(
"https://mcp.example.com", "test_token", "tools/list", {}
):
results.append(data)
@@ -172,7 +171,7 @@ async def test_call_mcp_server_error():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
with pytest.raises(ValueError, match="Failed to call MCP server"):
- async for _ in call_mcp_server(
+ async for _ in mcp_call(
"https://mcp.example.com", "test_token", "tools/list"
):
pass
@@ -203,7 +202,7 @@ async def test_call_mcp_server_invalid_json():
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
results = []
- async for data in call_mcp_server(
+ async for data in mcp_call(
"https://mcp.example.com", "test_token", "tools/list"
):
results.append(data)
diff --git a/tests/memory/discord_tests/test_messages.py b/tests/memory/discord_tests/test_messages.py
index e2f60a9..edabcc2 100644
--- a/tests/memory/discord_tests/test_messages.py
+++ b/tests/memory/discord_tests/test_messages.py
@@ -19,6 +19,7 @@ from memory.common.db.models import (
DiscordChannel,
DiscordServer,
DiscordMessage,
+ DiscordBotUser,
HumanUser,
ScheduledLLMCall,
)
@@ -34,6 +35,19 @@ def sample_discord_user(db_session):
return user
+@pytest.fixture
+def sample_bot_user(db_session, sample_discord_user):
+ """Create a sample Discord bot user."""
+ bot = DiscordBotUser.create_with_api_key(
+ discord_users=[sample_discord_user],
+ name="Test Bot",
+ email="testbot@example.com",
+ )
+ db_session.add(bot)
+ db_session.commit()
+ return bot
+
+
@pytest.fixture
def sample_discord_channel(db_session):
"""Create a sample Discord channel."""
@@ -290,13 +304,13 @@ def test_upsert_scheduled_message_cancels_earlier_call(
# Test previous_messages
-def test_previous_messages_empty(db_session):
+def test_previous_messages_empty(db_session, sample_bot_user):
"""Test getting previous messages when none exist."""
- result = previous_messages(db_session, user_id=123, channel_id=456)
+ result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=123, channel_id=456)
assert result == []
-def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel):
+def test_previous_messages_filters_by_user(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test filtering messages by recipient user."""
# Create some messages
msg1 = DiscordMessage(
@@ -322,14 +336,14 @@ def test_previous_messages_filters_by_user(db_session, sample_discord_user, samp
db_session.add_all([msg1, msg2])
db_session.commit()
- result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None)
+ result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None)
assert len(result) == 2
# Should be in chronological order (oldest first)
assert result[0].message_id == 1
assert result[1].message_id == 2
-def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel):
+def test_previous_messages_limits_results(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test limiting the number of previous messages."""
# Create 15 messages
for i in range(15):
@@ -347,7 +361,7 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
db_session.commit()
result = previous_messages(
- db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5
+ db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None, max_messages=5
)
assert len(result) == 5
@@ -355,10 +369,10 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
# Test comm_channel_prompt
-def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel):
+def test_comm_channel_prompt_basic(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
"""Test generating a basic communication channel prompt."""
result = comm_channel_prompt(
- db_session, user=sample_discord_user, channel=sample_discord_channel
+ db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
)
assert "You are a bot communicating on Discord" in result
@@ -366,31 +380,31 @@ def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_disco
assert len(result) > 0
-def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel):
+def test_comm_channel_prompt_includes_server_context(db_session, sample_bot_user, sample_discord_channel):
"""Test that prompt includes server context when available."""
server = sample_discord_channel.server
server.summary = "Gaming community server"
db_session.commit()
- result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
+ result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
assert "server_context" in result.lower()
assert "Gaming community server" in result
-def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel):
+def test_comm_channel_prompt_includes_channel_context(db_session, sample_bot_user, sample_discord_channel):
"""Test that prompt includes channel context."""
sample_discord_channel.summary = "General discussion channel"
db_session.commit()
- result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
+ result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
assert "channel_context" in result.lower()
assert "General discussion channel" in result
def test_comm_channel_prompt_includes_user_notes(
- db_session, sample_discord_user, sample_discord_channel
+ db_session, sample_bot_user, sample_discord_user, sample_discord_channel
):
"""Test that prompt includes user notes from previous messages."""
sample_discord_user.summary = "Helpful community member"
@@ -411,7 +425,7 @@ def test_comm_channel_prompt_includes_user_notes(
db_session.commit()
result = comm_channel_prompt(
- db_session, user=sample_discord_user, channel=sample_discord_channel
+ db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
)
assert "user_notes" in result.lower()
@@ -442,12 +456,16 @@ def test_call_llm_includes_web_search_and_mcp_servers(
web_tool_instance = MagicMock(name="web_tool")
mock_web_search.return_value = web_tool_instance
- bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
+ bot_user = SimpleNamespace(
+ system_user=SimpleNamespace(discord_id=999888777),
+ system_prompt="bot prompt"
+ )
from_user = SimpleNamespace(id=123)
mcp_model = SimpleNamespace(
name="Server",
mcp_server_url="https://mcp.example.com",
access_token="token123",
+ disabled_tools=[],
)
result = call_llm(
@@ -502,7 +520,10 @@ def test_call_llm_filters_disallowed_tools(
mock_web_search.return_value = MagicMock(name="web_tool")
- bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
+ bot_user = SimpleNamespace(
+ system_user=SimpleNamespace(discord_id=999888777),
+ system_prompt=None
+ )
from_user = SimpleNamespace(id=1)
call_llm(
diff --git a/tests/memory/workers/tasks/test_discord_tasks.py b/tests/memory/workers/tasks/test_discord_tasks.py
index 8decc7d..9650ba7 100644
--- a/tests/memory/workers/tasks/test_discord_tasks.py
+++ b/tests/memory/workers/tasks/test_discord_tasks.py
@@ -14,12 +14,25 @@ from memory.workers.tasks import discord
@pytest.fixture
def discord_bot_user(db_session):
+ # Create a discord user for the bot first
+ bot_discord_user = DiscordUser(
+ id=999999999,
+ username="testbot",
+ )
+ db_session.add(bot_discord_user)
+ db_session.flush()
+
bot = DiscordBotUser.create_with_api_key(
- discord_users=[],
+ discord_users=[bot_discord_user],
name="Test Bot",
email="bot@example.com",
)
db_session.add(bot)
+ db_session.flush()
+
+ # Link the discord user to the system user
+ bot_discord_user.system_user_id = bot.id
+
db_session.commit()
return bot
@@ -176,26 +189,29 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel):
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
-@patch("memory.workers.tasks.discord.create_provider")
+@patch("memory.workers.tasks.discord.call_llm")
+@patch("memory.workers.tasks.discord.discord.trigger_typing_channel")
def test_should_process_normal_message(
- mock_create_provider,
+ mock_trigger_typing,
+ mock_call_llm,
db_session,
mock_discord_user,
mock_discord_server,
mock_discord_channel,
+ discord_bot_user,
):
"""Test should_process returns True for normal messages."""
- # Mock the LLM provider to return "yes"
- mock_provider = Mock()
- mock_provider.generate.return_value = "yes"
- mock_provider.as_messages.return_value = []
- mock_create_provider.return_value = mock_provider
+ # Create a separate recipient user (the bot)
+ bot_discord_user = discord_bot_user.discord_users[0]
+
+ # Mock call_llm to return a high number (100 = always process)
+ mock_call_llm.return_value = "100Test"
message = DiscordMessage(
message_id=1,
channel_id=mock_discord_channel.id,
from_id=mock_discord_user.id,
- recipient_id=mock_discord_user.id,
+ recipient_id=bot_discord_user.id, # Bot is recipient, not the from_user
server_id=mock_discord_server.id,
content="Test",
sent_at=datetime.now(timezone.utc),
@@ -207,6 +223,8 @@ def test_should_process_normal_message(
db_session.refresh(message)
assert discord.should_process(message) is True
+ mock_call_llm.assert_called_once()
+ mock_trigger_typing.assert_called_once()
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
@@ -344,6 +362,7 @@ def test_add_discord_message_success(db_session, sample_message_data, qdrant):
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
"""Test adding a Discord message that is a reply."""
sample_message_data["message_reference_id"] = 111222333
+ sample_message_data["message_type"] = "reply" # Explicitly set message_type
discord.add_discord_message(**sample_message_data)
@@ -523,8 +542,17 @@ def test_edit_discord_message_updates_context(
assert result["status"] == "processed"
-def test_process_discord_message_success(db_session, sample_message_data, qdrant):
+@patch("memory.workers.tasks.discord.send_discord_response")
+@patch("memory.workers.tasks.discord.call_llm")
+def test_process_discord_message_success(
+ mock_call_llm, mock_send_response, db_session, sample_message_data, qdrant
+):
"""Test processing a Discord message."""
+ # Mock LLM to return a response
+ mock_call_llm.return_value = "Test response from bot"
+ # Mock Discord API to succeed
+ mock_send_response.return_value = True
+
# Add a message first
add_result = discord.add_discord_message(**sample_message_data)
message_id = add_result["discordmessage_id"]
@@ -534,6 +562,8 @@ def test_process_discord_message_success(db_session, sample_message_data, qdrant
assert result["status"] == "processed"
assert result["message_id"] == message_id
+ mock_call_llm.assert_called_once()
+ mock_send_response.assert_called_once()
def test_process_discord_message_not_found(db_session):
diff --git a/tests/memory/workers/tasks/test_proactive.py b/tests/memory/workers/tasks/test_proactive.py
new file mode 100644
index 0000000..3edcd38
--- /dev/null
+++ b/tests/memory/workers/tasks/test_proactive.py
@@ -0,0 +1,536 @@
+"""Tests for proactive check-in tasks."""
+
+import pytest
+from datetime import datetime, timezone, timedelta
+from unittest.mock import Mock, patch, MagicMock
+
+from memory.common.db.models import (
+ DiscordBotUser,
+ DiscordUser,
+ DiscordChannel,
+ DiscordServer,
+)
+from memory.workers.tasks import proactive
+from memory.workers.tasks.proactive import is_cron_due
+
+
+# ============================================================================
+# Fixtures
+# ============================================================================
+
+
+@pytest.fixture
+def bot_user(db_session):
+ """Create a bot user for testing."""
+ bot_discord_user = DiscordUser(
+ id=999999999,
+ username="testbot",
+ )
+ db_session.add(bot_discord_user)
+ db_session.flush()
+
+ user = DiscordBotUser.create_with_api_key(
+ discord_users=[bot_discord_user],
+ name="testbot",
+ email="bot@example.com",
+ )
+ db_session.add(user)
+ db_session.commit()
+ return user
+
+
+@pytest.fixture
+def target_user(db_session):
+ """Create a target Discord user for testing."""
+ discord_user = DiscordUser(
+ id=123456789,
+ username="targetuser",
+ proactive_cron="0 9 * * *", # 9am daily
+ chattiness_threshold=50,
+ )
+ db_session.add(discord_user)
+ db_session.commit()
+ return discord_user
+
+
+@pytest.fixture
+def target_user_no_cron(db_session):
+ """Create a target Discord user without proactive cron."""
+ discord_user = DiscordUser(
+ id=123456790,
+ username="nocronuser",
+ proactive_cron=None,
+ )
+ db_session.add(discord_user)
+ db_session.commit()
+ return discord_user
+
+
+@pytest.fixture
+def target_server(db_session):
+ """Create a target Discord server for testing."""
+ server = DiscordServer(
+ id=987654321,
+ name="Test Server",
+ proactive_cron="0 */4 * * *", # Every 4 hours
+ chattiness_threshold=30,
+ )
+ db_session.add(server)
+ db_session.commit()
+ return server
+
+
+@pytest.fixture
+def target_channel(db_session, target_server):
+ """Create a target Discord channel for testing."""
+ channel = DiscordChannel(
+ id=111222333,
+ name="test-channel",
+ channel_type="text",
+ server_id=target_server.id,
+ proactive_cron="0 12 * * 1-5", # Noon on weekdays
+ chattiness_threshold=70,
+ )
+ db_session.add(channel)
+ db_session.commit()
+ return channel
+
+
+# ============================================================================
+# Tests for is_cron_due helper
+# ============================================================================
+
+
+@pytest.mark.parametrize(
+ "cron_expr,now,last_run,expected",
+ [
+ # Cron is due when never run before and time matches
+ (
+ "0 9 * * *",
+ datetime(2025, 12, 24, 9, 0, 30, tzinfo=timezone.utc),
+ None,
+ True,
+ ),
+ # Cron is due when last run was before the scheduled time
+ (
+ "0 9 * * *",
+ datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc),
+ datetime(2025, 12, 23, 9, 0, 0, tzinfo=timezone.utc),
+ True,
+ ),
+ # Cron is NOT due when already run this period
+ (
+ "0 9 * * *",
+ datetime(2025, 12, 24, 9, 30, 0, tzinfo=timezone.utc),
+ datetime(2025, 12, 24, 9, 5, 0, tzinfo=timezone.utc),
+ False,
+ ),
+ # Cron is NOT due when current time is before scheduled time
+ (
+ "0 9 * * *",
+ datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
+ None,
+ False,
+ ),
+ # Hourly cron schedule
+ (
+ "0 * * * *",
+ datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
+ datetime(2025, 12, 24, 11, 0, 0, tzinfo=timezone.utc),
+ True,
+ ),
+ # Every 4 hours cron schedule
+ (
+ "0 */4 * * *",
+ datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
+ datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
+ True,
+ ),
+ ],
+ ids=[
+ "due_never_run",
+ "due_last_run_before_schedule",
+ "not_due_already_run",
+ "not_due_too_early",
+ "due_hourly",
+ "due_every_4_hours",
+ ],
+)
+def test_is_cron_due(cron_expr, now, last_run, expected):
+ """Test is_cron_due with various scenarios."""
+ assert is_cron_due(cron_expr, last_run, now) is expected
+
+
+def test_is_cron_due_invalid_expression():
+ """Test invalid cron expression returns False."""
+ now = datetime(2025, 12, 24, 9, 0, 0, tzinfo=timezone.utc)
+ assert is_cron_due("invalid cron", None, now) is False
+
+
+def test_is_cron_due_with_naive_last_run():
+ """Test cron handles naive datetime for last_run."""
+ now = datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc)
+ cron_expr = "0 9 * * *"
+ last_run = datetime(2025, 12, 23, 9, 0, 0) # Naive datetime
+ assert is_cron_due(cron_expr, last_run, now) is True
+
+
+# ============================================================================
+# Tests for evaluate_proactive_checkins task
+# ============================================================================
+
+
+@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
+@patch("memory.workers.tasks.proactive.is_cron_due")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_evaluate_proactive_checkins_dispatches_due(
+ mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
+):
+ """Test that due check-ins are dispatched."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+ mock_is_cron_due.return_value = True
+
+ result = proactive.evaluate_proactive_checkins()
+
+ assert result["count"] >= 1
+ mock_execute.delay.assert_called()
+
+
+@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
+@patch("memory.workers.tasks.proactive.is_cron_due")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_evaluate_proactive_checkins_skips_not_due(
+ mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
+):
+ """Test that not-due check-ins are not dispatched."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+ mock_is_cron_due.return_value = False
+
+ result = proactive.evaluate_proactive_checkins()
+
+ assert result["count"] == 0
+ mock_execute.delay.assert_not_called()
+
+
+@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_evaluate_proactive_checkins_skips_no_cron(
+ mock_make_session, mock_execute, db_session, target_user_no_cron
+):
+ """Test that entities without proactive_cron are skipped."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ result = proactive.evaluate_proactive_checkins()
+
+ for call in mock_execute.delay.call_args_list:
+ entity_type, entity_id = call[0]
+ assert entity_id != target_user_no_cron.id
+
+
+@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
+@patch("memory.workers.tasks.proactive.is_cron_due")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_evaluate_proactive_checkins_multiple_entity_types(
+ mock_make_session,
+ mock_is_cron_due,
+ mock_execute,
+ db_session,
+ target_user,
+ target_server,
+ target_channel,
+):
+ """Test that check-ins are dispatched for users, channels, and servers."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+ mock_is_cron_due.return_value = True
+
+ result = proactive.evaluate_proactive_checkins()
+
+ assert result["count"] == 3
+ dispatched_types = {d["type"] for d in result["dispatched"]}
+ assert "user" in dispatched_types
+ assert "channel" in dispatched_types
+ assert "server" in dispatched_types
+
+
+# ============================================================================
+# Tests for execute_proactive_checkin task
+# ============================================================================
+
+
+@patch("memory.workers.tasks.proactive.send_discord_response")
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_sends_when_above_threshold(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ mock_send,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test check-in is sent when interest exceeds threshold."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.side_effect = [
+ "80Should check in",
+ "Hey! Just checking in - how are things going?",
+ ]
+ mock_send.return_value = True
+
+ result = proactive.execute_proactive_checkin("user", target_user.id)
+
+ assert result["status"] == "sent"
+ assert result["interest"] == 80
+ mock_send.assert_called_once()
+
+ db_session.refresh(target_user)
+ assert target_user.last_proactive_at is not None
+
+
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_skips_below_threshold(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test check-in is skipped when interest is below threshold."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.return_value = (
+ "30Not much to say"
+ )
+
+ result = proactive.execute_proactive_checkin("user", target_user.id)
+
+ assert result["status"] == "below_threshold"
+ assert result["interest"] == 30
+ assert result["threshold"] == 50
+
+
+@pytest.mark.parametrize(
+ "llm_response,expected_status",
+ [
+ (None, "no_eval_response"),
+ ("I'm not sure what to say.", "no_score_in_response"),
+ ],
+ ids=["no_response", "malformed_response"],
+)
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_handles_bad_llm_response(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ llm_response,
+ expected_status,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test handling of missing or malformed LLM responses."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+ mock_call_llm.return_value = llm_response
+
+ result = proactive.execute_proactive_checkin("user", target_user.id)
+
+ assert result["status"] == expected_status
+
+
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_nonexistent_entity(mock_make_session, db_session):
+ """Test handling when entity doesn't exist."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ result = proactive.execute_proactive_checkin("user", 999999)
+
+ assert "error" in result
+ assert "not found" in result["error"]
+
+
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_no_bot_user(
+ mock_make_session, mock_get_bot, db_session, target_user
+):
+ """Test handling when no bot user is found."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+ mock_get_bot.return_value = None
+
+ result = proactive.execute_proactive_checkin("user", target_user.id)
+
+ assert "error" in result
+ assert "No bot user" in result["error"]
+
+
+@patch("memory.workers.tasks.proactive.send_discord_response")
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_uses_proactive_prompt(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ mock_send,
+ db_session,
+ bot_user,
+):
+ """Test that proactive_prompt is included in the evaluation."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ user_with_prompt = DiscordUser(
+ id=555666777,
+ username="promptuser",
+ proactive_cron="0 9 * * *",
+ proactive_prompt="Focus on their coding projects",
+ chattiness_threshold=50,
+ )
+ db_session.add(user_with_prompt)
+ db_session.commit()
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.side_effect = [
+ "80Check on projects",
+ "How are your coding projects coming along?",
+ ]
+ mock_send.return_value = True
+
+ result = proactive.execute_proactive_checkin("user", user_with_prompt.id)
+
+ assert result["status"] == "sent"
+ call_args = mock_call_llm.call_args_list[0]
+ messages_arg = call_args.kwargs.get("messages") or call_args[1].get("messages")
+ assert any("Focus on their coding projects" in str(m) for m in messages_arg)
+
+
+@patch("memory.workers.tasks.proactive.send_discord_response")
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_channel(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ mock_send,
+ db_session,
+ target_channel,
+ bot_user,
+):
+ """Test check-in to a channel."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.side_effect = [
+ "50Check channel",
+ "Good morning everyone!",
+ ]
+ mock_send.return_value = True
+
+ result = proactive.execute_proactive_checkin("channel", target_channel.id)
+
+ assert result["status"] == "sent"
+ assert result["entity_type"] == "channel"
+
+ send_call = mock_send.call_args
+ assert send_call.kwargs.get("channel_id") == target_channel.id
+
+
+@patch("memory.workers.tasks.proactive.send_discord_response")
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_updates_last_proactive_at(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ mock_send,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test that last_proactive_at is updated after successful check-in."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.side_effect = [
+ "80Check in",
+ "Hey there!",
+ ]
+ mock_send.return_value = True
+
+ before_time = datetime.now(timezone.utc)
+ proactive.execute_proactive_checkin("user", target_user.id)
+ after_time = datetime.now(timezone.utc)
+
+ db_session.refresh(target_user)
+ assert target_user.last_proactive_at is not None
+ assert before_time <= target_user.last_proactive_at <= after_time
+
+
+@patch("memory.workers.tasks.proactive.call_llm")
+@patch("memory.workers.tasks.proactive.get_bot_for_entity")
+@patch("memory.workers.tasks.proactive.make_session")
+def test_execute_proactive_checkin_updates_last_proactive_at_on_skip(
+ mock_make_session,
+ mock_get_bot,
+ mock_call_llm,
+ db_session,
+ target_user,
+ bot_user,
+):
+ """Test that last_proactive_at is updated even when check-in is skipped."""
+ mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
+ mock_make_session.return_value.__exit__ = Mock(return_value=False)
+
+ bot_discord_user = bot_user.discord_users[0]
+ bot_discord_user.system_user = bot_user
+ mock_get_bot.return_value = bot_discord_user
+
+ mock_call_llm.return_value = (
+ "10Nothing to say"
+ )
+
+ proactive.execute_proactive_checkin("user", target_user.id)
+
+ db_session.refresh(target_user)
+ assert target_user.last_proactive_at is not None
diff --git a/tests/memory/workers/tasks/test_scheduled_calls.py b/tests/memory/workers/tasks/test_scheduled_calls.py
index 44d567d..ceb28d5 100644
--- a/tests/memory/workers/tasks/test_scheduled_calls.py
+++ b/tests/memory/workers/tasks/test_scheduled_calls.py
@@ -16,8 +16,16 @@ from memory.workers.tasks import scheduled_calls
@pytest.fixture
def sample_user(db_session):
"""Create a sample user for testing."""
+ # Create a discord user for the bot
+ bot_discord_user = DiscordUser(
+ id=999999999,
+ username="testbot",
+ )
+ db_session.add(bot_discord_user)
+ db_session.flush()
+
user = DiscordBotUser.create_with_api_key(
- discord_users=[],
+ discord_users=[bot_discord_user],
name="testbot",
email="bot@example.com",
)
@@ -122,65 +130,64 @@ def future_scheduled_call(db_session, sample_user, sample_discord_user):
return call
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
+@patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
"""Test sending to Discord user."""
response = "This is a test response."
- scheduled_calls.send_to_discord(pending_scheduled_call, response)
+ scheduled_calls.send_to_discord(999999999, pending_scheduled_call, response)
mock_send_dm.assert_called_once_with(
- pending_scheduled_call.user_id,
+ 999999999, # bot_id
"testuser", # username, not ID
- "**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
+ response,
)
-@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
-def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
+@patch("memory.discord.messages.discord.send_to_channel")
+def test_send_to_discord_channel(mock_send_to_channel, completed_scheduled_call):
"""Test sending to Discord channel."""
response = "This is a channel response."
- scheduled_calls.send_to_discord(completed_scheduled_call, response)
+ scheduled_calls.send_to_discord(999999999, completed_scheduled_call, response)
- mock_broadcast.assert_called_once_with(
- completed_scheduled_call.user_id,
- "test-channel", # channel name, not ID
- "**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
+ mock_send_to_channel.assert_called_once_with(
+ 999999999, # bot_id
+ completed_scheduled_call.discord_channel.id, # channel ID, not name
+ response,
)
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
+@patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
"""Test message truncation for long responses."""
long_response = "A" * 2500 # Very long response
- scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
+ scheduled_calls.send_to_discord(999999999, pending_scheduled_call, long_response)
- # Verify the message was truncated
+ # With the new implementation, send_discord_response sends the full response
+ # No truncation happens in _send_to_discord
args, kwargs = mock_send_dm.call_args
- assert args[0] == pending_scheduled_call.user_id
+ assert args[0] == 999999999 # bot_id
message = args[2]
- assert len(message) <= 1950 # Should be truncated
- assert message.endswith("... (response truncated)")
+ assert message == long_response
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
+@patch("memory.discord.messages.discord.send_dm")
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
"""Test that normal length messages are not truncated."""
normal_response = "This is a normal length response."
- scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
+ scheduled_calls.send_to_discord(999999999, pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args
- assert args[0] == pending_scheduled_call.user_id
+ assert args[0] == 999999999 # bot_id
message = args[2]
- assert not message.endswith("... (response truncated)")
- assert "This is a normal length response." in message
+ assert message == normal_response
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_success(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -189,12 +196,8 @@ def test_execute_scheduled_call_success(
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
- # Verify LLM was called with correct parameters
- mock_llm_call.assert_called_once_with(
- prompt="What is the weather like today?",
- model="anthropic/claude-3-5-sonnet-20241022",
- system_prompt="You are a helpful assistant.",
- )
+ # Verify LLM was called
+ mock_llm_call.assert_called_once()
# Verify result
assert result["success"] is True
@@ -218,7 +221,7 @@ def test_execute_scheduled_call_not_found(db_session):
assert result == {"error": "Scheduled call not found"}
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_not_pending(
mock_llm_call, completed_scheduled_call, db_session
):
@@ -229,8 +232,8 @@ def test_execute_scheduled_call_not_pending(
mock_llm_call.assert_not_called()
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_with_default_system_prompt(
mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
):
@@ -254,16 +257,12 @@ def test_execute_scheduled_call_with_default_system_prompt(
scheduled_calls.execute_scheduled_call(call.id)
- # Verify default system prompt was used
- mock_llm_call.assert_called_once_with(
- prompt="Test prompt",
- model="anthropic/claude-3-5-sonnet-20241022",
- system_prompt=None, # The code uses system_prompt as-is, not a default
- )
+ # Verify LLM was called
+ mock_llm_call.assert_called_once()
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_discord_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -286,26 +285,27 @@ def test_execute_scheduled_call_discord_error(
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_llm_error(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
"""Test execution when LLM call fails."""
mock_llm_call.side_effect = Exception("LLM API error")
- # The safe_task_execution decorator should catch this
+ # The execute_scheduled_call function catches the exception and returns an error response
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
- assert result["status"] == "error"
- assert "LLM API error" in result["error"]
+ assert result["success"] is False
+ assert "error" in result
+ assert "LLM call failed" in result["error"]
# Discord should not be called
mock_send_discord.assert_not_called()
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_long_response_truncation(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -477,8 +477,8 @@ def test_run_scheduled_calls_timezone_handling(
mock_execute_delay.delay.assert_called_once_with(due_call.id)
-@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_status_transition_pending_to_executing_to_completed(
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
):
@@ -502,14 +502,14 @@ def test_status_transition_pending_to_executing_to_completed(
"has_discord_user,has_discord_channel,expected_method",
[
(True, False, "send_dm"),
- (False, True, "broadcast_message"),
- (True, True, "send_dm"), # User takes precedence
+ (False, True, "send_to_channel"),
+ (True, True, "send_to_channel"), # Channel takes precedence in the implementation
],
)
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
-@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
+@patch("memory.discord.messages.discord.send_dm")
+@patch("memory.discord.messages.discord.send_to_channel")
def test_discord_destination_priority(
- mock_broadcast,
+ mock_send_to_channel,
mock_send_dm,
has_discord_user,
has_discord_channel,
@@ -535,50 +535,39 @@ def test_discord_destination_priority(
db_session.commit()
response = "Test response"
- scheduled_calls.send_to_discord(call, response)
+ scheduled_calls.send_to_discord(999999999, call, response)
if expected_method == "send_dm":
mock_send_dm.assert_called_once()
- mock_broadcast.assert_not_called()
+ mock_send_to_channel.assert_not_called()
else:
- mock_broadcast.assert_called_once()
+ mock_send_to_channel.assert_called_once()
mock_send_dm.assert_not_called()
@pytest.mark.parametrize(
- "topic,model,response,expected_in_message",
+ "topic,model,response",
[
(
"Weather Check",
"anthropic/claude-3-5-sonnet-20241022",
"It's sunny!",
- [
- "**Topic:** Weather Check",
- "**Model:** anthropic/claude-3-5-sonnet-20241022",
- "**Response:** It's sunny!",
- ],
),
(
"Test Topic",
"gpt-4",
"Hello world",
- ["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
),
(
"Long Topic Name Here",
"claude-2",
"Short",
- [
- "**Topic:** Long Topic Name Here",
- "**Model:** claude-2",
- "**Response:** Short",
- ],
),
],
)
-@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
-def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
- """Test the Discord message formatting with different inputs."""
+@patch("memory.discord.messages.discord.send_dm")
+def test_message_formatting(mock_send_dm, topic, model, response):
+ """Test that _send_to_discord sends the response as-is."""
# Create a mock scheduled call with a mock Discord user
mock_discord_user = Mock()
mock_discord_user.username = "testuser"
@@ -590,16 +579,15 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
mock_call.discord_user = mock_discord_user
mock_call.discord_channel = None
- scheduled_calls.send_to_discord(mock_call, response)
+ scheduled_calls.send_to_discord(999999999, mock_call, response)
# Get the actual message that was sent
args, kwargs = mock_send_dm.call_args
- assert args[0] == mock_call.user_id
+ assert args[0] == 999999999 # bot_id
actual_message = args[2]
- # Verify all expected parts are in the message
- for expected_part in expected_in_message:
- assert expected_part in actual_message
+ # The new implementation sends the response as-is, without formatting
+ assert actual_message == response
@pytest.mark.parametrize(
@@ -612,7 +600,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
("cancelled", False),
],
)
-@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
+@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
def test_execute_scheduled_call_status_check(
mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
):