mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
migrate to fastmcp
This commit is contained in:
parent
d3d71edf1d
commit
48c380b903
@ -1,7 +1,8 @@
|
||||
fastapi==0.112.2
|
||||
fastapi>=0.115.12
|
||||
uvicorn==0.29.0
|
||||
python-jose==3.3.0
|
||||
python-multipart==0.0.9
|
||||
sqladmin==0.20.1
|
||||
mcp==1.10.0
|
||||
fastmcp>=2.10.0
|
||||
slowapi==0.1.9
|
||||
@ -1,14 +1,14 @@
|
||||
sqlalchemy==2.0.30
|
||||
psycopg2-binary==2.9.9
|
||||
pydantic==2.7.2
|
||||
pydantic>=2.11.7
|
||||
alembic==1.13.1
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.69.0
|
||||
openai==2.3.0
|
||||
# Pin the httpx version, as newer versions break the anthropic client
|
||||
httpx==0.27.0
|
||||
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
|
||||
httpx>=0.28.1
|
||||
celery[redis,sqs]==5.3.6
|
||||
cryptography==43.0.0
|
||||
bcrypt==4.1.2
|
||||
@ -1,8 +1,16 @@
|
||||
import memory.api.MCP.tools
|
||||
import memory.api.MCP.memory
|
||||
import memory.api.MCP.metadata
|
||||
import memory.api.MCP.schedules
|
||||
import memory.api.MCP.books
|
||||
import memory.api.MCP.manifest
|
||||
import memory.api.MCP.github
|
||||
import memory.api.MCP.people
|
||||
"""
|
||||
MCP server with composed subservers.
|
||||
|
||||
Subservers are mounted with prefixes:
|
||||
- core: search_knowledge_base, observe, search_observations, create_note, note_files, fetch_file
|
||||
- github: list_github_issues, search_github_issues, github_issue_details, github_work_summary, github_repo_overview
|
||||
- people: add_person, update_person_info, get_person, list_people, delete_person
|
||||
- schedule: schedule_message, list_scheduled_llm_calls, cancel_scheduled_llm_call
|
||||
- books: all_books, read_book
|
||||
- meta: get_metadata_schemas, get_all_tags, get_all_subjects, get_all_observation_types, get_current_time, get_authenticated_user, get_forecasts
|
||||
"""
|
||||
|
||||
# Import base to trigger subserver mounting
|
||||
from memory.api.MCP.base import mcp, get_current_user
|
||||
|
||||
__all__ = ["mcp", "get_current_user"]
|
||||
|
||||
@ -1,60 +1,28 @@
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import cast
|
||||
|
||||
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
||||
from mcp.server.auth.handlers.token import (
|
||||
AuthorizationCodeRequest,
|
||||
RefreshTokenRequest,
|
||||
TokenRequest,
|
||||
)
|
||||
from mcp.server.auth.middleware.auth_context import get_access_token
|
||||
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
from pydantic import AnyHttpUrl
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, RedirectResponse
|
||||
from starlette.templating import Jinja2Templates
|
||||
|
||||
from memory.api.MCP.oauth_provider import (
|
||||
ALLOWED_SCOPES,
|
||||
BASE_SCOPES,
|
||||
SimpleOAuthProvider,
|
||||
)
|
||||
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
||||
from memory.api.MCP.servers.books import books_mcp
|
||||
from memory.api.MCP.servers.core import core_mcp
|
||||
from memory.api.MCP.servers.github import github_mcp
|
||||
from memory.api.MCP.servers.meta import meta_mcp
|
||||
from memory.api.MCP.servers.meta import set_auth_provider as set_meta_auth
|
||||
from memory.api.MCP.servers.people import people_mcp
|
||||
from memory.api.MCP.servers.schedule import schedule_mcp
|
||||
from memory.api.MCP.servers.schedule import set_auth_provider as set_schedule_auth
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.connection import make_session, get_engine
|
||||
from memory.common.db.models import OAuthState, UserSession
|
||||
from memory.common.db.models.users import HumanUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_metadata(klass: type):
|
||||
orig_validate = klass.model_validate
|
||||
|
||||
def validate(data: dict):
|
||||
data = dict(data)
|
||||
if "redirect_uris" in data:
|
||||
data["redirect_uris"] = [
|
||||
str(uri).replace("cursor://", "http://")
|
||||
for uri in data["redirect_uris"]
|
||||
]
|
||||
if "redirect_uri" in data:
|
||||
data["redirect_uri"] = str(data["redirect_uri"]).replace(
|
||||
"cursor://", "http://"
|
||||
)
|
||||
|
||||
return orig_validate(data)
|
||||
|
||||
klass.model_validate = validate
|
||||
|
||||
|
||||
validate_metadata(OAuthClientMetadata)
|
||||
validate_metadata(AuthorizationRequest)
|
||||
validate_metadata(AuthorizationCodeRequest)
|
||||
validate_metadata(RefreshTokenRequest)
|
||||
validate_metadata(TokenRequest)
|
||||
engine = get_engine()
|
||||
|
||||
|
||||
# Setup templates
|
||||
@ -63,22 +31,10 @@ templates = Jinja2Templates(directory=template_dir)
|
||||
|
||||
|
||||
oauth_provider = SimpleOAuthProvider()
|
||||
auth_settings = AuthSettings(
|
||||
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
||||
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore
|
||||
client_registration_options=ClientRegistrationOptions(
|
||||
enabled=True,
|
||||
valid_scopes=ALLOWED_SCOPES,
|
||||
default_scopes=BASE_SCOPES,
|
||||
),
|
||||
required_scopes=BASE_SCOPES,
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
"memory",
|
||||
stateless_http=True,
|
||||
auth_server_provider=oauth_provider,
|
||||
auth=auth_settings,
|
||||
auth=oauth_provider,
|
||||
)
|
||||
|
||||
|
||||
@ -162,3 +118,52 @@ def get_current_user() -> dict:
|
||||
"client_id": access_token.client_id,
|
||||
"user": user_info,
|
||||
}
|
||||
|
||||
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request):
|
||||
"""Health check endpoint that verifies all dependencies are accessible."""
|
||||
from sqlalchemy import text
|
||||
|
||||
checks = {"mcp_oauth": "enabled"}
|
||||
all_healthy = True
|
||||
|
||||
# Check database connection
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
checks["database"] = "healthy"
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
checks["database"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
# Check Qdrant connection
|
||||
try:
|
||||
from memory.common.qdrant import get_qdrant_client
|
||||
|
||||
client = get_qdrant_client()
|
||||
client.get_collections()
|
||||
checks["qdrant"] = "healthy"
|
||||
except Exception as e:
|
||||
logger.error(f"Qdrant health check failed: {e}")
|
||||
checks["qdrant"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
checks["status"] = "healthy" if all_healthy else "degraded"
|
||||
status_code = 200 if all_healthy else 503
|
||||
return JSONResponse(checks, status_code=status_code)
|
||||
|
||||
|
||||
# Inject auth provider into subservers that need it
|
||||
set_schedule_auth(get_current_user)
|
||||
set_meta_auth(get_current_user)
|
||||
|
||||
# Mount all subservers onto the main MCP server
|
||||
# Tools will be prefixed with their server name (e.g., core_search_knowledge_base)
|
||||
mcp.mount(core_mcp, prefix="core")
|
||||
mcp.mount(github_mcp, prefix="github")
|
||||
mcp.mount(people_mcp, prefix="people")
|
||||
mcp.mount(schedule_mcp, prefix="schedule")
|
||||
mcp.mount(books_mcp, prefix="books")
|
||||
mcp.mount(meta_mcp, prefix="meta")
|
||||
|
||||
@ -1,119 +0,0 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from datetime import datetime
|
||||
|
||||
from typing import TypedDict, NotRequired, Literal
|
||||
from memory.api.MCP.tools import mcp
|
||||
|
||||
|
||||
class BinaryProbs(TypedDict):
|
||||
prob: float
|
||||
|
||||
|
||||
class MultiProbs(TypedDict):
|
||||
answerProbs: dict[str, float]
|
||||
|
||||
|
||||
Probs = dict[str, BinaryProbs | MultiProbs]
|
||||
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
|
||||
|
||||
|
||||
class MarketAnswer(TypedDict):
|
||||
id: str
|
||||
text: str
|
||||
resolutionProbability: float
|
||||
|
||||
|
||||
class MarketDetails(TypedDict):
|
||||
id: str
|
||||
createdTime: int
|
||||
question: str
|
||||
outcomeType: OutcomeType
|
||||
textDescription: str
|
||||
groupSlugs: list[str]
|
||||
volume: float
|
||||
isResolved: bool
|
||||
answers: list[MarketAnswer]
|
||||
|
||||
|
||||
class Market(TypedDict):
|
||||
id: str
|
||||
url: str
|
||||
question: str
|
||||
volume: int
|
||||
createdTime: int
|
||||
outcomeType: OutcomeType
|
||||
createdAt: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
answers: NotRequired[dict[str, float]]
|
||||
probability: NotRequired[float]
|
||||
details: NotRequired[MarketDetails]
|
||||
|
||||
|
||||
async def get_details(session: aiohttp.ClientSession, market_id: str):
|
||||
async with session.get(
|
||||
f"https://api.manifold.markets/v0/market/{market_id}"
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
|
||||
async def format_market(session: aiohttp.ClientSession, market: Market):
|
||||
if market.get("outcomeType") != "BINARY":
|
||||
details = await get_details(session, market["id"])
|
||||
market["answers"] = {
|
||||
answer["text"]: round(
|
||||
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
|
||||
)
|
||||
for answer in details["answers"]
|
||||
}
|
||||
if creationTime := market.get("createdTime"):
|
||||
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"url",
|
||||
"question",
|
||||
"volume",
|
||||
"createdAt",
|
||||
"details",
|
||||
"probability",
|
||||
"answers",
|
||||
]
|
||||
return {k: v for k, v in market.items() if k in fields}
|
||||
|
||||
|
||||
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.manifold.markets/v0/search-markets",
|
||||
params={
|
||||
"term": term,
|
||||
"contractType": "BINARY" if binary else "ALL",
|
||||
},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
markets = await resp.json()
|
||||
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
format_market(session, market)
|
||||
for market in markets
|
||||
if market.get("volume", 0) >= min_volume
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_forecasts(
|
||||
term: str, min_volume: int = 1000, binary: bool = False
|
||||
) -> list[dict]:
|
||||
"""Get prediction market forecasts for a given term.
|
||||
|
||||
Args:
|
||||
term: The term to search for.
|
||||
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
|
||||
binary: Whether to only return binary markets.
|
||||
"""
|
||||
return await search_markets(term, min_volume, binary)
|
||||
@ -1,119 +0,0 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, TypedDict, get_args, get_type_hints
|
||||
|
||||
from memory.common import qdrant
|
||||
from sqlalchemy import func
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.common.db.models.source_items import AgentObservation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchemaArg(TypedDict):
|
||||
type: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
class CollectionMetadata(TypedDict):
|
||||
schema: dict[str, SchemaArg]
|
||||
size: int
|
||||
|
||||
|
||||
def from_annotation(annotation: Annotated) -> SchemaArg | None:
|
||||
try:
|
||||
type_, description = get_args(annotation)
|
||||
type_str = str(type_)
|
||||
if type_str.startswith("typing."):
|
||||
type_str = type_str[7:]
|
||||
elif len((parts := type_str.split("'"))) > 1:
|
||||
type_str = parts[1]
|
||||
return SchemaArg(type=type_str, description=description)
|
||||
except IndexError:
|
||||
logger.error(f"Error from annotation: {annotation}")
|
||||
return None
|
||||
|
||||
|
||||
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
|
||||
if not hasattr(klass, "as_payload"):
|
||||
return {}
|
||||
|
||||
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
|
||||
return {}
|
||||
|
||||
return {
|
||||
name: schema
|
||||
for name, arg in payload_type.__annotations__.items()
|
||||
if (schema := from_annotation(arg))
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
|
||||
"""Get the metadata schema for each collection used in the knowledge base.
|
||||
|
||||
These schemas can be used to filter the knowledge base.
|
||||
|
||||
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
|
||||
|
||||
Example:
|
||||
```
|
||||
{
|
||||
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
|
||||
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
|
||||
}
|
||||
"""
|
||||
client = qdrant.get_qdrant_client()
|
||||
sizes = qdrant.get_collection_sizes(client)
|
||||
schemas = defaultdict(dict)
|
||||
for klass in SourceItem.__subclasses__():
|
||||
for collection in klass.get_collections():
|
||||
schemas[collection].update(get_schema(klass))
|
||||
|
||||
return {
|
||||
collection: CollectionMetadata(schema=schema, size=size)
|
||||
for collection, schema in schemas.items()
|
||||
if (size := sizes.get(collection))
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""Get all unique tags used across the entire knowledge base.
|
||||
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""Get all unique subjects from observations about the user.
|
||||
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""Get all observation types that have been used.
|
||||
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
@ -5,11 +5,12 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from fastmcp.server.auth import OAuthProvider
|
||||
from fastmcp.server.auth.auth import AccessToken as FastMCPAccessToken
|
||||
from mcp.server.auth.provider import (
|
||||
AccessToken,
|
||||
AuthorizationCode,
|
||||
AuthorizationParams,
|
||||
OAuthAuthorizationServerProvider,
|
||||
RefreshToken,
|
||||
)
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
@ -133,8 +134,45 @@ def make_token(
|
||||
)
|
||||
|
||||
|
||||
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
class SimpleOAuthProvider(OAuthProvider):
|
||||
"""OAuth provider that extends fastmcp's OAuthProvider with custom login flow."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
base_url=settings.SERVER_URL,
|
||||
issuer_url=settings.SERVER_URL,
|
||||
client_registration_options=None, # We handle registration ourselves
|
||||
required_scopes=BASE_SCOPES,
|
||||
)
|
||||
|
||||
async def verify_token(self, token: str) -> FastMCPAccessToken | None:
|
||||
"""Verify an access token and return token info if valid."""
|
||||
with make_session() as session:
|
||||
# Try as OAuth access token first
|
||||
user_session = session.query(UserSession).get(token)
|
||||
if user_session:
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
if user_session.expires_at < now:
|
||||
return None
|
||||
return FastMCPAccessToken(
|
||||
token=token,
|
||||
client_id=user_session.oauth_state.client_id,
|
||||
scopes=user_session.oauth_state.scopes,
|
||||
)
|
||||
|
||||
# Try as bot API key
|
||||
bot = session.query(User).filter(User.api_key == token).first()
|
||||
if bot:
|
||||
logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key")
|
||||
return FastMCPAccessToken(
|
||||
token=token,
|
||||
client_id=cast(str, bot.name or bot.email),
|
||||
scopes=["read", "write"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
"""Get OAuth client information."""
|
||||
with make_session() as session:
|
||||
client = session.get(OAuthClientInformation, client_id)
|
||||
|
||||
17
src/memory/api/MCP/servers/__init__.py
Normal file
17
src/memory/api/MCP/servers/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""MCP subservers for composable tool organization."""
|
||||
|
||||
from memory.api.MCP.servers.core import core_mcp
|
||||
from memory.api.MCP.servers.github import github_mcp
|
||||
from memory.api.MCP.servers.people import people_mcp
|
||||
from memory.api.MCP.servers.schedule import schedule_mcp
|
||||
from memory.api.MCP.servers.books import books_mcp
|
||||
from memory.api.MCP.servers.meta import meta_mcp
|
||||
|
||||
__all__ = [
|
||||
"core_mcp",
|
||||
"github_mcp",
|
||||
"people_mcp",
|
||||
"schedule_mcp",
|
||||
"books_mcp",
|
||||
"meta_mcp",
|
||||
]
|
||||
@ -1,15 +1,19 @@
|
||||
"""MCP subserver for ebook access."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Book, BookSection, BookSectionPayload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
books_mcp = FastMCP("memory-books")
|
||||
|
||||
@mcp.tool()
|
||||
|
||||
@books_mcp.tool()
|
||||
async def all_books(sections: bool = False) -> list[dict]:
|
||||
"""
|
||||
Get all books in the database.
|
||||
@ -31,7 +35,7 @@ async def all_books(sections: bool = False) -> list[dict]:
|
||||
return [book.as_payload(sections=sections) for book in books]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@books_mcp.tool()
|
||||
def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]:
|
||||
"""
|
||||
Read a book from the database.
|
||||
@ -1,53 +1,45 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
Core MCP subserver for knowledge base search, observations, and notes.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
from PIL import Image
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.api.search.search import search
|
||||
from memory.api.search.types import SearchFilters, SearchConfig
|
||||
from memory.api.search.types import SearchConfig, SearchFilters
|
||||
from memory.common import extract, settings
|
||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||
from memory.common.celery_app import app as celery_app
|
||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import SourceItem, AgentObservation
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.common.formatters import observation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
core_mcp = FastMCP("memory-core")
|
||||
|
||||
|
||||
def validate_path_within_directory(
|
||||
base_dir: pathlib.Path, requested_path: str
|
||||
) -> pathlib.Path:
|
||||
"""Validate that a requested path resolves within the base directory.
|
||||
|
||||
Prevents path traversal attacks using ../ or similar techniques.
|
||||
|
||||
Args:
|
||||
base_dir: The allowed base directory
|
||||
requested_path: The user-provided path
|
||||
|
||||
Returns:
|
||||
The resolved absolute path if valid
|
||||
|
||||
Raises:
|
||||
ValueError: If the path would escape the base directory
|
||||
"""
|
||||
"""Validate that a requested path resolves within the base directory."""
|
||||
resolved = (base_dir / requested_path.lstrip("/")).resolve()
|
||||
base_resolved = base_dir.resolve()
|
||||
|
||||
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
||||
if (
|
||||
not str(resolved).startswith(str(base_resolved) + "/")
|
||||
and resolved != base_resolved
|
||||
):
|
||||
raise ValueError(f"Path escapes allowed directory: {requested_path}")
|
||||
|
||||
return resolved
|
||||
@ -63,7 +55,6 @@ def filter_observation_source_ids(
|
||||
items_query = session.query(AgentObservation.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
@ -89,7 +80,6 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
||||
items_query = session.query(SourceItem.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
@ -102,7 +92,7 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
||||
return source_ids
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str,
|
||||
filters: SearchFilters = {},
|
||||
@ -141,7 +131,6 @@ async def search_knowledge_base(
|
||||
|
||||
if not modalities:
|
||||
modalities = set(ALL_COLLECTIONS.keys())
|
||||
# Filter to valid collections, excluding observation collections
|
||||
modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS
|
||||
|
||||
search_filters = SearchFilters(**filters)
|
||||
@ -167,7 +156,7 @@ class RawObservation(BaseModel):
|
||||
tags: list[str] = []
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def observe(
|
||||
observations: list[RawObservation],
|
||||
session_id: str | None = None,
|
||||
@ -212,23 +201,23 @@ async def observe(
|
||||
logger.info("MCP: Observing")
|
||||
tasks = [
|
||||
(
|
||||
observation,
|
||||
obs,
|
||||
celery_app.send_task(
|
||||
SYNC_OBSERVATION,
|
||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||
kwargs={
|
||||
"subject": observation.subject,
|
||||
"content": observation.content,
|
||||
"observation_type": observation.observation_type,
|
||||
"confidences": observation.confidences,
|
||||
"evidence": observation.evidence,
|
||||
"tags": observation.tags,
|
||||
"subject": obs.subject,
|
||||
"content": obs.content,
|
||||
"observation_type": obs.observation_type,
|
||||
"confidences": obs.confidences,
|
||||
"evidence": obs.evidence,
|
||||
"tags": obs.tags,
|
||||
"session_id": session_id,
|
||||
"agent_model": agent_model,
|
||||
},
|
||||
),
|
||||
)
|
||||
for observation in observations
|
||||
for obs in observations
|
||||
]
|
||||
|
||||
def short_content(obs: RawObservation) -> str:
|
||||
@ -242,7 +231,7 @@ async def observe(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def search_observations(
|
||||
query: str,
|
||||
subject: str = "",
|
||||
@ -308,7 +297,7 @@ async def search_observations(
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def create_note(
|
||||
subject: str,
|
||||
content: str,
|
||||
@ -362,7 +351,7 @@ async def create_note(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
async def note_files(path: str = "/"):
|
||||
"""
|
||||
List note files in the user's note storage.
|
||||
@ -386,7 +375,7 @@ async def note_files(path: str = "/"):
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@core_mcp.tool()
|
||||
def fetch_file(filename: str) -> dict:
|
||||
"""
|
||||
Read file content with automatic type detection.
|
||||
@ -1,14 +1,14 @@
|
||||
"""MCP tools for GitHub issue tracking and management."""
|
||||
"""MCP subserver for GitHub issue tracking and management."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Text, case, desc, func, asc
|
||||
from fastmcp import FastMCP
|
||||
from sqlalchemy import Text, case, desc, func
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.api.search.search import search
|
||||
from memory.api.search.types import SearchConfig, SearchFilters
|
||||
from memory.common import extract
|
||||
@ -17,6 +17,8 @@ from memory.common.db.models import GithubItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
github_mcp = FastMCP("memory-github")
|
||||
|
||||
|
||||
def _build_github_url(repo_path: str, number: int | None, kind: str) -> str:
|
||||
"""Build GitHub URL from repo path and issue/PR number."""
|
||||
@ -53,7 +55,6 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
|
||||
}
|
||||
if include_content:
|
||||
result["content"] = item.content
|
||||
# Include PR-specific data if available
|
||||
if item.kind == "pr" and item.pr_data:
|
||||
result["pr_data"] = {
|
||||
"additions": item.pr_data.additions,
|
||||
@ -62,12 +63,12 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
|
||||
"files": item.pr_data.files,
|
||||
"reviews": item.pr_data.reviews,
|
||||
"review_comments": item.pr_data.review_comments,
|
||||
"diff": item.pr_data.diff, # Decompressed via property
|
||||
"diff": item.pr_data.diff,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def list_github_issues(
|
||||
repo: str | None = None,
|
||||
assignee: str | None = None,
|
||||
@ -109,64 +110,48 @@ async def list_github_issues(
|
||||
with make_session() as session:
|
||||
query = session.query(GithubItem)
|
||||
|
||||
# Apply filters
|
||||
if repo:
|
||||
query = query.filter(GithubItem.repo_path == repo)
|
||||
|
||||
if assignee:
|
||||
query = query.filter(GithubItem.assignees.any(assignee))
|
||||
|
||||
if author:
|
||||
query = query.filter(GithubItem.author == author)
|
||||
|
||||
if state:
|
||||
query = query.filter(GithubItem.state == state)
|
||||
|
||||
if kind:
|
||||
query = query.filter(GithubItem.kind == kind)
|
||||
else:
|
||||
# Exclude comments by default, only show issues and PRs
|
||||
query = query.filter(GithubItem.kind.in_(["issue", "pr"]))
|
||||
|
||||
if labels:
|
||||
# Match any label in the list using PostgreSQL array overlap
|
||||
query = query.filter(
|
||||
GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text)))
|
||||
)
|
||||
|
||||
if project_status:
|
||||
query = query.filter(GithubItem.project_status == project_status)
|
||||
|
||||
if project_field:
|
||||
for key, value in project_field.items():
|
||||
query = query.filter(
|
||||
GithubItem.project_fields[key].astext == value
|
||||
)
|
||||
|
||||
query = query.filter(GithubItem.project_fields[key].astext == value)
|
||||
if updated_since:
|
||||
since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00"))
|
||||
query = query.filter(GithubItem.github_updated_at >= since_dt)
|
||||
|
||||
if updated_before:
|
||||
before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00"))
|
||||
query = query.filter(GithubItem.github_updated_at <= before_dt)
|
||||
|
||||
# Apply ordering
|
||||
if order_by == "created":
|
||||
query = query.order_by(desc(GithubItem.created_at))
|
||||
elif order_by == "number":
|
||||
query = query.order_by(desc(GithubItem.number))
|
||||
else: # default: updated
|
||||
else:
|
||||
query = query.order_by(desc(GithubItem.github_updated_at))
|
||||
|
||||
query = query.limit(limit)
|
||||
|
||||
items = query.all()
|
||||
|
||||
return [_serialize_issue(item) for item in items]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def search_github_issues(
|
||||
query: str,
|
||||
repo: str | None = None,
|
||||
@ -191,7 +176,6 @@ async def search_github_issues(
|
||||
|
||||
limit = min(limit, 100)
|
||||
|
||||
# Pre-filter source_ids if repo/state/kind filters are specified
|
||||
source_ids = None
|
||||
if repo or state or kind:
|
||||
with make_session() as session:
|
||||
@ -206,7 +190,6 @@ async def search_github_issues(
|
||||
q = q.filter(GithubItem.kind.in_(["issue", "pr"]))
|
||||
source_ids = [item.id for item in q.all()]
|
||||
|
||||
# Use the existing search infrastructure
|
||||
data = extract.extract_text(query, skip_summary=True)
|
||||
config = SearchConfig(limit=limit, previews=True)
|
||||
filters = SearchFilters()
|
||||
@ -220,7 +203,6 @@ async def search_github_issues(
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Fetch full issue details for the results
|
||||
output = []
|
||||
with make_session() as session:
|
||||
for result in results:
|
||||
@ -233,7 +215,7 @@ async def search_github_issues(
|
||||
return output
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def github_issue_details(
|
||||
repo: str,
|
||||
number: int,
|
||||
@ -267,7 +249,7 @@ async def github_issue_details(
|
||||
return _serialize_issue(item, include_content=True)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def github_work_summary(
|
||||
since: str,
|
||||
until: str | None = None,
|
||||
@ -294,7 +276,6 @@ async def github_work_summary(
|
||||
else:
|
||||
until_dt = datetime.now(timezone.utc)
|
||||
|
||||
# Map group_by to SQL expression
|
||||
group_mappings = {
|
||||
"client": GithubItem.project_fields["EquiStamp.Client"].astext,
|
||||
"status": GithubItem.project_status,
|
||||
@ -311,7 +292,6 @@ async def github_work_summary(
|
||||
group_col = group_mappings[group_by]
|
||||
|
||||
with make_session() as session:
|
||||
# Build base query for the period
|
||||
base_query = session.query(GithubItem).filter(
|
||||
GithubItem.github_updated_at >= since_dt,
|
||||
GithubItem.github_updated_at <= until_dt,
|
||||
@ -321,7 +301,6 @@ async def github_work_summary(
|
||||
if repo:
|
||||
base_query = base_query.filter(GithubItem.repo_path == repo)
|
||||
|
||||
# Get aggregated counts by group
|
||||
agg_query = (
|
||||
session.query(
|
||||
group_col.label("group_name"),
|
||||
@ -346,7 +325,6 @@ async def github_work_summary(
|
||||
|
||||
groups = agg_query.all()
|
||||
|
||||
# Build summary with sample issues for each group
|
||||
summary = []
|
||||
total_issues = 0
|
||||
total_prs = 0
|
||||
@ -358,7 +336,6 @@ async def github_work_summary(
|
||||
total_issues += issue_count
|
||||
total_prs += pr_count
|
||||
|
||||
# Get sample issues for this group
|
||||
sample_query = base_query.filter(group_col == group_name).limit(5)
|
||||
samples = [
|
||||
{
|
||||
@ -394,7 +371,7 @@ async def github_work_summary(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@github_mcp.tool()
|
||||
async def github_repo_overview(
|
||||
repo: str,
|
||||
) -> dict:
|
||||
@ -410,13 +387,6 @@ async def github_repo_overview(
|
||||
logger.info(f"github_repo_overview called: repo={repo}")
|
||||
|
||||
with make_session() as session:
|
||||
# Base query for this repo
|
||||
base_query = session.query(GithubItem).filter(
|
||||
GithubItem.repo_path == repo,
|
||||
GithubItem.kind.in_(["issue", "pr"]),
|
||||
)
|
||||
|
||||
# Get total counts
|
||||
counts_query = session.query(
|
||||
func.count(GithubItem.id).label("total"),
|
||||
func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"),
|
||||
@ -441,7 +411,6 @@ async def github_repo_overview(
|
||||
|
||||
counts = counts_query.first()
|
||||
|
||||
# Status breakdown (for project_status)
|
||||
status_query = (
|
||||
session.query(
|
||||
GithubItem.project_status.label("status"),
|
||||
@ -458,7 +427,6 @@ async def github_repo_overview(
|
||||
|
||||
status_breakdown = {row.status: row.count for row in status_query.all()}
|
||||
|
||||
# Top assignees (open issues only)
|
||||
assignee_query = (
|
||||
session.query(
|
||||
func.unnest(GithubItem.assignees).label("assignee"),
|
||||
@ -479,7 +447,6 @@ async def github_repo_overview(
|
||||
for row in assignee_query.all()
|
||||
]
|
||||
|
||||
# Label counts
|
||||
label_query = (
|
||||
session.query(
|
||||
func.unnest(GithubItem.labels).label("label"),
|
||||
277
src/memory/api/MCP/servers/meta.py
Normal file
277
src/memory/api/MCP/servers/meta.py
Normal file
@ -0,0 +1,277 @@
|
||||
"""MCP subserver for metadata, utilities, and forecasting."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, NotRequired, TypedDict, get_args, get_type_hints
|
||||
|
||||
import aiohttp
|
||||
from fastmcp import FastMCP
|
||||
from sqlalchemy import func
|
||||
|
||||
from memory.common import qdrant
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.common.db.models.source_items import AgentObservation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
meta_mcp = FastMCP("memory-meta")
|
||||
|
||||
# Auth provider will be injected at mount time
|
||||
_get_current_user = None
|
||||
|
||||
|
||||
def set_auth_provider(get_current_user_func):
|
||||
"""Set the authentication provider function."""
|
||||
global _get_current_user
|
||||
_get_current_user = get_current_user_func
|
||||
|
||||
|
||||
def get_current_user() -> dict:
|
||||
"""Get the current authenticated user."""
|
||||
if _get_current_user is None:
|
||||
return {"authenticated": False, "error": "Auth provider not configured"}
|
||||
return _get_current_user()
|
||||
|
||||
|
||||
# --- Metadata tools ---
|
||||
|
||||
|
||||
class SchemaArg(TypedDict):
|
||||
type: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
class CollectionMetadata(TypedDict):
|
||||
schema: dict[str, SchemaArg]
|
||||
size: int
|
||||
|
||||
|
||||
def from_annotation(annotation: Annotated) -> SchemaArg | None:
|
||||
try:
|
||||
type_, description = get_args(annotation)
|
||||
type_str = str(type_)
|
||||
if type_str.startswith("typing."):
|
||||
type_str = type_str[7:]
|
||||
elif len((parts := type_str.split("'"))) > 1:
|
||||
type_str = parts[1]
|
||||
return SchemaArg(type=type_str, description=description)
|
||||
except IndexError:
|
||||
logger.error(f"Error from annotation: {annotation}")
|
||||
return None
|
||||
|
||||
|
||||
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
|
||||
if not hasattr(klass, "as_payload"):
|
||||
return {}
|
||||
|
||||
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
|
||||
return {}
|
||||
|
||||
return {
|
||||
name: schema
|
||||
for name, arg in payload_type.__annotations__.items()
|
||||
if (schema := from_annotation(arg))
|
||||
}
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
|
||||
"""Get the metadata schema for each collection used in the knowledge base.
|
||||
|
||||
These schemas can be used to filter the knowledge base.
|
||||
|
||||
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
|
||||
|
||||
Example:
|
||||
```
|
||||
{
|
||||
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
|
||||
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
|
||||
}
|
||||
"""
|
||||
client = qdrant.get_qdrant_client()
|
||||
sizes = qdrant.get_collection_sizes(client)
|
||||
schemas = defaultdict(dict)
|
||||
for klass in SourceItem.__subclasses__():
|
||||
for collection in klass.get_collections():
|
||||
schemas[collection].update(get_schema(klass))
|
||||
|
||||
return {
|
||||
collection: CollectionMetadata(schema=schema, size=size)
|
||||
for collection, schema in schemas.items()
|
||||
if (size := sizes.get(collection))
|
||||
}
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""Get all unique tags used across the entire knowledge base.
|
||||
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""Get all unique subjects from observations about the user.
|
||||
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""Get all observation types that have been used.
|
||||
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# --- Utility tools ---
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_current_time() -> dict:
|
||||
"""Get the current time in UTC."""
|
||||
logger.info("get_current_time tool called")
|
||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_authenticated_user() -> dict:
|
||||
"""Get information about the authenticated user."""
|
||||
return get_current_user()
|
||||
|
||||
|
||||
# --- Forecasting tools ---
|
||||
|
||||
|
||||
class BinaryProbs(TypedDict):
|
||||
prob: float
|
||||
|
||||
|
||||
class MultiProbs(TypedDict):
|
||||
answerProbs: dict[str, float]
|
||||
|
||||
|
||||
Probs = dict[str, BinaryProbs | MultiProbs]
|
||||
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
|
||||
|
||||
|
||||
class MarketAnswer(TypedDict):
|
||||
id: str
|
||||
text: str
|
||||
resolutionProbability: float
|
||||
|
||||
|
||||
class MarketDetails(TypedDict):
|
||||
id: str
|
||||
createdTime: int
|
||||
question: str
|
||||
outcomeType: OutcomeType
|
||||
textDescription: str
|
||||
groupSlugs: list[str]
|
||||
volume: float
|
||||
isResolved: bool
|
||||
answers: list[MarketAnswer]
|
||||
|
||||
|
||||
class Market(TypedDict):
|
||||
id: str
|
||||
url: str
|
||||
question: str
|
||||
volume: int
|
||||
createdTime: int
|
||||
outcomeType: OutcomeType
|
||||
createdAt: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
answers: NotRequired[dict[str, float]]
|
||||
probability: NotRequired[float]
|
||||
details: NotRequired[MarketDetails]
|
||||
|
||||
|
||||
async def get_details(session: aiohttp.ClientSession, market_id: str):
|
||||
async with session.get(
|
||||
f"https://api.manifold.markets/v0/market/{market_id}"
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
|
||||
async def format_market(session: aiohttp.ClientSession, market: Market):
|
||||
if market.get("outcomeType") != "BINARY":
|
||||
details = await get_details(session, market["id"])
|
||||
market["answers"] = {
|
||||
answer["text"]: round(
|
||||
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
|
||||
)
|
||||
for answer in details["answers"]
|
||||
}
|
||||
if creationTime := market.get("createdTime"):
|
||||
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"url",
|
||||
"question",
|
||||
"volume",
|
||||
"createdAt",
|
||||
"details",
|
||||
"probability",
|
||||
"answers",
|
||||
]
|
||||
return {k: v for k, v in market.items() if k in fields}
|
||||
|
||||
|
||||
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.manifold.markets/v0/search-markets",
|
||||
params={
|
||||
"term": term,
|
||||
"contractType": "BINARY" if binary else "ALL",
|
||||
},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
markets = await resp.json()
|
||||
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
format_market(session, market)
|
||||
for market in markets
|
||||
if market.get("volume", 0) >= min_volume
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@meta_mcp.tool()
|
||||
async def get_forecasts(
|
||||
term: str, min_volume: int = 1000, binary: bool = False
|
||||
) -> list[dict]:
|
||||
"""Get prediction market forecasts for a given term.
|
||||
|
||||
Args:
|
||||
term: The term to search for.
|
||||
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
|
||||
binary: Whether to only return binary markets.
|
||||
"""
|
||||
return await search_markets(term, min_volume, binary)
|
||||
@ -1,23 +1,23 @@
|
||||
"""
|
||||
MCP tools for tracking people.
|
||||
"""
|
||||
"""MCP subserver for tracking people."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.MCP.base import mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Person
|
||||
from memory.common import settings
|
||||
from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON
|
||||
from memory.common.celery_app import app as celery_app
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Person
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
people_mcp = FastMCP("memory-people")
|
||||
|
||||
|
||||
def _person_to_dict(person: Person) -> dict[str, Any]:
|
||||
"""Convert a Person model to a dictionary for API responses."""
|
||||
@ -32,7 +32,7 @@ def _person_to_dict(person: Person) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def add_person(
|
||||
identifier: str,
|
||||
display_name: str,
|
||||
@ -67,7 +67,6 @@ async def add_person(
|
||||
"""
|
||||
logger.info(f"MCP: Adding person: {identifier}")
|
||||
|
||||
# Check if person already exists
|
||||
with make_session() as session:
|
||||
existing = session.query(Person).filter(Person.identifier == identifier).first()
|
||||
if existing:
|
||||
@ -93,7 +92,7 @@ async def add_person(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def update_person_info(
|
||||
identifier: str,
|
||||
display_name: str | None = None,
|
||||
@ -135,7 +134,6 @@ async def update_person_info(
|
||||
"""
|
||||
logger.info(f"MCP: Updating person: {identifier}")
|
||||
|
||||
# Verify person exists
|
||||
with make_session() as session:
|
||||
person = session.query(Person).filter(Person.identifier == identifier).first()
|
||||
if not person:
|
||||
@ -162,7 +160,7 @@ async def update_person_info(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def get_person(identifier: str) -> dict | None:
|
||||
"""
|
||||
Get a person by their identifier.
|
||||
@ -182,7 +180,7 @@ async def get_person(identifier: str) -> dict | None:
|
||||
return _person_to_dict(person)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def list_people(
|
||||
tags: list[str] | None = None,
|
||||
search: str | None = None,
|
||||
@ -207,9 +205,7 @@ async def list_people(
|
||||
query = session.query(Person)
|
||||
|
||||
if tags:
|
||||
query = query.filter(
|
||||
Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
|
||||
)
|
||||
query = query.filter(Person.tags.op("&&")(sql_cast(tags, ARRAY(Text))))
|
||||
|
||||
if search:
|
||||
search_term = f"%{search.lower()}%"
|
||||
@ -225,7 +221,7 @@ async def list_people(
|
||||
return [_person_to_dict(p) for p in people]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@people_mcp.tool()
|
||||
async def delete_person(identifier: str) -> dict:
|
||||
"""
|
||||
Delete a person by their identifier.
|
||||
@ -1,20 +1,37 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
"""MCP subserver for scheduling messages and LLM calls."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from memory.api.MCP.base import get_current_user, mcp
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
|
||||
from memory.common.db.models import DiscordBotUser, ScheduledLLMCall
|
||||
from memory.discord.messages import schedule_discord_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
schedule_mcp = FastMCP("memory-schedule")
|
||||
|
||||
@mcp.tool()
|
||||
# We need access to get_current_user from base - this will be injected at mount time
|
||||
_get_current_user = None
|
||||
|
||||
|
||||
def set_auth_provider(get_current_user_func):
|
||||
"""Set the authentication provider function."""
|
||||
global _get_current_user
|
||||
_get_current_user = get_current_user_func
|
||||
|
||||
|
||||
def get_current_user() -> dict:
|
||||
"""Get the current authenticated user."""
|
||||
if _get_current_user is None:
|
||||
return {"authenticated": False, "error": "Auth provider not configured"}
|
||||
return _get_current_user()
|
||||
|
||||
|
||||
@schedule_mcp.tool()
|
||||
async def schedule_message(
|
||||
scheduled_time: str,
|
||||
message: str,
|
||||
@ -61,10 +78,8 @@ async def schedule_message(
|
||||
if not discord_user and not discord_channel:
|
||||
raise ValueError("Either discord_user or discord_channel must be provided")
|
||||
|
||||
# Parse scheduled time
|
||||
try:
|
||||
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
|
||||
# Ensure we store as naive datetime (remove timezone info for database storage)
|
||||
if scheduled_dt.tzinfo is not None:
|
||||
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
except ValueError:
|
||||
@ -98,7 +113,7 @@ async def schedule_message(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@schedule_mcp.tool()
|
||||
async def list_scheduled_llm_calls(
|
||||
status: str | None = None, limit: int | None = 50
|
||||
) -> dict[str, Any]:
|
||||
@ -143,7 +158,7 @@ async def list_scheduled_llm_calls(
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@schedule_mcp.tool()
|
||||
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Cancel a scheduled LLM call.
|
||||
@ -164,7 +179,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
||||
return {"error": "User not found", "user": current_user}
|
||||
|
||||
with make_session() as session:
|
||||
# Find the scheduled call
|
||||
scheduled_call = (
|
||||
session.query(ScheduledLLMCall)
|
||||
.filter(
|
||||
@ -180,7 +194,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
||||
if not scheduled_call.can_be_cancelled():
|
||||
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
|
||||
|
||||
# Update the status
|
||||
scheduled_call.status = "cancelled"
|
||||
session.commit()
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.api.MCP.base import mcp, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_observation_source_ids(
|
||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||
):
|
||||
if not tags and not observation_types:
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
items_query = session.query(AgentObservation.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if observation_types:
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.observation_type.in_(observation_types)
|
||||
)
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
|
||||
return source_ids
|
||||
|
||||
|
||||
def filter_source_ids(
|
||||
modalities: set[str],
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
if not tags:
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
items_query = session.query(SourceItem.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if modalities:
|
||||
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
|
||||
return source_ids
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_current_time() -> dict:
|
||||
"""Get the current time in UTC."""
|
||||
logger.info("get_current_time tool called")
|
||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_authenticated_user() -> dict:
|
||||
"""Get information about the authenticated user."""
|
||||
return get_current_user()
|
||||
@ -2,7 +2,6 @@
|
||||
FastAPI application for the knowledge base.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import logging
|
||||
import mimetypes
|
||||
@ -35,15 +34,10 @@ limiter = Limiter(
|
||||
enabled=settings.API_RATE_LIMIT_ENABLED,
|
||||
)
|
||||
|
||||
# Create the MCP http app to get its lifespan
|
||||
mcp_http_app = mcp.http_app(stateless_http=True)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
await stack.enter_async_context(mcp.session_manager.run())
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=mcp_http_app.lifespan)
|
||||
app.state.limiter = limiter
|
||||
|
||||
# Rate limit exception handler
|
||||
@ -158,46 +152,9 @@ app.include_router(auth_router)
|
||||
|
||||
|
||||
# Add health check to MCP server instead of main app
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request):
|
||||
"""Health check endpoint that verifies all dependencies are accessible."""
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import text
|
||||
|
||||
checks = {"mcp_oauth": "enabled"}
|
||||
all_healthy = True
|
||||
|
||||
# Check database connection
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
checks["database"] = "healthy"
|
||||
except Exception as e:
|
||||
# Log error details but don't expose in response
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
checks["database"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
# Check Qdrant connection
|
||||
try:
|
||||
from memory.common.qdrant import get_qdrant_client
|
||||
|
||||
client = get_qdrant_client()
|
||||
client.get_collections()
|
||||
checks["qdrant"] = "healthy"
|
||||
except Exception as e:
|
||||
# Log error details but don't expose in response
|
||||
logger.error(f"Qdrant health check failed: {e}")
|
||||
checks["qdrant"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
checks["status"] = "healthy" if all_healthy else "degraded"
|
||||
status_code = 200 if all_healthy else 503
|
||||
return JSONResponse(checks, status_code=status_code)
|
||||
|
||||
|
||||
# Mount MCP server at root - OAuth endpoints need to be at root level
|
||||
app.mount("/", mcp.streamable_http_app())
|
||||
# Health check is defined in MCP/base.py
|
||||
app.mount("/", mcp_http_app)
|
||||
|
||||
|
||||
def main(reload: bool = False):
|
||||
|
||||
@ -58,7 +58,6 @@ sync_code() {
|
||||
"$PROJECT_DIR/frontend" \
|
||||
"$PROJECT_DIR/requirements" \
|
||||
"$PROJECT_DIR/setup.py" \
|
||||
"$PROJECT_DIR/pyproject.toml" \
|
||||
"$PROJECT_DIR/docker-compose.yaml" \
|
||||
"$PROJECT_DIR/pytest.ini" \
|
||||
"$REMOTE_HOST:$REMOTE_DIR/"
|
||||
|
||||
195
tools/diagnose.sh
Executable file
195
tools/diagnose.sh
Executable file
@ -0,0 +1,195 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
REMOTE_HOST="memory"
|
||||
REMOTE_DIR="/home/ec2-user/memory"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 <command> [options]"
|
||||
echo ""
|
||||
echo "Safe diagnostic commands for the memory server."
|
||||
echo ""
|
||||
echo "Commands:"
|
||||
echo " logs [service] [lines] View docker logs (default: all services, 100 lines)"
|
||||
echo " ps Show docker container status"
|
||||
echo " disk Show disk usage"
|
||||
echo " mem Show memory usage"
|
||||
echo " top Show running processes"
|
||||
echo " ls <path> List directory contents"
|
||||
echo " cat <file> View file contents"
|
||||
echo " tail <file> [lines] Tail a file (default: 50 lines)"
|
||||
echo " grep <pattern> <path> Search for pattern in files"
|
||||
echo " db <query> Run read-only SQL query"
|
||||
echo " get <path> [port] GET request to localhost (default port: 8000)"
|
||||
echo " status Overall system status"
|
||||
exit 1
|
||||
}
|
||||
|
||||
remote() {
|
||||
ssh "$REMOTE_HOST" "$@"
|
||||
}
|
||||
|
||||
docker_logs() {
|
||||
local service="${1:-}"
|
||||
local lines="${2:-100}"
|
||||
if [ -n "$service" ]; then
|
||||
echo -e "${GREEN}Logs for $service (last $lines lines):${NC}"
|
||||
remote "cd $REMOTE_DIR && docker compose logs --tail=$lines $service"
|
||||
else
|
||||
echo -e "${GREEN}All logs (last $lines lines):${NC}"
|
||||
remote "cd $REMOTE_DIR && docker compose logs --tail=$lines"
|
||||
fi
|
||||
}
|
||||
|
||||
docker_ps() {
|
||||
echo -e "${GREEN}Container status:${NC}"
|
||||
remote "cd $REMOTE_DIR && docker compose ps"
|
||||
}
|
||||
|
||||
disk_usage() {
|
||||
echo -e "${GREEN}Disk usage:${NC}"
|
||||
remote "df -h && echo '' && du -sh $REMOTE_DIR/* 2>/dev/null | sort -h"
|
||||
}
|
||||
|
||||
mem_usage() {
|
||||
echo -e "${GREEN}Memory usage:${NC}"
|
||||
remote "free -h"
|
||||
}
|
||||
|
||||
show_top() {
|
||||
echo -e "${GREEN}Top processes:${NC}"
|
||||
remote "ps aux --sort=-%mem | head -20"
|
||||
}
|
||||
|
||||
list_dir() {
|
||||
local path="${1:-.}"
|
||||
# Ensure path is within project directory for safety
|
||||
echo -e "${GREEN}Contents of $path:${NC}"
|
||||
remote "cd $REMOTE_DIR && ls -la $path"
|
||||
}
|
||||
|
||||
cat_file() {
|
||||
local file="$1"
|
||||
if [ -z "$file" ]; then
|
||||
echo -e "${RED}Error: No file specified${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${GREEN}Contents of $file:${NC}"
|
||||
remote "cd $REMOTE_DIR && cat $file"
|
||||
}
|
||||
|
||||
tail_file() {
|
||||
local file="$1"
|
||||
local lines="${2:-50}"
|
||||
if [ -z "$file" ]; then
|
||||
echo -e "${RED}Error: No file specified${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${GREEN}Last $lines lines of $file:${NC}"
|
||||
remote "cd $REMOTE_DIR && tail -n $lines $file"
|
||||
}
|
||||
|
||||
grep_files() {
|
||||
local pattern="$1"
|
||||
local path="${2:-.}"
|
||||
if [ -z "$pattern" ]; then
|
||||
echo -e "${RED}Error: No pattern specified${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${GREEN}Searching for '$pattern' in $path:${NC}"
|
||||
remote "cd $REMOTE_DIR && grep -r --color=always '$pattern' $path || true"
|
||||
}
|
||||
|
||||
db_query() {
|
||||
local query="$1"
|
||||
if [ -z "$query" ]; then
|
||||
echo -e "${RED}Error: No query specified${NC}"
|
||||
exit 1
|
||||
fi
|
||||
# Only allow SELECT queries for safety
|
||||
if ! echo "$query" | grep -qi "^select"; then
|
||||
echo -e "${RED}Error: Only SELECT queries are allowed${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${GREEN}Running query:${NC}"
|
||||
remote "cd $REMOTE_DIR && docker compose exec -T postgres psql -U kb -d kb -c \"$query\""
|
||||
}
|
||||
|
||||
http_get() {
|
||||
local path="$1"
|
||||
local port="${2:-8000}"
|
||||
if [ -z "$path" ]; then
|
||||
echo -e "${RED}Error: No path specified${NC}"
|
||||
exit 1
|
||||
fi
|
||||
# Ensure path starts with /
|
||||
if [[ "$path" != /* ]]; then
|
||||
path="/$path"
|
||||
fi
|
||||
echo -e "${GREEN}GET http://localhost:${port}${path}${NC}"
|
||||
remote "curl -s -w '\n\nHTTP Status: %{http_code}\n' 'http://localhost:${port}${path}'"
|
||||
}
|
||||
|
||||
system_status() {
|
||||
echo -e "${GREEN}=== System Status ===${NC}"
|
||||
echo ""
|
||||
docker_ps
|
||||
echo ""
|
||||
echo -e "${GREEN}=== Memory ===${NC}"
|
||||
mem_usage
|
||||
echo ""
|
||||
echo -e "${GREEN}=== Disk ===${NC}"
|
||||
remote "df -h /"
|
||||
echo ""
|
||||
echo -e "${GREEN}=== Recent Errors (last 20 lines) ===${NC}"
|
||||
remote "cd $REMOTE_DIR && docker compose logs --tail=100 2>&1 | grep -i -E '(error|exception|failed|fatal)' | tail -20 || echo 'No recent errors found'"
|
||||
}
|
||||
|
||||
# Main
|
||||
case "${1:-}" in
|
||||
logs)
|
||||
docker_logs "${2:-}" "${3:-}"
|
||||
;;
|
||||
ps)
|
||||
docker_ps
|
||||
;;
|
||||
disk)
|
||||
disk_usage
|
||||
;;
|
||||
mem)
|
||||
mem_usage
|
||||
;;
|
||||
top)
|
||||
show_top
|
||||
;;
|
||||
ls)
|
||||
list_dir "${2:-}"
|
||||
;;
|
||||
cat)
|
||||
cat_file "${2:-}"
|
||||
;;
|
||||
tail)
|
||||
tail_file "${2:-}" "${3:-}"
|
||||
;;
|
||||
grep)
|
||||
grep_files "${2:-}" "${3:-}"
|
||||
;;
|
||||
db)
|
||||
db_query "${2:-}"
|
||||
;;
|
||||
get)
|
||||
http_get "${2:-}" "${3:-}"
|
||||
;;
|
||||
status)
|
||||
system_status
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
Loading…
x
Reference in New Issue
Block a user