mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
proactive stuff
This commit is contained in:
parent
47180e1e71
commit
f042f9aed8
@ -0,0 +1,45 @@
|
||||
"""Add proactive check-in fields to Discord entities
|
||||
|
||||
Revision ID: e1f2a3b4c5d6
|
||||
Revises: d0e1f2a3b4c5
|
||||
Create Date: 2025-12-24 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e1f2a3b4c5d6"
|
||||
down_revision: Union[str, None] = "d0e1f2a3b4c5"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add proactive fields to all MessageProcessor tables
|
||||
for table in ["discord_servers", "discord_channels", "discord_users"]:
|
||||
op.add_column(
|
||||
table,
|
||||
sa.Column("proactive_cron", sa.Text(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
table,
|
||||
sa.Column("proactive_prompt", sa.Text(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
table,
|
||||
sa.Column(
|
||||
"last_proactive_at", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ["discord_servers", "discord_channels", "discord_users"]:
|
||||
op.drop_column(table, "last_proactive_at")
|
||||
op.drop_column(table, "proactive_prompt")
|
||||
op.drop_column(table, "proactive_cron")
|
||||
@ -1,7 +1,8 @@
|
||||
fastapi==0.112.2
|
||||
fastapi>=0.115.12
|
||||
uvicorn==0.29.0
|
||||
python-jose==3.3.0
|
||||
python-multipart==0.0.9
|
||||
sqladmin==0.20.1
|
||||
mcp==1.10.0
|
||||
fastmcp>=2.10.0
|
||||
slowapi==0.1.9
|
||||
@ -1,14 +1,15 @@
|
||||
sqlalchemy==2.0.30
|
||||
psycopg2-binary==2.9.9
|
||||
pydantic==2.7.2
|
||||
pydantic>=2.11.7
|
||||
alembic==1.13.1
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.69.0
|
||||
openai==2.3.0
|
||||
# Pin the httpx version, as newer versions break the anthropic client
|
||||
httpx==0.27.0
|
||||
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
|
||||
httpx>=0.28.1
|
||||
celery[redis,sqs]==5.3.6
|
||||
croniter==2.0.1
|
||||
cryptography==43.0.0
|
||||
bcrypt==4.1.2
|
||||
@ -1,7 +1,9 @@
|
||||
pytest==7.4.4
|
||||
pytest-cov==4.1.0
|
||||
pytest-asyncio==0.23.0
|
||||
black==23.12.1
|
||||
mypy==1.8.0
|
||||
isort==5.13.2
|
||||
isort==5.13.2
|
||||
testcontainers[qdrant]==4.10.0
|
||||
click==8.1.7
|
||||
click==8.1.7
|
||||
croniter==2.0.1
|
||||
@ -1,8 +1,16 @@
|
||||
import memory.api.MCP.tools
|
||||
import memory.api.MCP.memory
|
||||
import memory.api.MCP.metadata
|
||||
import memory.api.MCP.schedules
|
||||
import memory.api.MCP.books
|
||||
import memory.api.MCP.manifest
|
||||
import memory.api.MCP.github
|
||||
import memory.api.MCP.people
|
||||
"""
|
||||
MCP server with composed subservers.
|
||||
|
||||
Subservers are mounted with prefixes:
|
||||
- core: search_knowledge_base, observe, search_observations, create_note, note_files, fetch_file
|
||||
- github: list_github_issues, search_github_issues, github_issue_details, github_work_summary, github_repo_overview
|
||||
- people: add_person, update_person_info, get_person, list_people, delete_person
|
||||
- schedule: schedule_message, list_scheduled_llm_calls, cancel_scheduled_llm_call
|
||||
- books: all_books, read_book
|
||||
- meta: get_metadata_schemas, get_all_tags, get_all_subjects, get_all_observation_types, get_current_time, get_authenticated_user, get_forecasts
|
||||
"""
|
||||
|
||||
# Import base to trigger subserver mounting
|
||||
from memory.api.MCP.base import mcp, get_current_user
|
||||
|
||||
__all__ = ["mcp", "get_current_user"]
|
||||
|
||||
@ -1,60 +1,28 @@
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import cast
|
||||
|
||||
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
||||
from mcp.server.auth.handlers.token import (
|
||||
AuthorizationCodeRequest,
|
||||
RefreshTokenRequest,
|
||||
TokenRequest,
|
||||
)
|
||||
from mcp.server.auth.middleware.auth_context import get_access_token
|
||||
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
from pydantic import AnyHttpUrl
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, RedirectResponse
|
||||
from starlette.templating import Jinja2Templates
|
||||
|
||||
from memory.api.MCP.oauth_provider import (
|
||||
ALLOWED_SCOPES,
|
||||
BASE_SCOPES,
|
||||
SimpleOAuthProvider,
|
||||
)
|
||||
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
||||
from memory.api.MCP.servers.books import books_mcp
|
||||
from memory.api.MCP.servers.core import core_mcp
|
||||
from memory.api.MCP.servers.github import github_mcp
|
||||
from memory.api.MCP.servers.meta import meta_mcp
|
||||
from memory.api.MCP.servers.meta import set_auth_provider as set_meta_auth
|
||||
from memory.api.MCP.servers.people import people_mcp
|
||||
from memory.api.MCP.servers.schedule import schedule_mcp
|
||||
from memory.api.MCP.servers.schedule import set_auth_provider as set_schedule_auth
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.connection import make_session, get_engine
|
||||
from memory.common.db.models import OAuthState, UserSession
|
||||
from memory.common.db.models.users import HumanUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_metadata(klass: type):
|
||||
orig_validate = klass.model_validate
|
||||
|
||||
def validate(data: dict):
|
||||
data = dict(data)
|
||||
if "redirect_uris" in data:
|
||||
data["redirect_uris"] = [
|
||||
str(uri).replace("cursor://", "http://")
|
||||
for uri in data["redirect_uris"]
|
||||
]
|
||||
if "redirect_uri" in data:
|
||||
data["redirect_uri"] = str(data["redirect_uri"]).replace(
|
||||
"cursor://", "http://"
|
||||
)
|
||||
|
||||
return orig_validate(data)
|
||||
|
||||
klass.model_validate = validate
|
||||
|
||||
|
||||
validate_metadata(OAuthClientMetadata)
|
||||
validate_metadata(AuthorizationRequest)
|
||||
validate_metadata(AuthorizationCodeRequest)
|
||||
validate_metadata(RefreshTokenRequest)
|
||||
validate_metadata(TokenRequest)
|
||||
engine = get_engine()
|
||||
|
||||
|
||||
# Setup templates
|
||||
@ -63,22 +31,10 @@ templates = Jinja2Templates(directory=template_dir)
|
||||
|
||||
|
||||
oauth_provider = SimpleOAuthProvider()
|
||||
auth_settings = AuthSettings(
|
||||
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
||||
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore
|
||||
client_registration_options=ClientRegistrationOptions(
|
||||
enabled=True,
|
||||
valid_scopes=ALLOWED_SCOPES,
|
||||
default_scopes=BASE_SCOPES,
|
||||
),
|
||||
required_scopes=BASE_SCOPES,
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
"memory",
|
||||
stateless_http=True,
|
||||
auth_server_provider=oauth_provider,
|
||||
auth=auth_settings,
|
||||
auth=oauth_provider,
|
||||
)
|
||||
|
||||
|
||||
@ -162,3 +118,52 @@ def get_current_user() -> dict:
|
||||
"client_id": access_token.client_id,
|
||||
"user": user_info,
|
||||
}
|
||||
|
||||
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request):
|
||||
"""Health check endpoint that verifies all dependencies are accessible."""
|
||||
from sqlalchemy import text
|
||||
|
||||
checks = {"mcp_oauth": "enabled"}
|
||||
all_healthy = True
|
||||
|
||||
# Check database connection
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
checks["database"] = "healthy"
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
checks["database"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
# Check Qdrant connection
|
||||
try:
|
||||
from memory.common.qdrant import get_qdrant_client
|
||||
|
||||
client = get_qdrant_client()
|
||||
client.get_collections()
|
||||
checks["qdrant"] = "healthy"
|
||||
except Exception as e:
|
||||
logger.error(f"Qdrant health check failed: {e}")
|
||||
checks["qdrant"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
checks["status"] = "healthy" if all_healthy else "degraded"
|
||||
status_code = 200 if all_healthy else 503
|
||||
return JSONResponse(checks, status_code=status_code)
|
||||
|
||||
|
||||
# Inject auth provider into subservers that need it
|
||||
set_schedule_auth(get_current_user)
|
||||
set_meta_auth(get_current_user)
|
||||
|
||||
# Mount all subservers onto the main MCP server
|
||||
# Tools will be prefixed with their server name (e.g., core_search_knowledge_base)
|
||||
mcp.mount(core_mcp, prefix="core")
|
||||
mcp.mount(github_mcp, prefix="github")
|
||||
mcp.mount(people_mcp, prefix="people")
|
||||
mcp.mount(schedule_mcp, prefix="schedule")
|
||||
mcp.mount(books_mcp, prefix="books")
|
||||
mcp.mount(meta_mcp, prefix="meta")
|
||||
|
||||
@ -1,119 +0,0 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from datetime import datetime
|
||||
|
||||
from typing import TypedDict, NotRequired, Literal
|
||||
from memory.api.MCP.tools import mcp
|
||||
|
||||
|
||||
class BinaryProbs(TypedDict):
|
||||
prob: float
|
||||
|
||||
|
||||
class MultiProbs(TypedDict):
|
||||
answerProbs: dict[str, float]
|
||||
|
||||
|
||||
Probs = dict[str, BinaryProbs | MultiProbs]
|
||||
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
|
||||
|
||||
|
||||
class MarketAnswer(TypedDict):
|
||||
id: str
|
||||
text: str
|
||||
resolutionProbability: float
|
||||
|
||||
|
||||
class MarketDetails(TypedDict):
|
||||
id: str
|
||||
createdTime: int
|
||||
question: str
|
||||
outcomeType: OutcomeType
|
||||
textDescription: str
|
||||
groupSlugs: list[str]
|
||||
volume: float
|
||||
isResolved: bool
|
||||
answers: list[MarketAnswer]
|
||||
|
||||
|
||||
class Market(TypedDict):
|
||||
id: str
|
||||
url: str
|
||||
question: str
|
||||
volume: int
|
||||
createdTime: int
|
||||
outcomeType: OutcomeType
|
||||
createdAt: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
answers: NotRequired[dict[str, float]]
|
||||
probability: NotRequired[float]
|
||||
details: NotRequired[MarketDetails]
|
||||
|
||||
|
||||
async def get_details(session: aiohttp.ClientSession, market_id: str):
|
||||
async with session.get(
|
||||
f"https://api.manifold.markets/v0/market/{market_id}"
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
|
||||
async def format_market(session: aiohttp.ClientSession, market: Market):
|
||||
if market.get("outcomeType") != "BINARY":
|
||||
details = await get_details(session, market["id"])
|
||||
market["answers"] = {
|
||||
answer["text"]: round(
|
||||
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
|
||||
)
|
||||
for answer in details["answers"]
|
||||
}
|
||||
if creationTime := market.get("createdTime"):
|
||||
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"url",
|
||||
"question",
|
||||
"volume",
|
||||
"createdAt",
|
||||
"details",
|
||||
"probability",
|
||||
"answers",
|
||||
]
|
||||
return {k: v for k, v in market.items() if k in fields}
|
||||
|
||||
|
||||
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.manifold.markets/v0/search-markets",
|
||||
params={
|
||||
"term": term,
|
||||
"contractType": "BINARY" if binary else "ALL",
|
||||
},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
markets = await resp.json()
|
||||
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
format_market(session, market)
|
||||
for market in markets
|
||||
if market.get("volume", 0) >= min_volume
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_forecasts(
|
||||
term: str, min_volume: int = 1000, binary: bool = False
|
||||
) -> list[dict]:
|
||||
"""Get prediction market forecasts for a given term.
|
||||
|
||||
Args:
|
||||
term: The term to search for.
|
||||
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
|
||||
binary: Whether to only return binary markets.
|
||||
"""
|
||||
return await search_markets(term, min_volume, binary)
|
||||
@ -1,119 +0,0 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, TypedDict, get_args, get_type_hints
|
||||
|
||||
from memory.common import qdrant
|
||||
from sqlalchemy import func
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.common.db.models.source_items import AgentObservation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchemaArg(TypedDict):
|
||||
type: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
class CollectionMetadata(TypedDict):
|
||||
schema: dict[str, SchemaArg]
|
||||
size: int
|
||||
|
||||
|
||||
def from_annotation(annotation: Annotated) -> SchemaArg | None:
|
||||
try:
|
||||
type_, description = get_args(annotation)
|
||||
type_str = str(type_)
|
||||
if type_str.startswith("typing."):
|
||||
type_str = type_str[7:]
|
||||
elif len((parts := type_str.split("'"))) > 1:
|
||||
type_str = parts[1]
|
||||
return SchemaArg(type=type_str, description=description)
|
||||
except IndexError:
|
||||
logger.error(f"Error from annotation: {annotation}")
|
||||
return None
|
||||
|
||||
|
||||
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
|
||||
if not hasattr(klass, "as_payload"):
|
||||
return {}
|
||||
|
||||
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
|
||||
return {}
|
||||
|
||||
return {
|
||||
name: schema
|
||||
for name, arg in payload_type.__annotations__.items()
|
||||
if (schema := from_annotation(arg))
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
|
||||
"""Get the metadata schema for each collection used in the knowledge base.
|
||||
|
||||
These schemas can be used to filter the knowledge base.
|
||||
|
||||
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
|
||||
|
||||
Example:
|
||||
```
|
||||
{
|
||||
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
|
||||
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
|
||||
}
|
||||
"""
|
||||
client = qdrant.get_qdrant_client()
|
||||
sizes = qdrant.get_collection_sizes(client)
|
||||
schemas = defaultdict(dict)
|
||||
for klass in SourceItem.__subclasses__():
|
||||
for collection in klass.get_collections():
|
||||
schemas[collection].update(get_schema(klass))
|
||||
|
||||
return {
|
||||
collection: CollectionMetadata(schema=schema, size=size)
|
||||
for collection, schema in schemas.items()
|
||||
if (size := sizes.get(collection))
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""Get all unique tags used across the entire knowledge base.
|
||||
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""Get all unique subjects from observations about the user.
|
||||
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""Get all observation types that have been used.
|
||||
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
@ -5,11 +5,12 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from fastmcp.server.auth import OAuthProvider
|
||||
from fastmcp.server.auth.auth import AccessToken as FastMCPAccessToken
|
||||
from mcp.server.auth.provider import (
|
||||
AccessToken,
|
||||
AuthorizationCode,
|
||||
AuthorizationParams,
|
||||
OAuthAuthorizationServerProvider,
|
||||
RefreshToken,
|
||||
)
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
@ -133,8 +134,45 @@ def make_token(
|
||||
)
|
||||
|
||||
|
||||
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
class SimpleOAuthProvider(OAuthProvider):
|
||||
"""OAuth provider that extends fastmcp's OAuthProvider with custom login flow."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
base_url=settings.SERVER_URL,
|
||||
issuer_url=settings.SERVER_URL,
|
||||
client_registration_options=None, # We handle registration ourselves
|
||||
required_scopes=BASE_SCOPES,
|
||||
)
|
||||
|
||||
async def verify_token(self, token: str) -> FastMCPAccessToken | None:
|
||||
"""Verify an access token and return token info if valid."""
|
||||
with make_session() as session:
|
||||
# Try as OAuth access token first
|
||||
user_session = session.query(UserSession).get(token)
|
||||
if user_session:
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
if user_session.expires_at < now:
|
||||
return None
|
||||
return FastMCPAccessToken(
|
||||
token=token,
|
||||
client_id=user_session.oauth_state.client_id,
|
||||
scopes=user_session.oauth_state.scopes,
|
||||
)
|
||||
|
||||
# Try as bot API key
|
||||
bot = session.query(User).filter(User.api_key == token).first()
|
||||
if bot:
|
||||
logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key")
|
||||
return FastMCPAccessToken(
|
||||
token=token,
|
||||
client_id=cast(str, bot.name or bot.email),
|
||||
scopes=["read", "write"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
"""Get OAuth client information."""
|
||||
with make_session() as session:
|
||||
client = session.get(OAuthClientInformation, client_id)
|
||||
|
||||
17
src/memory/api/MCP/servers/__init__.py
Normal file
17
src/memory/api/MCP/servers/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""MCP subservers for composable tool organization."""
|
||||
|
||||
from memory.api.MCP.servers.core import core_mcp
|
||||
from memory.api.MCP.servers.github import github_mcp
|
||||
from memory.api.MCP.servers.people import people_mcp
|
||||
from memory.api.MCP.servers.schedule import schedule_mcp
|
||||
from memory.api.MCP.servers.books import books_mcp
|
||||
from memory.api.MCP.servers.meta import meta_mcp
|
||||
|
||||
__all__ = [
|
||||
"core_mcp",
|
||||
"github_mcp",
|
||||
"people_mcp",
|
||||
"schedule_mcp",
|
||||
"books_mcp",
|
||||
"meta_mcp",
|
||||
]
|
||||
@ -1,15 +1,19 @@
|
||||
"""MCP subserver for ebook access."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Book, BookSection, BookSectionPayload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
books_mcp = FastMCP("memory-books")
|
||||
|
||||
@mcp.tool()
|
||||
|
||||
@books_mcp.tool()
|
||||
async def all_books(sections: bool = False) -> list[dict]:
|
||||
"""
|
||||
Get all books in the database.
|
||||
@ -31,7 +35,7 @@ async def all_books(sections: bool = False) -> list[dict]:
|
||||
return [book.as_payload(sections=sections) for book in books]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@books_mcp.tool()
|
||||
def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]:
|
||||
"""
|
||||
Read a book from the database.
|
||||
@ -1,53 +1,45 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
Core MCP subserver for knowledge base search, observations, and notes.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
from PIL import Image
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.api.search.search import search
|
||||
from memory.api.search.types import SearchFilters, SearchConfig
|
||||
from memory.api.search.types import SearchConfig, SearchFilters
|
||||
from memory.common import extract, settings
|
||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||
from memory.common.celery_app import app as celery_app
|
||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import SourceItem, AgentObservation
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.common.formatters import observation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
core_mcp = FastMCP("memory-core")
|
||||
|
||||
|
||||
def validate_path_within_directory(
|
||||
base_dir: pathlib.Path, requested_path: str
|
||||
) -> pathlib.Path:
|
||||
"""Validate that a requested path resolves within the base directory.
|
||||
|
||||
Prevents path traversal attacks using ../ or similar techniques.
|
||||
|
||||
Args:
|
||||
base_dir: The allowed base directory
|
||||
requested_path: The user-provided path
|
||||
|
||||
Returns:
|
||||
The resolved absolute path if valid
|
||||
|
||||
Raises:
|
||||
ValueError: If the path would escape the base directory
|
||||
"""
|
||||
"""Validate that a requested path resolves within the base directory."""
|
||||
resolved = (base_dir / requested_path.lstrip("/")).resolve()
|
||||
base_resolved = base_dir.resolve()
|
||||
|
||||
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
||||
if (
|
||||
not str(resolved).startswith(str(base_resolved) + "/")
|
||||
and resolved != base_resolved
|
||||
):
|
||||
raise ValueError(f"Path escapes allowed directory: {requested_path}")
|
||||
|
||||
return resolved
|
||||
@ -63,7 +55,6 @@ def filter_observation_source_ids(
|
||||
items_query = session.query(AgentObservation.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
@ -89,7 +80,6 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
||||
items_query = session.query(SourceItem.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
@ -102,7 +92,7 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
||||
return source_ids
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str,
|
||||
filters: SearchFilters = {},
|
||||
@ -141,7 +131,6 @@ async def search_knowledge_base(
|
||||
|
||||
if not modalities:
|
||||
modalities = set(ALL_COLLECTIONS.keys())
|
||||
# Filter to valid collections, excluding observation collections
|
||||
modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS
|
||||
|
||||
search_filters = SearchFilters(**filters)
|
||||
@ -167,7 +156,7 @@ class RawObservation(BaseModel):
|
||||
tags: list[str] = []
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def observe(
|
||||
observations: list[RawObservation],
|
||||
session_id: str | None = None,
|
||||
@ -212,23 +201,23 @@ async def observe(
|
||||
logger.info("MCP: Observing")
|
||||
tasks = [
|
||||
(
|
||||
observation,
|
||||
obs,
|
||||
celery_app.send_task(
|
||||
SYNC_OBSERVATION,
|
||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||
kwargs={
|
||||
"subject": observation.subject,
|
||||
"content": observation.content,
|
||||
"observation_type": observation.observation_type,
|
||||
"confidences": observation.confidences,
|
||||
"evidence": observation.evidence,
|
||||
"tags": observation.tags,
|
||||
"subject": obs.subject,
|
||||
"content": obs.content,
|
||||
"observation_type": obs.observation_type,
|
||||
"confidences": obs.confidences,
|
||||
"evidence": obs.evidence,
|
||||
"tags": obs.tags,
|
||||
"session_id": session_id,
|
||||
"agent_model": agent_model,
|
||||
},
|
||||
),
|
||||
)
|
||||
for observation in observations
|
||||
for obs in observations
|
||||
]
|
||||
|
||||
def short_content(obs: RawObservation) -> str:
|
||||
@ -242,7 +231,7 @@ async def observe(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def search_observations(
|
||||
query: str,
|
||||
subject: str = "",
|
||||
@ -308,7 +297,7 @@ async def search_observations(
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def create_note(
|
||||
subject: str,
|
||||
content: str,
|
||||
@ -362,7 +351,7 @@ async def create_note(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def note_files(path: str = "/"):
|
||||
"""
|
||||
List note files in the user's note storage.
|
||||
@ -386,7 +375,7 @@ async def note_files(path: str = "/"):
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
def fetch_file(filename: str) -> dict:
|
||||
"""
|
||||
Read file content with automatic type detection.
|
||||
@ -1,14 +1,14 @@
|
||||
"""MCP tools for GitHub issue tracking and management."""
|
||||
"""MCP subserver for GitHub issue tracking and management."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Text, case, desc, func, asc
|
||||
from fastmcp import FastMCP
|
||||
from sqlalchemy import Text, case, desc, func
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.api.search.search import search
|
||||
from memory.api.search.types import SearchConfig, SearchFilters
|
||||
from memory.common import extract
|
||||
@ -17,6 +17,8 @@ from memory.common.db.models import GithubItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
github_mcp = FastMCP("memory-github")
|
||||
|
||||
|
||||
def _build_github_url(repo_path: str, number: int | None, kind: str) -> str:
|
||||
"""Build GitHub URL from repo path and issue/PR number."""
|
||||
@ -53,7 +55,6 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
|
||||
}
|
||||
if include_content:
|
||||
result["content"] = item.content
|
||||
# Include PR-specific data if available
|
||||
if item.kind == "pr" and item.pr_data:
|
||||
result["pr_data"] = {
|
||||
"additions": item.pr_data.additions,
|
||||
@ -62,12 +63,12 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
|
||||
"files": item.pr_data.files,
|
||||
"reviews": item.pr_data.reviews,
|
||||
"review_comments": item.pr_data.review_comments,
|
||||
"diff": item.pr_data.diff, # Decompressed via property
|
||||
"diff": item.pr_data.diff,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def list_github_issues(
|
||||
repo: str | None = None,
|
||||
assignee: str | None = None,
|
||||
@ -109,64 +110,48 @@ async def list_github_issues(
|
||||
with make_session() as session:
|
||||
query = session.query(GithubItem)
|
||||
|
||||
# Apply filters
|
||||
if repo:
|
||||
query = query.filter(GithubItem.repo_path == repo)
|
||||
|
||||
if assignee:
|
||||
query = query.filter(GithubItem.assignees.any(assignee))
|
||||
|
||||
if author:
|
||||
query = query.filter(GithubItem.author == author)
|
||||
|
||||
if state:
|
||||
query = query.filter(GithubItem.state == state)
|
||||
|
||||
if kind:
|
||||
query = query.filter(GithubItem.kind == kind)
|
||||
else:
|
||||
# Exclude comments by default, only show issues and PRs
|
||||
query = query.filter(GithubItem.kind.in_(["issue", "pr"]))
|
||||
|
||||
if labels:
|
||||
# Match any label in the list using PostgreSQL array overlap
|
||||
query = query.filter(
|
||||
GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text)))
|
||||
)
|
||||
|
||||
if project_status:
|
||||
query = query.filter(GithubItem.project_status == project_status)
|
||||
|
||||
if project_field:
|
||||
for key, value in project_field.items():
|
||||
query = query.filter(
|
||||
GithubItem.project_fields[key].astext == value
|
||||
)
|
||||
|
||||
query = query.filter(GithubItem.project_fields[key].astext == value)
|
||||
if updated_since:
|
||||
since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00"))
|
||||
query = query.filter(GithubItem.github_updated_at >= since_dt)
|
||||
|
||||
if updated_before:
|
||||
before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00"))
|
||||
query = query.filter(GithubItem.github_updated_at <= before_dt)
|
||||
|
||||
# Apply ordering
|
||||
if order_by == "created":
|
||||
query = query.order_by(desc(GithubItem.created_at))
|
||||
elif order_by == "number":
|
||||
query = query.order_by(desc(GithubItem.number))
|
||||
else: # default: updated
|
||||
else:
|
||||
query = query.order_by(desc(GithubItem.github_updated_at))
|
||||
|
||||
query = query.limit(limit)
|
||||
|
||||
items = query.all()
|
||||
|
||||
return [_serialize_issue(item) for item in items]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def search_github_issues(
|
||||
query: str,
|
||||
repo: str | None = None,
|
||||
@ -191,7 +176,6 @@ async def search_github_issues(
|
||||
|
||||
limit = min(limit, 100)
|
||||
|
||||
# Pre-filter source_ids if repo/state/kind filters are specified
|
||||
source_ids = None
|
||||
if repo or state or kind:
|
||||
with make_session() as session:
|
||||
@ -206,7 +190,6 @@ async def search_github_issues(
|
||||
q = q.filter(GithubItem.kind.in_(["issue", "pr"]))
|
||||
source_ids = [item.id for item in q.all()]
|
||||
|
||||
# Use the existing search infrastructure
|
||||
data = extract.extract_text(query, skip_summary=True)
|
||||
config = SearchConfig(limit=limit, previews=True)
|
||||
filters = SearchFilters()
|
||||
@ -220,7 +203,6 @@ async def search_github_issues(
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Fetch full issue details for the results
|
||||
output = []
|
||||
with make_session() as session:
|
||||
for result in results:
|
||||
@ -233,7 +215,7 @@ async def search_github_issues(
|
||||
return output
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def github_issue_details(
|
||||
repo: str,
|
||||
number: int,
|
||||
@ -267,7 +249,7 @@ async def github_issue_details(
|
||||
return _serialize_issue(item, include_content=True)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def github_work_summary(
|
||||
since: str,
|
||||
until: str | None = None,
|
||||
@ -294,7 +276,6 @@ async def github_work_summary(
|
||||
else:
|
||||
until_dt = datetime.now(timezone.utc)
|
||||
|
||||
# Map group_by to SQL expression
|
||||
group_mappings = {
|
||||
"client": GithubItem.project_fields["EquiStamp.Client"].astext,
|
||||
"status": GithubItem.project_status,
|
||||
@ -311,7 +292,6 @@ async def github_work_summary(
|
||||
group_col = group_mappings[group_by]
|
||||
|
||||
with make_session() as session:
|
||||
# Build base query for the period
|
||||
base_query = session.query(GithubItem).filter(
|
||||
GithubItem.github_updated_at >= since_dt,
|
||||
GithubItem.github_updated_at <= until_dt,
|
||||
@ -321,7 +301,6 @@ async def github_work_summary(
|
||||
if repo:
|
||||
base_query = base_query.filter(GithubItem.repo_path == repo)
|
||||
|
||||
# Get aggregated counts by group
|
||||
agg_query = (
|
||||
session.query(
|
||||
group_col.label("group_name"),
|
||||
@ -346,7 +325,6 @@ async def github_work_summary(
|
||||
|
||||
groups = agg_query.all()
|
||||
|
||||
# Build summary with sample issues for each group
|
||||
summary = []
|
||||
total_issues = 0
|
||||
total_prs = 0
|
||||
@ -358,7 +336,6 @@ async def github_work_summary(
|
||||
total_issues += issue_count
|
||||
total_prs += pr_count
|
||||
|
||||
# Get sample issues for this group
|
||||
sample_query = base_query.filter(group_col == group_name).limit(5)
|
||||
samples = [
|
||||
{
|
||||
@ -394,7 +371,7 @@ async def github_work_summary(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def github_repo_overview(
|
||||
repo: str,
|
||||
) -> dict:
|
||||
@ -410,13 +387,6 @@ async def github_repo_overview(
|
||||
logger.info(f"github_repo_overview called: repo={repo}")
|
||||
|
||||
with make_session() as session:
|
||||
# Base query for this repo
|
||||
base_query = session.query(GithubItem).filter(
|
||||
GithubItem.repo_path == repo,
|
||||
GithubItem.kind.in_(["issue", "pr"]),
|
||||
)
|
||||
|
||||
# Get total counts
|
||||
counts_query = session.query(
|
||||
func.count(GithubItem.id).label("total"),
|
||||
func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"),
|
||||
@ -441,7 +411,6 @@ async def github_repo_overview(
|
||||
|
||||
counts = counts_query.first()
|
||||
|
||||
# Status breakdown (for project_status)
|
||||
status_query = (
|
||||
session.query(
|
||||
GithubItem.project_status.label("status"),
|
||||
@ -458,7 +427,6 @@ async def github_repo_overview(
|
||||
|
||||
status_breakdown = {row.status: row.count for row in status_query.all()}
|
||||
|
||||
# Top assignees (open issues only)
|
||||
assignee_query = (
|
||||
session.query(
|
||||
func.unnest(GithubItem.assignees).label("assignee"),
|
||||
@ -479,7 +447,6 @@ async def github_repo_overview(
|
||||
for row in assignee_query.all()
|
||||
]
|
||||
|
||||
# Label counts
|
||||
label_query = (
|
||||
session.query(
|
||||
func.unnest(GithubItem.labels).label("label"),
|
||||
277
src/memory/api/MCP/servers/meta.py
Normal file
277
src/memory/api/MCP/servers/meta.py
Normal file
@ -0,0 +1,277 @@
|
||||
"""MCP subserver for metadata, utilities, and forecasting."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, NotRequired, TypedDict, get_args, get_type_hints
|
||||
|
||||
import aiohttp
|
||||
from fastmcp import FastMCP
|
||||
from sqlalchemy import func
|
||||
|
||||
from memory.common import qdrant
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.common.db.models.source_items import AgentObservation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
meta_mcp = FastMCP("memory-meta")
|
||||
|
||||
# Auth provider will be injected at mount time
|
||||
_get_current_user = None
|
||||
|
||||
|
||||
def set_auth_provider(get_current_user_func):
|
||||
"""Set the authentication provider function."""
|
||||
global _get_current_user
|
||||
_get_current_user = get_current_user_func
|
||||
|
||||
|
||||
def get_current_user() -> dict:
|
||||
"""Get the current authenticated user."""
|
||||
if _get_current_user is None:
|
||||
return {"authenticated": False, "error": "Auth provider not configured"}
|
||||
return _get_current_user()
|
||||
|
||||
|
||||
# --- Metadata tools ---
|
||||
|
||||
|
||||
class SchemaArg(TypedDict):
|
||||
type: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
class CollectionMetadata(TypedDict):
|
||||
schema: dict[str, SchemaArg]
|
||||
size: int
|
||||
|
||||
|
||||
def from_annotation(annotation: Annotated) -> SchemaArg | None:
|
||||
try:
|
||||
type_, description = get_args(annotation)
|
||||
type_str = str(type_)
|
||||
if type_str.startswith("typing."):
|
||||
type_str = type_str[7:]
|
||||
elif len((parts := type_str.split("'"))) > 1:
|
||||
type_str = parts[1]
|
||||
return SchemaArg(type=type_str, description=description)
|
||||
except IndexError:
|
||||
logger.error(f"Error from annotation: {annotation}")
|
||||
return None
|
||||
|
||||
|
||||
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
|
||||
if not hasattr(klass, "as_payload"):
|
||||
return {}
|
||||
|
||||
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
|
||||
return {}
|
||||
|
||||
return {
|
||||
name: schema
|
||||
for name, arg in payload_type.__annotations__.items()
|
||||
if (schema := from_annotation(arg))
|
||||
}
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
|
||||
"""Get the metadata schema for each collection used in the knowledge base.
|
||||
|
||||
These schemas can be used to filter the knowledge base.
|
||||
|
||||
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
|
||||
|
||||
Example:
|
||||
```
|
||||
{
|
||||
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
|
||||
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
|
||||
}
|
||||
"""
|
||||
client = qdrant.get_qdrant_client()
|
||||
sizes = qdrant.get_collection_sizes(client)
|
||||
schemas = defaultdict(dict)
|
||||
for klass in SourceItem.__subclasses__():
|
||||
for collection in klass.get_collections():
|
||||
schemas[collection].update(get_schema(klass))
|
||||
|
||||
return {
|
||||
collection: CollectionMetadata(schema=schema, size=size)
|
||||
for collection, schema in schemas.items()
|
||||
if (size := sizes.get(collection))
|
||||
}
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""Get all unique tags used across the entire knowledge base.
|
||||
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""Get all unique subjects from observations about the user.
|
||||
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""Get all observation types that have been used.
|
||||
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# --- Utility tools ---
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_current_time() -> dict:
|
||||
"""Get the current time in UTC."""
|
||||
logger.info("get_current_time tool called")
|
||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_authenticated_user() -> dict:
|
||||
"""Get information about the authenticated user."""
|
||||
return get_current_user()
|
||||
|
||||
|
||||
# --- Forecasting tools ---
|
||||
|
||||
|
||||
class BinaryProbs(TypedDict):
|
||||
prob: float
|
||||
|
||||
|
||||
class MultiProbs(TypedDict):
|
||||
answerProbs: dict[str, float]
|
||||
|
||||
|
||||
Probs = dict[str, BinaryProbs | MultiProbs]
|
||||
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
|
||||
|
||||
|
||||
class MarketAnswer(TypedDict):
|
||||
id: str
|
||||
text: str
|
||||
resolutionProbability: float
|
||||
|
||||
|
||||
class MarketDetails(TypedDict):
|
||||
id: str
|
||||
createdTime: int
|
||||
question: str
|
||||
outcomeType: OutcomeType
|
||||
textDescription: str
|
||||
groupSlugs: list[str]
|
||||
volume: float
|
||||
isResolved: bool
|
||||
answers: list[MarketAnswer]
|
||||
|
||||
|
||||
class Market(TypedDict):
|
||||
id: str
|
||||
url: str
|
||||
question: str
|
||||
volume: int
|
||||
createdTime: int
|
||||
outcomeType: OutcomeType
|
||||
createdAt: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
answers: NotRequired[dict[str, float]]
|
||||
probability: NotRequired[float]
|
||||
details: NotRequired[MarketDetails]
|
||||
|
||||
|
||||
async def get_details(session: aiohttp.ClientSession, market_id: str):
|
||||
async with session.get(
|
||||
f"https://api.manifold.markets/v0/market/{market_id}"
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
|
||||
async def format_market(session: aiohttp.ClientSession, market: Market):
|
||||
if market.get("outcomeType") != "BINARY":
|
||||
details = await get_details(session, market["id"])
|
||||
market["answers"] = {
|
||||
answer["text"]: round(
|
||||
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
|
||||
)
|
||||
for answer in details["answers"]
|
||||
}
|
||||
if creationTime := market.get("createdTime"):
|
||||
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"url",
|
||||
"question",
|
||||
"volume",
|
||||
"createdAt",
|
||||
"details",
|
||||
"probability",
|
||||
"answers",
|
||||
]
|
||||
return {k: v for k, v in market.items() if k in fields}
|
||||
|
||||
|
||||
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.manifold.markets/v0/search-markets",
|
||||
params={
|
||||
"term": term,
|
||||
"contractType": "BINARY" if binary else "ALL",
|
||||
},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
markets = await resp.json()
|
||||
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
format_market(session, market)
|
||||
for market in markets
|
||||
if market.get("volume", 0) >= min_volume
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_forecasts(
|
||||
term: str, min_volume: int = 1000, binary: bool = False
|
||||
) -> list[dict]:
|
||||
"""Get prediction market forecasts for a given term.
|
||||
|
||||
Args:
|
||||
term: The term to search for.
|
||||
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
|
||||
binary: Whether to only return binary markets.
|
||||
"""
|
||||
return await search_markets(term, min_volume, binary)
|
||||
@ -1,23 +1,23 @@
|
||||
"""
|
||||
MCP tools for tracking people.
|
||||
"""
|
||||
"""MCP subserver for tracking people."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Person
|
||||
from memory.common import settings
|
||||
from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON
|
||||
from memory.common.celery_app import app as celery_app
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Person
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
people_mcp = FastMCP("memory-people")
|
||||
|
||||
|
||||
def _person_to_dict(person: Person) -> dict[str, Any]:
|
||||
"""Convert a Person model to a dictionary for API responses."""
|
||||
@ -32,7 +32,7 @@ def _person_to_dict(person: Person) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def add_person(
|
||||
identifier: str,
|
||||
display_name: str,
|
||||
@ -67,7 +67,6 @@ async def add_person(
|
||||
"""
|
||||
logger.info(f"MCP: Adding person: {identifier}")
|
||||
|
||||
# Check if person already exists
|
||||
with make_session() as session:
|
||||
existing = session.query(Person).filter(Person.identifier == identifier).first()
|
||||
if existing:
|
||||
@ -93,7 +92,7 @@ async def add_person(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def update_person_info(
|
||||
identifier: str,
|
||||
display_name: str | None = None,
|
||||
@ -135,7 +134,6 @@ async def update_person_info(
|
||||
"""
|
||||
logger.info(f"MCP: Updating person: {identifier}")
|
||||
|
||||
# Verify person exists
|
||||
with make_session() as session:
|
||||
person = session.query(Person).filter(Person.identifier == identifier).first()
|
||||
if not person:
|
||||
@ -162,7 +160,7 @@ async def update_person_info(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def get_person(identifier: str) -> dict | None:
|
||||
"""
|
||||
Get a person by their identifier.
|
||||
@ -182,7 +180,7 @@ async def get_person(identifier: str) -> dict | None:
|
||||
return _person_to_dict(person)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def list_people(
|
||||
tags: list[str] | None = None,
|
||||
search: str | None = None,
|
||||
@ -207,9 +205,7 @@ async def list_people(
|
||||
query = session.query(Person)
|
||||
|
||||
if tags:
|
||||
query = query.filter(
|
||||
Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
|
||||
)
|
||||
query = query.filter(Person.tags.op("&&")(sql_cast(tags, ARRAY(Text))))
|
||||
|
||||
if search:
|
||||
search_term = f"%{search.lower()}%"
|
||||
@ -225,7 +221,7 @@ async def list_people(
|
||||
return [_person_to_dict(p) for p in people]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def delete_person(identifier: str) -> dict:
|
||||
"""
|
||||
Delete a person by their identifier.
|
||||
@ -1,20 +1,37 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
"""MCP subserver for scheduling messages and LLM calls."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from memory.api.MCP.base import get_current_user, mcp
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
|
||||
from memory.common.db.models import DiscordBotUser, ScheduledLLMCall
|
||||
from memory.discord.messages import schedule_discord_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
schedule_mcp = FastMCP("memory-schedule")
|
||||
|
||||
@mcp.tool()
|
||||
# We need access to get_current_user from base - this will be injected at mount time
|
||||
_get_current_user = None
|
||||
|
||||
|
||||
def set_auth_provider(get_current_user_func):
|
||||
"""Set the authentication provider function."""
|
||||
global _get_current_user
|
||||
_get_current_user = get_current_user_func
|
||||
|
||||
|
||||
def get_current_user() -> dict:
|
||||
"""Get the current authenticated user."""
|
||||
if _get_current_user is None:
|
||||
return {"authenticated": False, "error": "Auth provider not configured"}
|
||||
return _get_current_user()
|
||||
|
||||
|
||||
@schedule_mcp.tool()
|
||||
async def schedule_message(
|
||||
scheduled_time: str,
|
||||
message: str,
|
||||
@ -61,10 +78,8 @@ async def schedule_message(
|
||||
if not discord_user and not discord_channel:
|
||||
raise ValueError("Either discord_user or discord_channel must be provided")
|
||||
|
||||
# Parse scheduled time
|
||||
try:
|
||||
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
|
||||
# Ensure we store as naive datetime (remove timezone info for database storage)
|
||||
if scheduled_dt.tzinfo is not None:
|
||||
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
except ValueError:
|
||||
@ -98,7 +113,7 @@ async def schedule_message(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@schedule_mcp.tool()
|
||||
async def list_scheduled_llm_calls(
|
||||
status: str | None = None, limit: int | None = 50
|
||||
) -> dict[str, Any]:
|
||||
@ -143,7 +158,7 @@ async def list_scheduled_llm_calls(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@schedule_mcp.tool()
|
||||
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Cancel a scheduled LLM call.
|
||||
@ -164,7 +179,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
||||
return {"error": "User not found", "user": current_user}
|
||||
|
||||
with make_session() as session:
|
||||
# Find the scheduled call
|
||||
scheduled_call = (
|
||||
session.query(ScheduledLLMCall)
|
||||
.filter(
|
||||
@ -180,7 +194,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
||||
if not scheduled_call.can_be_cancelled():
|
||||
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
|
||||
|
||||
# Update the status
|
||||
scheduled_call.status = "cancelled"
|
||||
session.commit()
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.api.MCP.base import mcp, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_observation_source_ids(
|
||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||
):
|
||||
if not tags and not observation_types:
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
items_query = session.query(AgentObservation.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if observation_types:
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.observation_type.in_(observation_types)
|
||||
)
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
|
||||
return source_ids
|
||||
|
||||
|
||||
def filter_source_ids(
|
||||
modalities: set[str],
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
if not tags:
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
items_query = session.query(SourceItem.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if modalities:
|
||||
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
|
||||
return source_ids
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_current_time() -> dict:
|
||||
"""Get the current time in UTC."""
|
||||
logger.info("get_current_time tool called")
|
||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_authenticated_user() -> dict:
|
||||
"""Get information about the authenticated user."""
|
||||
return get_current_user()
|
||||
@ -2,7 +2,6 @@
|
||||
FastAPI application for the knowledge base.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import logging
|
||||
import mimetypes
|
||||
@ -35,15 +34,10 @@ limiter = Limiter(
|
||||
enabled=settings.API_RATE_LIMIT_ENABLED,
|
||||
)
|
||||
|
||||
# Create the MCP http app to get its lifespan
|
||||
mcp_http_app = mcp.http_app(stateless_http=True)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
await stack.enter_async_context(mcp.session_manager.run())
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=mcp_http_app.lifespan)
|
||||
app.state.limiter = limiter
|
||||
|
||||
# Rate limit exception handler
|
||||
@ -158,46 +152,9 @@ app.include_router(auth_router)
|
||||
|
||||
|
||||
# Add health check to MCP server instead of main app
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request):
|
||||
"""Health check endpoint that verifies all dependencies are accessible."""
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import text
|
||||
|
||||
checks = {"mcp_oauth": "enabled"}
|
||||
all_healthy = True
|
||||
|
||||
# Check database connection
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
checks["database"] = "healthy"
|
||||
except Exception as e:
|
||||
# Log error details but don't expose in response
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
checks["database"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
# Check Qdrant connection
|
||||
try:
|
||||
from memory.common.qdrant import get_qdrant_client
|
||||
|
||||
client = get_qdrant_client()
|
||||
client.get_collections()
|
||||
checks["qdrant"] = "healthy"
|
||||
except Exception as e:
|
||||
# Log error details but don't expose in response
|
||||
logger.error(f"Qdrant health check failed: {e}")
|
||||
checks["qdrant"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
checks["status"] = "healthy" if all_healthy else "degraded"
|
||||
status_code = 200 if all_healthy else 503
|
||||
return JSONResponse(checks, status_code=status_code)
|
||||
|
||||
|
||||
# Mount MCP server at root - OAuth endpoints need to be at root level
|
||||
app.mount("/", mcp.streamable_http_app())
|
||||
# Health check is defined in MCP/base.py
|
||||
app.mount("/", mcp_http_app)
|
||||
|
||||
|
||||
def main(reload: bool = False):
|
||||
|
||||
@ -17,6 +17,7 @@ DISCORD_ROOT = "memory.workers.tasks.discord"
|
||||
BACKUP_ROOT = "memory.workers.tasks.backup"
|
||||
GITHUB_ROOT = "memory.workers.tasks.github"
|
||||
PEOPLE_ROOT = "memory.workers.tasks.people"
|
||||
PROACTIVE_ROOT = "memory.workers.tasks.proactive"
|
||||
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
||||
@ -73,6 +74,10 @@ SYNC_PERSON = f"{PEOPLE_ROOT}.sync_person"
|
||||
UPDATE_PERSON = f"{PEOPLE_ROOT}.update_person"
|
||||
SYNC_PROFILE_FROM_FILE = f"{PEOPLE_ROOT}.sync_profile_from_file"
|
||||
|
||||
# Proactive check-in tasks
|
||||
EVALUATE_PROACTIVE_CHECKINS = f"{PROACTIVE_ROOT}.evaluate_proactive_checkins"
|
||||
EXECUTE_PROACTIVE_CHECKIN = f"{PROACTIVE_ROOT}.execute_proactive_checkin"
|
||||
|
||||
|
||||
def get_broker_url() -> str:
|
||||
protocol = settings.CELERY_BROKER_TYPE
|
||||
@ -130,12 +135,17 @@ app.conf.update(
|
||||
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
|
||||
f"{GITHUB_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-github"},
|
||||
f"{PEOPLE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-people"},
|
||||
f"{PROACTIVE_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-discord"},
|
||||
},
|
||||
beat_schedule={
|
||||
"sync-github-repos-hourly": {
|
||||
"task": SYNC_ALL_GITHUB_REPOS,
|
||||
"schedule": crontab(minute=0), # Every hour at :00
|
||||
},
|
||||
"evaluate-proactive-checkins": {
|
||||
"task": EVALUATE_PROACTIVE_CHECKINS,
|
||||
"schedule": crontab(), # Every minute
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -63,13 +63,29 @@ class MessageProcessor:
|
||||
doc=textwrap.dedent(
|
||||
"""
|
||||
A summary of this processor, made by and for AI systems.
|
||||
|
||||
|
||||
The idea here is that AI systems can use this summary to keep notes on the given processor.
|
||||
These should automatically be injected into the context of the messages that are processed by this processor.
|
||||
These should automatically be injected into the context of the messages that are processed by this processor.
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
proactive_cron = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
doc="Cron schedule for proactive check-ins (e.g., '0 9 * * *' for 9am daily). None = disabled.",
|
||||
)
|
||||
proactive_prompt = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
doc="Custom instructions for proactive check-ins.",
|
||||
)
|
||||
last_proactive_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
doc="When the last proactive check-in was sent.",
|
||||
)
|
||||
|
||||
@property
|
||||
def entity_type(self) -> str:
|
||||
return self.__class__.__tablename__[8:-1] # type: ignore
|
||||
|
||||
@ -132,6 +132,7 @@ CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
|
||||
NOTES_SYNC_INTERVAL = int(os.getenv("NOTES_SYNC_INTERVAL", 15 * 60))
|
||||
LESSWRONG_SYNC_INTERVAL = int(os.getenv("LESSWRONG_SYNC_INTERVAL", 60 * 60 * 24))
|
||||
SCHEDULED_CALL_RUN_INTERVAL = int(os.getenv("SCHEDULED_CALL_RUN_INTERVAL", 60))
|
||||
PROACTIVE_CHECKIN_INTERVAL = int(os.getenv("PROACTIVE_CHECKIN_INTERVAL", 60))
|
||||
|
||||
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
||||
|
||||
|
||||
@ -167,6 +167,25 @@ def _create_scope_group(
|
||||
url=url and url.strip(),
|
||||
)
|
||||
|
||||
# Proactive command
|
||||
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
|
||||
@discord.app_commands.describe(
|
||||
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
|
||||
prompt="Optional custom instructions for check-ins",
|
||||
)
|
||||
async def proactive_cmd(
|
||||
interaction: discord.Interaction,
|
||||
cron: str | None = None,
|
||||
prompt: str | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_proactive,
|
||||
cron=cron and cron.strip(),
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
@ -265,6 +284,28 @@ def _create_user_scope_group(
|
||||
url=url and url.strip(),
|
||||
)
|
||||
|
||||
# Proactive command
|
||||
@group.command(name="proactive", description=f"Configure {name}'s proactive check-ins")
|
||||
@discord.app_commands.describe(
|
||||
user="Target user",
|
||||
cron="Cron schedule (e.g., '0 9 * * *' for 9am daily) or 'off' to disable",
|
||||
prompt="Optional custom instructions for check-ins",
|
||||
)
|
||||
async def proactive_cmd(
|
||||
interaction: discord.Interaction,
|
||||
user: discord.User,
|
||||
cron: str | None = None,
|
||||
prompt: str | None = None,
|
||||
):
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_proactive,
|
||||
target_user=user,
|
||||
cron=cron and cron.strip(),
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
@ -663,3 +704,68 @@ async def handle_mcp_servers(
|
||||
except Exception as exc:
|
||||
logger.error(f"Error running MCP server command: {exc}", exc_info=True)
|
||||
raise CommandError(f"Error: {exc}") from exc
|
||||
|
||||
|
||||
async def handle_proactive(
|
||||
context: CommandContext,
|
||||
*,
|
||||
cron: str | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> CommandResponse:
|
||||
"""Handle proactive check-in configuration."""
|
||||
from croniter import croniter
|
||||
|
||||
model = context.target
|
||||
|
||||
# If no arguments, show current settings
|
||||
if cron is None and prompt is None:
|
||||
current_cron = getattr(model, "proactive_cron", None)
|
||||
current_prompt = getattr(model, "proactive_prompt", None)
|
||||
|
||||
if not current_cron:
|
||||
return CommandResponse(
|
||||
content=f"Proactive check-ins are disabled for {context.display_name}."
|
||||
)
|
||||
|
||||
lines = [f"Proactive check-ins for {context.display_name}:"]
|
||||
lines.append(f" Schedule: `{current_cron}`")
|
||||
if current_prompt:
|
||||
lines.append(f" Prompt: {current_prompt}")
|
||||
return CommandResponse(content="\n".join(lines))
|
||||
|
||||
# Handle cron setting
|
||||
if cron is not None:
|
||||
if cron.lower() == "off":
|
||||
setattr(model, "proactive_cron", None)
|
||||
return CommandResponse(
|
||||
content=f"Proactive check-ins disabled for {context.display_name}."
|
||||
)
|
||||
|
||||
# Validate cron expression
|
||||
try:
|
||||
croniter(cron)
|
||||
except (ValueError, KeyError) as e:
|
||||
raise CommandError(
|
||||
f"Invalid cron expression: {cron}\n"
|
||||
"Examples:\n"
|
||||
" `0 9 * * *` - 9am daily\n"
|
||||
" `0 9,17 * * 1-5` - 9am and 5pm weekdays\n"
|
||||
" `0 */4 * * *` - every 4 hours"
|
||||
) from e
|
||||
|
||||
setattr(model, "proactive_cron", cron)
|
||||
|
||||
# Handle prompt setting
|
||||
if prompt is not None:
|
||||
setattr(model, "proactive_prompt", prompt or None)
|
||||
|
||||
# Build response
|
||||
current_cron = getattr(model, "proactive_cron", None)
|
||||
current_prompt = getattr(model, "proactive_prompt", None)
|
||||
|
||||
lines = [f"Updated proactive settings for {context.display_name}:"]
|
||||
lines.append(f" Schedule: `{current_cron}`")
|
||||
if current_prompt:
|
||||
lines.append(f" Prompt: {current_prompt}")
|
||||
|
||||
return CommandResponse(content="\n".join(lines))
|
||||
|
||||
@ -11,6 +11,7 @@ from memory.common.celery_app import (
|
||||
SYNC_LESSWRONG,
|
||||
RUN_SCHEDULED_CALLS,
|
||||
BACKUP_ALL,
|
||||
EVALUATE_PROACTIVE_CHECKINS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -53,4 +54,8 @@ app.conf.beat_schedule = {
|
||||
"task": BACKUP_ALL,
|
||||
"schedule": settings.S3_BACKUP_INTERVAL,
|
||||
},
|
||||
"evaluate-proactive-checkins": {
|
||||
"task": EVALUATE_PROACTIVE_CHECKINS,
|
||||
"schedule": settings.PROACTIVE_CHECKIN_INTERVAL,
|
||||
},
|
||||
}
|
||||
|
||||
@ -14,20 +14,24 @@ from memory.workers.tasks import (
|
||||
maintenance,
|
||||
notes,
|
||||
observations,
|
||||
people,
|
||||
proactive,
|
||||
scheduled_calls,
|
||||
) # noqa
|
||||
|
||||
__all__ = [
|
||||
"backup",
|
||||
"email",
|
||||
"comic",
|
||||
"blogs",
|
||||
"ebook",
|
||||
"comic",
|
||||
"discord",
|
||||
"ebook",
|
||||
"email",
|
||||
"forums",
|
||||
"github",
|
||||
"maintenance",
|
||||
"notes",
|
||||
"observations",
|
||||
"people",
|
||||
"proactive",
|
||||
"scheduled_calls",
|
||||
]
|
||||
|
||||
341
src/memory/workers/tasks/proactive.py
Normal file
341
src/memory/workers/tasks/proactive.py
Normal file
@ -0,0 +1,341 @@
|
||||
"""
|
||||
Celery tasks for proactive Discord check-ins.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from croniter import croniter
|
||||
from sqlalchemy import or_
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.celery_app import app
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import DiscordChannel, DiscordServer, DiscordUser
|
||||
from memory.discord.messages import call_llm, comm_channel_prompt, send_discord_response
|
||||
from memory.workers.tasks.content_processing import safe_task_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVALUATE_PROACTIVE_CHECKINS = "memory.workers.tasks.proactive.evaluate_proactive_checkins"
|
||||
EXECUTE_PROACTIVE_CHECKIN = "memory.workers.tasks.proactive.execute_proactive_checkin"
|
||||
|
||||
EntityType = Literal["user", "channel", "server"]
|
||||
|
||||
|
||||
def is_cron_due(cron_expr: str, last_run: datetime | None, now: datetime) -> bool:
|
||||
"""Check if a cron expression is due to run now.
|
||||
|
||||
Uses croniter to determine if the current time falls within the cron's schedule
|
||||
and enough time has passed since the last run.
|
||||
"""
|
||||
try:
|
||||
cron = croniter(cron_expr, now)
|
||||
# Get the previous scheduled time from now
|
||||
prev_run = cron.get_prev(datetime)
|
||||
# Get the one before that to determine the interval
|
||||
cron.get_prev(datetime)
|
||||
prev_prev_run = cron.get_current(datetime)
|
||||
|
||||
# If we haven't run since the last scheduled time, we should run
|
||||
if last_run is None:
|
||||
# Never run before - check if current time is within a minute of prev_run
|
||||
time_since_scheduled = (now - prev_run).total_seconds()
|
||||
return time_since_scheduled < 120 # Within 2 minutes of scheduled time
|
||||
|
||||
# Make sure last_run is timezone aware
|
||||
if last_run.tzinfo is None:
|
||||
last_run = last_run.replace(tzinfo=timezone.utc)
|
||||
|
||||
# We should run if last_run is before the previous scheduled time
|
||||
return last_run < prev_run
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid cron expression '{cron_expr}': {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_bot_for_entity(
|
||||
session, entity_type: EntityType, entity_id: int
|
||||
) -> DiscordUser | None:
|
||||
"""Get the bot user associated with an entity."""
|
||||
from memory.common.db.models import DiscordBotUser, DiscordMessage
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
# For servers, find a bot that has sent messages in that server
|
||||
if entity_type == "server":
|
||||
# Find bots that have interacted with this server
|
||||
bot_users = (
|
||||
session.query(DiscordUser)
|
||||
.options(joinedload(DiscordUser.system_user))
|
||||
.join(DiscordMessage, DiscordMessage.from_id == DiscordUser.id)
|
||||
.filter(
|
||||
DiscordMessage.server_id == entity_id,
|
||||
DiscordUser.system_user_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
# Find one that's actually a bot
|
||||
for user in bot_users:
|
||||
if user.system_user and user.system_user.user_type == "discord_bot":
|
||||
return user
|
||||
|
||||
# For channels, check the server the channel belongs to
|
||||
if entity_type == "channel":
|
||||
channel = session.get(DiscordChannel, entity_id)
|
||||
if channel and channel.server_id:
|
||||
return get_bot_for_entity(session, "server", channel.server_id)
|
||||
|
||||
# Fallback: use first available bot
|
||||
bot = (
|
||||
session.query(DiscordBotUser)
|
||||
.options(joinedload(DiscordBotUser.discord_users).joinedload(DiscordUser.system_user))
|
||||
.first()
|
||||
)
|
||||
if bot and bot.discord_users:
|
||||
return bot.discord_users[0]
|
||||
return None
|
||||
|
||||
|
||||
def get_target_user_for_entity(
|
||||
session, entity_type: EntityType, entity_id: int
|
||||
) -> DiscordUser | None:
|
||||
"""Get the target user for sending a proactive message."""
|
||||
if entity_type == "user":
|
||||
return session.get(DiscordUser, entity_id)
|
||||
# For channels and servers, we don't have a specific target user
|
||||
return None
|
||||
|
||||
|
||||
def get_channel_for_entity(
|
||||
session, entity_type: EntityType, entity_id: int
|
||||
) -> DiscordChannel | None:
|
||||
"""Get the channel for sending a proactive message."""
|
||||
if entity_type == "channel":
|
||||
return session.get(DiscordChannel, entity_id)
|
||||
if entity_type == "server":
|
||||
# For servers, find the first text channel (prefer "general")
|
||||
channels = (
|
||||
session.query(DiscordChannel)
|
||||
.filter(
|
||||
DiscordChannel.server_id == entity_id,
|
||||
DiscordChannel.channel_type == "text",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not channels:
|
||||
return None
|
||||
# Prefer a channel named "general" if it exists
|
||||
for channel in channels:
|
||||
if channel.name and "general" in channel.name.lower():
|
||||
return channel
|
||||
return channels[0]
|
||||
# For users, we use DMs (no channel)
|
||||
return None
|
||||
|
||||
|
||||
@app.task(name=EVALUATE_PROACTIVE_CHECKINS)
|
||||
@safe_task_execution
|
||||
def evaluate_proactive_checkins() -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate which entities need proactive check-ins.
|
||||
|
||||
This task runs every minute and checks all entities with proactive_cron set
|
||||
to see if they're due for a check-in.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
dispatched = []
|
||||
|
||||
with make_session() as session:
|
||||
# Query all entities with proactive_cron set
|
||||
for model, entity_type in [
|
||||
(DiscordUser, "user"),
|
||||
(DiscordChannel, "channel"),
|
||||
(DiscordServer, "server"),
|
||||
]:
|
||||
entities = (
|
||||
session.query(model)
|
||||
.filter(model.proactive_cron.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
for entity in entities:
|
||||
cron_expr = cast(str, entity.proactive_cron)
|
||||
last_run = entity.last_proactive_at
|
||||
|
||||
if is_cron_due(cron_expr, last_run, now):
|
||||
logger.info(
|
||||
f"Proactive check-in due for {entity_type} {entity.id}"
|
||||
)
|
||||
execute_proactive_checkin.delay(entity_type, entity.id)
|
||||
dispatched.append({"type": entity_type, "id": entity.id})
|
||||
|
||||
return {
|
||||
"evaluated_at": now.isoformat(),
|
||||
"dispatched": dispatched,
|
||||
"count": len(dispatched),
|
||||
}
|
||||
|
||||
|
||||
@app.task(name=EXECUTE_PROACTIVE_CHECKIN)
|
||||
@safe_task_execution
|
||||
def execute_proactive_checkin(entity_type: EntityType, entity_id: int) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a proactive check-in for a specific entity.
|
||||
|
||||
This evaluates whether the bot should reach out and, if so, generates
|
||||
and sends a check-in message.
|
||||
"""
|
||||
logger.info(f"Executing proactive check-in for {entity_type} {entity_id}")
|
||||
|
||||
with make_session() as session:
|
||||
# Get the entity
|
||||
model_class = {
|
||||
"user": DiscordUser,
|
||||
"channel": DiscordChannel,
|
||||
"server": DiscordServer,
|
||||
}[entity_type]
|
||||
|
||||
entity = session.get(model_class, entity_id)
|
||||
if not entity:
|
||||
return {"error": f"{entity_type} {entity_id} not found"}
|
||||
|
||||
# Get the bot user
|
||||
bot_user = get_bot_for_entity(session, entity_type, entity_id)
|
||||
if not bot_user:
|
||||
return {"error": "No bot user found"}
|
||||
|
||||
# Get target user and channel
|
||||
target_user = get_target_user_for_entity(session, entity_type, entity_id)
|
||||
channel = get_channel_for_entity(session, entity_type, entity_id)
|
||||
|
||||
if not target_user and not channel:
|
||||
return {"error": "No target user or channel for proactive check-in"}
|
||||
|
||||
# Get chattiness threshold
|
||||
chattiness = entity.chattiness_threshold or 90
|
||||
|
||||
# Build the evaluation prompt
|
||||
proactive_prompt = entity.proactive_prompt or ""
|
||||
eval_prompt = textwrap.dedent("""
|
||||
You are considering whether to proactively reach out to check in.
|
||||
|
||||
{proactive_prompt}
|
||||
|
||||
Based on your notes and the context of previous conversations:
|
||||
1. Is there anything worth checking in about?
|
||||
2. Has enough happened or enough time passed to warrant a check-in?
|
||||
3. Would reaching out now be welcome or intrusive?
|
||||
|
||||
Please return a number between 0 and 100 indicating how strongly you want to check in
|
||||
(0 = definitely not, 100 = definitely yes).
|
||||
|
||||
<response>
|
||||
<number>50</number>
|
||||
<reason>Your reasoning here</reason>
|
||||
</response>
|
||||
""").format(proactive_prompt=proactive_prompt)
|
||||
|
||||
# Build context
|
||||
system_prompt = comm_channel_prompt(
|
||||
session, bot_user, target_user, channel
|
||||
)
|
||||
|
||||
# First, evaluate whether we should check in
|
||||
eval_response = call_llm(
|
||||
session,
|
||||
bot_user=bot_user,
|
||||
from_user=target_user,
|
||||
channel=channel,
|
||||
model=settings.SUMMARIZER_MODEL,
|
||||
system_prompt=system_prompt,
|
||||
messages=[eval_prompt],
|
||||
allowed_tools=[
|
||||
"update_channel_summary",
|
||||
"update_user_summary",
|
||||
"update_server_summary",
|
||||
],
|
||||
)
|
||||
|
||||
if not eval_response:
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
return {"status": "no_eval_response", "entity_type": entity_type, "entity_id": entity_id}
|
||||
|
||||
# Parse the interest score
|
||||
match = re.search(r"<number>(\d+)</number>", eval_response)
|
||||
if not match:
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
return {"status": "no_score_in_response", "entity_type": entity_type, "entity_id": entity_id}
|
||||
|
||||
interest_score = int(match.group(1))
|
||||
threshold = 100 - chattiness
|
||||
|
||||
logger.info(
|
||||
f"Proactive check-in eval: interest={interest_score}, threshold={threshold}, chattiness={chattiness}"
|
||||
)
|
||||
|
||||
if interest_score < threshold:
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
return {
|
||||
"status": "below_threshold",
|
||||
"interest": interest_score,
|
||||
"threshold": threshold,
|
||||
"entity_type": entity_type,
|
||||
"entity_id": entity_id,
|
||||
}
|
||||
|
||||
# Generate the actual check-in message
|
||||
checkin_prompt = textwrap.dedent("""
|
||||
You've decided to proactively check in. Generate a natural, friendly check-in message.
|
||||
|
||||
{proactive_prompt}
|
||||
|
||||
Keep it brief and genuine. Don't be overly formal or robotic.
|
||||
Reference specific things from your notes if relevant.
|
||||
""").format(proactive_prompt=proactive_prompt)
|
||||
|
||||
response = call_llm(
|
||||
session,
|
||||
bot_user=bot_user,
|
||||
from_user=target_user,
|
||||
channel=channel,
|
||||
model=settings.DISCORD_MODEL,
|
||||
system_prompt=system_prompt,
|
||||
messages=[checkin_prompt],
|
||||
)
|
||||
|
||||
if not response:
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
return {"status": "no_message_generated", "entity_type": entity_type, "entity_id": entity_id}
|
||||
|
||||
# Send the message
|
||||
bot_id = bot_user.system_user.id if bot_user.system_user else None
|
||||
if not bot_id:
|
||||
return {"error": "No system user for bot"}
|
||||
|
||||
success = send_discord_response(
|
||||
bot_id=bot_id,
|
||||
response=response,
|
||||
channel_id=channel.id if channel else None,
|
||||
user_identifier=target_user.username if target_user else None,
|
||||
)
|
||||
|
||||
# Update last_proactive_at
|
||||
entity.last_proactive_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"status": "sent" if success else "send_failed",
|
||||
"interest": interest_score,
|
||||
"entity_type": entity_type,
|
||||
"entity_id": entity_id,
|
||||
"response_preview": response[:100] + "..." if len(response) > 100 else response,
|
||||
}
|
||||
@ -766,7 +766,7 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
),
|
||||
(
|
||||
0.409,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
|
||||
),
|
||||
],
|
||||
},
|
||||
@ -835,11 +835,11 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"semantic": [
|
||||
(0.489, "I find backend logic more interesting than UI work"),
|
||||
(0.462, "The user prefers working on backend systems over frontend UI"),
|
||||
(0.455, "The user said pure functions are yucky"),
|
||||
(
|
||||
0.455,
|
||||
"The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(0.455, "The user said pure functions are yucky"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
|
||||
@ -137,10 +137,9 @@ class TestBuildPrompt:
|
||||
):
|
||||
prompt = _build_prompt()
|
||||
|
||||
assert "lesswrong" in prompt.lower()
|
||||
assert "comic" in prompt.lower()
|
||||
assert "Remove" in prompt
|
||||
assert "Remove meta-language" in prompt
|
||||
assert "Return ONLY valid JSON" in prompt
|
||||
assert "recalled_content" in prompt
|
||||
|
||||
|
||||
class TestAnalyzeQuery:
|
||||
|
||||
@ -83,9 +83,10 @@ def test_logout_handles_missing_session(mock_get_user_session):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.make_session")
|
||||
async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
||||
async def test_oauth_callback_discord_success(mock_make_session, mock_complete, mock_mcp_tools):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@contextmanager
|
||||
@ -95,9 +96,12 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
mcp_server = MagicMock()
|
||||
mcp_server.mcp_server_url = "https://example.com"
|
||||
mcp_server.access_token = "token123"
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||
|
||||
mock_complete.return_value = (200, "Authorized")
|
||||
mock_mcp_tools.return_value = [{"name": "test_tool"}]
|
||||
|
||||
request = make_request("code=abc123&state=state456")
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
@ -107,14 +111,15 @@ async def test_oauth_callback_discord_success(mock_make_session, mock_complete):
|
||||
assert "Authorization Successful" in body
|
||||
assert "Authorized" in body
|
||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||
mock_session.commit.assert_called_once()
|
||||
assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.auth.mcp_tools_list", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.complete_oauth_flow", new_callable=AsyncMock)
|
||||
@patch("memory.api.auth.make_session")
|
||||
async def test_oauth_callback_discord_handles_failures(
|
||||
mock_make_session, mock_complete
|
||||
mock_make_session, mock_complete, mock_mcp_tools
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
|
||||
@ -125,9 +130,12 @@ async def test_oauth_callback_discord_handles_failures(
|
||||
mock_make_session.return_value = session_cm()
|
||||
|
||||
mcp_server = MagicMock()
|
||||
mcp_server.mcp_server_url = "https://example.com"
|
||||
mcp_server.access_token = "token123"
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mcp_server
|
||||
|
||||
mock_complete.return_value = (500, "Failure")
|
||||
mock_mcp_tools.return_value = []
|
||||
|
||||
request = make_request("code=abc123&state=state456")
|
||||
response = await auth.oauth_callback_discord(request)
|
||||
@ -137,7 +145,7 @@ async def test_oauth_callback_discord_handles_failures(
|
||||
assert "Authorization Failed" in body
|
||||
assert "Failure" in body
|
||||
mock_complete.assert_awaited_once_with(mcp_server, "abc123", "state456")
|
||||
mock_session.commit.assert_called_once()
|
||||
assert mock_session.commit.call_count == 2 # Once after complete_oauth_flow, once after tools list
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
||||
DiscordUser,
|
||||
DiscordMessage,
|
||||
BotUser,
|
||||
DiscordBotUser,
|
||||
HumanUser,
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
@ -67,9 +68,10 @@ def sample_discord_user(db_session):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bot_user(db_session):
|
||||
def sample_bot_user(db_session, sample_discord_user):
|
||||
"""Create a sample bot user for testing."""
|
||||
bot = BotUser.create_with_api_key(
|
||||
bot = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[sample_discord_user],
|
||||
name="Test Bot",
|
||||
email="testbot@example.com",
|
||||
)
|
||||
@ -209,9 +211,9 @@ def test_schedule_message_with_user(
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = schedule_message(
|
||||
user_id=sample_human_user.id,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
bot_id=sample_human_user.id,
|
||||
recipient_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
message="Test message",
|
||||
date_time=future_time,
|
||||
@ -240,9 +242,9 @@ def test_schedule_message_with_channel(
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
result = schedule_message(
|
||||
user_id=sample_human_user.id,
|
||||
user=None,
|
||||
channel=sample_discord_channel.id,
|
||||
bot_id=sample_human_user.id,
|
||||
recipient_id=None,
|
||||
channel_id=sample_discord_channel.id,
|
||||
model="test-model",
|
||||
message="Test message",
|
||||
date_time=future_time,
|
||||
@ -265,12 +267,12 @@ def test_make_message_scheduler_with_user(sample_bot_user, sample_discord_user):
|
||||
"""Test creating a message scheduler tool for a user."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert tool.name == "schedule_message"
|
||||
assert tool.name == "schedule_discord_message"
|
||||
assert "from your chat with this user" in tool.description
|
||||
assert tool.input_schema["type"] == "object"
|
||||
assert "message" in tool.input_schema["properties"]
|
||||
@ -282,12 +284,12 @@ def test_make_message_scheduler_with_channel(sample_bot_user, sample_discord_cha
|
||||
"""Test creating a message scheduler tool for a channel."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=None,
|
||||
channel=sample_discord_channel.id,
|
||||
user_id=None,
|
||||
channel_id=sample_discord_channel.id,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert tool.name == "schedule_message"
|
||||
assert tool.name == "schedule_discord_message"
|
||||
assert "in this channel" in tool.description
|
||||
assert callable(tool.function)
|
||||
|
||||
@ -297,8 +299,8 @@ def test_make_message_scheduler_without_user_or_channel(sample_bot_user):
|
||||
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
||||
make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=None,
|
||||
channel=None,
|
||||
user_id=None,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -310,8 +312,8 @@ def test_message_scheduler_handler_success(
|
||||
"""Test message scheduler handler with valid input."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -330,8 +332,8 @@ def test_message_scheduler_handler_invalid_input(sample_bot_user, sample_discord
|
||||
"""Test message scheduler handler with non-dict input."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -345,8 +347,8 @@ def test_message_scheduler_handler_invalid_datetime(
|
||||
"""Test message scheduler handler with invalid datetime."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -365,8 +367,8 @@ def test_message_scheduler_handler_missing_datetime(
|
||||
"""Test message scheduler handler with missing datetime."""
|
||||
tool = make_message_scheduler(
|
||||
bot=sample_bot_user,
|
||||
user=sample_discord_user.id,
|
||||
channel=None,
|
||||
user_id=sample_discord_user.id,
|
||||
channel_id=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
@ -375,9 +377,9 @@ def test_message_scheduler_handler_missing_datetime(
|
||||
|
||||
|
||||
# Tests for make_prev_messages_tool
|
||||
def test_make_prev_messages_tool_with_user(sample_discord_user):
|
||||
def test_make_prev_messages_tool_with_user(sample_bot_user, sample_discord_user):
|
||||
"""Test creating a previous messages tool for a user."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
assert tool.name == "previous_messages"
|
||||
assert "from your chat with this user" in tool.description
|
||||
@ -387,26 +389,26 @@ def test_make_prev_messages_tool_with_user(sample_discord_user):
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_prev_messages_tool_with_channel(sample_discord_channel):
|
||||
def test_make_prev_messages_tool_with_channel(sample_bot_user, sample_discord_channel):
|
||||
"""Test creating a previous messages tool for a channel."""
|
||||
tool = make_prev_messages_tool(user=None, channel=sample_discord_channel.id)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=sample_discord_channel.id)
|
||||
|
||||
assert tool.name == "previous_messages"
|
||||
assert "in this channel" in tool.description
|
||||
assert callable(tool.function)
|
||||
|
||||
|
||||
def test_make_prev_messages_tool_without_user_or_channel():
|
||||
def test_make_prev_messages_tool_without_user_or_channel(sample_bot_user):
|
||||
"""Test that creating a tool without user or channel raises error."""
|
||||
with pytest.raises(ValueError, match="Either user or channel must be provided"):
|
||||
make_prev_messages_tool(user=None, channel=None)
|
||||
make_prev_messages_tool(bot=sample_bot_user, user_id=None, channel_id=None)
|
||||
|
||||
|
||||
def test_prev_messages_handler_success(
|
||||
db_session, sample_discord_user, sample_discord_channel
|
||||
db_session, sample_bot_user, sample_discord_user, sample_discord_channel
|
||||
):
|
||||
"""Test previous messages handler with valid input."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
# Create some actual messages in the database
|
||||
msg1 = DiscordMessage(
|
||||
@ -440,9 +442,9 @@ def test_prev_messages_handler_success(
|
||||
assert "Message 1" in result or "Message 2" in result
|
||||
|
||||
|
||||
def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
|
||||
def test_prev_messages_handler_with_defaults(db_session, sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with default values."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
result = tool.function({})
|
||||
|
||||
@ -450,35 +452,35 @@ def test_prev_messages_handler_with_defaults(db_session, sample_discord_user):
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_prev_messages_handler_invalid_input(sample_discord_user):
|
||||
def test_prev_messages_handler_invalid_input(sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with non-dict input."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Input must be a dictionary"):
|
||||
tool.function("not a dict")
|
||||
|
||||
|
||||
def test_prev_messages_handler_invalid_max_messages(sample_discord_user):
|
||||
def test_prev_messages_handler_invalid_max_messages(sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with invalid max_messages (negative value)."""
|
||||
# Note: max_messages=0 doesn't trigger validation due to `or 10` defaulting,
|
||||
# so we test with -1 which actually triggers the validation
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Max messages must be greater than 0"):
|
||||
tool.function({"max_messages": -1})
|
||||
|
||||
|
||||
def test_prev_messages_handler_invalid_offset(sample_discord_user):
|
||||
def test_prev_messages_handler_invalid_offset(sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with invalid offset."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Offset must be greater than or equal to 0"):
|
||||
tool.function({"offset": -1})
|
||||
|
||||
|
||||
def test_prev_messages_handler_non_integer_values(sample_discord_user):
|
||||
def test_prev_messages_handler_non_integer_values(sample_bot_user, sample_discord_user):
|
||||
"""Test previous messages handler with non-integer values."""
|
||||
tool = make_prev_messages_tool(user=sample_discord_user.id, channel=None)
|
||||
tool = make_prev_messages_tool(bot=sample_bot_user, user_id=sample_discord_user.id, channel_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Max messages and offset must be integers"):
|
||||
tool.function({"max_messages": "not an int"})
|
||||
@ -496,10 +498,10 @@ def test_make_discord_tools_with_user_and_channel(
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# Should have: schedule_discord_message, previous_messages, update_channel_summary,
|
||||
# update_user_summary, update_server_summary, add_reaction
|
||||
assert len(tools) == 6
|
||||
assert "schedule_message" in tools
|
||||
assert "schedule_discord_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_user_summary" in tools
|
||||
@ -516,10 +518,10 @@ def test_make_discord_tools_with_user_only(sample_bot_user, sample_discord_user)
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_user_summary
|
||||
# Should have: schedule_discord_message, previous_messages, update_user_summary
|
||||
# Note: Without channel, there's no channel summary tool
|
||||
assert len(tools) >= 2 # At least schedule and previous messages
|
||||
assert "schedule_message" in tools
|
||||
assert "schedule_discord_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_user_summary" in tools
|
||||
|
||||
@ -533,10 +535,10 @@ def test_make_discord_tools_with_channel_only(sample_bot_user, sample_discord_ch
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should have: schedule_message, previous_messages, update_channel_summary,
|
||||
# Should have: schedule_discord_message, previous_messages, update_channel_summary,
|
||||
# update_server_summary, add_reaction (no user summary without author)
|
||||
assert len(tools) == 5
|
||||
assert "schedule_message" in tools
|
||||
assert "schedule_discord_message" in tools
|
||||
assert "previous_messages" in tools
|
||||
assert "update_channel_summary" in tools
|
||||
assert "update_server_summary" in tools
|
||||
|
||||
@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
"http://localhost:8000/send_channel",
|
||||
json={
|
||||
"bot_id": BOT_ID,
|
||||
"channel_name": "general",
|
||||
"channel": "general",
|
||||
"message": "Announcement!",
|
||||
},
|
||||
timeout=10,
|
||||
|
||||
@ -91,7 +91,7 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
"http://localhost:8000/send_channel",
|
||||
json={
|
||||
"bot_id": BOT_ID,
|
||||
"channel_name": "general",
|
||||
"channel": "general",
|
||||
"message": "Announcement!",
|
||||
},
|
||||
timeout=10,
|
||||
|
||||
@ -15,6 +15,7 @@ from memory.discord.commands import (
|
||||
handle_chattiness,
|
||||
handle_ignore,
|
||||
handle_summary,
|
||||
handle_proactive,
|
||||
respond,
|
||||
with_object_context,
|
||||
handle_mcp_servers,
|
||||
@ -377,3 +378,207 @@ async def test_handle_mcp_servers_wraps_errors(mock_run_mcp, interaction):
|
||||
await handle_mcp_servers(context, action="list", url=None)
|
||||
|
||||
assert "Error: boom" in str(exc.value)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for handle_proactive
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_show_disabled(db_session, interaction, guild):
|
||||
"""Test showing proactive settings when disabled."""
|
||||
server = DiscordServer(id=guild.id, name="Guild", proactive_cron=None)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context)
|
||||
|
||||
assert "disabled" in response.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_show_enabled(db_session, interaction, guild):
|
||||
"""Test showing proactive settings when enabled."""
|
||||
server = DiscordServer(
|
||||
id=guild.id,
|
||||
name="Guild",
|
||||
proactive_cron="0 9 * * *",
|
||||
proactive_prompt="Check on projects",
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context)
|
||||
|
||||
assert "0 9 * * *" in response.content
|
||||
assert "Check on projects" in response.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_set_cron(db_session, interaction, guild):
|
||||
"""Test setting proactive cron schedule."""
|
||||
server = DiscordServer(id=guild.id, name="Guild")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, cron="0 9 * * *")
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert "0 9 * * *" in response.content
|
||||
assert server.proactive_cron == "0 9 * * *"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_set_prompt(db_session, interaction, guild):
|
||||
"""Test setting proactive prompt."""
|
||||
server = DiscordServer(id=guild.id, name="Guild", proactive_cron="0 9 * * *")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, prompt="Focus on daily standups")
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert server.proactive_prompt == "Focus on daily standups"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_disable(db_session, interaction, guild):
|
||||
"""Test disabling proactive check-ins."""
|
||||
server = DiscordServer(
|
||||
id=guild.id,
|
||||
name="Guild",
|
||||
proactive_cron="0 9 * * *",
|
||||
proactive_prompt="Some prompt",
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, cron="off")
|
||||
|
||||
assert "disabled" in response.content.lower()
|
||||
assert server.proactive_cron is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_invalid_cron(db_session, interaction, guild):
|
||||
"""Test error on invalid cron expression."""
|
||||
server = DiscordServer(id=guild.id, name="Guild")
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="server",
|
||||
target=server,
|
||||
display_name="server **Guild**",
|
||||
)
|
||||
|
||||
with pytest.raises(CommandError) as exc:
|
||||
await handle_proactive(context, cron="not a valid cron")
|
||||
|
||||
assert "Invalid cron expression" in str(exc.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_user_scope(db_session, interaction, discord_user):
|
||||
"""Test proactive settings for user scope."""
|
||||
user_model = DiscordUser(
|
||||
id=discord_user.id, username="testuser", proactive_cron=None
|
||||
)
|
||||
db_session.add(user_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="me",
|
||||
target=user_model,
|
||||
display_name="you (**testuser**)",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, cron="0 9,17 * * 1-5")
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert user_model.proactive_cron == "0 9,17 * * 1-5"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_proactive_channel_scope(
|
||||
db_session, interaction, guild, text_channel
|
||||
):
|
||||
"""Test proactive settings for channel scope."""
|
||||
server = DiscordServer(id=guild.id, name="Guild")
|
||||
db_session.add(server)
|
||||
db_session.flush()
|
||||
|
||||
channel_model = DiscordChannel(
|
||||
id=text_channel.id,
|
||||
name="general",
|
||||
channel_type="text",
|
||||
server_id=guild.id,
|
||||
)
|
||||
db_session.add(channel_model)
|
||||
db_session.commit()
|
||||
|
||||
context = CommandContext(
|
||||
session=db_session,
|
||||
interaction=interaction,
|
||||
actor=MagicMock(spec=DiscordUser),
|
||||
scope="channel",
|
||||
target=channel_model,
|
||||
display_name="channel **#general**",
|
||||
)
|
||||
|
||||
response = await handle_proactive(context, cron="0 12 * * *")
|
||||
|
||||
assert "Updated" in response.content
|
||||
assert channel_model.proactive_cron == "0 12 * * *"
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Tests for Discord MCP server management."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import aiohttp
|
||||
@ -8,8 +7,8 @@ import discord
|
||||
import pytest
|
||||
|
||||
from memory.common.db.models import MCPServer, MCPServerAssignment
|
||||
from memory.common.mcp import mcp_call
|
||||
from memory.discord.mcp import (
|
||||
call_mcp_server,
|
||||
find_mcp_server,
|
||||
handle_mcp_add,
|
||||
handle_mcp_connect,
|
||||
@ -142,7 +141,7 @@ async def test_call_mcp_server_success():
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
results = []
|
||||
async for data in call_mcp_server(
|
||||
async for data in mcp_call(
|
||||
"https://mcp.example.com", "test_token", "tools/list", {}
|
||||
):
|
||||
results.append(data)
|
||||
@ -172,7 +171,7 @@ async def test_call_mcp_server_error():
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
with pytest.raises(ValueError, match="Failed to call MCP server"):
|
||||
async for _ in call_mcp_server(
|
||||
async for _ in mcp_call(
|
||||
"https://mcp.example.com", "test_token", "tools/list"
|
||||
):
|
||||
pass
|
||||
@ -203,7 +202,7 @@ async def test_call_mcp_server_invalid_json():
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session_ctx):
|
||||
results = []
|
||||
async for data in call_mcp_server(
|
||||
async for data in mcp_call(
|
||||
"https://mcp.example.com", "test_token", "tools/list"
|
||||
):
|
||||
results.append(data)
|
||||
|
||||
@ -19,6 +19,7 @@ from memory.common.db.models import (
|
||||
DiscordChannel,
|
||||
DiscordServer,
|
||||
DiscordMessage,
|
||||
DiscordBotUser,
|
||||
HumanUser,
|
||||
ScheduledLLMCall,
|
||||
)
|
||||
@ -34,6 +35,19 @@ def sample_discord_user(db_session):
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bot_user(db_session, sample_discord_user):
|
||||
"""Create a sample Discord bot user."""
|
||||
bot = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[sample_discord_user],
|
||||
name="Test Bot",
|
||||
email="testbot@example.com",
|
||||
)
|
||||
db_session.add(bot)
|
||||
db_session.commit()
|
||||
return bot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discord_channel(db_session):
|
||||
"""Create a sample Discord channel."""
|
||||
@ -290,13 +304,13 @@ def test_upsert_scheduled_message_cancels_earlier_call(
|
||||
# Test previous_messages
|
||||
|
||||
|
||||
def test_previous_messages_empty(db_session):
|
||||
def test_previous_messages_empty(db_session, sample_bot_user):
|
||||
"""Test getting previous messages when none exist."""
|
||||
result = previous_messages(db_session, user_id=123, channel_id=456)
|
||||
result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=123, channel_id=456)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_previous_messages_filters_by_user(db_session, sample_discord_user, sample_discord_channel):
|
||||
def test_previous_messages_filters_by_user(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
|
||||
"""Test filtering messages by recipient user."""
|
||||
# Create some messages
|
||||
msg1 = DiscordMessage(
|
||||
@ -322,14 +336,14 @@ def test_previous_messages_filters_by_user(db_session, sample_discord_user, samp
|
||||
db_session.add_all([msg1, msg2])
|
||||
db_session.commit()
|
||||
|
||||
result = previous_messages(db_session, user_id=sample_discord_user.id, channel_id=None)
|
||||
result = previous_messages(db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None)
|
||||
assert len(result) == 2
|
||||
# Should be in chronological order (oldest first)
|
||||
assert result[0].message_id == 1
|
||||
assert result[1].message_id == 2
|
||||
|
||||
|
||||
def test_previous_messages_limits_results(db_session, sample_discord_user, sample_discord_channel):
|
||||
def test_previous_messages_limits_results(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
|
||||
"""Test limiting the number of previous messages."""
|
||||
# Create 15 messages
|
||||
for i in range(15):
|
||||
@ -347,7 +361,7 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
|
||||
db_session.commit()
|
||||
|
||||
result = previous_messages(
|
||||
db_session, user_id=sample_discord_user.id, channel_id=None, max_messages=5
|
||||
db_session, bot_id=sample_bot_user.discord_id, user_id=sample_discord_user.id, channel_id=None, max_messages=5
|
||||
)
|
||||
assert len(result) == 5
|
||||
|
||||
@ -355,10 +369,10 @@ def test_previous_messages_limits_results(db_session, sample_discord_user, sampl
|
||||
# Test comm_channel_prompt
|
||||
|
||||
|
||||
def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_discord_channel):
|
||||
def test_comm_channel_prompt_basic(db_session, sample_bot_user, sample_discord_user, sample_discord_channel):
|
||||
"""Test generating a basic communication channel prompt."""
|
||||
result = comm_channel_prompt(
|
||||
db_session, user=sample_discord_user, channel=sample_discord_channel
|
||||
db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
|
||||
)
|
||||
|
||||
assert "You are a bot communicating on Discord" in result
|
||||
@ -366,31 +380,31 @@ def test_comm_channel_prompt_basic(db_session, sample_discord_user, sample_disco
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_comm_channel_prompt_includes_server_context(db_session, sample_discord_channel):
|
||||
def test_comm_channel_prompt_includes_server_context(db_session, sample_bot_user, sample_discord_channel):
|
||||
"""Test that prompt includes server context when available."""
|
||||
server = sample_discord_channel.server
|
||||
server.summary = "Gaming community server"
|
||||
db_session.commit()
|
||||
|
||||
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
|
||||
result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
|
||||
|
||||
assert "server_context" in result.lower()
|
||||
assert "Gaming community server" in result
|
||||
|
||||
|
||||
def test_comm_channel_prompt_includes_channel_context(db_session, sample_discord_channel):
|
||||
def test_comm_channel_prompt_includes_channel_context(db_session, sample_bot_user, sample_discord_channel):
|
||||
"""Test that prompt includes channel context."""
|
||||
sample_discord_channel.summary = "General discussion channel"
|
||||
db_session.commit()
|
||||
|
||||
result = comm_channel_prompt(db_session, user=None, channel=sample_discord_channel)
|
||||
result = comm_channel_prompt(db_session, bot=sample_bot_user.discord_id, user=None, channel=sample_discord_channel)
|
||||
|
||||
assert "channel_context" in result.lower()
|
||||
assert "General discussion channel" in result
|
||||
|
||||
|
||||
def test_comm_channel_prompt_includes_user_notes(
|
||||
db_session, sample_discord_user, sample_discord_channel
|
||||
db_session, sample_bot_user, sample_discord_user, sample_discord_channel
|
||||
):
|
||||
"""Test that prompt includes user notes from previous messages."""
|
||||
sample_discord_user.summary = "Helpful community member"
|
||||
@ -411,7 +425,7 @@ def test_comm_channel_prompt_includes_user_notes(
|
||||
db_session.commit()
|
||||
|
||||
result = comm_channel_prompt(
|
||||
db_session, user=sample_discord_user, channel=sample_discord_channel
|
||||
db_session, bot=sample_bot_user.discord_id, user=sample_discord_user, channel=sample_discord_channel
|
||||
)
|
||||
|
||||
assert "user_notes" in result.lower()
|
||||
@ -442,12 +456,16 @@ def test_call_llm_includes_web_search_and_mcp_servers(
|
||||
web_tool_instance = MagicMock(name="web_tool")
|
||||
mock_web_search.return_value = web_tool_instance
|
||||
|
||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt="bot prompt")
|
||||
bot_user = SimpleNamespace(
|
||||
system_user=SimpleNamespace(discord_id=999888777),
|
||||
system_prompt="bot prompt"
|
||||
)
|
||||
from_user = SimpleNamespace(id=123)
|
||||
mcp_model = SimpleNamespace(
|
||||
name="Server",
|
||||
mcp_server_url="https://mcp.example.com",
|
||||
access_token="token123",
|
||||
disabled_tools=[],
|
||||
)
|
||||
|
||||
result = call_llm(
|
||||
@ -502,7 +520,10 @@ def test_call_llm_filters_disallowed_tools(
|
||||
|
||||
mock_web_search.return_value = MagicMock(name="web_tool")
|
||||
|
||||
bot_user = SimpleNamespace(system_user="system-user", system_prompt=None)
|
||||
bot_user = SimpleNamespace(
|
||||
system_user=SimpleNamespace(discord_id=999888777),
|
||||
system_prompt=None
|
||||
)
|
||||
from_user = SimpleNamespace(id=1)
|
||||
|
||||
call_llm(
|
||||
|
||||
@ -14,12 +14,25 @@ from memory.workers.tasks import discord
|
||||
|
||||
@pytest.fixture
|
||||
def discord_bot_user(db_session):
|
||||
# Create a discord user for the bot first
|
||||
bot_discord_user = DiscordUser(
|
||||
id=999999999,
|
||||
username="testbot",
|
||||
)
|
||||
db_session.add(bot_discord_user)
|
||||
db_session.flush()
|
||||
|
||||
bot = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
discord_users=[bot_discord_user],
|
||||
name="Test Bot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
db_session.add(bot)
|
||||
db_session.flush()
|
||||
|
||||
# Link the discord user to the system user
|
||||
bot_discord_user.system_user_id = bot.id
|
||||
|
||||
db_session.commit()
|
||||
return bot
|
||||
|
||||
@ -176,26 +189,29 @@ def test_get_prev_empty_channel(db_session, mock_discord_channel):
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", True)
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.workers.tasks.discord.create_provider")
|
||||
@patch("memory.workers.tasks.discord.call_llm")
|
||||
@patch("memory.workers.tasks.discord.discord.trigger_typing_channel")
|
||||
def test_should_process_normal_message(
|
||||
mock_create_provider,
|
||||
mock_trigger_typing,
|
||||
mock_call_llm,
|
||||
db_session,
|
||||
mock_discord_user,
|
||||
mock_discord_server,
|
||||
mock_discord_channel,
|
||||
discord_bot_user,
|
||||
):
|
||||
"""Test should_process returns True for normal messages."""
|
||||
# Mock the LLM provider to return "yes"
|
||||
mock_provider = Mock()
|
||||
mock_provider.generate.return_value = "<response>yes</response>"
|
||||
mock_provider.as_messages.return_value = []
|
||||
mock_create_provider.return_value = mock_provider
|
||||
# Create a separate recipient user (the bot)
|
||||
bot_discord_user = discord_bot_user.discord_users[0]
|
||||
|
||||
# Mock call_llm to return a high number (100 = always process)
|
||||
mock_call_llm.return_value = "<response><number>100</number><reason>Test</reason></response>"
|
||||
|
||||
message = DiscordMessage(
|
||||
message_id=1,
|
||||
channel_id=mock_discord_channel.id,
|
||||
from_id=mock_discord_user.id,
|
||||
recipient_id=mock_discord_user.id,
|
||||
recipient_id=bot_discord_user.id, # Bot is recipient, not the from_user
|
||||
server_id=mock_discord_server.id,
|
||||
content="Test",
|
||||
sent_at=datetime.now(timezone.utc),
|
||||
@ -207,6 +223,8 @@ def test_should_process_normal_message(
|
||||
db_session.refresh(message)
|
||||
|
||||
assert discord.should_process(message) is True
|
||||
mock_call_llm.assert_called_once()
|
||||
mock_trigger_typing.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_PROCESS_MESSAGES", False)
|
||||
@ -344,6 +362,7 @@ def test_add_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
def test_add_discord_message_with_reply(db_session, sample_message_data, qdrant):
|
||||
"""Test adding a Discord message that is a reply."""
|
||||
sample_message_data["message_reference_id"] = 111222333
|
||||
sample_message_data["message_type"] = "reply" # Explicitly set message_type
|
||||
|
||||
discord.add_discord_message(**sample_message_data)
|
||||
|
||||
@ -523,8 +542,17 @@ def test_edit_discord_message_updates_context(
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
def test_process_discord_message_success(db_session, sample_message_data, qdrant):
|
||||
@patch("memory.workers.tasks.discord.send_discord_response")
|
||||
@patch("memory.workers.tasks.discord.call_llm")
|
||||
def test_process_discord_message_success(
|
||||
mock_call_llm, mock_send_response, db_session, sample_message_data, qdrant
|
||||
):
|
||||
"""Test processing a Discord message."""
|
||||
# Mock LLM to return a response
|
||||
mock_call_llm.return_value = "Test response from bot"
|
||||
# Mock Discord API to succeed
|
||||
mock_send_response.return_value = True
|
||||
|
||||
# Add a message first
|
||||
add_result = discord.add_discord_message(**sample_message_data)
|
||||
message_id = add_result["discordmessage_id"]
|
||||
@ -534,6 +562,8 @@ def test_process_discord_message_success(db_session, sample_message_data, qdrant
|
||||
|
||||
assert result["status"] == "processed"
|
||||
assert result["message_id"] == message_id
|
||||
mock_call_llm.assert_called_once()
|
||||
mock_send_response.assert_called_once()
|
||||
|
||||
|
||||
def test_process_discord_message_not_found(db_session):
|
||||
|
||||
536
tests/memory/workers/tasks/test_proactive.py
Normal file
536
tests/memory/workers/tasks/test_proactive.py
Normal file
@ -0,0 +1,536 @@
|
||||
"""Tests for proactive check-in tasks."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from memory.common.db.models import (
|
||||
DiscordBotUser,
|
||||
DiscordUser,
|
||||
DiscordChannel,
|
||||
DiscordServer,
|
||||
)
|
||||
from memory.workers.tasks import proactive
|
||||
from memory.workers.tasks.proactive import is_cron_due
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bot_user(db_session):
|
||||
"""Create a bot user for testing."""
|
||||
bot_discord_user = DiscordUser(
|
||||
id=999999999,
|
||||
username="testbot",
|
||||
)
|
||||
db_session.add(bot_discord_user)
|
||||
db_session.flush()
|
||||
|
||||
user = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[bot_discord_user],
|
||||
name="testbot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user(db_session):
|
||||
"""Create a target Discord user for testing."""
|
||||
discord_user = DiscordUser(
|
||||
id=123456789,
|
||||
username="targetuser",
|
||||
proactive_cron="0 9 * * *", # 9am daily
|
||||
chattiness_threshold=50,
|
||||
)
|
||||
db_session.add(discord_user)
|
||||
db_session.commit()
|
||||
return discord_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_no_cron(db_session):
|
||||
"""Create a target Discord user without proactive cron."""
|
||||
discord_user = DiscordUser(
|
||||
id=123456790,
|
||||
username="nocronuser",
|
||||
proactive_cron=None,
|
||||
)
|
||||
db_session.add(discord_user)
|
||||
db_session.commit()
|
||||
return discord_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_server(db_session):
|
||||
"""Create a target Discord server for testing."""
|
||||
server = DiscordServer(
|
||||
id=987654321,
|
||||
name="Test Server",
|
||||
proactive_cron="0 */4 * * *", # Every 4 hours
|
||||
chattiness_threshold=30,
|
||||
)
|
||||
db_session.add(server)
|
||||
db_session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_channel(db_session, target_server):
|
||||
"""Create a target Discord channel for testing."""
|
||||
channel = DiscordChannel(
|
||||
id=111222333,
|
||||
name="test-channel",
|
||||
channel_type="text",
|
||||
server_id=target_server.id,
|
||||
proactive_cron="0 12 * * 1-5", # Noon on weekdays
|
||||
chattiness_threshold=70,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
return channel
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for is_cron_due helper
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cron_expr,now,last_run,expected",
|
||||
[
|
||||
# Cron is due when never run before and time matches
|
||||
(
|
||||
"0 9 * * *",
|
||||
datetime(2025, 12, 24, 9, 0, 30, tzinfo=timezone.utc),
|
||||
None,
|
||||
True,
|
||||
),
|
||||
# Cron is due when last run was before the scheduled time
|
||||
(
|
||||
"0 9 * * *",
|
||||
datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc),
|
||||
datetime(2025, 12, 23, 9, 0, 0, tzinfo=timezone.utc),
|
||||
True,
|
||||
),
|
||||
# Cron is NOT due when already run this period
|
||||
(
|
||||
"0 9 * * *",
|
||||
datetime(2025, 12, 24, 9, 30, 0, tzinfo=timezone.utc),
|
||||
datetime(2025, 12, 24, 9, 5, 0, tzinfo=timezone.utc),
|
||||
False,
|
||||
),
|
||||
# Cron is NOT due when current time is before scheduled time
|
||||
(
|
||||
"0 9 * * *",
|
||||
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
|
||||
None,
|
||||
False,
|
||||
),
|
||||
# Hourly cron schedule
|
||||
(
|
||||
"0 * * * *",
|
||||
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
|
||||
datetime(2025, 12, 24, 11, 0, 0, tzinfo=timezone.utc),
|
||||
True,
|
||||
),
|
||||
# Every 4 hours cron schedule
|
||||
(
|
||||
"0 */4 * * *",
|
||||
datetime(2025, 12, 24, 12, 0, 30, tzinfo=timezone.utc),
|
||||
datetime(2025, 12, 24, 8, 0, 0, tzinfo=timezone.utc),
|
||||
True,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"due_never_run",
|
||||
"due_last_run_before_schedule",
|
||||
"not_due_already_run",
|
||||
"not_due_too_early",
|
||||
"due_hourly",
|
||||
"due_every_4_hours",
|
||||
],
|
||||
)
|
||||
def test_is_cron_due(cron_expr, now, last_run, expected):
|
||||
"""Test is_cron_due with various scenarios."""
|
||||
assert is_cron_due(cron_expr, last_run, now) is expected
|
||||
|
||||
|
||||
def test_is_cron_due_invalid_expression():
|
||||
"""Test invalid cron expression returns False."""
|
||||
now = datetime(2025, 12, 24, 9, 0, 0, tzinfo=timezone.utc)
|
||||
assert is_cron_due("invalid cron", None, now) is False
|
||||
|
||||
|
||||
def test_is_cron_due_with_naive_last_run():
|
||||
"""Test cron handles naive datetime for last_run."""
|
||||
now = datetime(2025, 12, 24, 9, 1, 0, tzinfo=timezone.utc)
|
||||
cron_expr = "0 9 * * *"
|
||||
last_run = datetime(2025, 12, 23, 9, 0, 0) # Naive datetime
|
||||
assert is_cron_due(cron_expr, last_run, now) is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for evaluate_proactive_checkins task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||
@patch("memory.workers.tasks.proactive.is_cron_due")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_evaluate_proactive_checkins_dispatches_due(
|
||||
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
|
||||
):
|
||||
"""Test that due check-ins are dispatched."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
mock_is_cron_due.return_value = True
|
||||
|
||||
result = proactive.evaluate_proactive_checkins()
|
||||
|
||||
assert result["count"] >= 1
|
||||
mock_execute.delay.assert_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||
@patch("memory.workers.tasks.proactive.is_cron_due")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_evaluate_proactive_checkins_skips_not_due(
|
||||
mock_make_session, mock_is_cron_due, mock_execute, db_session, target_user
|
||||
):
|
||||
"""Test that not-due check-ins are not dispatched."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
mock_is_cron_due.return_value = False
|
||||
|
||||
result = proactive.evaluate_proactive_checkins()
|
||||
|
||||
assert result["count"] == 0
|
||||
mock_execute.delay.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_evaluate_proactive_checkins_skips_no_cron(
|
||||
mock_make_session, mock_execute, db_session, target_user_no_cron
|
||||
):
|
||||
"""Test that entities without proactive_cron are skipped."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
result = proactive.evaluate_proactive_checkins()
|
||||
|
||||
for call in mock_execute.delay.call_args_list:
|
||||
entity_type, entity_id = call[0]
|
||||
assert entity_id != target_user_no_cron.id
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.execute_proactive_checkin")
|
||||
@patch("memory.workers.tasks.proactive.is_cron_due")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_evaluate_proactive_checkins_multiple_entity_types(
|
||||
mock_make_session,
|
||||
mock_is_cron_due,
|
||||
mock_execute,
|
||||
db_session,
|
||||
target_user,
|
||||
target_server,
|
||||
target_channel,
|
||||
):
|
||||
"""Test that check-ins are dispatched for users, channels, and servers."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
mock_is_cron_due.return_value = True
|
||||
|
||||
result = proactive.evaluate_proactive_checkins()
|
||||
|
||||
assert result["count"] == 3
|
||||
dispatched_types = {d["type"] for d in result["dispatched"]}
|
||||
assert "user" in dispatched_types
|
||||
assert "channel" in dispatched_types
|
||||
assert "server" in dispatched_types
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for execute_proactive_checkin task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_sends_when_above_threshold(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
mock_send,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test check-in is sent when interest exceeds threshold."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.side_effect = [
|
||||
"<response><number>80</number><reason>Should check in</reason></response>",
|
||||
"Hey! Just checking in - how are things going?",
|
||||
]
|
||||
mock_send.return_value = True
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
assert result["status"] == "sent"
|
||||
assert result["interest"] == 80
|
||||
mock_send.assert_called_once()
|
||||
|
||||
db_session.refresh(target_user)
|
||||
assert target_user.last_proactive_at is not None
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_skips_below_threshold(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test check-in is skipped when interest is below threshold."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.return_value = (
|
||||
"<response><number>30</number><reason>Not much to say</reason></response>"
|
||||
)
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
assert result["status"] == "below_threshold"
|
||||
assert result["interest"] == 30
|
||||
assert result["threshold"] == 50
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_response,expected_status",
|
||||
[
|
||||
(None, "no_eval_response"),
|
||||
("I'm not sure what to say.", "no_score_in_response"),
|
||||
],
|
||||
ids=["no_response", "malformed_response"],
|
||||
)
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_handles_bad_llm_response(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
llm_response,
|
||||
expected_status,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test handling of missing or malformed LLM responses."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
mock_call_llm.return_value = llm_response
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
assert result["status"] == expected_status
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_nonexistent_entity(mock_make_session, db_session):
|
||||
"""Test handling when entity doesn't exist."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", 999999)
|
||||
|
||||
assert "error" in result
|
||||
assert "not found" in result["error"]
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_no_bot_user(
|
||||
mock_make_session, mock_get_bot, db_session, target_user
|
||||
):
|
||||
"""Test handling when no bot user is found."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
mock_get_bot.return_value = None
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
assert "error" in result
|
||||
assert "No bot user" in result["error"]
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_uses_proactive_prompt(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
mock_send,
|
||||
db_session,
|
||||
bot_user,
|
||||
):
|
||||
"""Test that proactive_prompt is included in the evaluation."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
user_with_prompt = DiscordUser(
|
||||
id=555666777,
|
||||
username="promptuser",
|
||||
proactive_cron="0 9 * * *",
|
||||
proactive_prompt="Focus on their coding projects",
|
||||
chattiness_threshold=50,
|
||||
)
|
||||
db_session.add(user_with_prompt)
|
||||
db_session.commit()
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.side_effect = [
|
||||
"<response><number>80</number><reason>Check on projects</reason></response>",
|
||||
"How are your coding projects coming along?",
|
||||
]
|
||||
mock_send.return_value = True
|
||||
|
||||
result = proactive.execute_proactive_checkin("user", user_with_prompt.id)
|
||||
|
||||
assert result["status"] == "sent"
|
||||
call_args = mock_call_llm.call_args_list[0]
|
||||
messages_arg = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
assert any("Focus on their coding projects" in str(m) for m in messages_arg)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_channel(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
mock_send,
|
||||
db_session,
|
||||
target_channel,
|
||||
bot_user,
|
||||
):
|
||||
"""Test check-in to a channel."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.side_effect = [
|
||||
"<response><number>50</number><reason>Check channel</reason></response>",
|
||||
"Good morning everyone!",
|
||||
]
|
||||
mock_send.return_value = True
|
||||
|
||||
result = proactive.execute_proactive_checkin("channel", target_channel.id)
|
||||
|
||||
assert result["status"] == "sent"
|
||||
assert result["entity_type"] == "channel"
|
||||
|
||||
send_call = mock_send.call_args
|
||||
assert send_call.kwargs.get("channel_id") == target_channel.id
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.send_discord_response")
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_updates_last_proactive_at(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
mock_send,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test that last_proactive_at is updated after successful check-in."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.side_effect = [
|
||||
"<response><number>80</number><reason>Check in</reason></response>",
|
||||
"Hey there!",
|
||||
]
|
||||
mock_send.return_value = True
|
||||
|
||||
before_time = datetime.now(timezone.utc)
|
||||
proactive.execute_proactive_checkin("user", target_user.id)
|
||||
after_time = datetime.now(timezone.utc)
|
||||
|
||||
db_session.refresh(target_user)
|
||||
assert target_user.last_proactive_at is not None
|
||||
assert before_time <= target_user.last_proactive_at <= after_time
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.proactive.call_llm")
|
||||
@patch("memory.workers.tasks.proactive.get_bot_for_entity")
|
||||
@patch("memory.workers.tasks.proactive.make_session")
|
||||
def test_execute_proactive_checkin_updates_last_proactive_at_on_skip(
|
||||
mock_make_session,
|
||||
mock_get_bot,
|
||||
mock_call_llm,
|
||||
db_session,
|
||||
target_user,
|
||||
bot_user,
|
||||
):
|
||||
"""Test that last_proactive_at is updated even when check-in is skipped."""
|
||||
mock_make_session.return_value.__enter__ = Mock(return_value=db_session)
|
||||
mock_make_session.return_value.__exit__ = Mock(return_value=False)
|
||||
|
||||
bot_discord_user = bot_user.discord_users[0]
|
||||
bot_discord_user.system_user = bot_user
|
||||
mock_get_bot.return_value = bot_discord_user
|
||||
|
||||
mock_call_llm.return_value = (
|
||||
"<response><number>10</number><reason>Nothing to say</reason></response>"
|
||||
)
|
||||
|
||||
proactive.execute_proactive_checkin("user", target_user.id)
|
||||
|
||||
db_session.refresh(target_user)
|
||||
assert target_user.last_proactive_at is not None
|
||||
@ -16,8 +16,16 @@ from memory.workers.tasks import scheduled_calls
|
||||
@pytest.fixture
|
||||
def sample_user(db_session):
|
||||
"""Create a sample user for testing."""
|
||||
# Create a discord user for the bot
|
||||
bot_discord_user = DiscordUser(
|
||||
id=999999999,
|
||||
username="testbot",
|
||||
)
|
||||
db_session.add(bot_discord_user)
|
||||
db_session.flush()
|
||||
|
||||
user = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
discord_users=[bot_discord_user],
|
||||
name="testbot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
@ -122,65 +130,64 @@ def future_scheduled_call(db_session, sample_user, sample_discord_user):
|
||||
return call
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
||||
"""Test sending to Discord user."""
|
||||
response = "This is a test response."
|
||||
|
||||
scheduled_calls.send_to_discord(pending_scheduled_call, response)
|
||||
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, response)
|
||||
|
||||
mock_send_dm.assert_called_once_with(
|
||||
pending_scheduled_call.user_id,
|
||||
999999999, # bot_id
|
||||
"testuser", # username, not ID
|
||||
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
||||
response,
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
||||
def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
||||
@patch("memory.discord.messages.discord.send_to_channel")
|
||||
def test_send_to_discord_channel(mock_send_to_channel, completed_scheduled_call):
|
||||
"""Test sending to Discord channel."""
|
||||
response = "This is a channel response."
|
||||
|
||||
scheduled_calls.send_to_discord(completed_scheduled_call, response)
|
||||
scheduled_calls.send_to_discord(999999999, completed_scheduled_call, response)
|
||||
|
||||
mock_broadcast.assert_called_once_with(
|
||||
completed_scheduled_call.user_id,
|
||||
"test-channel", # channel name, not ID
|
||||
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
||||
mock_send_to_channel.assert_called_once_with(
|
||||
999999999, # bot_id
|
||||
completed_scheduled_call.discord_channel.id, # channel ID, not name
|
||||
response,
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled_call):
|
||||
"""Test message truncation for long responses."""
|
||||
long_response = "A" * 2500 # Very long response
|
||||
|
||||
scheduled_calls.send_to_discord(pending_scheduled_call, long_response)
|
||||
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, long_response)
|
||||
|
||||
# Verify the message was truncated
|
||||
# With the new implementation, send_discord_response sends the full response
|
||||
# No truncation happens in _send_to_discord
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
assert args[0] == pending_scheduled_call.user_id
|
||||
assert args[0] == 999999999 # bot_id
|
||||
message = args[2]
|
||||
assert len(message) <= 1950 # Should be truncated
|
||||
assert message.endswith("... (response truncated)")
|
||||
assert message == long_response
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_call):
|
||||
"""Test that normal length messages are not truncated."""
|
||||
normal_response = "This is a normal length response."
|
||||
|
||||
scheduled_calls.send_to_discord(pending_scheduled_call, normal_response)
|
||||
scheduled_calls.send_to_discord(999999999, pending_scheduled_call, normal_response)
|
||||
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
assert args[0] == pending_scheduled_call.user_id
|
||||
assert args[0] == 999999999 # bot_id
|
||||
message = args[2]
|
||||
assert not message.endswith("... (response truncated)")
|
||||
assert "This is a normal length response." in message
|
||||
assert message == normal_response
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_success(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -189,12 +196,8 @@ def test_execute_scheduled_call_success(
|
||||
|
||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||
|
||||
# Verify LLM was called with correct parameters
|
||||
mock_llm_call.assert_called_once_with(
|
||||
prompt="What is the weather like today?",
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
# Verify LLM was called
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
assert result["success"] is True
|
||||
@ -218,7 +221,7 @@ def test_execute_scheduled_call_not_found(db_session):
|
||||
assert result == {"error": "Scheduled call not found"}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_not_pending(
|
||||
mock_llm_call, completed_scheduled_call, db_session
|
||||
):
|
||||
@ -229,8 +232,8 @@ def test_execute_scheduled_call_not_pending(
|
||||
mock_llm_call.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_with_default_system_prompt(
|
||||
mock_llm_call, mock_send_discord, db_session, sample_user, sample_discord_user
|
||||
):
|
||||
@ -254,16 +257,12 @@ def test_execute_scheduled_call_with_default_system_prompt(
|
||||
|
||||
scheduled_calls.execute_scheduled_call(call.id)
|
||||
|
||||
# Verify default system prompt was used
|
||||
mock_llm_call.assert_called_once_with(
|
||||
prompt="Test prompt",
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
system_prompt=None, # The code uses system_prompt as-is, not a default
|
||||
)
|
||||
# Verify LLM was called
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_discord_error(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -286,26 +285,27 @@ def test_execute_scheduled_call_discord_error(
|
||||
assert pending_scheduled_call.data["discord_error"] == "Discord API error"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_llm_error(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
"""Test execution when LLM call fails."""
|
||||
mock_llm_call.side_effect = Exception("LLM API error")
|
||||
|
||||
# The safe_task_execution decorator should catch this
|
||||
# The execute_scheduled_call function catches the exception and returns an error response
|
||||
result = scheduled_calls.execute_scheduled_call(pending_scheduled_call.id)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "LLM API error" in result["error"]
|
||||
assert result["success"] is False
|
||||
assert "error" in result
|
||||
assert "LLM call failed" in result["error"]
|
||||
|
||||
# Discord should not be called
|
||||
mock_send_discord.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_long_response_truncation(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -477,8 +477,8 @@ def test_run_scheduled_calls_timezone_handling(
|
||||
mock_execute_delay.delay.assert_called_once_with(due_call.id)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.scheduled_calls._send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.send_to_discord")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_status_transition_pending_to_executing_to_completed(
|
||||
mock_llm_call, mock_send_discord, pending_scheduled_call, db_session
|
||||
):
|
||||
@ -502,14 +502,14 @@ def test_status_transition_pending_to_executing_to_completed(
|
||||
"has_discord_user,has_discord_channel,expected_method",
|
||||
[
|
||||
(True, False, "send_dm"),
|
||||
(False, True, "broadcast_message"),
|
||||
(True, True, "send_dm"), # User takes precedence
|
||||
(False, True, "send_to_channel"),
|
||||
(True, True, "send_to_channel"), # Channel takes precedence in the implementation
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.broadcast_message")
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
@patch("memory.discord.messages.discord.send_to_channel")
|
||||
def test_discord_destination_priority(
|
||||
mock_broadcast,
|
||||
mock_send_to_channel,
|
||||
mock_send_dm,
|
||||
has_discord_user,
|
||||
has_discord_channel,
|
||||
@ -535,50 +535,39 @@ def test_discord_destination_priority(
|
||||
db_session.commit()
|
||||
|
||||
response = "Test response"
|
||||
scheduled_calls.send_to_discord(call, response)
|
||||
scheduled_calls.send_to_discord(999999999, call, response)
|
||||
|
||||
if expected_method == "send_dm":
|
||||
mock_send_dm.assert_called_once()
|
||||
mock_broadcast.assert_not_called()
|
||||
mock_send_to_channel.assert_not_called()
|
||||
else:
|
||||
mock_broadcast.assert_called_once()
|
||||
mock_send_to_channel.assert_called_once()
|
||||
mock_send_dm.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topic,model,response,expected_in_message",
|
||||
"topic,model,response",
|
||||
[
|
||||
(
|
||||
"Weather Check",
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
"It's sunny!",
|
||||
[
|
||||
"**Topic:** Weather Check",
|
||||
"**Model:** anthropic/claude-3-5-sonnet-20241022",
|
||||
"**Response:** It's sunny!",
|
||||
],
|
||||
),
|
||||
(
|
||||
"Test Topic",
|
||||
"gpt-4",
|
||||
"Hello world",
|
||||
["**Topic:** Test Topic", "**Model:** gpt-4", "**Response:** Hello world"],
|
||||
),
|
||||
(
|
||||
"Long Topic Name Here",
|
||||
"claude-2",
|
||||
"Short",
|
||||
[
|
||||
"**Topic:** Long Topic Name Here",
|
||||
"**Model:** claude-2",
|
||||
"**Response:** Short",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.discord.send_dm")
|
||||
def test_message_formatting(mock_send_dm, topic, model, response, expected_in_message):
|
||||
"""Test the Discord message formatting with different inputs."""
|
||||
@patch("memory.discord.messages.discord.send_dm")
|
||||
def test_message_formatting(mock_send_dm, topic, model, response):
|
||||
"""Test that _send_to_discord sends the response as-is."""
|
||||
# Create a mock scheduled call with a mock Discord user
|
||||
mock_discord_user = Mock()
|
||||
mock_discord_user.username = "testuser"
|
||||
@ -590,16 +579,15 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
mock_call.discord_user = mock_discord_user
|
||||
mock_call.discord_channel = None
|
||||
|
||||
scheduled_calls.send_to_discord(mock_call, response)
|
||||
scheduled_calls.send_to_discord(999999999, mock_call, response)
|
||||
|
||||
# Get the actual message that was sent
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
assert args[0] == mock_call.user_id
|
||||
assert args[0] == 999999999 # bot_id
|
||||
actual_message = args[2]
|
||||
|
||||
# Verify all expected parts are in the message
|
||||
for expected_part in expected_in_message:
|
||||
assert expected_part in actual_message
|
||||
# The new implementation sends the response as-is, without formatting
|
||||
assert actual_message == response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -612,7 +600,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
("cancelled", False),
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.scheduled_calls.llms.summarize")
|
||||
@patch("memory.workers.tasks.scheduled_calls.call_llm_for_scheduled")
|
||||
def test_execute_scheduled_call_status_check(
|
||||
mock_llm_call, status, should_execute, db_session, sample_user, sample_discord_user
|
||||
):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user