migrate to fastmcp

This commit is contained in:
Daniel O'Connell 2025-12-24 23:19:41 +01:00
parent d3d71edf1d
commit 48c380b903
18 changed files with 703 additions and 549 deletions

View File

@ -1,7 +1,8 @@
fastapi==0.112.2
fastapi>=0.115.12
uvicorn==0.29.0
python-jose==3.3.0
python-multipart==0.0.9
sqladmin==0.20.1
mcp==1.10.0
fastmcp>=2.10.0
slowapi==0.1.9

View File

@ -1,14 +1,14 @@
sqlalchemy==2.0.30
psycopg2-binary==2.9.9
pydantic==2.7.2
pydantic>=2.11.7
alembic==1.13.1
dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
anthropic==0.69.0
openai==2.3.0
# Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
httpx>=0.28.1
celery[redis,sqs]==5.3.6
cryptography==43.0.0
bcrypt==4.1.2

View File

@ -1,8 +1,16 @@
import memory.api.MCP.tools
import memory.api.MCP.memory
import memory.api.MCP.metadata
import memory.api.MCP.schedules
import memory.api.MCP.books
import memory.api.MCP.manifest
import memory.api.MCP.github
import memory.api.MCP.people
"""
MCP server with composed subservers.
Subservers are mounted with prefixes:
- core: search_knowledge_base, observe, search_observations, create_note, note_files, fetch_file
- github: list_github_issues, search_github_issues, github_issue_details, github_work_summary, github_repo_overview
- people: add_person, update_person_info, get_person, list_people, delete_person
- schedule: schedule_message, list_scheduled_llm_calls, cancel_scheduled_llm_call
- books: all_books, read_book
- meta: get_metadata_schemas, get_all_tags, get_all_subjects, get_all_observation_types, get_current_time, get_authenticated_user, get_forecasts
"""
# Import base to trigger subserver mounting
from memory.api.MCP.base import mcp, get_current_user
__all__ = ["mcp", "get_current_user"]

View File

@ -1,60 +1,28 @@
import logging
import pathlib
from typing import cast
from mcp.server.auth.handlers.authorize import AuthorizationRequest
from mcp.server.auth.handlers.token import (
AuthorizationCodeRequest,
RefreshTokenRequest,
TokenRequest,
)
from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
from mcp.server.fastmcp import FastMCP
from mcp.shared.auth import OAuthClientMetadata
from pydantic import AnyHttpUrl
from fastmcp import FastMCP
from fastmcp.server.dependencies import get_access_token
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.templating import Jinja2Templates
from memory.api.MCP.oauth_provider import (
ALLOWED_SCOPES,
BASE_SCOPES,
SimpleOAuthProvider,
)
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
from memory.api.MCP.servers.books import books_mcp
from memory.api.MCP.servers.core import core_mcp
from memory.api.MCP.servers.github import github_mcp
from memory.api.MCP.servers.meta import meta_mcp
from memory.api.MCP.servers.meta import set_auth_provider as set_meta_auth
from memory.api.MCP.servers.people import people_mcp
from memory.api.MCP.servers.schedule import schedule_mcp
from memory.api.MCP.servers.schedule import set_auth_provider as set_schedule_auth
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.connection import make_session, get_engine
from memory.common.db.models import OAuthState, UserSession
from memory.common.db.models.users import HumanUser
logger = logging.getLogger(__name__)
def validate_metadata(klass: type):
orig_validate = klass.model_validate
def validate(data: dict):
data = dict(data)
if "redirect_uris" in data:
data["redirect_uris"] = [
str(uri).replace("cursor://", "http://")
for uri in data["redirect_uris"]
]
if "redirect_uri" in data:
data["redirect_uri"] = str(data["redirect_uri"]).replace(
"cursor://", "http://"
)
return orig_validate(data)
klass.model_validate = validate
validate_metadata(OAuthClientMetadata)
validate_metadata(AuthorizationRequest)
validate_metadata(AuthorizationCodeRequest)
validate_metadata(RefreshTokenRequest)
validate_metadata(TokenRequest)
engine = get_engine()
# Setup templates
@ -63,22 +31,10 @@ templates = Jinja2Templates(directory=template_dir)
oauth_provider = SimpleOAuthProvider()
auth_settings = AuthSettings(
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore
client_registration_options=ClientRegistrationOptions(
enabled=True,
valid_scopes=ALLOWED_SCOPES,
default_scopes=BASE_SCOPES,
),
required_scopes=BASE_SCOPES,
)
mcp = FastMCP(
"memory",
stateless_http=True,
auth_server_provider=oauth_provider,
auth=auth_settings,
auth=oauth_provider,
)
@ -162,3 +118,52 @@ def get_current_user() -> dict:
"client_id": access_token.client_id,
"user": user_info,
}
@mcp.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
"""Health check endpoint that verifies all dependencies are accessible."""
from sqlalchemy import text
checks = {"mcp_oauth": "enabled"}
all_healthy = True
# Check database connection
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
checks["database"] = "healthy"
except Exception as e:
logger.error(f"Database health check failed: {e}")
checks["database"] = "unhealthy"
all_healthy = False
# Check Qdrant connection
try:
from memory.common.qdrant import get_qdrant_client
client = get_qdrant_client()
client.get_collections()
checks["qdrant"] = "healthy"
except Exception as e:
logger.error(f"Qdrant health check failed: {e}")
checks["qdrant"] = "unhealthy"
all_healthy = False
checks["status"] = "healthy" if all_healthy else "degraded"
status_code = 200 if all_healthy else 503
return JSONResponse(checks, status_code=status_code)
# Inject auth provider into subservers that need it
set_schedule_auth(get_current_user)
set_meta_auth(get_current_user)
# Mount all subservers onto the main MCP server
# Tools will be prefixed with their server name (e.g., core_search_knowledge_base)
mcp.mount(core_mcp, prefix="core")
mcp.mount(github_mcp, prefix="github")
mcp.mount(people_mcp, prefix="people")
mcp.mount(schedule_mcp, prefix="schedule")
mcp.mount(books_mcp, prefix="books")
mcp.mount(meta_mcp, prefix="meta")

View File

@ -1,119 +0,0 @@
import asyncio
import aiohttp
from datetime import datetime
from typing import TypedDict, NotRequired, Literal
from memory.api.MCP.tools import mcp
class BinaryProbs(TypedDict):
prob: float
class MultiProbs(TypedDict):
answerProbs: dict[str, float]
Probs = dict[str, BinaryProbs | MultiProbs]
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
class MarketAnswer(TypedDict):
id: str
text: str
resolutionProbability: float
class MarketDetails(TypedDict):
id: str
createdTime: int
question: str
outcomeType: OutcomeType
textDescription: str
groupSlugs: list[str]
volume: float
isResolved: bool
answers: list[MarketAnswer]
class Market(TypedDict):
id: str
url: str
question: str
volume: int
createdTime: int
outcomeType: OutcomeType
createdAt: NotRequired[str]
description: NotRequired[str]
answers: NotRequired[dict[str, float]]
probability: NotRequired[float]
details: NotRequired[MarketDetails]
async def get_details(session: aiohttp.ClientSession, market_id: str):
async with session.get(
f"https://api.manifold.markets/v0/market/{market_id}"
) as resp:
resp.raise_for_status()
return await resp.json()
async def format_market(session: aiohttp.ClientSession, market: Market):
if market.get("outcomeType") != "BINARY":
details = await get_details(session, market["id"])
market["answers"] = {
answer["text"]: round(
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
)
for answer in details["answers"]
}
if creationTime := market.get("createdTime"):
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
fields = [
"id",
"name",
"url",
"question",
"volume",
"createdAt",
"details",
"probability",
"answers",
]
return {k: v for k, v in market.items() if k in fields}
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.manifold.markets/v0/search-markets",
params={
"term": term,
"contractType": "BINARY" if binary else "ALL",
},
) as resp:
resp.raise_for_status()
markets = await resp.json()
return await asyncio.gather(
*[
format_market(session, market)
for market in markets
if market.get("volume", 0) >= min_volume
]
)
@mcp.tool()
async def get_forecasts(
term: str, min_volume: int = 1000, binary: bool = False
) -> list[dict]:
"""Get prediction market forecasts for a given term.
Args:
term: The term to search for.
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
binary: Whether to only return binary markets.
"""
return await search_markets(term, min_volume, binary)

View File

@ -1,119 +0,0 @@
import logging
from collections import defaultdict
from typing import Annotated, TypedDict, get_args, get_type_hints
from memory.common import qdrant
from sqlalchemy import func
from memory.api.MCP.tools import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import SourceItem
from memory.common.db.models.source_items import AgentObservation
logger = logging.getLogger(__name__)
class SchemaArg(TypedDict):
type: str | None
description: str | None
class CollectionMetadata(TypedDict):
schema: dict[str, SchemaArg]
size: int
def from_annotation(annotation: Annotated) -> SchemaArg | None:
try:
type_, description = get_args(annotation)
type_str = str(type_)
if type_str.startswith("typing."):
type_str = type_str[7:]
elif len((parts := type_str.split("'"))) > 1:
type_str = parts[1]
return SchemaArg(type=type_str, description=description)
except IndexError:
logger.error(f"Error from annotation: {annotation}")
return None
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
if not hasattr(klass, "as_payload"):
return {}
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
return {}
return {
name: schema
for name, arg in payload_type.__annotations__.items()
if (schema := from_annotation(arg))
}
@mcp.tool()
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
"""Get the metadata schema for each collection used in the knowledge base.
These schemas can be used to filter the knowledge base.
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
Example:
```
{
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
}
"""
client = qdrant.get_qdrant_client()
sizes = qdrant.get_collection_sizes(client)
schemas = defaultdict(dict)
for klass in SourceItem.__subclasses__():
for collection in klass.get_collections():
schemas[collection].update(get_schema(klass))
return {
collection: CollectionMetadata(schema=schema, size=size)
for collection, schema in schemas.items()
if (size := sizes.get(collection))
}
@mcp.tool()
async def get_all_tags() -> list[str]:
"""Get all unique tags used across the entire knowledge base.
Returns sorted list of tags from both observations and content.
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@mcp.tool()
async def get_all_subjects() -> list[str]:
"""Get all unique subjects from observations about the user.
Returns sorted list of subject identifiers used in observations.
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@mcp.tool()
async def get_all_observation_types() -> list[str]:
"""Get all observation types that have been used.
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)

View File

@ -5,11 +5,12 @@ from datetime import datetime, timezone
from typing import Any, Optional, cast
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from fastmcp.server.auth import OAuthProvider
from fastmcp.server.auth.auth import AccessToken as FastMCPAccessToken
from mcp.server.auth.provider import (
AccessToken,
AuthorizationCode,
AuthorizationParams,
OAuthAuthorizationServerProvider,
RefreshToken,
)
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
@ -133,8 +134,45 @@ def make_token(
)
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
class SimpleOAuthProvider(OAuthProvider):
"""OAuth provider that extends fastmcp's OAuthProvider with custom login flow."""
def __init__(self):
super().__init__(
base_url=settings.SERVER_URL,
issuer_url=settings.SERVER_URL,
client_registration_options=None, # We handle registration ourselves
required_scopes=BASE_SCOPES,
)
async def verify_token(self, token: str) -> FastMCPAccessToken | None:
"""Verify an access token and return token info if valid."""
with make_session() as session:
# Try as OAuth access token first
user_session = session.query(UserSession).get(token)
if user_session:
now = datetime.now(timezone.utc).replace(tzinfo=None)
if user_session.expires_at < now:
return None
return FastMCPAccessToken(
token=token,
client_id=user_session.oauth_state.client_id,
scopes=user_session.oauth_state.scopes,
)
# Try as bot API key
bot = session.query(User).filter(User.api_key == token).first()
if bot:
logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key")
return FastMCPAccessToken(
token=token,
client_id=cast(str, bot.name or bot.email),
scopes=["read", "write"],
)
return None
def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""Get OAuth client information."""
with make_session() as session:
client = session.get(OAuthClientInformation, client_id)

View File

@ -0,0 +1,17 @@
"""MCP subservers for composable tool organization."""
from memory.api.MCP.servers.core import core_mcp
from memory.api.MCP.servers.github import github_mcp
from memory.api.MCP.servers.people import people_mcp
from memory.api.MCP.servers.schedule import schedule_mcp
from memory.api.MCP.servers.books import books_mcp
from memory.api.MCP.servers.meta import meta_mcp
__all__ = [
"core_mcp",
"github_mcp",
"people_mcp",
"schedule_mcp",
"books_mcp",
"meta_mcp",
]

View File

@ -1,15 +1,19 @@
"""MCP subserver for ebook access."""
import logging
from fastmcp import FastMCP
from sqlalchemy.orm import joinedload
from memory.api.MCP.tools import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import Book, BookSection, BookSectionPayload
logger = logging.getLogger(__name__)
books_mcp = FastMCP("memory-books")
@mcp.tool()
@books_mcp.tool()
async def all_books(sections: bool = False) -> list[dict]:
"""
Get all books in the database.
@ -31,7 +35,7 @@ async def all_books(sections: bool = False) -> list[dict]:
return [book.as_payload(sections=sections) for book in books]
@mcp.tool()
@books_mcp.tool()
def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]:
"""
Read a book from the database.

View File

@ -1,53 +1,45 @@
"""
MCP tools for the epistemic sparring partner system.
Core MCP subserver for knowledge base search, observations, and notes.
"""
import base64
import logging
import pathlib
from datetime import datetime, timezone
from PIL import Image
from fastmcp import FastMCP
from PIL import Image
from pydantic import BaseModel
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.base import mcp
from memory.api.search.search import search
from memory.api.search.types import SearchFilters, SearchConfig
from memory.api.search.types import SearchConfig, SearchFilters
from memory.common import extract, settings
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
from memory.common.celery_app import app as celery_app
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
from memory.common.db.connection import make_session
from memory.common.db.models import SourceItem, AgentObservation
from memory.common.db.models import AgentObservation, SourceItem
from memory.common.formatters import observation
logger = logging.getLogger(__name__)
core_mcp = FastMCP("memory-core")
def validate_path_within_directory(
base_dir: pathlib.Path, requested_path: str
) -> pathlib.Path:
"""Validate that a requested path resolves within the base directory.
Prevents path traversal attacks using ../ or similar techniques.
Args:
base_dir: The allowed base directory
requested_path: The user-provided path
Returns:
The resolved absolute path if valid
Raises:
ValueError: If the path would escape the base directory
"""
"""Validate that a requested path resolves within the base directory."""
resolved = (base_dir / requested_path.lstrip("/")).resolve()
base_resolved = base_dir.resolve()
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
if (
not str(resolved).startswith(str(base_resolved) + "/")
and resolved != base_resolved
):
raise ValueError(f"Path escapes allowed directory: {requested_path}")
return resolved
@ -63,7 +55,6 @@ def filter_observation_source_ids(
items_query = session.query(AgentObservation.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
@ -89,7 +80,6 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
items_query = session.query(SourceItem.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
@ -102,7 +92,7 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
return source_ids
@mcp.tool()
@core_mcp.tool()
async def search_knowledge_base(
query: str,
filters: SearchFilters = {},
@ -141,7 +131,6 @@ async def search_knowledge_base(
if not modalities:
modalities = set(ALL_COLLECTIONS.keys())
# Filter to valid collections, excluding observation collections
modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS
search_filters = SearchFilters(**filters)
@ -167,7 +156,7 @@ class RawObservation(BaseModel):
tags: list[str] = []
@mcp.tool()
@core_mcp.tool()
async def observe(
observations: list[RawObservation],
session_id: str | None = None,
@ -212,23 +201,23 @@ async def observe(
logger.info("MCP: Observing")
tasks = [
(
observation,
obs,
celery_app.send_task(
SYNC_OBSERVATION,
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
kwargs={
"subject": observation.subject,
"content": observation.content,
"observation_type": observation.observation_type,
"confidences": observation.confidences,
"evidence": observation.evidence,
"tags": observation.tags,
"subject": obs.subject,
"content": obs.content,
"observation_type": obs.observation_type,
"confidences": obs.confidences,
"evidence": obs.evidence,
"tags": obs.tags,
"session_id": session_id,
"agent_model": agent_model,
},
),
)
for observation in observations
for obs in observations
]
def short_content(obs: RawObservation) -> str:
@ -242,7 +231,7 @@ async def observe(
}
@mcp.tool()
@core_mcp.tool()
async def search_observations(
query: str,
subject: str = "",
@ -308,7 +297,7 @@ async def search_observations(
]
@mcp.tool()
@core_mcp.tool()
async def create_note(
subject: str,
content: str,
@ -362,7 +351,7 @@ async def create_note(
}
@mcp.tool()
@core_mcp.tool()
async def note_files(path: str = "/"):
"""
List note files in the user's note storage.
@ -386,7 +375,7 @@ async def note_files(path: str = "/"):
]
@mcp.tool()
@core_mcp.tool()
def fetch_file(filename: str) -> dict:
"""
Read file content with automatic type detection.

View File

@ -1,14 +1,14 @@
"""MCP tools for GitHub issue tracking and management."""
"""MCP subserver for GitHub issue tracking and management."""
import logging
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import Text, case, desc, func, asc
from fastmcp import FastMCP
from sqlalchemy import Text, case, desc, func
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.base import mcp
from memory.api.search.search import search
from memory.api.search.types import SearchConfig, SearchFilters
from memory.common import extract
@ -17,6 +17,8 @@ from memory.common.db.models import GithubItem
logger = logging.getLogger(__name__)
github_mcp = FastMCP("memory-github")
def _build_github_url(repo_path: str, number: int | None, kind: str) -> str:
"""Build GitHub URL from repo path and issue/PR number."""
@ -53,7 +55,6 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
}
if include_content:
result["content"] = item.content
# Include PR-specific data if available
if item.kind == "pr" and item.pr_data:
result["pr_data"] = {
"additions": item.pr_data.additions,
@ -62,12 +63,12 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
"files": item.pr_data.files,
"reviews": item.pr_data.reviews,
"review_comments": item.pr_data.review_comments,
"diff": item.pr_data.diff, # Decompressed via property
"diff": item.pr_data.diff,
}
return result
@mcp.tool()
@github_mcp.tool()
async def list_github_issues(
repo: str | None = None,
assignee: str | None = None,
@ -109,64 +110,48 @@ async def list_github_issues(
with make_session() as session:
query = session.query(GithubItem)
# Apply filters
if repo:
query = query.filter(GithubItem.repo_path == repo)
if assignee:
query = query.filter(GithubItem.assignees.any(assignee))
if author:
query = query.filter(GithubItem.author == author)
if state:
query = query.filter(GithubItem.state == state)
if kind:
query = query.filter(GithubItem.kind == kind)
else:
# Exclude comments by default, only show issues and PRs
query = query.filter(GithubItem.kind.in_(["issue", "pr"]))
if labels:
# Match any label in the list using PostgreSQL array overlap
query = query.filter(
GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text)))
)
if project_status:
query = query.filter(GithubItem.project_status == project_status)
if project_field:
for key, value in project_field.items():
query = query.filter(
GithubItem.project_fields[key].astext == value
)
query = query.filter(GithubItem.project_fields[key].astext == value)
if updated_since:
since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00"))
query = query.filter(GithubItem.github_updated_at >= since_dt)
if updated_before:
before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00"))
query = query.filter(GithubItem.github_updated_at <= before_dt)
# Apply ordering
if order_by == "created":
query = query.order_by(desc(GithubItem.created_at))
elif order_by == "number":
query = query.order_by(desc(GithubItem.number))
else: # default: updated
else:
query = query.order_by(desc(GithubItem.github_updated_at))
query = query.limit(limit)
items = query.all()
return [_serialize_issue(item) for item in items]
@mcp.tool()
@github_mcp.tool()
async def search_github_issues(
query: str,
repo: str | None = None,
@ -191,7 +176,6 @@ async def search_github_issues(
limit = min(limit, 100)
# Pre-filter source_ids if repo/state/kind filters are specified
source_ids = None
if repo or state or kind:
with make_session() as session:
@ -206,7 +190,6 @@ async def search_github_issues(
q = q.filter(GithubItem.kind.in_(["issue", "pr"]))
source_ids = [item.id for item in q.all()]
# Use the existing search infrastructure
data = extract.extract_text(query, skip_summary=True)
config = SearchConfig(limit=limit, previews=True)
filters = SearchFilters()
@ -220,7 +203,6 @@ async def search_github_issues(
config=config,
)
# Fetch full issue details for the results
output = []
with make_session() as session:
for result in results:
@ -233,7 +215,7 @@ async def search_github_issues(
return output
@mcp.tool()
@github_mcp.tool()
async def github_issue_details(
repo: str,
number: int,
@ -267,7 +249,7 @@ async def github_issue_details(
return _serialize_issue(item, include_content=True)
@mcp.tool()
@github_mcp.tool()
async def github_work_summary(
since: str,
until: str | None = None,
@ -294,7 +276,6 @@ async def github_work_summary(
else:
until_dt = datetime.now(timezone.utc)
# Map group_by to SQL expression
group_mappings = {
"client": GithubItem.project_fields["EquiStamp.Client"].astext,
"status": GithubItem.project_status,
@ -311,7 +292,6 @@ async def github_work_summary(
group_col = group_mappings[group_by]
with make_session() as session:
# Build base query for the period
base_query = session.query(GithubItem).filter(
GithubItem.github_updated_at >= since_dt,
GithubItem.github_updated_at <= until_dt,
@ -321,7 +301,6 @@ async def github_work_summary(
if repo:
base_query = base_query.filter(GithubItem.repo_path == repo)
# Get aggregated counts by group
agg_query = (
session.query(
group_col.label("group_name"),
@ -346,7 +325,6 @@ async def github_work_summary(
groups = agg_query.all()
# Build summary with sample issues for each group
summary = []
total_issues = 0
total_prs = 0
@ -358,7 +336,6 @@ async def github_work_summary(
total_issues += issue_count
total_prs += pr_count
# Get sample issues for this group
sample_query = base_query.filter(group_col == group_name).limit(5)
samples = [
{
@ -394,7 +371,7 @@ async def github_work_summary(
}
@mcp.tool()
@github_mcp.tool()
async def github_repo_overview(
repo: str,
) -> dict:
@ -410,13 +387,6 @@ async def github_repo_overview(
logger.info(f"github_repo_overview called: repo={repo}")
with make_session() as session:
# Base query for this repo
base_query = session.query(GithubItem).filter(
GithubItem.repo_path == repo,
GithubItem.kind.in_(["issue", "pr"]),
)
# Get total counts
counts_query = session.query(
func.count(GithubItem.id).label("total"),
func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"),
@ -441,7 +411,6 @@ async def github_repo_overview(
counts = counts_query.first()
# Status breakdown (for project_status)
status_query = (
session.query(
GithubItem.project_status.label("status"),
@ -458,7 +427,6 @@ async def github_repo_overview(
status_breakdown = {row.status: row.count for row in status_query.all()}
# Top assignees (open issues only)
assignee_query = (
session.query(
func.unnest(GithubItem.assignees).label("assignee"),
@ -479,7 +447,6 @@ async def github_repo_overview(
for row in assignee_query.all()
]
# Label counts
label_query = (
session.query(
func.unnest(GithubItem.labels).label("label"),

View File

@ -0,0 +1,277 @@
"""MCP subserver for metadata, utilities, and forecasting."""
import asyncio
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import Annotated, Literal, NotRequired, TypedDict, get_args, get_type_hints
import aiohttp
from fastmcp import FastMCP
from sqlalchemy import func
from memory.common import qdrant
from memory.common.db.connection import make_session
from memory.common.db.models import SourceItem
from memory.common.db.models.source_items import AgentObservation
logger = logging.getLogger(__name__)
meta_mcp = FastMCP("memory-meta")
# Auth provider will be injected at mount time
_get_current_user = None
def set_auth_provider(get_current_user_func):
"""Set the authentication provider function."""
global _get_current_user
_get_current_user = get_current_user_func
def get_current_user() -> dict:
"""Get the current authenticated user."""
if _get_current_user is None:
return {"authenticated": False, "error": "Auth provider not configured"}
return _get_current_user()
# --- Metadata tools ---
class SchemaArg(TypedDict):
type: str | None
description: str | None
class CollectionMetadata(TypedDict):
schema: dict[str, SchemaArg]
size: int
def from_annotation(annotation: Annotated) -> SchemaArg | None:
try:
type_, description = get_args(annotation)
type_str = str(type_)
if type_str.startswith("typing."):
type_str = type_str[7:]
elif len((parts := type_str.split("'"))) > 1:
type_str = parts[1]
return SchemaArg(type=type_str, description=description)
except IndexError:
logger.error(f"Error from annotation: {annotation}")
return None
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
if not hasattr(klass, "as_payload"):
return {}
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
return {}
return {
name: schema
for name, arg in payload_type.__annotations__.items()
if (schema := from_annotation(arg))
}
@meta_mcp.tool()
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
"""Get the metadata schema for each collection used in the knowledge base.
These schemas can be used to filter the knowledge base.
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
Example:
```
{
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
}
"""
client = qdrant.get_qdrant_client()
sizes = qdrant.get_collection_sizes(client)
schemas = defaultdict(dict)
for klass in SourceItem.__subclasses__():
for collection in klass.get_collections():
schemas[collection].update(get_schema(klass))
return {
collection: CollectionMetadata(schema=schema, size=size)
for collection, schema in schemas.items()
if (size := sizes.get(collection))
}
@meta_mcp.tool()
async def get_all_tags() -> list[str]:
"""Get all unique tags used across the entire knowledge base.
Returns sorted list of tags from both observations and content.
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@meta_mcp.tool()
async def get_all_subjects() -> list[str]:
"""Get all unique subjects from observations about the user.
Returns sorted list of subject identifiers used in observations.
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@meta_mcp.tool()
async def get_all_observation_types() -> list[str]:
"""Get all observation types that have been used.
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)
# --- Utility tools ---
@meta_mcp.tool()
async def get_current_time() -> dict:
"""Get the current time in UTC."""
logger.info("get_current_time tool called")
return {"current_time": datetime.now(timezone.utc).isoformat()}
@meta_mcp.tool()
async def get_authenticated_user() -> dict:
"""Get information about the authenticated user."""
return get_current_user()
# --- Forecasting tools ---
class BinaryProbs(TypedDict):
prob: float
class MultiProbs(TypedDict):
answerProbs: dict[str, float]
Probs = dict[str, BinaryProbs | MultiProbs]
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
class MarketAnswer(TypedDict):
id: str
text: str
resolutionProbability: float
class MarketDetails(TypedDict):
id: str
createdTime: int
question: str
outcomeType: OutcomeType
textDescription: str
groupSlugs: list[str]
volume: float
isResolved: bool
answers: list[MarketAnswer]
class Market(TypedDict):
id: str
url: str
question: str
volume: int
createdTime: int
outcomeType: OutcomeType
createdAt: NotRequired[str]
description: NotRequired[str]
answers: NotRequired[dict[str, float]]
probability: NotRequired[float]
details: NotRequired[MarketDetails]
async def get_details(session: aiohttp.ClientSession, market_id: str):
async with session.get(
f"https://api.manifold.markets/v0/market/{market_id}"
) as resp:
resp.raise_for_status()
return await resp.json()
async def format_market(session: aiohttp.ClientSession, market: Market):
if market.get("outcomeType") != "BINARY":
details = await get_details(session, market["id"])
market["answers"] = {
answer["text"]: round(
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
)
for answer in details["answers"]
}
if creationTime := market.get("createdTime"):
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
fields = [
"id",
"name",
"url",
"question",
"volume",
"createdAt",
"details",
"probability",
"answers",
]
return {k: v for k, v in market.items() if k in fields}
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.manifold.markets/v0/search-markets",
params={
"term": term,
"contractType": "BINARY" if binary else "ALL",
},
) as resp:
resp.raise_for_status()
markets = await resp.json()
return await asyncio.gather(
*[
format_market(session, market)
for market in markets
if market.get("volume", 0) >= min_volume
]
)
@meta_mcp.tool()
async def get_forecasts(
term: str, min_volume: int = 1000, binary: bool = False
) -> list[dict]:
"""Get prediction market forecasts for a given term.
Args:
term: The term to search for.
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
binary: Whether to only return binary markets.
"""
return await search_markets(term, min_volume, binary)

View File

@ -1,23 +1,23 @@
"""
MCP tools for tracking people.
"""
"""MCP subserver for tracking people."""
import logging
from typing import Any
from fastmcp import FastMCP
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.base import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import Person
from memory.common import settings
from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON
from memory.common.celery_app import app as celery_app
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import Person
logger = logging.getLogger(__name__)
people_mcp = FastMCP("memory-people")
def _person_to_dict(person: Person) -> dict[str, Any]:
"""Convert a Person model to a dictionary for API responses."""
@ -32,7 +32,7 @@ def _person_to_dict(person: Person) -> dict[str, Any]:
}
@mcp.tool()
@people_mcp.tool()
async def add_person(
identifier: str,
display_name: str,
@ -67,7 +67,6 @@ async def add_person(
"""
logger.info(f"MCP: Adding person: {identifier}")
# Check if person already exists
with make_session() as session:
existing = session.query(Person).filter(Person.identifier == identifier).first()
if existing:
@ -93,7 +92,7 @@ async def add_person(
}
@mcp.tool()
@people_mcp.tool()
async def update_person_info(
identifier: str,
display_name: str | None = None,
@ -135,7 +134,6 @@ async def update_person_info(
"""
logger.info(f"MCP: Updating person: {identifier}")
# Verify person exists
with make_session() as session:
person = session.query(Person).filter(Person.identifier == identifier).first()
if not person:
@ -162,7 +160,7 @@ async def update_person_info(
}
@mcp.tool()
@people_mcp.tool()
async def get_person(identifier: str) -> dict | None:
"""
Get a person by their identifier.
@ -182,7 +180,7 @@ async def get_person(identifier: str) -> dict | None:
return _person_to_dict(person)
@mcp.tool()
@people_mcp.tool()
async def list_people(
tags: list[str] | None = None,
search: str | None = None,
@ -207,9 +205,7 @@ async def list_people(
query = session.query(Person)
if tags:
query = query.filter(
Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
)
query = query.filter(Person.tags.op("&&")(sql_cast(tags, ARRAY(Text))))
if search:
search_term = f"%{search.lower()}%"
@ -225,7 +221,7 @@ async def list_people(
return [_person_to_dict(p) for p in people]
@mcp.tool()
@people_mcp.tool()
async def delete_person(identifier: str) -> dict:
"""
Delete a person by their identifier.

View File

@ -1,20 +1,37 @@
"""
MCP tools for the epistemic sparring partner system.
"""
"""MCP subserver for scheduling messages and LLM calls."""
import logging
from datetime import datetime, timezone
from typing import Any, cast
from memory.api.MCP.base import get_current_user, mcp
from fastmcp import FastMCP
from memory.common.db.connection import make_session
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
from memory.common.db.models import DiscordBotUser, ScheduledLLMCall
from memory.discord.messages import schedule_discord_message
logger = logging.getLogger(__name__)
schedule_mcp = FastMCP("memory-schedule")
@mcp.tool()
# We need access to get_current_user from base - this will be injected at mount time
_get_current_user = None
def set_auth_provider(get_current_user_func):
"""Set the authentication provider function."""
global _get_current_user
_get_current_user = get_current_user_func
def get_current_user() -> dict:
"""Get the current authenticated user."""
if _get_current_user is None:
return {"authenticated": False, "error": "Auth provider not configured"}
return _get_current_user()
@schedule_mcp.tool()
async def schedule_message(
scheduled_time: str,
message: str,
@ -61,10 +78,8 @@ async def schedule_message(
if not discord_user and not discord_channel:
raise ValueError("Either discord_user or discord_channel must be provided")
# Parse scheduled time
try:
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
# Ensure we store as naive datetime (remove timezone info for database storage)
if scheduled_dt.tzinfo is not None:
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
except ValueError:
@ -98,7 +113,7 @@ async def schedule_message(
}
@mcp.tool()
@schedule_mcp.tool()
async def list_scheduled_llm_calls(
status: str | None = None, limit: int | None = 50
) -> dict[str, Any]:
@ -143,7 +158,7 @@ async def list_scheduled_llm_calls(
}
@mcp.tool()
@schedule_mcp.tool()
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
"""
Cancel a scheduled LLM call.
@ -164,7 +179,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
return {"error": "User not found", "user": current_user}
with make_session() as session:
# Find the scheduled call
scheduled_call = (
session.query(ScheduledLLMCall)
.filter(
@ -180,7 +194,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
if not scheduled_call.can_be_cancelled():
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
# Update the status
scheduled_call.status = "cancelled"
session.commit()

View File

@ -1,74 +0,0 @@
"""
MCP tools for the epistemic sparring partner system.
"""
import logging
from datetime import datetime, timezone
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.common.db.connection import make_session
from memory.common.db.models import AgentObservation, SourceItem
from memory.api.MCP.base import mcp, get_current_user
logger = logging.getLogger(__name__)
def filter_observation_source_ids(
tags: list[str] | None = None, observation_types: list[str] | None = None
):
if not tags and not observation_types:
return None
with make_session() as session:
items_query = session.query(AgentObservation.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if observation_types:
items_query = items_query.filter(
AgentObservation.observation_type.in_(observation_types)
)
source_ids = [item.id for item in items_query.all()]
return source_ids
def filter_source_ids(
modalities: set[str],
tags: list[str] | None = None,
):
if not tags:
return None
with make_session() as session:
items_query = session.query(SourceItem.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if modalities:
items_query = items_query.filter(SourceItem.modality.in_(modalities))
source_ids = [item.id for item in items_query.all()]
return source_ids
@mcp.tool()
async def get_current_time() -> dict:
"""Get the current time in UTC."""
logger.info("get_current_time tool called")
return {"current_time": datetime.now(timezone.utc).isoformat()}
@mcp.tool()
async def get_authenticated_user() -> dict:
"""Get information about the authenticated user."""
return get_current_user()

View File

@ -2,7 +2,6 @@
FastAPI application for the knowledge base.
"""
import contextlib
import os
import logging
import mimetypes
@ -35,15 +34,10 @@ limiter = Limiter(
enabled=settings.API_RATE_LIMIT_ENABLED,
)
# Create the MCP http app to get its lifespan
mcp_http_app = mcp.http_app(stateless_http=True)
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
async with contextlib.AsyncExitStack() as stack:
await stack.enter_async_context(mcp.session_manager.run())
yield
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
app = FastAPI(title="Knowledge Base API", lifespan=mcp_http_app.lifespan)
app.state.limiter = limiter
# Rate limit exception handler
@ -158,46 +152,9 @@ app.include_router(auth_router)
# Add health check to MCP server instead of main app
@mcp.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
"""Health check endpoint that verifies all dependencies are accessible."""
from fastapi.responses import JSONResponse
from sqlalchemy import text
checks = {"mcp_oauth": "enabled"}
all_healthy = True
# Check database connection
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
checks["database"] = "healthy"
except Exception as e:
# Log error details but don't expose in response
logger.error(f"Database health check failed: {e}")
checks["database"] = "unhealthy"
all_healthy = False
# Check Qdrant connection
try:
from memory.common.qdrant import get_qdrant_client
client = get_qdrant_client()
client.get_collections()
checks["qdrant"] = "healthy"
except Exception as e:
# Log error details but don't expose in response
logger.error(f"Qdrant health check failed: {e}")
checks["qdrant"] = "unhealthy"
all_healthy = False
checks["status"] = "healthy" if all_healthy else "degraded"
status_code = 200 if all_healthy else 503
return JSONResponse(checks, status_code=status_code)
# Mount MCP server at root - OAuth endpoints need to be at root level
app.mount("/", mcp.streamable_http_app())
# Health check is defined in MCP/base.py
app.mount("/", mcp_http_app)
def main(reload: bool = False):

View File

@ -58,7 +58,6 @@ sync_code() {
"$PROJECT_DIR/frontend" \
"$PROJECT_DIR/requirements" \
"$PROJECT_DIR/setup.py" \
"$PROJECT_DIR/pyproject.toml" \
"$PROJECT_DIR/docker-compose.yaml" \
"$PROJECT_DIR/pytest.ini" \
"$REMOTE_HOST:$REMOTE_DIR/"

195
tools/diagnose.sh Executable file
View File

@ -0,0 +1,195 @@
#!/bin/bash
set -e
REMOTE_HOST="memory"
REMOTE_DIR="/home/ec2-user/memory"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
usage() {
echo "Usage: $0 <command> [options]"
echo ""
echo "Safe diagnostic commands for the memory server."
echo ""
echo "Commands:"
echo " logs [service] [lines] View docker logs (default: all services, 100 lines)"
echo " ps Show docker container status"
echo " disk Show disk usage"
echo " mem Show memory usage"
echo " top Show running processes"
echo " ls <path> List directory contents"
echo " cat <file> View file contents"
echo " tail <file> [lines] Tail a file (default: 50 lines)"
echo " grep <pattern> <path> Search for pattern in files"
echo " db <query> Run read-only SQL query"
echo " get <path> [port] GET request to localhost (default port: 8000)"
echo " status Overall system status"
exit 1
}
remote() {
ssh "$REMOTE_HOST" "$@"
}
docker_logs() {
local service="${1:-}"
local lines="${2:-100}"
if [ -n "$service" ]; then
echo -e "${GREEN}Logs for $service (last $lines lines):${NC}"
remote "cd $REMOTE_DIR && docker compose logs --tail=$lines $service"
else
echo -e "${GREEN}All logs (last $lines lines):${NC}"
remote "cd $REMOTE_DIR && docker compose logs --tail=$lines"
fi
}
docker_ps() {
echo -e "${GREEN}Container status:${NC}"
remote "cd $REMOTE_DIR && docker compose ps"
}
disk_usage() {
echo -e "${GREEN}Disk usage:${NC}"
remote "df -h && echo '' && du -sh $REMOTE_DIR/* 2>/dev/null | sort -h"
}
mem_usage() {
echo -e "${GREEN}Memory usage:${NC}"
remote "free -h"
}
show_top() {
echo -e "${GREEN}Top processes:${NC}"
remote "ps aux --sort=-%mem | head -20"
}
list_dir() {
local path="${1:-.}"
# Ensure path is within project directory for safety
echo -e "${GREEN}Contents of $path:${NC}"
remote "cd $REMOTE_DIR && ls -la $path"
}
cat_file() {
local file="$1"
if [ -z "$file" ]; then
echo -e "${RED}Error: No file specified${NC}"
exit 1
fi
echo -e "${GREEN}Contents of $file:${NC}"
remote "cd $REMOTE_DIR && cat $file"
}
tail_file() {
local file="$1"
local lines="${2:-50}"
if [ -z "$file" ]; then
echo -e "${RED}Error: No file specified${NC}"
exit 1
fi
echo -e "${GREEN}Last $lines lines of $file:${NC}"
remote "cd $REMOTE_DIR && tail -n $lines $file"
}
grep_files() {
local pattern="$1"
local path="${2:-.}"
if [ -z "$pattern" ]; then
echo -e "${RED}Error: No pattern specified${NC}"
exit 1
fi
echo -e "${GREEN}Searching for '$pattern' in $path:${NC}"
remote "cd $REMOTE_DIR && grep -r --color=always '$pattern' $path || true"
}
db_query() {
local query="$1"
if [ -z "$query" ]; then
echo -e "${RED}Error: No query specified${NC}"
exit 1
fi
# Only allow SELECT queries for safety
if ! echo "$query" | grep -qi "^select"; then
echo -e "${RED}Error: Only SELECT queries are allowed${NC}"
exit 1
fi
echo -e "${GREEN}Running query:${NC}"
remote "cd $REMOTE_DIR && docker compose exec -T postgres psql -U kb -d kb -c \"$query\""
}
http_get() {
local path="$1"
local port="${2:-8000}"
if [ -z "$path" ]; then
echo -e "${RED}Error: No path specified${NC}"
exit 1
fi
# Ensure path starts with /
if [[ "$path" != /* ]]; then
path="/$path"
fi
echo -e "${GREEN}GET http://localhost:${port}${path}${NC}"
remote "curl -s -w '\n\nHTTP Status: %{http_code}\n' 'http://localhost:${port}${path}'"
}
system_status() {
echo -e "${GREEN}=== System Status ===${NC}"
echo ""
docker_ps
echo ""
echo -e "${GREEN}=== Memory ===${NC}"
mem_usage
echo ""
echo -e "${GREEN}=== Disk ===${NC}"
remote "df -h /"
echo ""
echo -e "${GREEN}=== Recent Errors (last 20 lines) ===${NC}"
remote "cd $REMOTE_DIR && docker compose logs --tail=100 2>&1 | grep -i -E '(error|exception|failed|fatal)' | tail -20 || echo 'No recent errors found'"
}
# Main
case "${1:-}" in
logs)
docker_logs "${2:-}" "${3:-}"
;;
ps)
docker_ps
;;
disk)
disk_usage
;;
mem)
mem_usage
;;
top)
show_top
;;
ls)
list_dir "${2:-}"
;;
cat)
cat_file "${2:-}"
;;
tail)
tail_file "${2:-}" "${3:-}"
;;
grep)
grep_files "${2:-}" "${3:-}"
;;
db)
db_query "${2:-}"
;;
get)
http_get "${2:-}" "${3:-}"
;;
status)
system_status
;;
*)
usage
;;
esac