From f042f9aed847930ba21e1f667e5405b22adb75dd Mon Sep 17 00:00:00 2001 From: mruwnik Date: Mon, 29 Dec 2025 14:07:12 +0000 Subject: [PATCH] proactive stuff --- .../20251224_160000_add_proactive_checkins.py | 45 ++ requirements/requirements-api.txt | 3 +- requirements/requirements-common.txt | 7 +- requirements/requirements-dev.txt | 6 +- src/memory/api/MCP/__init__.py | 24 +- src/memory/api/MCP/base.py | 121 ++-- src/memory/api/MCP/manifest.py | 119 ---- src/memory/api/MCP/metadata.py | 119 ---- src/memory/api/MCP/oauth_provider.py | 44 +- src/memory/api/MCP/servers/__init__.py | 17 + src/memory/api/MCP/{ => servers}/books.py | 10 +- .../api/MCP/{memory.py => servers/core.py} | 63 +- src/memory/api/MCP/{ => servers}/github.py | 59 +- src/memory/api/MCP/servers/meta.py | 277 +++++++++ src/memory/api/MCP/{ => servers}/people.py | 30 +- .../MCP/{schedules.py => servers/schedule.py} | 37 +- src/memory/api/MCP/tools.py | 74 --- src/memory/api/app.py | 53 +- src/memory/common/celery_app.py | 10 + src/memory/common/db/models/discord.py | 20 +- src/memory/common/settings.py | 1 + src/memory/discord/commands.py | 106 ++++ src/memory/workers/ingest.py | 5 + src/memory/workers/tasks/__init__.py | 10 +- src/memory/workers/tasks/proactive.py | 341 +++++++++++ tests/integration/test_real_queries.py | 4 +- .../memory/api/search/test_query_analysis.py | 5 +- tests/memory/api/test_auth.py | 16 +- .../common/llms/tools/test_discord_tools.py | 98 ++-- tests/memory/common/test_discord.py | 2 +- .../memory/common/test_discord_integration.py | 2 +- tests/memory/discord_tests/test_commands.py | 205 +++++++ tests/memory/discord_tests/test_mcp.py | 9 +- tests/memory/discord_tests/test_messages.py | 53 +- .../workers/tasks/test_discord_tasks.py | 50 +- tests/memory/workers/tasks/test_proactive.py | 536 ++++++++++++++++++ .../workers/tasks/test_scheduled_calls.py | 148 +++-- 37 files changed, 2004 insertions(+), 725 deletions(-) create mode 100644 db/migrations/versions/20251224_160000_add_proactive_checkins.py delete mode 100644 src/memory/api/MCP/manifest.py delete mode 100644 src/memory/api/MCP/metadata.py create mode 100644 src/memory/api/MCP/servers/__init__.py rename src/memory/api/MCP/{ => servers}/books.py (92%) rename src/memory/api/MCP/{memory.py => servers/core.py} (91%) rename src/memory/api/MCP/{ => servers}/github.py (92%) create mode 100644 src/memory/api/MCP/servers/meta.py rename src/memory/api/MCP/{ => servers}/people.py (95%) rename src/memory/api/MCP/{schedules.py => servers/schedule.py} (89%) delete mode 100644 src/memory/api/MCP/tools.py create mode 100644 src/memory/workers/tasks/proactive.py create mode 100644 tests/memory/workers/tasks/test_proactive.py diff --git a/db/migrations/versions/20251224_160000_add_proactive_checkins.py b/db/migrations/versions/20251224_160000_add_proactive_checkins.py new file mode 100644 index 0000000..e8a8a5d --- /dev/null +++ b/db/migrations/versions/20251224_160000_add_proactive_checkins.py @@ -0,0 +1,45 @@ +"""Add proactive check-in fields to Discord entities + +Revision ID: e1f2a3b4c5d6 +Revises: d0e1f2a3b4c5 +Create Date: 2025-12-24 16:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "e1f2a3b4c5d6" +down_revision: Union[str, None] = "d0e1f2a3b4c5" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add proactive fields to all MessageProcessor tables + for table in ["discord_servers", "discord_channels", "discord_users"]: + op.add_column( + table, + sa.Column("proactive_cron", sa.Text(), nullable=True), + ) + op.add_column( + table, + sa.Column("proactive_prompt", sa.Text(), nullable=True), + ) + op.add_column( + table, + sa.Column( + "last_proactive_at", sa.DateTime(timezone=True), nullable=True + ), + ) + + +def downgrade() -> None: + for table in ["discord_servers", "discord_channels", "discord_users"]: + op.drop_column(table, "last_proactive_at") + op.drop_column(table, "proactive_prompt") + op.drop_column(table, "proactive_cron") diff --git a/requirements/requirements-api.txt b/requirements/requirements-api.txt index 3e1a402..0fc683f 100644 --- a/requirements/requirements-api.txt +++ b/requirements/requirements-api.txt @@ -1,7 +1,8 @@ -fastapi==0.112.2 +fastapi>=0.115.12 uvicorn==0.29.0 python-jose==3.3.0 python-multipart==0.0.9 sqladmin==0.20.1 mcp==1.10.0 +fastmcp>=2.10.0 slowapi==0.1.9 \ No newline at end of file diff --git a/requirements/requirements-common.txt b/requirements/requirements-common.txt index 09674e1..2cbcd4e 100644 --- a/requirements/requirements-common.txt +++ b/requirements/requirements-common.txt @@ -1,14 +1,15 @@ sqlalchemy==2.0.30 psycopg2-binary==2.9.9 -pydantic==2.7.2 +pydantic>=2.11.7 alembic==1.13.1 dotenv==0.9.9 voyageai==0.3.2 qdrant-client==1.9.0 anthropic==0.69.0 openai==2.3.0 -# Pin the httpx version, as newer versions break the anthropic client -httpx==0.27.0 +# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1) +httpx>=0.28.1 celery[redis,sqs]==5.3.6 +croniter==2.0.1 cryptography==43.0.0 bcrypt==4.1.2 \ No newline at end of file diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 445b8ba..2a89bfb 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,7 +1,9 @@ pytest==7.4.4 pytest-cov==4.1.0 +pytest-asyncio==0.23.0 black==23.12.1 mypy==1.8.0 -isort==5.13.2 +isort==5.13.2 testcontainers[qdrant]==4.10.0 -click==8.1.7 \ No newline at end of file +click==8.1.7 +croniter==2.0.1 \ No newline at end of file diff --git a/src/memory/api/MCP/__init__.py b/src/memory/api/MCP/__init__.py index 07f0efa..540920d 100644 --- a/src/memory/api/MCP/__init__.py +++ b/src/memory/api/MCP/__init__.py @@ -1,8 +1,16 @@ -import memory.api.MCP.tools -import memory.api.MCP.memory -import memory.api.MCP.metadata -import memory.api.MCP.schedules -import memory.api.MCP.books -import memory.api.MCP.manifest -import memory.api.MCP.github -import memory.api.MCP.people +""" +MCP server with composed subservers. + +Subservers are mounted with prefixes: +- core: search_knowledge_base, observe, search_observations, create_note, note_files, fetch_file +- github: list_github_issues, search_github_issues, github_issue_details, github_work_summary, github_repo_overview +- people: add_person, update_person_info, get_person, list_people, delete_person +- schedule: schedule_message, list_scheduled_llm_calls, cancel_scheduled_llm_call +- books: all_books, read_book +- meta: get_metadata_schemas, get_all_tags, get_all_subjects, get_all_observation_types, get_current_time, get_authenticated_user, get_forecasts +""" + +# Import base to trigger subserver mounting +from memory.api.MCP.base import mcp, get_current_user + +__all__ = ["mcp", "get_current_user"] diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py index 92c4ad0..719c4e0 100644 --- a/src/memory/api/MCP/base.py +++ b/src/memory/api/MCP/base.py @@ -1,60 +1,28 @@ import logging import pathlib -from typing import cast -from mcp.server.auth.handlers.authorize import AuthorizationRequest -from mcp.server.auth.handlers.token import ( - AuthorizationCodeRequest, - RefreshTokenRequest, - TokenRequest, -) -from mcp.server.auth.middleware.auth_context import get_access_token -from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions -from mcp.server.fastmcp import FastMCP -from mcp.shared.auth import OAuthClientMetadata -from pydantic import AnyHttpUrl +from fastmcp import FastMCP +from fastmcp.server.dependencies import get_access_token from starlette.requests import Request from starlette.responses import JSONResponse, RedirectResponse from starlette.templating import Jinja2Templates -from memory.api.MCP.oauth_provider import ( - ALLOWED_SCOPES, - BASE_SCOPES, - SimpleOAuthProvider, -) +from memory.api.MCP.oauth_provider import SimpleOAuthProvider +from memory.api.MCP.servers.books import books_mcp +from memory.api.MCP.servers.core import core_mcp +from memory.api.MCP.servers.github import github_mcp +from memory.api.MCP.servers.meta import meta_mcp +from memory.api.MCP.servers.meta import set_auth_provider as set_meta_auth +from memory.api.MCP.servers.people import people_mcp +from memory.api.MCP.servers.schedule import schedule_mcp +from memory.api.MCP.servers.schedule import set_auth_provider as set_schedule_auth from memory.common import settings -from memory.common.db.connection import make_session +from memory.common.db.connection import make_session, get_engine from memory.common.db.models import OAuthState, UserSession from memory.common.db.models.users import HumanUser logger = logging.getLogger(__name__) - - -def validate_metadata(klass: type): - orig_validate = klass.model_validate - - def validate(data: dict): - data = dict(data) - if "redirect_uris" in data: - data["redirect_uris"] = [ - str(uri).replace("cursor://", "http://") - for uri in data["redirect_uris"] - ] - if "redirect_uri" in data: - data["redirect_uri"] = str(data["redirect_uri"]).replace( - "cursor://", "http://" - ) - - return orig_validate(data) - - klass.model_validate = validate - - -validate_metadata(OAuthClientMetadata) -validate_metadata(AuthorizationRequest) -validate_metadata(AuthorizationCodeRequest) -validate_metadata(RefreshTokenRequest) -validate_metadata(TokenRequest) +engine = get_engine() # Setup templates @@ -63,22 +31,10 @@ templates = Jinja2Templates(directory=template_dir) oauth_provider = SimpleOAuthProvider() -auth_settings = AuthSettings( - issuer_url=cast(AnyHttpUrl, settings.SERVER_URL), - resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore - client_registration_options=ClientRegistrationOptions( - enabled=True, - valid_scopes=ALLOWED_SCOPES, - default_scopes=BASE_SCOPES, - ), - required_scopes=BASE_SCOPES, -) mcp = FastMCP( "memory", - stateless_http=True, - auth_server_provider=oauth_provider, - auth=auth_settings, + auth=oauth_provider, ) @@ -162,3 +118,52 @@ def get_current_user() -> dict: "client_id": access_token.client_id, "user": user_info, } + + +@mcp.custom_route("/health", methods=["GET"]) +async def health_check(request: Request): + """Health check endpoint that verifies all dependencies are accessible.""" + from sqlalchemy import text + + checks = {"mcp_oauth": "enabled"} + all_healthy = True + + # Check database connection + try: + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + checks["database"] = "healthy" + except Exception as e: + logger.error(f"Database health check failed: {e}") + checks["database"] = "unhealthy" + all_healthy = False + + # Check Qdrant connection + try: + from memory.common.qdrant import get_qdrant_client + + client = get_qdrant_client() + client.get_collections() + checks["qdrant"] = "healthy" + except Exception as e: + logger.error(f"Qdrant health check failed: {e}") + checks["qdrant"] = "unhealthy" + all_healthy = False + + checks["status"] = "healthy" if all_healthy else "degraded" + status_code = 200 if all_healthy else 503 + return JSONResponse(checks, status_code=status_code) + + +# Inject auth provider into subservers that need it +set_schedule_auth(get_current_user) +set_meta_auth(get_current_user) + +# Mount all subservers onto the main MCP server +# Tools will be prefixed with their server name (e.g., core_search_knowledge_base) +mcp.mount(core_mcp, prefix="core") +mcp.mount(github_mcp, prefix="github") +mcp.mount(people_mcp, prefix="people") +mcp.mount(schedule_mcp, prefix="schedule") +mcp.mount(books_mcp, prefix="books") +mcp.mount(meta_mcp, prefix="meta") diff --git a/src/memory/api/MCP/manifest.py b/src/memory/api/MCP/manifest.py deleted file mode 100644 index 6dcd76a..0000000 --- a/src/memory/api/MCP/manifest.py +++ /dev/null @@ -1,119 +0,0 @@ -import asyncio -import aiohttp -from datetime import datetime - -from typing import TypedDict, NotRequired, Literal -from memory.api.MCP.tools import mcp - - -class BinaryProbs(TypedDict): - prob: float - - -class MultiProbs(TypedDict): - answerProbs: dict[str, float] - - -Probs = dict[str, BinaryProbs | MultiProbs] -OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"] - - -class MarketAnswer(TypedDict): - id: str - text: str - resolutionProbability: float - - -class MarketDetails(TypedDict): - id: str - createdTime: int - question: str - outcomeType: OutcomeType - textDescription: str - groupSlugs: list[str] - volume: float - isResolved: bool - answers: list[MarketAnswer] - - -class Market(TypedDict): - id: str - url: str - question: str - volume: int - createdTime: int - outcomeType: OutcomeType - createdAt: NotRequired[str] - description: NotRequired[str] - answers: NotRequired[dict[str, float]] - probability: NotRequired[float] - details: NotRequired[MarketDetails] - - -async def get_details(session: aiohttp.ClientSession, market_id: str): - async with session.get( - f"https://api.manifold.markets/v0/market/{market_id}" - ) as resp: - resp.raise_for_status() - return await resp.json() - - -async def format_market(session: aiohttp.ClientSession, market: Market): - if market.get("outcomeType") != "BINARY": - details = await get_details(session, market["id"]) - market["answers"] = { - answer["text"]: round( - answer.get("resolutionProbability") or answer.get("probability") or 0, 3 - ) - for answer in details["answers"] - } - if creationTime := market.get("createdTime"): - market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat() - - fields = [ - "id", - "name", - "url", - "question", - "volume", - "createdAt", - "details", - "probability", - "answers", - ] - return {k: v for k, v in market.items() if k in fields} - - -async def search_markets(term: str, min_volume: int = 1000, binary: bool = False): - async with aiohttp.ClientSession() as session: - async with session.get( - "https://api.manifold.markets/v0/search-markets", - params={ - "term": term, - "contractType": "BINARY" if binary else "ALL", - }, - ) as resp: - resp.raise_for_status() - markets = await resp.json() - - return await asyncio.gather( - *[ - format_market(session, market) - for market in markets - if market.get("volume", 0) >= min_volume - ] - ) - - -@mcp.tool() -async def get_forecasts( - term: str, min_volume: int = 1000, binary: bool = False -) -> list[dict]: - """Get prediction market forecasts for a given term. - - Args: - term: The term to search for. - min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold. - binary: Whether to only return binary markets. - """ - return await search_markets(term, min_volume, binary) diff --git a/src/memory/api/MCP/metadata.py b/src/memory/api/MCP/metadata.py deleted file mode 100644 index ce54bec..0000000 --- a/src/memory/api/MCP/metadata.py +++ /dev/null @@ -1,119 +0,0 @@ -import logging -from collections import defaultdict -from typing import Annotated, TypedDict, get_args, get_type_hints - -from memory.common import qdrant -from sqlalchemy import func - -from memory.api.MCP.tools import mcp -from memory.common.db.connection import make_session -from memory.common.db.models import SourceItem -from memory.common.db.models.source_items import AgentObservation - -logger = logging.getLogger(__name__) - - -class SchemaArg(TypedDict): - type: str | None - description: str | None - - -class CollectionMetadata(TypedDict): - schema: dict[str, SchemaArg] - size: int - - -def from_annotation(annotation: Annotated) -> SchemaArg | None: - try: - type_, description = get_args(annotation) - type_str = str(type_) - if type_str.startswith("typing."): - type_str = type_str[7:] - elif len((parts := type_str.split("'"))) > 1: - type_str = parts[1] - return SchemaArg(type=type_str, description=description) - except IndexError: - logger.error(f"Error from annotation: {annotation}") - return None - - -def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]: - if not hasattr(klass, "as_payload"): - return {} - - if not (payload_type := get_type_hints(klass.as_payload).get("return")): - return {} - - return { - name: schema - for name, arg in payload_type.__annotations__.items() - if (schema := from_annotation(arg)) - } - - -@mcp.tool() -async def get_metadata_schemas() -> dict[str, CollectionMetadata]: - """Get the metadata schema for each collection used in the knowledge base. - - These schemas can be used to filter the knowledge base. - - Returns: A mapping of collection names to their metadata schemas with field types and descriptions. - - Example: - ``` - { - "mail": {"subject": {"type": "str", "description": "The subject of the email."}}, - "chat": {"subject": {"type": "str", "description": "The subject of the chat message."}} - } - """ - client = qdrant.get_qdrant_client() - sizes = qdrant.get_collection_sizes(client) - schemas = defaultdict(dict) - for klass in SourceItem.__subclasses__(): - for collection in klass.get_collections(): - schemas[collection].update(get_schema(klass)) - - return { - collection: CollectionMetadata(schema=schema, size=size) - for collection, schema in schemas.items() - if (size := sizes.get(collection)) - } - - -@mcp.tool() -async def get_all_tags() -> list[str]: - """Get all unique tags used across the entire knowledge base. - - Returns sorted list of tags from both observations and content. - """ - with make_session() as session: - tags_query = session.query(func.unnest(SourceItem.tags)).distinct() - return sorted({row[0] for row in tags_query if row[0] is not None}) - - -@mcp.tool() -async def get_all_subjects() -> list[str]: - """Get all unique subjects from observations about the user. - - Returns sorted list of subject identifiers used in observations. - """ - with make_session() as session: - return sorted( - r.subject for r in session.query(AgentObservation.subject).distinct() - ) - - -@mcp.tool() -async def get_all_observation_types() -> list[str]: - """Get all observation types that have been used. - - Standard types are belief, preference, behavior, contradiction, general, but there can be more. - """ - with make_session() as session: - return sorted( - { - r.observation_type - for r in session.query(AgentObservation.observation_type).distinct() - if r.observation_type is not None - } - ) diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py index 3f9ac9f..db5b6f4 100644 --- a/src/memory/api/MCP/oauth_provider.py +++ b/src/memory/api/MCP/oauth_provider.py @@ -5,11 +5,12 @@ from datetime import datetime, timezone from typing import Any, Optional, cast from urllib.parse import parse_qs, urlencode, urlparse, urlunparse +from fastmcp.server.auth import OAuthProvider +from fastmcp.server.auth.auth import AccessToken as FastMCPAccessToken from mcp.server.auth.provider import ( AccessToken, AuthorizationCode, AuthorizationParams, - OAuthAuthorizationServerProvider, RefreshToken, ) from mcp.shared.auth import OAuthClientInformationFull, OAuthToken @@ -133,8 +134,45 @@ def make_token( ) -class SimpleOAuthProvider(OAuthAuthorizationServerProvider): - async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: +class SimpleOAuthProvider(OAuthProvider): + """OAuth provider that extends fastmcp's OAuthProvider with custom login flow.""" + + def __init__(self): + super().__init__( + base_url=settings.SERVER_URL, + issuer_url=settings.SERVER_URL, + client_registration_options=None, # We handle registration ourselves + required_scopes=BASE_SCOPES, + ) + + async def verify_token(self, token: str) -> FastMCPAccessToken | None: + """Verify an access token and return token info if valid.""" + with make_session() as session: + # Try as OAuth access token first + user_session = session.query(UserSession).get(token) + if user_session: + now = datetime.now(timezone.utc).replace(tzinfo=None) + if user_session.expires_at < now: + return None + return FastMCPAccessToken( + token=token, + client_id=user_session.oauth_state.client_id, + scopes=user_session.oauth_state.scopes, + ) + + # Try as bot API key + bot = session.query(User).filter(User.api_key == token).first() + if bot: + logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key") + return FastMCPAccessToken( + token=token, + client_id=cast(str, bot.name or bot.email), + scopes=["read", "write"], + ) + + return None + + def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """Get OAuth client information.""" with make_session() as session: client = session.get(OAuthClientInformation, client_id) diff --git a/src/memory/api/MCP/servers/__init__.py b/src/memory/api/MCP/servers/__init__.py new file mode 100644 index 0000000..68122c0 --- /dev/null +++ b/src/memory/api/MCP/servers/__init__.py @@ -0,0 +1,17 @@ +"""MCP subservers for composable tool organization.""" + +from memory.api.MCP.servers.core import core_mcp +from memory.api.MCP.servers.github import github_mcp +from memory.api.MCP.servers.people import people_mcp +from memory.api.MCP.servers.schedule import schedule_mcp +from memory.api.MCP.servers.books import books_mcp +from memory.api.MCP.servers.meta import meta_mcp + +__all__ = [ + "core_mcp", + "github_mcp", + "people_mcp", + "schedule_mcp", + "books_mcp", + "meta_mcp", +] diff --git a/src/memory/api/MCP/books.py b/src/memory/api/MCP/servers/books.py similarity index 92% rename from src/memory/api/MCP/books.py rename to src/memory/api/MCP/servers/books.py index ad102ef..d5989ff 100644 --- a/src/memory/api/MCP/books.py +++ b/src/memory/api/MCP/servers/books.py @@ -1,15 +1,19 @@ +"""MCP subserver for ebook access.""" + import logging +from fastmcp import FastMCP from sqlalchemy.orm import joinedload -from memory.api.MCP.tools import mcp from memory.common.db.connection import make_session from memory.common.db.models import Book, BookSection, BookSectionPayload logger = logging.getLogger(__name__) +books_mcp = FastMCP("memory-books") -@mcp.tool() + +@books_mcp.tool() async def all_books(sections: bool = False) -> list[dict]: """ Get all books in the database. @@ -31,7 +35,7 @@ async def all_books(sections: bool = False) -> list[dict]: return [book.as_payload(sections=sections) for book in books] -@mcp.tool() +@books_mcp.tool() def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]: """ Read a book from the database. diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/servers/core.py similarity index 91% rename from src/memory/api/MCP/memory.py rename to src/memory/api/MCP/servers/core.py index 8ea3fd0..f73f251 100644 --- a/src/memory/api/MCP/memory.py +++ b/src/memory/api/MCP/servers/core.py @@ -1,53 +1,45 @@ """ -MCP tools for the epistemic sparring partner system. +Core MCP subserver for knowledge base search, observations, and notes. """ import base64 import logging import pathlib from datetime import datetime, timezone -from PIL import Image +from fastmcp import FastMCP +from PIL import Image from pydantic import BaseModel from sqlalchemy import Text from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY -from memory.api.MCP.base import mcp from memory.api.search.search import search -from memory.api.search.types import SearchFilters, SearchConfig +from memory.api.search.types import SearchConfig, SearchFilters from memory.common import extract, settings from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION from memory.common.celery_app import app as celery_app from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS from memory.common.db.connection import make_session -from memory.common.db.models import SourceItem, AgentObservation +from memory.common.db.models import AgentObservation, SourceItem from memory.common.formatters import observation logger = logging.getLogger(__name__) +core_mcp = FastMCP("memory-core") + def validate_path_within_directory( base_dir: pathlib.Path, requested_path: str ) -> pathlib.Path: - """Validate that a requested path resolves within the base directory. - - Prevents path traversal attacks using ../ or similar techniques. - - Args: - base_dir: The allowed base directory - requested_path: The user-provided path - - Returns: - The resolved absolute path if valid - - Raises: - ValueError: If the path would escape the base directory - """ + """Validate that a requested path resolves within the base directory.""" resolved = (base_dir / requested_path.lstrip("/")).resolve() base_resolved = base_dir.resolve() - if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved: + if ( + not str(resolved).startswith(str(base_resolved) + "/") + and resolved != base_resolved + ): raise ValueError(f"Path escapes allowed directory: {requested_path}") return resolved @@ -63,7 +55,6 @@ def filter_observation_source_ids( items_query = session.query(AgentObservation.id) if tags: - # Use PostgreSQL array overlap operator with proper array casting items_query = items_query.filter( AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))), ) @@ -89,7 +80,6 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] items_query = session.query(SourceItem.id) if tags: - # Use PostgreSQL array overlap operator with proper array casting items_query = items_query.filter( SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))), ) @@ -102,7 +92,7 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] return source_ids -@mcp.tool() +@core_mcp.tool() async def search_knowledge_base( query: str, filters: SearchFilters = {}, @@ -141,7 +131,6 @@ async def search_knowledge_base( if not modalities: modalities = set(ALL_COLLECTIONS.keys()) - # Filter to valid collections, excluding observation collections modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS search_filters = SearchFilters(**filters) @@ -167,7 +156,7 @@ class RawObservation(BaseModel): tags: list[str] = [] -@mcp.tool() +@core_mcp.tool() async def observe( observations: list[RawObservation], session_id: str | None = None, @@ -212,23 +201,23 @@ async def observe( logger.info("MCP: Observing") tasks = [ ( - observation, + obs, celery_app.send_task( SYNC_OBSERVATION, queue=f"{settings.CELERY_QUEUE_PREFIX}-notes", kwargs={ - "subject": observation.subject, - "content": observation.content, - "observation_type": observation.observation_type, - "confidences": observation.confidences, - "evidence": observation.evidence, - "tags": observation.tags, + "subject": obs.subject, + "content": obs.content, + "observation_type": obs.observation_type, + "confidences": obs.confidences, + "evidence": obs.evidence, + "tags": obs.tags, "session_id": session_id, "agent_model": agent_model, }, ), ) - for observation in observations + for obs in observations ] def short_content(obs: RawObservation) -> str: @@ -242,7 +231,7 @@ async def observe( } -@mcp.tool() +@core_mcp.tool() async def search_observations( query: str, subject: str = "", @@ -308,7 +297,7 @@ async def search_observations( ] -@mcp.tool() +@core_mcp.tool() async def create_note( subject: str, content: str, @@ -362,7 +351,7 @@ async def create_note( } -@mcp.tool() +@core_mcp.tool() async def note_files(path: str = "/"): """ List note files in the user's note storage. @@ -386,7 +375,7 @@ async def note_files(path: str = "/"): ] -@mcp.tool() +@core_mcp.tool() def fetch_file(filename: str) -> dict: """ Read file content with automatic type detection. diff --git a/src/memory/api/MCP/github.py b/src/memory/api/MCP/servers/github.py similarity index 92% rename from src/memory/api/MCP/github.py rename to src/memory/api/MCP/servers/github.py index 859c224..624c49f 100644 --- a/src/memory/api/MCP/github.py +++ b/src/memory/api/MCP/servers/github.py @@ -1,14 +1,14 @@ -"""MCP tools for GitHub issue tracking and management.""" +"""MCP subserver for GitHub issue tracking and management.""" import logging from datetime import datetime, timezone from typing import Any -from sqlalchemy import Text, case, desc, func, asc +from fastmcp import FastMCP +from sqlalchemy import Text, case, desc, func from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY -from memory.api.MCP.base import mcp from memory.api.search.search import search from memory.api.search.types import SearchConfig, SearchFilters from memory.common import extract @@ -17,6 +17,8 @@ from memory.common.db.models import GithubItem logger = logging.getLogger(__name__) +github_mcp = FastMCP("memory-github") + def _build_github_url(repo_path: str, number: int | None, kind: str) -> str: """Build GitHub URL from repo path and issue/PR number.""" @@ -53,7 +55,6 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st } if include_content: result["content"] = item.content - # Include PR-specific data if available if item.kind == "pr" and item.pr_data: result["pr_data"] = { "additions": item.pr_data.additions, @@ -62,12 +63,12 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st "files": item.pr_data.files, "reviews": item.pr_data.reviews, "review_comments": item.pr_data.review_comments, - "diff": item.pr_data.diff, # Decompressed via property + "diff": item.pr_data.diff, } return result -@mcp.tool() +@github_mcp.tool() async def list_github_issues( repo: str | None = None, assignee: str | None = None, @@ -109,64 +110,48 @@ async def list_github_issues( with make_session() as session: query = session.query(GithubItem) - # Apply filters if repo: query = query.filter(GithubItem.repo_path == repo) - if assignee: query = query.filter(GithubItem.assignees.any(assignee)) - if author: query = query.filter(GithubItem.author == author) - if state: query = query.filter(GithubItem.state == state) - if kind: query = query.filter(GithubItem.kind == kind) else: - # Exclude comments by default, only show issues and PRs query = query.filter(GithubItem.kind.in_(["issue", "pr"])) - if labels: - # Match any label in the list using PostgreSQL array overlap query = query.filter( GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text))) ) - if project_status: query = query.filter(GithubItem.project_status == project_status) - if project_field: for key, value in project_field.items(): - query = query.filter( - GithubItem.project_fields[key].astext == value - ) - + query = query.filter(GithubItem.project_fields[key].astext == value) if updated_since: since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00")) query = query.filter(GithubItem.github_updated_at >= since_dt) - if updated_before: before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00")) query = query.filter(GithubItem.github_updated_at <= before_dt) - # Apply ordering if order_by == "created": query = query.order_by(desc(GithubItem.created_at)) elif order_by == "number": query = query.order_by(desc(GithubItem.number)) - else: # default: updated + else: query = query.order_by(desc(GithubItem.github_updated_at)) query = query.limit(limit) - items = query.all() return [_serialize_issue(item) for item in items] -@mcp.tool() +@github_mcp.tool() async def search_github_issues( query: str, repo: str | None = None, @@ -191,7 +176,6 @@ async def search_github_issues( limit = min(limit, 100) - # Pre-filter source_ids if repo/state/kind filters are specified source_ids = None if repo or state or kind: with make_session() as session: @@ -206,7 +190,6 @@ async def search_github_issues( q = q.filter(GithubItem.kind.in_(["issue", "pr"])) source_ids = [item.id for item in q.all()] - # Use the existing search infrastructure data = extract.extract_text(query, skip_summary=True) config = SearchConfig(limit=limit, previews=True) filters = SearchFilters() @@ -220,7 +203,6 @@ async def search_github_issues( config=config, ) - # Fetch full issue details for the results output = [] with make_session() as session: for result in results: @@ -233,7 +215,7 @@ async def search_github_issues( return output -@mcp.tool() +@github_mcp.tool() async def github_issue_details( repo: str, number: int, @@ -267,7 +249,7 @@ async def github_issue_details( return _serialize_issue(item, include_content=True) -@mcp.tool() +@github_mcp.tool() async def github_work_summary( since: str, until: str | None = None, @@ -294,7 +276,6 @@ async def github_work_summary( else: until_dt = datetime.now(timezone.utc) - # Map group_by to SQL expression group_mappings = { "client": GithubItem.project_fields["EquiStamp.Client"].astext, "status": GithubItem.project_status, @@ -311,7 +292,6 @@ async def github_work_summary( group_col = group_mappings[group_by] with make_session() as session: - # Build base query for the period base_query = session.query(GithubItem).filter( GithubItem.github_updated_at >= since_dt, GithubItem.github_updated_at <= until_dt, @@ -321,7 +301,6 @@ async def github_work_summary( if repo: base_query = base_query.filter(GithubItem.repo_path == repo) - # Get aggregated counts by group agg_query = ( session.query( group_col.label("group_name"), @@ -346,7 +325,6 @@ async def github_work_summary( groups = agg_query.all() - # Build summary with sample issues for each group summary = [] total_issues = 0 total_prs = 0 @@ -358,7 +336,6 @@ async def github_work_summary( total_issues += issue_count total_prs += pr_count - # Get sample issues for this group sample_query = base_query.filter(group_col == group_name).limit(5) samples = [ { @@ -394,7 +371,7 @@ async def github_work_summary( } -@mcp.tool() +@github_mcp.tool() async def github_repo_overview( repo: str, ) -> dict: @@ -410,13 +387,6 @@ async def github_repo_overview( logger.info(f"github_repo_overview called: repo={repo}") with make_session() as session: - # Base query for this repo - base_query = session.query(GithubItem).filter( - GithubItem.repo_path == repo, - GithubItem.kind.in_(["issue", "pr"]), - ) - - # Get total counts counts_query = session.query( func.count(GithubItem.id).label("total"), func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"), @@ -441,7 +411,6 @@ async def github_repo_overview( counts = counts_query.first() - # Status breakdown (for project_status) status_query = ( session.query( GithubItem.project_status.label("status"), @@ -458,7 +427,6 @@ async def github_repo_overview( status_breakdown = {row.status: row.count for row in status_query.all()} - # Top assignees (open issues only) assignee_query = ( session.query( func.unnest(GithubItem.assignees).label("assignee"), @@ -479,7 +447,6 @@ async def github_repo_overview( for row in assignee_query.all() ] - # Label counts label_query = ( session.query( func.unnest(GithubItem.labels).label("label"), diff --git a/src/memory/api/MCP/servers/meta.py b/src/memory/api/MCP/servers/meta.py new file mode 100644 index 0000000..c9f361d --- /dev/null +++ b/src/memory/api/MCP/servers/meta.py @@ -0,0 +1,277 @@ +"""MCP subserver for metadata, utilities, and forecasting.""" + +import asyncio +import logging +from collections import defaultdict +from datetime import datetime, timezone +from typing import Annotated, Literal, NotRequired, TypedDict, get_args, get_type_hints + +import aiohttp +from fastmcp import FastMCP +from sqlalchemy import func + +from memory.common import qdrant +from memory.common.db.connection import make_session +from memory.common.db.models import SourceItem +from memory.common.db.models.source_items import AgentObservation + +logger = logging.getLogger(__name__) + +meta_mcp = FastMCP("memory-meta") + +# Auth provider will be injected at mount time +_get_current_user = None + + +def set_auth_provider(get_current_user_func): + """Set the authentication provider function.""" + global _get_current_user + _get_current_user = get_current_user_func + + +def get_current_user() -> dict: + """Get the current authenticated user.""" + if _get_current_user is None: + return {"authenticated": False, "error": "Auth provider not configured"} + return _get_current_user() + + +# --- Metadata tools --- + + +class SchemaArg(TypedDict): + type: str | None + description: str | None + + +class CollectionMetadata(TypedDict): + schema: dict[str, SchemaArg] + size: int + + +def from_annotation(annotation: Annotated) -> SchemaArg | None: + try: + type_, description = get_args(annotation) + type_str = str(type_) + if type_str.startswith("typing."): + type_str = type_str[7:] + elif len((parts := type_str.split("'"))) > 1: + type_str = parts[1] + return SchemaArg(type=type_str, description=description) + except IndexError: + logger.error(f"Error from annotation: {annotation}") + return None + + +def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]: + if not hasattr(klass, "as_payload"): + return {} + + if not (payload_type := get_type_hints(klass.as_payload).get("return")): + return {} + + return { + name: schema + for name, arg in payload_type.__annotations__.items() + if (schema := from_annotation(arg)) + } + + +@meta_mcp.tool() +async def get_metadata_schemas() -> dict[str, CollectionMetadata]: + """Get the metadata schema for each collection used in the knowledge base. + + These schemas can be used to filter the knowledge base. + + Returns: A mapping of collection names to their metadata schemas with field types and descriptions. + + Example: + ``` + { + "mail": {"subject": {"type": "str", "description": "The subject of the email."}}, + "chat": {"subject": {"type": "str", "description": "The subject of the chat message."}} + } + """ + client = qdrant.get_qdrant_client() + sizes = qdrant.get_collection_sizes(client) + schemas = defaultdict(dict) + for klass in SourceItem.__subclasses__(): + for collection in klass.get_collections(): + schemas[collection].update(get_schema(klass)) + + return { + collection: CollectionMetadata(schema=schema, size=size) + for collection, schema in schemas.items() + if (size := sizes.get(collection)) + } + + +@meta_mcp.tool() +async def get_all_tags() -> list[str]: + """Get all unique tags used across the entire knowledge base. + + Returns sorted list of tags from both observations and content. + """ + with make_session() as session: + tags_query = session.query(func.unnest(SourceItem.tags)).distinct() + return sorted({row[0] for row in tags_query if row[0] is not None}) + + +@meta_mcp.tool() +async def get_all_subjects() -> list[str]: + """Get all unique subjects from observations about the user. + + Returns sorted list of subject identifiers used in observations. + """ + with make_session() as session: + return sorted( + r.subject for r in session.query(AgentObservation.subject).distinct() + ) + + +@meta_mcp.tool() +async def get_all_observation_types() -> list[str]: + """Get all observation types that have been used. + + Standard types are belief, preference, behavior, contradiction, general, but there can be more. + """ + with make_session() as session: + return sorted( + { + r.observation_type + for r in session.query(AgentObservation.observation_type).distinct() + if r.observation_type is not None + } + ) + + +# --- Utility tools --- + + +@meta_mcp.tool() +async def get_current_time() -> dict: + """Get the current time in UTC.""" + logger.info("get_current_time tool called") + return {"current_time": datetime.now(timezone.utc).isoformat()} + + +@meta_mcp.tool() +async def get_authenticated_user() -> dict: + """Get information about the authenticated user.""" + return get_current_user() + + +# --- Forecasting tools --- + + +class BinaryProbs(TypedDict): + prob: float + + +class MultiProbs(TypedDict): + answerProbs: dict[str, float] + + +Probs = dict[str, BinaryProbs | MultiProbs] +OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"] + + +class MarketAnswer(TypedDict): + id: str + text: str + resolutionProbability: float + + +class MarketDetails(TypedDict): + id: str + createdTime: int + question: str + outcomeType: OutcomeType + textDescription: str + groupSlugs: list[str] + volume: float + isResolved: bool + answers: list[MarketAnswer] + + +class Market(TypedDict): + id: str + url: str + question: str + volume: int + createdTime: int + outcomeType: OutcomeType + createdAt: NotRequired[str] + description: NotRequired[str] + answers: NotRequired[dict[str, float]] + probability: NotRequired[float] + details: NotRequired[MarketDetails] + + +async def get_details(session: aiohttp.ClientSession, market_id: str): + async with session.get( + f"https://api.manifold.markets/v0/market/{market_id}" + ) as resp: + resp.raise_for_status() + return await resp.json() + + +async def format_market(session: aiohttp.ClientSession, market: Market): + if market.get("outcomeType") != "BINARY": + details = await get_details(session, market["id"]) + market["answers"] = { + answer["text"]: round( + answer.get("resolutionProbability") or answer.get("probability") or 0, 3 + ) + for answer in details["answers"] + } + if creationTime := market.get("createdTime"): + market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat() + + fields = [ + "id", + "name", + "url", + "question", + "volume", + "createdAt", + "details", + "probability", + "answers", + ] + return {k: v for k, v in market.items() if k in fields} + + +async def search_markets(term: str, min_volume: int = 1000, binary: bool = False): + async with aiohttp.ClientSession() as session: + async with session.get( + "https://api.manifold.markets/v0/search-markets", + params={ + "term": term, + "contractType": "BINARY" if binary else "ALL", + }, + ) as resp: + resp.raise_for_status() + markets = await resp.json() + + return await asyncio.gather( + *[ + format_market(session, market) + for market in markets + if market.get("volume", 0) >= min_volume + ] + ) + + +@meta_mcp.tool() +async def get_forecasts( + term: str, min_volume: int = 1000, binary: bool = False +) -> list[dict]: + """Get prediction market forecasts for a given term. + + Args: + term: The term to search for. + min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold. + binary: Whether to only return binary markets. + """ + return await search_markets(term, min_volume, binary) diff --git a/src/memory/api/MCP/people.py b/src/memory/api/MCP/servers/people.py similarity index 95% rename from src/memory/api/MCP/people.py rename to src/memory/api/MCP/servers/people.py index 9982317..b92eea8 100644 --- a/src/memory/api/MCP/people.py +++ b/src/memory/api/MCP/servers/people.py @@ -1,23 +1,23 @@ -""" -MCP tools for tracking people. -""" +"""MCP subserver for tracking people.""" import logging from typing import Any +from fastmcp import FastMCP from sqlalchemy import Text from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY -from memory.api.MCP.base import mcp -from memory.common.db.connection import make_session -from memory.common.db.models import Person +from memory.common import settings from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON from memory.common.celery_app import app as celery_app -from memory.common import settings +from memory.common.db.connection import make_session +from memory.common.db.models import Person logger = logging.getLogger(__name__) +people_mcp = FastMCP("memory-people") + def _person_to_dict(person: Person) -> dict[str, Any]: """Convert a Person model to a dictionary for API responses.""" @@ -32,7 +32,7 @@ def _person_to_dict(person: Person) -> dict[str, Any]: } -@mcp.tool() +@people_mcp.tool() async def add_person( identifier: str, display_name: str, @@ -67,7 +67,6 @@ async def add_person( """ logger.info(f"MCP: Adding person: {identifier}") - # Check if person already exists with make_session() as session: existing = session.query(Person).filter(Person.identifier == identifier).first() if existing: @@ -93,7 +92,7 @@ async def add_person( } -@mcp.tool() +@people_mcp.tool() async def update_person_info( identifier: str, display_name: str | None = None, @@ -135,7 +134,6 @@ async def update_person_info( """ logger.info(f"MCP: Updating person: {identifier}") - # Verify person exists with make_session() as session: person = session.query(Person).filter(Person.identifier == identifier).first() if not person: @@ -162,7 +160,7 @@ async def update_person_info( } -@mcp.tool() +@people_mcp.tool() async def get_person(identifier: str) -> dict | None: """ Get a person by their identifier. @@ -182,7 +180,7 @@ async def get_person(identifier: str) -> dict | None: return _person_to_dict(person) -@mcp.tool() +@people_mcp.tool() async def list_people( tags: list[str] | None = None, search: str | None = None, @@ -207,9 +205,7 @@ async def list_people( query = session.query(Person) if tags: - query = query.filter( - Person.tags.op("&&")(sql_cast(tags, ARRAY(Text))) - ) + query = query.filter(Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))) if search: search_term = f"%{search.lower()}%" @@ -225,7 +221,7 @@ async def list_people( return [_person_to_dict(p) for p in people] -@mcp.tool() +@people_mcp.tool() async def delete_person(identifier: str) -> dict: """ Delete a person by their identifier. diff --git a/src/memory/api/MCP/schedules.py b/src/memory/api/MCP/servers/schedule.py similarity index 89% rename from src/memory/api/MCP/schedules.py rename to src/memory/api/MCP/servers/schedule.py index 26b50ad..a44e9f1 100644 --- a/src/memory/api/MCP/schedules.py +++ b/src/memory/api/MCP/servers/schedule.py @@ -1,20 +1,37 @@ -""" -MCP tools for the epistemic sparring partner system. -""" +"""MCP subserver for scheduling messages and LLM calls.""" import logging from datetime import datetime, timezone from typing import Any, cast -from memory.api.MCP.base import get_current_user, mcp +from fastmcp import FastMCP + from memory.common.db.connection import make_session -from memory.common.db.models import ScheduledLLMCall, DiscordBotUser +from memory.common.db.models import DiscordBotUser, ScheduledLLMCall from memory.discord.messages import schedule_discord_message logger = logging.getLogger(__name__) +schedule_mcp = FastMCP("memory-schedule") -@mcp.tool() +# We need access to get_current_user from base - this will be injected at mount time +_get_current_user = None + + +def set_auth_provider(get_current_user_func): + """Set the authentication provider function.""" + global _get_current_user + _get_current_user = get_current_user_func + + +def get_current_user() -> dict: + """Get the current authenticated user.""" + if _get_current_user is None: + return {"authenticated": False, "error": "Auth provider not configured"} + return _get_current_user() + + +@schedule_mcp.tool() async def schedule_message( scheduled_time: str, message: str, @@ -61,10 +78,8 @@ async def schedule_message( if not discord_user and not discord_channel: raise ValueError("Either discord_user or discord_channel must be provided") - # Parse scheduled time try: scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00")) - # Ensure we store as naive datetime (remove timezone info for database storage) if scheduled_dt.tzinfo is not None: scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None) except ValueError: @@ -98,7 +113,7 @@ async def schedule_message( } -@mcp.tool() +@schedule_mcp.tool() async def list_scheduled_llm_calls( status: str | None = None, limit: int | None = 50 ) -> dict[str, Any]: @@ -143,7 +158,7 @@ async def list_scheduled_llm_calls( } -@mcp.tool() +@schedule_mcp.tool() async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]: """ Cancel a scheduled LLM call. @@ -164,7 +179,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]: return {"error": "User not found", "user": current_user} with make_session() as session: - # Find the scheduled call scheduled_call = ( session.query(ScheduledLLMCall) .filter( @@ -180,7 +194,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]: if not scheduled_call.can_be_cancelled(): return {"error": f"Cannot cancel call with status: {scheduled_call.status}"} - # Update the status scheduled_call.status = "cancelled" session.commit() diff --git a/src/memory/api/MCP/tools.py b/src/memory/api/MCP/tools.py deleted file mode 100644 index b3efaf7..0000000 --- a/src/memory/api/MCP/tools.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -MCP tools for the epistemic sparring partner system. -""" - -import logging -from datetime import datetime, timezone - -from sqlalchemy import Text -from sqlalchemy import cast as sql_cast -from sqlalchemy.dialects.postgresql import ARRAY - -from memory.common.db.connection import make_session -from memory.common.db.models import AgentObservation, SourceItem -from memory.api.MCP.base import mcp, get_current_user - -logger = logging.getLogger(__name__) - - -def filter_observation_source_ids( - tags: list[str] | None = None, observation_types: list[str] | None = None -): - if not tags and not observation_types: - return None - - with make_session() as session: - items_query = session.query(AgentObservation.id) - - if tags: - # Use PostgreSQL array overlap operator with proper array casting - items_query = items_query.filter( - AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))), - ) - if observation_types: - items_query = items_query.filter( - AgentObservation.observation_type.in_(observation_types) - ) - source_ids = [item.id for item in items_query.all()] - - return source_ids - - -def filter_source_ids( - modalities: set[str], - tags: list[str] | None = None, -): - if not tags: - return None - - with make_session() as session: - items_query = session.query(SourceItem.id) - - if tags: - # Use PostgreSQL array overlap operator with proper array casting - items_query = items_query.filter( - SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))), - ) - if modalities: - items_query = items_query.filter(SourceItem.modality.in_(modalities)) - source_ids = [item.id for item in items_query.all()] - - return source_ids - - -@mcp.tool() -async def get_current_time() -> dict: - """Get the current time in UTC.""" - logger.info("get_current_time tool called") - return {"current_time": datetime.now(timezone.utc).isoformat()} - - -@mcp.tool() -async def get_authenticated_user() -> dict: - """Get information about the authenticated user.""" - return get_current_user() diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 959f276..44911a3 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -2,7 +2,6 @@ FastAPI application for the knowledge base. """ -import contextlib import os import logging import mimetypes @@ -35,15 +34,10 @@ limiter = Limiter( enabled=settings.API_RATE_LIMIT_ENABLED, ) +# Create the MCP http app to get its lifespan +mcp_http_app = mcp.http_app(stateless_http=True) -@contextlib.asynccontextmanager -async def lifespan(app: FastAPI): - async with contextlib.AsyncExitStack() as stack: - await stack.enter_async_context(mcp.session_manager.run()) - yield - - -app = FastAPI(title="Knowledge Base API", lifespan=lifespan) +app = FastAPI(title="Knowledge Base API", lifespan=mcp_http_app.lifespan) app.state.limiter = limiter # Rate limit exception handler @@ -158,46 +152,9 @@ app.include_router(auth_router) # Add health check to MCP server instead of main app -@mcp.custom_route("/health", methods=["GET"]) -async def health_check(request: Request): - """Health check endpoint that verifies all dependencies are accessible.""" - from fastapi.responses import JSONResponse - from sqlalchemy import text - - checks = {"mcp_oauth": "enabled"} - all_healthy = True - - # Check database connection - try: - with engine.connect() as conn: - conn.execute(text("SELECT 1")) - checks["database"] = "healthy" - except Exception as e: - # Log error details but don't expose in response - logger.error(f"Database health check failed: {e}") - checks["database"] = "unhealthy" - all_healthy = False - - # Check Qdrant connection - try: - from memory.common.qdrant import get_qdrant_client - - client = get_qdrant_client() - client.get_collections() - checks["qdrant"] = "healthy" - except Exception as e: - # Log error details but don't expose in response - logger.error(f"Qdrant health check failed: {e}") - checks["qdrant"] = "unhealthy" - all_healthy = False - - checks["status"] = "healthy" if all_healthy else "degraded" - status_code = 200 if all_healthy else 503 - return JSONResponse(checks, status_code=status_code) - - # Mount MCP server at root - OAuth endpoints need to be at root level -app.mount("/", mcp.streamable_http_app()) +# Health check is defined in MCP/base.py +app.mount("/", mcp_http_app) def main(reload: bool = False): diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index 4ec849e..46d9622 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -17,6 +17,7 @@ DISCORD_ROOT = "memory.workers.tasks.discord" BACKUP_ROOT = "memory.workers.tasks.backup" GITHUB_ROOT = "memory.workers.tasks.github" PEOPLE_ROOT = "memory.workers.tasks.people" +PROACTIVE_ROOT = "memory.workers.tasks.proactive" ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message" EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message" PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message" @@ -73,6 +74,10 @@ SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person" UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person" SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file" +# Proactive check-in tasks +EVALUATE_PROACTIVE_CHECKINS = f"{PROACTIVE_ROOT}.evaluate_proactive_checkins" +EXECUTE_PROACTIVE_CHECKIN = f"{PROACTIVE_ROOT}.execute_proactive_checkin" + def get_broker_url() -> str: protocol = settings.CELERY_BROKER_TYPE @@ -130,12 +135,17 @@ app.conf.update( f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"}, f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"}, f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"}, + f"{PROACTIVE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"}, }, beat_schedule={ "sync-github-repos-hourly": { "task": SYNC_ALL_GITHUB_REPOS, "schedule": crontab(minute=0), # Every hour at :00 }, + "evaluate-proactive-checkins": { + "task": EVALUATE_PROACTIVE_CHECKINS, + "schedule": crontab(), # Every minute + }, }, ) diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py index 0b5c568..da8d70a 100644 --- a/src/memory/common/db/models/discord.py +++ b/src/memory/common/db/models/discord.py @@ -63,13 +63,29 @@ class MessageProcessor: doc=textwrap.dedent( """ A summary of this processor, made by and for AI systems. - + The idea here is that AI systems can use this summary to keep notes on the given processor. - These should automatically be injected into the context of the messages that are processed by this processor. + These should automatically be injected into the context of the messages that are processed by this processor. """ ), ) + proactive_cron = Column( + Text, + nullable=True, + doc="Cron schedule for proactive check-ins (e.g., '0 9 * * *' for 9am daily). None = disabled.", + ) + proactive_prompt = Column( + Text, + nullable=True, + doc="Custom instructions for proactive check-ins.", + ) + last_proactive_at = Column( + DateTime(timezone=True), + nullable=True, + doc="When the last proactive check-in was sent.", + ) + @property def entity_type(self) -> str: return self.__class__.__tablename__[8:-1] # type: ignore diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 07a4306..d157cf0 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -132,6 +132,7 @@ CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60)) NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60)) LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24)) SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60)) +PROACTIVE_CHECKIN_INTERVAL = int(os.getenv("PROACTIVE_CHECKIN_INTERVAL", 60)) CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24)) diff --git a/src/memory/discord/commands.py b/src/memory/discord/commands.py index 0c432bf..668ad81 100644 --- a/src/memory/discord/commands.py +++ b/src/memory/discord/commands.py @@ -167,6 +167,25 @@ def _create_scope_group( url=url and url.strip(), ) + # Proactive command + @group.command(name="proactive", description=f"Configure {name}'s proactive check-ins") + @discord.app_commands.describe( + cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable", + prompt="Optional custom instructions for check-ins", + ) + async def proactive_cmd( + interaction: discord.Interaction, + cron: str | None = None, + prompt: str | None = None, + ): + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_proactive, + cron=cron and cron.strip(), + prompt=prompt, + ) + return group @@ -265,6 +284,28 @@ def _create_user_scope_group( url=url and url.strip(), ) + # Proactive command + @group.command(name="proactive", description=f"Configure {name}'s proactive check-ins") + @discord.app_commands.describe( + user="Target user", + cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable", + prompt="Optional custom instructions for check-ins", + ) + async def proactive_cmd( + interaction: discord.Interaction, + user: discord.User, + cron: str | None = None, + prompt: str | None = None, + ): + await _run_interaction_command( + interaction, + scope=scope, + handler=handle_proactive, + target_user=user, + cron=cron and cron.strip(), + prompt=prompt, + ) + return group @@ -663,3 +704,68 @@ async def handle_mcp_servers( except Exception as exc: logger.error(f"Error running MCP server command: {exc}", exc_info=True) raise CommandError(f"Error: {exc}") from exc + + +async def handle_proactive( + context: CommandContext, + *, + cron: str | None = None, + prompt: str | None = None, +) -> CommandResponse: + """Handle proactive check-in configuration.""" + from croniter import croniter + + model = context.target + + # If no arguments, show current settings + if cron is None and prompt is None: + current_cron = getattr(model, "proactive_cron", None) + current_prompt = getattr(model, "proactive_prompt", None) + + if not current_cron: + return CommandResponse( + content=f"Proactive check-ins are disabled for {context.display_name}." + ) + + lines = [f"Proactive check-ins for {context.display_name}:"] + lines.append(f" Schedule: `{current_cron}`") + if current_prompt: + lines.append(f" Prompt: {current_prompt}") + return CommandResponse(content="\n".join(lines)) + + # Handle cron setting + if cron is not None: + if cron.lower() == "off": + setattr(model, "proactive_cron", None) + return CommandResponse( + content=f"Proactive check-ins disabled for {context.display_name}." + ) + + # Validate cron expression + try: + croniter(cron) + except (ValueError, KeyError) as e: + raise CommandError( + f"Invalid cron expression: {cron}\n" + "Examples:\n" + " `0 9 * * *` - 9am daily\n" + " `0 9,17 * * 1-5` - 9am and 5pm weekdays\n" + " `0 */4 * * *` - every 4 hours" + ) from e + + setattr(model, "proactive_cron", cron) + + # Handle prompt setting + if prompt is not None: + setattr(model, "proactive_prompt", prompt or None) + + # Build response + current_cron = getattr(model, "proactive_cron", None) + current_prompt = getattr(model, "proactive_prompt", None) + + lines = [f"Updated proactive settings for {context.display_name}:"] + lines.append(f" Schedule: `{current_cron}`") + if current_prompt: + lines.append(f" Prompt: {current_prompt}") + + return CommandResponse(content="\n".join(lines)) diff --git a/src/memory/workers/ingest.py b/src/memory/workers/ingest.py index 6b0b4e8..b18810c 100644 --- a/src/memory/workers/ingest.py +++ b/src/memory/workers/ingest.py @@ -11,6 +11,7 @@ from memory.common.celery_app import ( SYNC_LESSWRONG, RUN_SCHEDULED_CALLS, BACKUP_ALL, + EVALUATE_PROACTIVE_CHECKINS, ) logger = logging.getLogger(__name__) @@ -53,4 +54,8 @@ app.conf.beat_schedule = { "task": BACKUP_ALL, "schedule": settings.S3_BACKUP_INTERVAL, }, + "evaluate-proactive-checkins": { + "task": EVALUATE_PROACTIVE_CHECKINS, + "schedule": settings.PROACTIVE_CHECKIN_INTERVAL, + }, } diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py index 33f22aa..f7c0a97 100644 --- a/src/memory/workers/tasks/__init__.py +++ b/src/memory/workers/tasks/__init__.py @@ -14,20 +14,24 @@ from memory.workers.tasks import ( maintenance, notes, observations, + people, + proactive, scheduled_calls, ) # noqa __all__ = [ "backup", - "email", - "comic", "blogs", - "ebook", + "comic", "discord", + "ebook", + "email", "forums", "github", "maintenance", "notes", "observations", + "people", + "proactive", "scheduled_calls", ] diff --git a/src/memory/workers/tasks/proactive.py b/src/memory/workers/tasks/proactive.py new file mode 100644 index 0000000..689cfc9 --- /dev/null +++ b/src/memory/workers/tasks/proactive.py @@ -0,0 +1,341 @@ +""" +Celery tasks for proactive Discord check-ins. +""" + +import logging +import re +import textwrap +from datetime import datetime, timezone +from typing import Any, Literal, cast + +from croniter import croniter +from sqlalchemy import or_ + +from memory.common import settings +from memory.common.celery_app import app +from memory.common.db.connection import make_session +from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser +from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response +from memory.workers.tasks.content_processing import safe_task_execution + +logger = logging.getLogger(__name__) + +EVALUATE_PROACTIVE_CHECKINS = "memory.workers.tasks.proactive.evaluate_proactive_checkins" +EXECUTE_PROACTIVE_CHECKIN = "memory.workers.tasks.proactive.execute_proactive_checkin" + +EntityType = Literal["user", "channel", "server"] + + +def is_cron_due(cron_expr: str, last_run: datetime | None, now: datetime) -> bool: + """Check if a cron expression is due to run now. + + Uses croniter to determine if the current time falls within the cron's schedule + and enough time has passed since the last run. + """ + try: + cron = croniter(cron_expr, now) + # Get the previous scheduled time from now + prev_run = cron.get_prev(datetime) + # Get the one before that to determine the interval + cron.get_prev(datetime) + prev_prev_run = cron.get_current(datetime) + + # If we haven't run since the last scheduled time, we should run + if last_run is None: + # Never run before - check if current time is within a minute of prev_run + time_since_scheduled = (now - prev_run).total_seconds() + return time_since_scheduled < 120 # Within 2 minutes of scheduled time + + # Make sure last_run is timezone aware + if last_run.tzinfo is None: + last_run = last_run.replace(tzinfo=timezone.utc) + + # We should run if last_run is before the previous scheduled time + return last_run < prev_run + except Exception as e: + logger.warning(f"Invalid cron expression '{cron_expr}': {e}") + return False + + +def get_bot_for_entity( + session, entity_type: EntityType, entity_id: int +) -> DiscordUser | None: + """Get the bot user associated with an entity.""" + from memory.common.db.models import DiscordBotUser, DiscordMessage + + from sqlalchemy.orm import joinedload + + # For servers, find a bot that has sent messages in that server + if entity_type == "server": + # Find bots that have interacted with this server + bot_users = ( + session.query(DiscordUser) + .options(joinedload(DiscordUser.system_user)) + .join(DiscordMessage, DiscordMessage.from_id == DiscordUser.id) + .filter( + DiscordMessage.server_id == entity_id, + DiscordUser.system_user_id.isnot(None), + ) + .distinct() + .all() + ) + # Find one that's actually a bot + for user in bot_users: + if user.system_user and user.system_user.user_type == "discord_bot": + return user + + # For channels, check the server the channel belongs to + if entity_type == "channel": + channel = session.get(DiscordChannel, entity_id) + if channel and channel.server_id: + return get_bot_for_entity(session, "server", channel.server_id) + + # Fallback: use first available bot + bot = ( + session.query(DiscordBotUser) + .options(joinedload(DiscordBotUser.discord_users).joinedload(DiscordUser.system_user)) + .first() + ) + if bot and bot.discord_users: + return bot.discord_users[0] + return None + + +def get_target_user_for_entity( + session, entity_type: EntityType, entity_id: int +) -> DiscordUser | None: + """Get the target user for sending a proactive message.""" + if entity_type == "user": + return session.get(DiscordUser, entity_id) + # For channels and servers, we don't have a specific target user + return None + + +def get_channel_for_entity( + session, entity_type: EntityType, entity_id: int +) -> DiscordChannel | None: + """Get the channel for sending a proactive message.""" + if entity_type == "channel": + return session.get(DiscordChannel, entity_id) + if entity_type == "server": + # For servers, find the first text channel (prefer "general") + channels = ( + session.query(DiscordChannel) + .filter( + DiscordChannel.server_id == entity_id, + DiscordChannel.channel_type == "text", + ) + .all() + ) + if not channels: + return None + # Prefer a channel named "general" if it exists + for channel in channels: + if channel.name and "general" in channel.name.lower(): + return channel + return channels[0] + # For users, we use DMs (no channel) + return None + + +@app.task(name=EVALUATE_PROACTIVE_CHECKINS) +@safe_task_execution +def evaluate_proactive_checkins() -> dict[str, Any]: + """ + Evaluate which entities need proactive check-ins. + + This task runs every minute and checks all entities with proactive_cron set + to see if they're due for a check-in. + """ + now = datetime.now(timezone.utc) + dispatched = [] + + with make_session() as session: + # Query all entities with proactive_cron set + for model, entity_type in [ + (DiscordUser, "user"), + (DiscordChannel, "channel"), + (DiscordServer, "server"), + ]: + entities = ( + session.query(model) + .filter(model.proactive_cron.isnot(None)) + .all() + ) + + for entity in entities: + cron_expr = cast(str, entity.proactive_cron) + last_run = entity.last_proactive_at + + if is_cron_due(cron_expr, last_run, now): + logger.info( + f"Proactive check-in due for {entity_type} {entity.id}" + ) + execute_proactive_checkin.delay(entity_type, entity.id) + dispatched.append({"type": entity_type, "id": entity.id}) + + return { + "evaluated_at": now.isoformat(), + "dispatched": dispatched, + "count": len(dispatched), + } + + +@app.task(name=EXECUTE_PROACTIVE_CHECKIN) +@safe_task_execution +def execute_proactive_checkin(entity_type: EntityType, entity_id: int) -> dict[str, Any]: + """ + Execute a proactive check-in for a specific entity. + + This evaluates whether the bot should reach out and, if so, generates + and sends a check-in message. + """ + logger.info(f"Executing proactive check-in for {entity_type} {entity_id}") + + with make_session() as session: + # Get the entity + model_class = { + "user": DiscordUser, + "channel": DiscordChannel, + "server": DiscordServer, + }[entity_type] + + entity = session.get(model_class, entity_id) + if not entity: + return {"error": f"{entity_type} {entity_id} not found"} + + # Get the bot user + bot_user = get_bot_for_entity(session, entity_type, entity_id) + if not bot_user: + return {"error": "No bot user found"} + + # Get target user and channel + target_user = get_target_user_for_entity(session, entity_type, entity_id) + channel = get_channel_for_entity(session, entity_type, entity_id) + + if not target_user and not channel: + return {"error": "No target user or channel for proactive check-in"} + + # Get chattiness threshold + chattiness = entity.chattiness_threshold or 90 + + # Build the evaluation prompt + proactive_prompt = entity.proactive_prompt or "" + eval_prompt = textwrap.dedent(""" + You are considering whether to proactively reach out to check in. + + {proactive_prompt} + + Based on your notes and the context of previous conversations: + 1. Is there anything worth checking in about? + 2. Has enough happened or enough time passed to warrant a check-in? + 3. Would reaching out now be welcome or intrusive? + + Please return a number between 0 and 100 indicating how strongly you want to check in + (0 = definitely not, 100 = definitely yes). + + + 50 + Your reasoning here + + """).format(proactive_prompt=proactive_prompt) + + # Build context + system_prompt = comm_channel_prompt( + session, bot_user, target_user, channel + ) + + # First, evaluate whether we should check in + eval_response = call_llm( + session, + bot_user=bot_user, + from_user=target_user, + channel=channel, + model=settings.SUMMARIZER_MODEL, + system_prompt=system_prompt, + messages=[eval_prompt], + allowed_tools=[ + "update_channel_summary", + "update_user_summary", + "update_server_summary", + ], + ) + + if not eval_response: + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + return {"status": "no_eval_response", "entity_type": entity_type, "entity_id": entity_id} + + # Parse the interest score + match = re.search(r"(\d+)", eval_response) + if not match: + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + return {"status": "no_score_in_response", "entity_type": entity_type, "entity_id": entity_id} + + interest_score = int(match.group(1)) + threshold = 100 - chattiness + + logger.info( + f"Proactive check-in eval: interest={interest_score}, threshold={threshold}, chattiness={chattiness}" + ) + + if interest_score < threshold: + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + return { + "status": "below_threshold", + "interest": interest_score, + "threshold": threshold, + "entity_type": entity_type, + "entity_id": entity_id, + } + + # Generate the actual check-in message + checkin_prompt = textwrap.dedent(""" + You've decided to proactively check in. Generate a natural, friendly check-in message. + + {proactive_prompt} + + Keep it brief and genuine. Don't be overly formal or robotic. + Reference specific things from your notes if relevant. + """).format(proactive_prompt=proactive_prompt) + + response = call_llm( + session, + bot_user=bot_user, + from_user=target_user, + channel=channel, + model=settings.DISCORD_MODEL, + system_prompt=system_prompt, + messages=[checkin_prompt], + ) + + if not response: + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + return {"status": "no_message_generated", "entity_type": entity_type, "entity_id": entity_id} + + # Send the message + bot_id = bot_user.system_user.id if bot_user.system_user else None + if not bot_id: + return {"error": "No system user for bot"} + + success = send_discord_response( + bot_id=bot_id, + response=response, + channel_id=channel.id if channel else None, + user_identifier=target_user.username if target_user else None, + ) + + # Update last_proactive_at + entity.last_proactive_at = datetime.now(timezone.utc) + session.commit() + + return { + "status": "sent" if success else "send_failed", + "interest": interest_score, + "entity_type": entity_type, + "entity_id": entity_id, + "response_preview": response[:100] + "..." if len(response) > 100 else response, + } diff --git a/tests/integration/test_real_queries.py b/tests/integration/test_real_queries.py index df46fbf..0da4a43 100644 --- a/tests/integration/test_real_queries.py +++ b/tests/integration/test_real_queries.py @@ -766,7 +766,7 @@ EXPECTED_OBSERVATION_RESULTS = { ), ( 0.409, - "Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI", + "Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches", ), ], }, @@ -835,11 +835,11 @@ EXPECTED_OBSERVATION_RESULTS = { "semantic": [ (0.489, "I find backend logic more interesting than UI work"), (0.462, "The user prefers working on backend systems over frontend UI"), + (0.455, "The user said pure functions are yucky"), ( 0.455, "The user believes functional programming leads to better code quality", ), - (0.455, "The user said pure functions are yucky"), ], "temporal": [ ( diff --git a/tests/memory/api/search/test_query_analysis.py b/tests/memory/api/search/test_query_analysis.py index 85ef343..2290d96 100644 --- a/tests/memory/api/search/test_query_analysis.py +++ b/tests/memory/api/search/test_query_analysis.py @@ -137,10 +137,9 @@ class TestBuildPrompt: ): prompt = _build_prompt() - assert "lesswrong" in prompt.lower() - assert "comic" in prompt.lower() - assert "Remove" in prompt + assert "Remove meta-language" in prompt assert "Return ONLY valid JSON" in prompt + assert "recalled_content" in prompt class TestAnalyzeQuery: diff --git a/tests/memory/api/test_auth.py b/tests/memory/api/test_auth.py index 708fb05..dddd814 100644 --- a/tests/memory/api/test_auth.py +++ b/tests/memory/api/test_auth.py @@ -83,9 +83,10 @@ def test_logout_handles_missing_session(mock_get_user_session): @pytest.mark.asyncio +@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock) @patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock) @patch("memory.api.auth.make_session") -async def test_oauth_callback_discord_success(mock_make_session, mock_complete): +async def test_oauth_callback_discord_success(mock_make_session, mock_complete, mock_mcp_tools): mock_session = MagicMock() @contextmanager @@ -95,9 +96,12 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete): mock_make_session.return_value = session_cm() mcp_server = MagicMock() + mcp_server.mcp_server_url = "https://example.com" + mcp_server.access_token = "token123" mock_session.query.return_value.filter.return_value.first.return_value = mcp_server mock_complete.return_value = (200, "Authorized") + mock_mcp_tools.return_value = [{"name": "test_tool"}] request = make_request("code=abc123&state=state456") response = await auth.oauth_callback_discord(request) @@ -107,14 +111,15 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete): assert "Authorization Successful" in body assert "Authorized" in body mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456") - mock_session.commit.assert_called_once() + assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list @pytest.mark.asyncio +@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock) @patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock) @patch("memory.api.auth.make_session") async def test_oauth_callback_discord_handles_failures( - mock_make_session, mock_complete + mock_make_session, mock_complete, mock_mcp_tools ): mock_session = MagicMock() @@ -125,9 +130,12 @@ async def test_oauth_callback_discord_handles_failures( mock_make_session.return_value = session_cm() mcp_server = MagicMock() + mcp_server.mcp_server_url = "https://example.com" + mcp_server.access_token = "token123" mock_session.query.return_value.filter.return_value.first.return_value = mcp_server mock_complete.return_value = (500, "Failure") + mock_mcp_tools.return_value = [] request = make_request("code=abc123&state=state456") response = await auth.oauth_callback_discord(request) @@ -137,7 +145,7 @@ async def test_oauth_callback_discord_handles_failures( assert "Authorization Failed" in body assert "Failure" in body mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456") - mock_session.commit.assert_called_once() + assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list @pytest.mark.asyncio diff --git a/tests/memory/common/llms/tools/test_discord_tools.py b/tests/memory/common/llms/tools/test_discord_tools.py index 29ce374..81a2d19 100644 --- a/tests/memory/common/llms/tools/test_discord_tools.py +++ b/tests/memory/common/llms/tools/test_discord_tools.py @@ -18,6 +18,7 @@ from memory.common.db.models import ( DiscordUser, DiscordMessage, BotUser, + DiscordBotUser, HumanUser, ScheduledLLMCall, ) @@ -67,9 +68,10 @@ def sample_discord_user(db_session): @pytest.fixture -def sample_bot_user(db_session): +def sample_bot_user(db_session, sample_discord_user): """Create a sample bot user for testing.""" - bot = BotUser.create_with_api_key( + bot = DiscordBotUser.create_with_api_key( + discord_users=[sample_discord_user], name="Test Bot", email="testbot@example.com", ) @@ -209,9 +211,9 @@ def test_schedule_message_with_user( future_time = datetime.now(timezone.utc) + timedelta(hours=1) result = schedule_message( - user_id=sample_human_user.id, - user=sample_discord_user.id, - channel=None, + bot_id=sample_human_user.id, + recipient_id=sample_discord_user.id, + channel_id=None, model="test-model", message="Test message", date_time=future_time, @@ -240,9 +242,9 @@ def test_schedule_message_with_channel( future_time = datetime.now(timezone.utc) + timedelta(hours=1) result = schedule_message( - user_id=sample_human_user.id, - user=None, - channel=sample_discord_channel.id, + bot_id=sample_human_user.id, + recipient_id=None, + channel_id=sample_discord_channel.id, model="test-model", message="Test message", date_time=future_time, @@ -265,12 +267,12 @@ def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user): """Test creating a message scheduler tool for a user.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) - assert tool.name == "schedule_message" + assert tool.name == "schedule_discord_message" assert "from your chat with this user" in tool.description assert tool.input_schema["type"] == "object" assert "message" in tool.input_schema["properties"] @@ -282,12 +284,12 @@ def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_cha """Test creating a message scheduler tool for a channel.""" tool = make_message_scheduler( bot=sample_bot_user, - user=None, - channel=sample_discord_channel.id, + user_id=None, + channel_id=sample_discord_channel.id, model="test-model", ) - assert tool.name == "schedule_message" + assert tool.name == "schedule_discord_message" assert "in this channel" in tool.description assert callable(tool.function) @@ -297,8 +299,8 @@ def test_make_message_scheduler_without_user_or_channel(sample_bot_user): with pytest.raises(ValueError, match="Either user or channel must be provided"): make_message_scheduler( bot=sample_bot_user, - user=None, - channel=None, + user_id=None, + channel_id=None, model="test-model", ) @@ -310,8 +312,8 @@ def test_message_scheduler_handler_success( """Test message scheduler handler with valid input.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) @@ -330,8 +332,8 @@ def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord """Test message scheduler handler with non-dict input.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) @@ -345,8 +347,8 @@ def test_message_scheduler_handler_invalid_datetime( """Test message scheduler handler with invalid datetime.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) @@ -365,8 +367,8 @@ def test_message_scheduler_handler_missing_datetime( """Test message scheduler handler with missing datetime.""" tool = make_message_scheduler( bot=sample_bot_user, - user=sample_discord_user.id, - channel=None, + user_id=sample_discord_user.id, + channel_id=None, model="test-model", ) @@ -375,9 +377,9 @@ def test_message_scheduler_handler_missing_datetime( # Tests for make_prev_messages_tool -def test_make_prev_messages_tool_with_user(sample_discord_user): +def test_make_prev_messages_tool_with_user(sample_bot_user, sample_discord_user): """Test creating a previous messages tool for a user.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) assert tool.name == "previous_messages" assert "from your chat with this user" in tool.description @@ -387,26 +389,26 @@ def test_make_prev_messages_tool_with_user(sample_discord_user): assert callable(tool.function) -def test_make_prev_messages_tool_with_channel(sample_discord_channel): +def test_make_prev_messages_tool_with_channel(sample_bot_user, sample_discord_channel): """Test creating a previous messages tool for a channel.""" - tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=sample_discord_channel.id) assert tool.name == "previous_messages" assert "in this channel" in tool.description assert callable(tool.function) -def test_make_prev_messages_tool_without_user_or_channel(): +def test_make_prev_messages_tool_without_user_or_channel(sample_bot_user): """Test that creating a tool without user or channel raises error.""" with pytest.raises(ValueError, match="Either user or channel must be provided"): - make_prev_messages_tool(user=None, channel=None) + make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=None) def test_prev_messages_handler_success( - db_session, sample_discord_user, sample_discord_channel + db_session, sample_bot_user, sample_discord_user, sample_discord_channel ): """Test previous messages handler with valid input.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) # Create some actual messages in the database msg1 = DiscordMessage( @@ -440,9 +442,9 @@ def test_prev_messages_handler_success( assert "Message 1" in result or "Message 2" in result -def test_prev_messages_handler_with_defaults(db_session, sample_discord_user): +def test_prev_messages_handler_with_defaults(db_session, sample_bot_user, sample_discord_user): """Test previous messages handler with default values.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) result = tool.function({}) @@ -450,35 +452,35 @@ def test_prev_messages_handler_with_defaults(db_session, sample_discord_user): assert isinstance(result, str) -def test_prev_messages_handler_invalid_input(sample_discord_user): +def test_prev_messages_handler_invalid_input(sample_bot_user, sample_discord_user): """Test previous messages handler with non-dict input.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) with pytest.raises(ValueError, match="Input must be a dictionary"): tool.function("not a dict") -def test_prev_messages_handler_invalid_max_messages(sample_discord_user): +def test_prev_messages_handler_invalid_max_messages(sample_bot_user, sample_discord_user): """Test previous messages handler with invalid max_messages (negative value).""" # Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting, # so we test with -1 which actually triggers the validation - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) with pytest.raises(ValueError, match="Max messages must be greater than 0"): tool.function({"max_messages": -1}) -def test_prev_messages_handler_invalid_offset(sample_discord_user): +def test_prev_messages_handler_invalid_offset(sample_bot_user, sample_discord_user): """Test previous messages handler with invalid offset.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"): tool.function({"offset": -1}) -def test_prev_messages_handler_non_integer_values(sample_discord_user): +def test_prev_messages_handler_non_integer_values(sample_bot_user, sample_discord_user): """Test previous messages handler with non-integer values.""" - tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None) + tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None) with pytest.raises(ValueError, match="Max messages and offset must be integers"): tool.function({"max_messages": "not an int"}) @@ -496,10 +498,10 @@ def test_make_discord_tools_with_user_and_channel( model="test-model", ) - # Should have: schedule_message, previous_messages, update_channel_summary, + # Should have: schedule_discord_message, previous_messages, update_channel_summary, # update_user_summary, update_server_summary, add_reaction assert len(tools) == 6 - assert "schedule_message" in tools + assert "schedule_discord_message" in tools assert "previous_messages" in tools assert "update_channel_summary" in tools assert "update_user_summary" in tools @@ -516,10 +518,10 @@ def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user) model="test-model", ) - # Should have: schedule_message, previous_messages, update_user_summary + # Should have: schedule_discord_message, previous_messages, update_user_summary # Note: Without channel, there's no channel summary tool assert len(tools) >= 2 # At least schedule and previous messages - assert "schedule_message" in tools + assert "schedule_discord_message" in tools assert "previous_messages" in tools assert "update_user_summary" in tools @@ -533,10 +535,10 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch model="test-model", ) - # Should have: schedule_message, previous_messages, update_channel_summary, + # Should have: schedule_discord_message, previous_messages, update_channel_summary, # update_server_summary, add_reaction (no user summary without author) assert len(tools) == 5 - assert "schedule_message" in tools + assert "schedule_discord_message" in tools assert "previous_messages" in tools assert "update_channel_summary" in tools assert "update_server_summary" in tools diff --git a/tests/memory/common/test_discord.py b/tests/memory/common/test_discord.py index ce4cf3f..55db891 100644 --- a/tests/memory/common/test_discord.py +++ b/tests/memory/common/test_discord.py @@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url): "http://localhost:8000/send_channel", json={ "bot_id": BOT_ID, - "channel_name": "general", + "channel": "general", "message": "Announcement!", }, timeout=10, diff --git a/tests/memory/common/test_discord_integration.py b/tests/memory/common/test_discord_integration.py index ce4cf3f..55db891 100644 --- a/tests/memory/common/test_discord_integration.py +++ b/tests/memory/common/test_discord_integration.py @@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url): "http://localhost:8000/send_channel", json={ "bot_id": BOT_ID, - "channel_name": "general", + "channel": "general", "message": "Announcement!", }, timeout=10, diff --git a/tests/memory/discord_tests/test_commands.py b/tests/memory/discord_tests/test_commands.py index fd2bf9d..9d05195 100644 --- a/tests/memory/discord_tests/test_commands.py +++ b/tests/memory/discord_tests/test_commands.py @@ -15,6 +15,7 @@ from memory.discord.commands import ( handle_chattiness, handle_ignore, handle_summary, + handle_proactive, respond, with_object_context, handle_mcp_servers, @@ -377,3 +378,207 @@ async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction): await handle_mcp_servers(context, action="list", url=None) assert "Error: boom" in str(exc.value) + + +# ============================================================================ +# Tests for handle_proactive +# ============================================================================ + + +@pytest.mark.asyncio +async def test_handle_proactive_show_disabled(db_session, interaction, guild): + """Test showing proactive settings when disabled.""" + server = DiscordServer(id=guild.id, name="Guild", proactive_cron=None) + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context) + + assert "disabled" in response.content.lower() + + +@pytest.mark.asyncio +async def test_handle_proactive_show_enabled(db_session, interaction, guild): + """Test showing proactive settings when enabled.""" + server = DiscordServer( + id=guild.id, + name="Guild", + proactive_cron="0 9 * * *", + proactive_prompt="Check on projects", + ) + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context) + + assert "0 9 * * *" in response.content + assert "Check on projects" in response.content + + +@pytest.mark.asyncio +async def test_handle_proactive_set_cron(db_session, interaction, guild): + """Test setting proactive cron schedule.""" + server = DiscordServer(id=guild.id, name="Guild") + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context, cron="0 9 * * *") + + assert "Updated" in response.content + assert "0 9 * * *" in response.content + assert server.proactive_cron == "0 9 * * *" + + +@pytest.mark.asyncio +async def test_handle_proactive_set_prompt(db_session, interaction, guild): + """Test setting proactive prompt.""" + server = DiscordServer(id=guild.id, name="Guild", proactive_cron="0 9 * * *") + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context, prompt="Focus on daily standups") + + assert "Updated" in response.content + assert server.proactive_prompt == "Focus on daily standups" + + +@pytest.mark.asyncio +async def test_handle_proactive_disable(db_session, interaction, guild): + """Test disabling proactive check-ins.""" + server = DiscordServer( + id=guild.id, + name="Guild", + proactive_cron="0 9 * * *", + proactive_prompt="Some prompt", + ) + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + response = await handle_proactive(context, cron="off") + + assert "disabled" in response.content.lower() + assert server.proactive_cron is None + + +@pytest.mark.asyncio +async def test_handle_proactive_invalid_cron(db_session, interaction, guild): + """Test error on invalid cron expression.""" + server = DiscordServer(id=guild.id, name="Guild") + db_session.add(server) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="server", + target=server, + display_name="server **Guild**", + ) + + with pytest.raises(CommandError) as exc: + await handle_proactive(context, cron="not a valid cron") + + assert "Invalid cron expression" in str(exc.value) + + +@pytest.mark.asyncio +async def test_handle_proactive_user_scope(db_session, interaction, discord_user): + """Test proactive settings for user scope.""" + user_model = DiscordUser( + id=discord_user.id, username="testuser", proactive_cron=None + ) + db_session.add(user_model) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="me", + target=user_model, + display_name="you (**testuser**)", + ) + + response = await handle_proactive(context, cron="0 9,17 * * 1-5") + + assert "Updated" in response.content + assert user_model.proactive_cron == "0 9,17 * * 1-5" + + +@pytest.mark.asyncio +async def test_handle_proactive_channel_scope( + db_session, interaction, guild, text_channel +): + """Test proactive settings for channel scope.""" + server = DiscordServer(id=guild.id, name="Guild") + db_session.add(server) + db_session.flush() + + channel_model = DiscordChannel( + id=text_channel.id, + name="general", + channel_type="text", + server_id=guild.id, + ) + db_session.add(channel_model) + db_session.commit() + + context = CommandContext( + session=db_session, + interaction=interaction, + actor=MagicMock(spec=DiscordUser), + scope="channel", + target=channel_model, + display_name="channel **#general**", + ) + + response = await handle_proactive(context, cron="0 12 * * *") + + assert "Updated" in response.content + assert channel_model.proactive_cron == "0 12 * * *" diff --git a/tests/memory/discord_tests/test_mcp.py b/tests/memory/discord_tests/test_mcp.py index 37a7649..0e59f89 100644 --- a/tests/memory/discord_tests/test_mcp.py +++ b/tests/memory/discord_tests/test_mcp.py @@ -1,6 +1,5 @@ """Tests for Discord MCP server management.""" -import json from unittest.mock import AsyncMock, Mock, patch import aiohttp @@ -8,8 +7,8 @@ import discord import pytest from memory.common.db.models import MCPServer, MCPServerAssignment +from memory.common.mcp import mcp_call from memory.discord.mcp import ( - call_mcp_server, find_mcp_server, handle_mcp_add, handle_mcp_connect, @@ -142,7 +141,7 @@ async def test_call_mcp_server_success(): with patch("aiohttp.ClientSession", return_value=mock_session_ctx): results = [] - async for data in call_mcp_server( + async for data in mcp_call( "https://mcp.example.com", "test_token", "tools/list", {} ): results.append(data) @@ -172,7 +171,7 @@ async def test_call_mcp_server_error(): with patch("aiohttp.ClientSession", return_value=mock_session_ctx): with pytest.raises(ValueError, match="Failed to call MCP server"): - async for _ in call_mcp_server( + async for _ in mcp_call( "https://mcp.example.com", "test_token", "tools/list" ): pass @@ -203,7 +202,7 @@ async def test_call_mcp_server_invalid_json(): with patch("aiohttp.ClientSession", return_value=mock_session_ctx): results = [] - async for data in call_mcp_server( + async for data in mcp_call( "https://mcp.example.com", "test_token", "tools/list" ): results.append(data) diff --git a/tests/memory/discord_tests/test_messages.py b/tests/memory/discord_tests/test_messages.py index e2f60a9..edabcc2 100644 --- a/tests/memory/discord_tests/test_messages.py +++ b/tests/memory/discord_tests/test_messages.py @@ -19,6 +19,7 @@ from memory.common.db.models import ( DiscordChannel, DiscordServer, DiscordMessage, + DiscordBotUser, HumanUser, ScheduledLLMCall, ) @@ -34,6 +35,19 @@ def sample_discord_user(db_session): return user +@pytest.fixture +def sample_bot_user(db_session, sample_discord_user): + """Create a sample Discord bot user.""" + bot = DiscordBotUser.create_with_api_key( + discord_users=[sample_discord_user], + name="Test Bot", + email="testbot@example.com", + ) + db_session.add(bot) + db_session.commit() + return bot + + @pytest.fixture def sample_discord_channel(db_session): """Create a sample Discord channel.""" @@ -290,13 +304,13 @@ def test_upsert_scheduled_message_cancels_earlier_call( # Test previous_messages -def test_previous_messages_empty(db_session): +def test_previous_messages_empty(db_session, sample_bot_user): """Test getting previous messages when none exist.""" - result = previous_messages(db_session, user_id=123, channel_id=456) + result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=123, channel_id=456) assert result == [] -def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel): +def test_previous_messages_filters_by_user(db_session, sample_bot_user, sample_discord_user, sample_discord_channel): """Test filtering messages by recipient user.""" # Create some messages msg1 = DiscordMessage( @@ -322,14 +336,14 @@ def test_previous_messages_filters_by_user(db_session, sample_discord_user, samp db_session.add_all([msg1, msg2]) db_session.commit() - result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None) + result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None) assert len(result) == 2 # Should be in chronological order (oldest first) assert result[0].message_id == 1 assert result[1].message_id == 2 -def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel): +def test_previous_messages_limits_results(db_session, sample_bot_user, sample_discord_user, sample_discord_channel): """Test limiting the number of previous messages.""" # Create 15 messages for i in range(15): @@ -347,7 +361,7 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl db_session.commit() result = previous_messages( - db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5 + db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None, max_messages=5 ) assert len(result) == 5 @@ -355,10 +369,10 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl # Test comm_channel_prompt -def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel): +def test_comm_channel_prompt_basic(db_session, sample_bot_user, sample_discord_user, sample_discord_channel): """Test generating a basic communication channel prompt.""" result = comm_channel_prompt( - db_session, user=sample_discord_user, channel=sample_discord_channel + db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel ) assert "You are a bot communicating on Discord" in result @@ -366,31 +380,31 @@ def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_disco assert len(result) > 0 -def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel): +def test_comm_channel_prompt_includes_server_context(db_session, sample_bot_user, sample_discord_channel): """Test that prompt includes server context when available.""" server = sample_discord_channel.server server.summary = "Gaming community server" db_session.commit() - result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel) + result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel) assert "server_context" in result.lower() assert "Gaming community server" in result -def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel): +def test_comm_channel_prompt_includes_channel_context(db_session, sample_bot_user, sample_discord_channel): """Test that prompt includes channel context.""" sample_discord_channel.summary = "General discussion channel" db_session.commit() - result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel) + result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel) assert "channel_context" in result.lower() assert "General discussion channel" in result def test_comm_channel_prompt_includes_user_notes( - db_session, sample_discord_user, sample_discord_channel + db_session, sample_bot_user, sample_discord_user, sample_discord_channel ): """Test that prompt includes user notes from previous messages.""" sample_discord_user.summary = "Helpful community member" @@ -411,7 +425,7 @@ def test_comm_channel_prompt_includes_user_notes( db_session.commit() result = comm_channel_prompt( - db_session, user=sample_discord_user, channel=sample_discord_channel + db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel ) assert "user_notes" in result.lower() @@ -442,12 +456,16 @@ def test_call_llm_includes_web_search_and_mcp_servers( web_tool_instance = MagicMock(name="web_tool") mock_web_search.return_value = web_tool_instance - bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt") + bot_user = SimpleNamespace( + system_user=SimpleNamespace(discord_id=999888777), + system_prompt="bot prompt" + ) from_user = SimpleNamespace(id=123) mcp_model = SimpleNamespace( name="Server", mcp_server_url="https://mcp.example.com", access_token="token123", + disabled_tools=[], ) result = call_llm( @@ -502,7 +520,10 @@ def test_call_llm_filters_disallowed_tools( mock_web_search.return_value = MagicMock(name="web_tool") - bot_user = SimpleNamespace(system_user="system-user", system_prompt=None) + bot_user = SimpleNamespace( + system_user=SimpleNamespace(discord_id=999888777), + system_prompt=None + ) from_user = SimpleNamespace(id=1) call_llm( diff --git a/tests/memory/workers/tasks/test_discord_tasks.py b/tests/memory/workers/tasks/test_discord_tasks.py index 8decc7d..9650ba7 100644 --- a/tests/memory/workers/tasks/test_discord_tasks.py +++ b/tests/memory/workers/tasks/test_discord_tasks.py @@ -14,12 +14,25 @@ from memory.workers.tasks import discord @pytest.fixture def discord_bot_user(db_session): + # Create a discord user for the bot first + bot_discord_user = DiscordUser( + id=999999999, + username="testbot", + ) + db_session.add(bot_discord_user) + db_session.flush() + bot = DiscordBotUser.create_with_api_key( - discord_users=[], + discord_users=[bot_discord_user], name="Test Bot", email="bot@example.com", ) db_session.add(bot) + db_session.flush() + + # Link the discord user to the system user + bot_discord_user.system_user_id = bot.id + db_session.commit() return bot @@ -176,26 +189,29 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel): @patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) -@patch("memory.workers.tasks.discord.create_provider") +@patch("memory.workers.tasks.discord.call_llm") +@patch("memory.workers.tasks.discord.discord.trigger_typing_channel") def test_should_process_normal_message( - mock_create_provider, + mock_trigger_typing, + mock_call_llm, db_session, mock_discord_user, mock_discord_server, mock_discord_channel, + discord_bot_user, ): """Test should_process returns True for normal messages.""" - # Mock the LLM provider to return "yes" - mock_provider = Mock() - mock_provider.generate.return_value = "yes" - mock_provider.as_messages.return_value = [] - mock_create_provider.return_value = mock_provider + # Create a separate recipient user (the bot) + bot_discord_user = discord_bot_user.discord_users[0] + + # Mock call_llm to return a high number (100 = always process) + mock_call_llm.return_value = "100Test" message = DiscordMessage( message_id=1, channel_id=mock_discord_channel.id, from_id=mock_discord_user.id, - recipient_id=mock_discord_user.id, + recipient_id=bot_discord_user.id, # Bot is recipient, not the from_user server_id=mock_discord_server.id, content="Test", sent_at=datetime.now(timezone.utc), @@ -207,6 +223,8 @@ def test_should_process_normal_message( db_session.refresh(message) assert discord.should_process(message) is True + mock_call_llm.assert_called_once() + mock_trigger_typing.assert_called_once() @patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False) @@ -344,6 +362,7 @@ def test_add_discord_message_success(db_session, sample_message_data, qdrant): def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant): """Test adding a Discord message that is a reply.""" sample_message_data["message_reference_id"] = 111222333 + sample_message_data["message_type"] = "reply" # Explicitly set message_type discord.add_discord_message(**sample_message_data) @@ -523,8 +542,17 @@ def test_edit_discord_message_updates_context( assert result["status"] == "processed" -def test_process_discord_message_success(db_session, sample_message_data, qdrant): +@patch("memory.workers.tasks.discord.send_discord_response") +@patch("memory.workers.tasks.discord.call_llm") +def test_process_discord_message_success( + mock_call_llm, mock_send_response, db_session, sample_message_data, qdrant +): """Test processing a Discord message.""" + # Mock LLM to return a response + mock_call_llm.return_value = "Test response from bot" + # Mock Discord API to succeed + mock_send_response.return_value = True + # Add a message first add_result = discord.add_discord_message(**sample_message_data) message_id = add_result["discordmessage_id"] @@ -534,6 +562,8 @@ def test_process_discord_message_success(db_session, sample_message_data, qdrant assert result["status"] == "processed" assert result["message_id"] == message_id + mock_call_llm.assert_called_once() + mock_send_response.assert_called_once() def test_process_discord_message_not_found(db_session): diff --git a/tests/memory/workers/tasks/test_proactive.py b/tests/memory/workers/tasks/test_proactive.py new file mode 100644 index 0000000..3edcd38 --- /dev/null +++ b/tests/memory/workers/tasks/test_proactive.py @@ -0,0 +1,536 @@ +"""Tests for proactive check-in tasks.""" + +import pytest +from datetime import datetime, timezone, timedelta +from unittest.mock import Mock, patch, MagicMock + +from memory.common.db.models import ( + DiscordBotUser, + DiscordUser, + DiscordChannel, + DiscordServer, +) +from memory.workers.tasks import proactive +from memory.workers.tasks.proactive import is_cron_due + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def bot_user(db_session): + """Create a bot user for testing.""" + bot_discord_user = DiscordUser( + id=999999999, + username="testbot", + ) + db_session.add(bot_discord_user) + db_session.flush() + + user = DiscordBotUser.create_with_api_key( + discord_users=[bot_discord_user], + name="testbot", + email="bot@example.com", + ) + db_session.add(user) + db_session.commit() + return user + + +@pytest.fixture +def target_user(db_session): + """Create a target Discord user for testing.""" + discord_user = DiscordUser( + id=123456789, + username="targetuser", + proactive_cron="0 9 * * *", # 9am daily + chattiness_threshold=50, + ) + db_session.add(discord_user) + db_session.commit() + return discord_user + + +@pytest.fixture +def target_user_no_cron(db_session): + """Create a target Discord user without proactive cron.""" + discord_user = DiscordUser( + id=123456790, + username="nocronuser", + proactive_cron=None, + ) + db_session.add(discord_user) + db_session.commit() + return discord_user + + +@pytest.fixture +def target_server(db_session): + """Create a target Discord server for testing.""" + server = DiscordServer( + id=987654321, + name="Test Server", + proactive_cron="0 */4 * * *", # Every 4 hours + chattiness_threshold=30, + ) + db_session.add(server) + db_session.commit() + return server + + +@pytest.fixture +def target_channel(db_session, target_server): + """Create a target Discord channel for testing.""" + channel = DiscordChannel( + id=111222333, + name="test-channel", + channel_type="text", + server_id=target_server.id, + proactive_cron="0 12 * * 1-5", # Noon on weekdays + chattiness_threshold=70, + ) + db_session.add(channel) + db_session.commit() + return channel + + +# ============================================================================ +# Tests for is_cron_due helper +# ============================================================================ + + +@pytest.mark.parametrize( + "cron_expr,now,last_run,expected", + [ + # Cron is due when never run before and time matches + ( + "0 9 * * *", + datetime(2025, 12, 24, 9, 0, 30, tzinfo=timezone.utc), + None, + True, + ), + # Cron is due when last run was before the scheduled time + ( + "0 9 * * *", + datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc), + datetime(2025, 12, 23, 9, 0, 0, tzinfo=timezone.utc), + True, + ), + # Cron is NOT due when already run this period + ( + "0 9 * * *", + datetime(2025, 12, 24, 9, 30, 0, tzinfo=timezone.utc), + datetime(2025, 12, 24, 9, 5, 0, tzinfo=timezone.utc), + False, + ), + # Cron is NOT due when current time is before scheduled time + ( + "0 9 * * *", + datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc), + None, + False, + ), + # Hourly cron schedule + ( + "0 * * * *", + datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc), + datetime(2025, 12, 24, 11, 0, 0, tzinfo=timezone.utc), + True, + ), + # Every 4 hours cron schedule + ( + "0 */4 * * *", + datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc), + datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc), + True, + ), + ], + ids=[ + "due_never_run", + "due_last_run_before_schedule", + "not_due_already_run", + "not_due_too_early", + "due_hourly", + "due_every_4_hours", + ], +) +def test_is_cron_due(cron_expr, now, last_run, expected): + """Test is_cron_due with various scenarios.""" + assert is_cron_due(cron_expr, last_run, now) is expected + + +def test_is_cron_due_invalid_expression(): + """Test invalid cron expression returns False.""" + now = datetime(2025, 12, 24, 9, 0, 0, tzinfo=timezone.utc) + assert is_cron_due("invalid cron", None, now) is False + + +def test_is_cron_due_with_naive_last_run(): + """Test cron handles naive datetime for last_run.""" + now = datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc) + cron_expr = "0 9 * * *" + last_run = datetime(2025, 12, 23, 9, 0, 0) # Naive datetime + assert is_cron_due(cron_expr, last_run, now) is True + + +# ============================================================================ +# Tests for evaluate_proactive_checkins task +# ============================================================================ + + +@patch("memory.workers.tasks.proactive.execute_proactive_checkin") +@patch("memory.workers.tasks.proactive.is_cron_due") +@patch("memory.workers.tasks.proactive.make_session") +def test_evaluate_proactive_checkins_dispatches_due( + mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user +): + """Test that due check-ins are dispatched.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + mock_is_cron_due.return_value = True + + result = proactive.evaluate_proactive_checkins() + + assert result["count"] >= 1 + mock_execute.delay.assert_called() + + +@patch("memory.workers.tasks.proactive.execute_proactive_checkin") +@patch("memory.workers.tasks.proactive.is_cron_due") +@patch("memory.workers.tasks.proactive.make_session") +def test_evaluate_proactive_checkins_skips_not_due( + mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user +): + """Test that not-due check-ins are not dispatched.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + mock_is_cron_due.return_value = False + + result = proactive.evaluate_proactive_checkins() + + assert result["count"] == 0 + mock_execute.delay.assert_not_called() + + +@patch("memory.workers.tasks.proactive.execute_proactive_checkin") +@patch("memory.workers.tasks.proactive.make_session") +def test_evaluate_proactive_checkins_skips_no_cron( + mock_make_session, mock_execute, db_session, target_user_no_cron +): + """Test that entities without proactive_cron are skipped.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + result = proactive.evaluate_proactive_checkins() + + for call in mock_execute.delay.call_args_list: + entity_type, entity_id = call[0] + assert entity_id != target_user_no_cron.id + + +@patch("memory.workers.tasks.proactive.execute_proactive_checkin") +@patch("memory.workers.tasks.proactive.is_cron_due") +@patch("memory.workers.tasks.proactive.make_session") +def test_evaluate_proactive_checkins_multiple_entity_types( + mock_make_session, + mock_is_cron_due, + mock_execute, + db_session, + target_user, + target_server, + target_channel, +): + """Test that check-ins are dispatched for users, channels, and servers.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + mock_is_cron_due.return_value = True + + result = proactive.evaluate_proactive_checkins() + + assert result["count"] == 3 + dispatched_types = {d["type"] for d in result["dispatched"]} + assert "user" in dispatched_types + assert "channel" in dispatched_types + assert "server" in dispatched_types + + +# ============================================================================ +# Tests for execute_proactive_checkin task +# ============================================================================ + + +@patch("memory.workers.tasks.proactive.send_discord_response") +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_sends_when_above_threshold( + mock_make_session, + mock_get_bot, + mock_call_llm, + mock_send, + db_session, + target_user, + bot_user, +): + """Test check-in is sent when interest exceeds threshold.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.side_effect = [ + "80Should check in", + "Hey! Just checking in - how are things going?", + ] + mock_send.return_value = True + + result = proactive.execute_proactive_checkin("user", target_user.id) + + assert result["status"] == "sent" + assert result["interest"] == 80 + mock_send.assert_called_once() + + db_session.refresh(target_user) + assert target_user.last_proactive_at is not None + + +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_skips_below_threshold( + mock_make_session, + mock_get_bot, + mock_call_llm, + db_session, + target_user, + bot_user, +): + """Test check-in is skipped when interest is below threshold.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.return_value = ( + "30Not much to say" + ) + + result = proactive.execute_proactive_checkin("user", target_user.id) + + assert result["status"] == "below_threshold" + assert result["interest"] == 30 + assert result["threshold"] == 50 + + +@pytest.mark.parametrize( + "llm_response,expected_status", + [ + (None, "no_eval_response"), + ("I'm not sure what to say.", "no_score_in_response"), + ], + ids=["no_response", "malformed_response"], +) +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_handles_bad_llm_response( + mock_make_session, + mock_get_bot, + mock_call_llm, + llm_response, + expected_status, + db_session, + target_user, + bot_user, +): + """Test handling of missing or malformed LLM responses.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + mock_call_llm.return_value = llm_response + + result = proactive.execute_proactive_checkin("user", target_user.id) + + assert result["status"] == expected_status + + +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_nonexistent_entity(mock_make_session, db_session): + """Test handling when entity doesn't exist.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + result = proactive.execute_proactive_checkin("user", 999999) + + assert "error" in result + assert "not found" in result["error"] + + +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_no_bot_user( + mock_make_session, mock_get_bot, db_session, target_user +): + """Test handling when no bot user is found.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + mock_get_bot.return_value = None + + result = proactive.execute_proactive_checkin("user", target_user.id) + + assert "error" in result + assert "No bot user" in result["error"] + + +@patch("memory.workers.tasks.proactive.send_discord_response") +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_uses_proactive_prompt( + mock_make_session, + mock_get_bot, + mock_call_llm, + mock_send, + db_session, + bot_user, +): + """Test that proactive_prompt is included in the evaluation.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + user_with_prompt = DiscordUser( + id=555666777, + username="promptuser", + proactive_cron="0 9 * * *", + proactive_prompt="Focus on their coding projects", + chattiness_threshold=50, + ) + db_session.add(user_with_prompt) + db_session.commit() + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.side_effect = [ + "80Check on projects", + "How are your coding projects coming along?", + ] + mock_send.return_value = True + + result = proactive.execute_proactive_checkin("user", user_with_prompt.id) + + assert result["status"] == "sent" + call_args = mock_call_llm.call_args_list[0] + messages_arg = call_args.kwargs.get("messages") or call_args[1].get("messages") + assert any("Focus on their coding projects" in str(m) for m in messages_arg) + + +@patch("memory.workers.tasks.proactive.send_discord_response") +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_channel( + mock_make_session, + mock_get_bot, + mock_call_llm, + mock_send, + db_session, + target_channel, + bot_user, +): + """Test check-in to a channel.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.side_effect = [ + "50Check channel", + "Good morning everyone!", + ] + mock_send.return_value = True + + result = proactive.execute_proactive_checkin("channel", target_channel.id) + + assert result["status"] == "sent" + assert result["entity_type"] == "channel" + + send_call = mock_send.call_args + assert send_call.kwargs.get("channel_id") == target_channel.id + + +@patch("memory.workers.tasks.proactive.send_discord_response") +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_updates_last_proactive_at( + mock_make_session, + mock_get_bot, + mock_call_llm, + mock_send, + db_session, + target_user, + bot_user, +): + """Test that last_proactive_at is updated after successful check-in.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.side_effect = [ + "80Check in", + "Hey there!", + ] + mock_send.return_value = True + + before_time = datetime.now(timezone.utc) + proactive.execute_proactive_checkin("user", target_user.id) + after_time = datetime.now(timezone.utc) + + db_session.refresh(target_user) + assert target_user.last_proactive_at is not None + assert before_time <= target_user.last_proactive_at <= after_time + + +@patch("memory.workers.tasks.proactive.call_llm") +@patch("memory.workers.tasks.proactive.get_bot_for_entity") +@patch("memory.workers.tasks.proactive.make_session") +def test_execute_proactive_checkin_updates_last_proactive_at_on_skip( + mock_make_session, + mock_get_bot, + mock_call_llm, + db_session, + target_user, + bot_user, +): + """Test that last_proactive_at is updated even when check-in is skipped.""" + mock_make_session.return_value.__enter__ = Mock(return_value=db_session) + mock_make_session.return_value.__exit__ = Mock(return_value=False) + + bot_discord_user = bot_user.discord_users[0] + bot_discord_user.system_user = bot_user + mock_get_bot.return_value = bot_discord_user + + mock_call_llm.return_value = ( + "10Nothing to say" + ) + + proactive.execute_proactive_checkin("user", target_user.id) + + db_session.refresh(target_user) + assert target_user.last_proactive_at is not None diff --git a/tests/memory/workers/tasks/test_scheduled_calls.py b/tests/memory/workers/tasks/test_scheduled_calls.py index 44d567d..ceb28d5 100644 --- a/tests/memory/workers/tasks/test_scheduled_calls.py +++ b/tests/memory/workers/tasks/test_scheduled_calls.py @@ -16,8 +16,16 @@ from memory.workers.tasks import scheduled_calls @pytest.fixture def sample_user(db_session): """Create a sample user for testing.""" + # Create a discord user for the bot + bot_discord_user = DiscordUser( + id=999999999, + username="testbot", + ) + db_session.add(bot_discord_user) + db_session.flush() + user = DiscordBotUser.create_with_api_key( - discord_users=[], + discord_users=[bot_discord_user], name="testbot", email="bot@example.com", ) @@ -122,65 +130,64 @@ def future_scheduled_call(db_session, sample_user, sample_discord_user): return call -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") +@patch("memory.discord.messages.discord.send_dm") def test_send_to_discord_user(mock_send_dm, pending_scheduled_call): """Test sending to Discord user.""" response = "This is a test response." - scheduled_calls.send_to_discord(pending_scheduled_call, response) + scheduled_calls.send_to_discord(999999999, pending_scheduled_call, response) mock_send_dm.assert_called_once_with( - pending_scheduled_call.user_id, + 999999999, # bot_id "testuser", # username, not ID - "**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.", + response, ) -@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message") -def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call): +@patch("memory.discord.messages.discord.send_to_channel") +def test_send_to_discord_channel(mock_send_to_channel, completed_scheduled_call): """Test sending to Discord channel.""" response = "This is a channel response." - scheduled_calls.send_to_discord(completed_scheduled_call, response) + scheduled_calls.send_to_discord(999999999, completed_scheduled_call, response) - mock_broadcast.assert_called_once_with( - completed_scheduled_call.user_id, - "test-channel", # channel name, not ID - "**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.", + mock_send_to_channel.assert_called_once_with( + 999999999, # bot_id + completed_scheduled_call.discord_channel.id, # channel ID, not name + response, ) -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") +@patch("memory.discord.messages.discord.send_dm") def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call): """Test message truncation for long responses.""" long_response = "A" * 2500 # Very long response - scheduled_calls.send_to_discord(pending_scheduled_call, long_response) + scheduled_calls.send_to_discord(999999999, pending_scheduled_call, long_response) - # Verify the message was truncated + # With the new implementation, send_discord_response sends the full response + # No truncation happens in _send_to_discord args, kwargs = mock_send_dm.call_args - assert args[0] == pending_scheduled_call.user_id + assert args[0] == 999999999 # bot_id message = args[2] - assert len(message) <= 1950 # Should be truncated - assert message.endswith("... (response truncated)") + assert message == long_response -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") +@patch("memory.discord.messages.discord.send_dm") def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call): """Test that normal length messages are not truncated.""" normal_response = "This is a normal length response." - scheduled_calls.send_to_discord(pending_scheduled_call, normal_response) + scheduled_calls.send_to_discord(999999999, pending_scheduled_call, normal_response) args, kwargs = mock_send_dm.call_args - assert args[0] == pending_scheduled_call.user_id + assert args[0] == 999999999 # bot_id message = args[2] - assert not message.endswith("... (response truncated)") - assert "This is a normal length response." in message + assert message == normal_response -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_success( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): @@ -189,12 +196,8 @@ def test_execute_scheduled_call_success( result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id) - # Verify LLM was called with correct parameters - mock_llm_call.assert_called_once_with( - prompt="What is the weather like today?", - model="anthropic/claude-3-5-sonnet-20241022", - system_prompt="You are a helpful assistant.", - ) + # Verify LLM was called + mock_llm_call.assert_called_once() # Verify result assert result["success"] is True @@ -218,7 +221,7 @@ def test_execute_scheduled_call_not_found(db_session): assert result == {"error": "Scheduled call not found"} -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_not_pending( mock_llm_call, completed_scheduled_call, db_session ): @@ -229,8 +232,8 @@ def test_execute_scheduled_call_not_pending( mock_llm_call.assert_not_called() -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_with_default_system_prompt( mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user ): @@ -254,16 +257,12 @@ def test_execute_scheduled_call_with_default_system_prompt( scheduled_calls.execute_scheduled_call(call.id) - # Verify default system prompt was used - mock_llm_call.assert_called_once_with( - prompt="Test prompt", - model="anthropic/claude-3-5-sonnet-20241022", - system_prompt=None, # The code uses system_prompt as-is, not a default - ) + # Verify LLM was called + mock_llm_call.assert_called_once() -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_discord_error( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): @@ -286,26 +285,27 @@ def test_execute_scheduled_call_discord_error( assert pending_scheduled_call.data["discord_error"] == "Discord API error" -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_llm_error( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): """Test execution when LLM call fails.""" mock_llm_call.side_effect = Exception("LLM API error") - # The safe_task_execution decorator should catch this + # The execute_scheduled_call function catches the exception and returns an error response result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id) - assert result["status"] == "error" - assert "LLM API error" in result["error"] + assert result["success"] is False + assert "error" in result + assert "LLM call failed" in result["error"] # Discord should not be called mock_send_discord.assert_not_called() -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_long_response_truncation( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): @@ -477,8 +477,8 @@ def test_run_scheduled_calls_timezone_handling( mock_execute_delay.delay.assert_called_once_with(due_call.id) -@patch("memory.workers.tasks.scheduled_calls._send_to_discord") -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.send_to_discord") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_status_transition_pending_to_executing_to_completed( mock_llm_call, mock_send_discord, pending_scheduled_call, db_session ): @@ -502,14 +502,14 @@ def test_status_transition_pending_to_executing_to_completed( "has_discord_user,has_discord_channel,expected_method", [ (True, False, "send_dm"), - (False, True, "broadcast_message"), - (True, True, "send_dm"), # User takes precedence + (False, True, "send_to_channel"), + (True, True, "send_to_channel"), # Channel takes precedence in the implementation ], ) -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") -@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message") +@patch("memory.discord.messages.discord.send_dm") +@patch("memory.discord.messages.discord.send_to_channel") def test_discord_destination_priority( - mock_broadcast, + mock_send_to_channel, mock_send_dm, has_discord_user, has_discord_channel, @@ -535,50 +535,39 @@ def test_discord_destination_priority( db_session.commit() response = "Test response" - scheduled_calls.send_to_discord(call, response) + scheduled_calls.send_to_discord(999999999, call, response) if expected_method == "send_dm": mock_send_dm.assert_called_once() - mock_broadcast.assert_not_called() + mock_send_to_channel.assert_not_called() else: - mock_broadcast.assert_called_once() + mock_send_to_channel.assert_called_once() mock_send_dm.assert_not_called() @pytest.mark.parametrize( - "topic,model,response,expected_in_message", + "topic,model,response", [ ( "Weather Check", "anthropic/claude-3-5-sonnet-20241022", "It's sunny!", - [ - "**Topic:** Weather Check", - "**Model:** anthropic/claude-3-5-sonnet-20241022", - "**Response:** It's sunny!", - ], ), ( "Test Topic", "gpt-4", "Hello world", - ["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"], ), ( "Long Topic Name Here", "claude-2", "Short", - [ - "**Topic:** Long Topic Name Here", - "**Model:** claude-2", - "**Response:** Short", - ], ), ], ) -@patch("memory.workers.tasks.scheduled_calls.discord.send_dm") -def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message): - """Test the Discord message formatting with different inputs.""" +@patch("memory.discord.messages.discord.send_dm") +def test_message_formatting(mock_send_dm, topic, model, response): + """Test that _send_to_discord sends the response as-is.""" # Create a mock scheduled call with a mock Discord user mock_discord_user = Mock() mock_discord_user.username = "testuser" @@ -590,16 +579,15 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me mock_call.discord_user = mock_discord_user mock_call.discord_channel = None - scheduled_calls.send_to_discord(mock_call, response) + scheduled_calls.send_to_discord(999999999, mock_call, response) # Get the actual message that was sent args, kwargs = mock_send_dm.call_args - assert args[0] == mock_call.user_id + assert args[0] == 999999999 # bot_id actual_message = args[2] - # Verify all expected parts are in the message - for expected_part in expected_in_message: - assert expected_part in actual_message + # The new implementation sends the response as-is, without formatting + assert actual_message == response @pytest.mark.parametrize( @@ -612,7 +600,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me ("cancelled", False), ], ) -@patch("memory.workers.tasks.scheduled_calls.llms.summarize") +@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled") def test_execute_scheduled_call_status_check( mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user ):