mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
migrate to fastmcp
This commit is contained in:
parent
d3d71edf1d
commit
48c380b903
@ -1,7 +1,8 @@
|
|||||||
fastapi==0.112.2
|
fastapi>=0.115.12
|
||||||
uvicorn==0.29.0
|
uvicorn==0.29.0
|
||||||
python-jose==3.3.0
|
python-jose==3.3.0
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
sqladmin==0.20.1
|
sqladmin==0.20.1
|
||||||
mcp==1.10.0
|
mcp==1.10.0
|
||||||
|
fastmcp>=2.10.0
|
||||||
slowapi==0.1.9
|
slowapi==0.1.9
|
||||||
@ -1,14 +1,14 @@
|
|||||||
sqlalchemy==2.0.30
|
sqlalchemy==2.0.30
|
||||||
psycopg2-binary==2.9.9
|
psycopg2-binary==2.9.9
|
||||||
pydantic==2.7.2
|
pydantic>=2.11.7
|
||||||
alembic==1.13.1
|
alembic==1.13.1
|
||||||
dotenv==0.9.9
|
dotenv==0.9.9
|
||||||
voyageai==0.3.2
|
voyageai==0.3.2
|
||||||
qdrant-client==1.9.0
|
qdrant-client==1.9.0
|
||||||
anthropic==0.69.0
|
anthropic==0.69.0
|
||||||
openai==2.3.0
|
openai==2.3.0
|
||||||
# Pin the httpx version, as newer versions break the anthropic client
|
# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1)
|
||||||
httpx==0.27.0
|
httpx>=0.28.1
|
||||||
celery[redis,sqs]==5.3.6
|
celery[redis,sqs]==5.3.6
|
||||||
cryptography==43.0.0
|
cryptography==43.0.0
|
||||||
bcrypt==4.1.2
|
bcrypt==4.1.2
|
||||||
@ -1,8 +1,16 @@
|
|||||||
import memory.api.MCP.tools
|
"""
|
||||||
import memory.api.MCP.memory
|
MCP server with composed subservers.
|
||||||
import memory.api.MCP.metadata
|
|
||||||
import memory.api.MCP.schedules
|
Subservers are mounted with prefixes:
|
||||||
import memory.api.MCP.books
|
- core: search_knowledge_base, observe, search_observations, create_note, note_files, fetch_file
|
||||||
import memory.api.MCP.manifest
|
- github: list_github_issues, search_github_issues, github_issue_details, github_work_summary, github_repo_overview
|
||||||
import memory.api.MCP.github
|
- people: add_person, update_person_info, get_person, list_people, delete_person
|
||||||
import memory.api.MCP.people
|
- schedule: schedule_message, list_scheduled_llm_calls, cancel_scheduled_llm_call
|
||||||
|
- books: all_books, read_book
|
||||||
|
- meta: get_metadata_schemas, get_all_tags, get_all_subjects, get_all_observation_types, get_current_time, get_authenticated_user, get_forecasts
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import base to trigger subserver mounting
|
||||||
|
from memory.api.MCP.base import mcp, get_current_user
|
||||||
|
|
||||||
|
__all__ = ["mcp", "get_current_user"]
|
||||||
|
|||||||
@ -1,60 +1,28 @@
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from mcp.server.auth.handlers.authorize import AuthorizationRequest
|
from fastmcp import FastMCP
|
||||||
from mcp.server.auth.handlers.token import (
|
from fastmcp.server.dependencies import get_access_token
|
||||||
AuthorizationCodeRequest,
|
|
||||||
RefreshTokenRequest,
|
|
||||||
TokenRequest,
|
|
||||||
)
|
|
||||||
from mcp.server.auth.middleware.auth_context import get_access_token
|
|
||||||
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
|
||||||
from mcp.shared.auth import OAuthClientMetadata
|
|
||||||
from pydantic import AnyHttpUrl
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse, RedirectResponse
|
from starlette.responses import JSONResponse, RedirectResponse
|
||||||
from starlette.templating import Jinja2Templates
|
from starlette.templating import Jinja2Templates
|
||||||
|
|
||||||
from memory.api.MCP.oauth_provider import (
|
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
||||||
ALLOWED_SCOPES,
|
from memory.api.MCP.servers.books import books_mcp
|
||||||
BASE_SCOPES,
|
from memory.api.MCP.servers.core import core_mcp
|
||||||
SimpleOAuthProvider,
|
from memory.api.MCP.servers.github import github_mcp
|
||||||
)
|
from memory.api.MCP.servers.meta import meta_mcp
|
||||||
|
from memory.api.MCP.servers.meta import set_auth_provider as set_meta_auth
|
||||||
|
from memory.api.MCP.servers.people import people_mcp
|
||||||
|
from memory.api.MCP.servers.schedule import schedule_mcp
|
||||||
|
from memory.api.MCP.servers.schedule import set_auth_provider as set_schedule_auth
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session, get_engine
|
||||||
from memory.common.db.models import OAuthState, UserSession
|
from memory.common.db.models import OAuthState, UserSession
|
||||||
from memory.common.db.models.users import HumanUser
|
from memory.common.db.models.users import HumanUser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
engine = get_engine()
|
||||||
|
|
||||||
def validate_metadata(klass: type):
|
|
||||||
orig_validate = klass.model_validate
|
|
||||||
|
|
||||||
def validate(data: dict):
|
|
||||||
data = dict(data)
|
|
||||||
if "redirect_uris" in data:
|
|
||||||
data["redirect_uris"] = [
|
|
||||||
str(uri).replace("cursor://", "http://")
|
|
||||||
for uri in data["redirect_uris"]
|
|
||||||
]
|
|
||||||
if "redirect_uri" in data:
|
|
||||||
data["redirect_uri"] = str(data["redirect_uri"]).replace(
|
|
||||||
"cursor://", "http://"
|
|
||||||
)
|
|
||||||
|
|
||||||
return orig_validate(data)
|
|
||||||
|
|
||||||
klass.model_validate = validate
|
|
||||||
|
|
||||||
|
|
||||||
validate_metadata(OAuthClientMetadata)
|
|
||||||
validate_metadata(AuthorizationRequest)
|
|
||||||
validate_metadata(AuthorizationCodeRequest)
|
|
||||||
validate_metadata(RefreshTokenRequest)
|
|
||||||
validate_metadata(TokenRequest)
|
|
||||||
|
|
||||||
|
|
||||||
# Setup templates
|
# Setup templates
|
||||||
@ -63,22 +31,10 @@ templates = Jinja2Templates(directory=template_dir)
|
|||||||
|
|
||||||
|
|
||||||
oauth_provider = SimpleOAuthProvider()
|
oauth_provider = SimpleOAuthProvider()
|
||||||
auth_settings = AuthSettings(
|
|
||||||
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
|
||||||
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore
|
|
||||||
client_registration_options=ClientRegistrationOptions(
|
|
||||||
enabled=True,
|
|
||||||
valid_scopes=ALLOWED_SCOPES,
|
|
||||||
default_scopes=BASE_SCOPES,
|
|
||||||
),
|
|
||||||
required_scopes=BASE_SCOPES,
|
|
||||||
)
|
|
||||||
|
|
||||||
mcp = FastMCP(
|
mcp = FastMCP(
|
||||||
"memory",
|
"memory",
|
||||||
stateless_http=True,
|
auth=oauth_provider,
|
||||||
auth_server_provider=oauth_provider,
|
|
||||||
auth=auth_settings,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -162,3 +118,52 @@ def get_current_user() -> dict:
|
|||||||
"client_id": access_token.client_id,
|
"client_id": access_token.client_id,
|
||||||
"user": user_info,
|
"user": user_info,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.custom_route("/health", methods=["GET"])
|
||||||
|
async def health_check(request: Request):
|
||||||
|
"""Health check endpoint that verifies all dependencies are accessible."""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
checks = {"mcp_oauth": "enabled"}
|
||||||
|
all_healthy = True
|
||||||
|
|
||||||
|
# Check database connection
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
checks["database"] = "healthy"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database health check failed: {e}")
|
||||||
|
checks["database"] = "unhealthy"
|
||||||
|
all_healthy = False
|
||||||
|
|
||||||
|
# Check Qdrant connection
|
||||||
|
try:
|
||||||
|
from memory.common.qdrant import get_qdrant_client
|
||||||
|
|
||||||
|
client = get_qdrant_client()
|
||||||
|
client.get_collections()
|
||||||
|
checks["qdrant"] = "healthy"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Qdrant health check failed: {e}")
|
||||||
|
checks["qdrant"] = "unhealthy"
|
||||||
|
all_healthy = False
|
||||||
|
|
||||||
|
checks["status"] = "healthy" if all_healthy else "degraded"
|
||||||
|
status_code = 200 if all_healthy else 503
|
||||||
|
return JSONResponse(checks, status_code=status_code)
|
||||||
|
|
||||||
|
|
||||||
|
# Inject auth provider into subservers that need it
|
||||||
|
set_schedule_auth(get_current_user)
|
||||||
|
set_meta_auth(get_current_user)
|
||||||
|
|
||||||
|
# Mount all subservers onto the main MCP server
|
||||||
|
# Tools will be prefixed with their server name (e.g., core_search_knowledge_base)
|
||||||
|
mcp.mount(core_mcp, prefix="core")
|
||||||
|
mcp.mount(github_mcp, prefix="github")
|
||||||
|
mcp.mount(people_mcp, prefix="people")
|
||||||
|
mcp.mount(schedule_mcp, prefix="schedule")
|
||||||
|
mcp.mount(books_mcp, prefix="books")
|
||||||
|
mcp.mount(meta_mcp, prefix="meta")
|
||||||
|
|||||||
@ -1,119 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import aiohttp
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from typing import TypedDict, NotRequired, Literal
|
|
||||||
from memory.api.MCP.tools import mcp
|
|
||||||
|
|
||||||
|
|
||||||
class BinaryProbs(TypedDict):
|
|
||||||
prob: float
|
|
||||||
|
|
||||||
|
|
||||||
class MultiProbs(TypedDict):
|
|
||||||
answerProbs: dict[str, float]
|
|
||||||
|
|
||||||
|
|
||||||
Probs = dict[str, BinaryProbs | MultiProbs]
|
|
||||||
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
|
|
||||||
|
|
||||||
|
|
||||||
class MarketAnswer(TypedDict):
|
|
||||||
id: str
|
|
||||||
text: str
|
|
||||||
resolutionProbability: float
|
|
||||||
|
|
||||||
|
|
||||||
class MarketDetails(TypedDict):
|
|
||||||
id: str
|
|
||||||
createdTime: int
|
|
||||||
question: str
|
|
||||||
outcomeType: OutcomeType
|
|
||||||
textDescription: str
|
|
||||||
groupSlugs: list[str]
|
|
||||||
volume: float
|
|
||||||
isResolved: bool
|
|
||||||
answers: list[MarketAnswer]
|
|
||||||
|
|
||||||
|
|
||||||
class Market(TypedDict):
|
|
||||||
id: str
|
|
||||||
url: str
|
|
||||||
question: str
|
|
||||||
volume: int
|
|
||||||
createdTime: int
|
|
||||||
outcomeType: OutcomeType
|
|
||||||
createdAt: NotRequired[str]
|
|
||||||
description: NotRequired[str]
|
|
||||||
answers: NotRequired[dict[str, float]]
|
|
||||||
probability: NotRequired[float]
|
|
||||||
details: NotRequired[MarketDetails]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_details(session: aiohttp.ClientSession, market_id: str):
|
|
||||||
async with session.get(
|
|
||||||
f"https://api.manifold.markets/v0/market/{market_id}"
|
|
||||||
) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
return await resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
async def format_market(session: aiohttp.ClientSession, market: Market):
|
|
||||||
if market.get("outcomeType") != "BINARY":
|
|
||||||
details = await get_details(session, market["id"])
|
|
||||||
market["answers"] = {
|
|
||||||
answer["text"]: round(
|
|
||||||
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
|
|
||||||
)
|
|
||||||
for answer in details["answers"]
|
|
||||||
}
|
|
||||||
if creationTime := market.get("createdTime"):
|
|
||||||
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
|
|
||||||
|
|
||||||
fields = [
|
|
||||||
"id",
|
|
||||||
"name",
|
|
||||||
"url",
|
|
||||||
"question",
|
|
||||||
"volume",
|
|
||||||
"createdAt",
|
|
||||||
"details",
|
|
||||||
"probability",
|
|
||||||
"answers",
|
|
||||||
]
|
|
||||||
return {k: v for k, v in market.items() if k in fields}
|
|
||||||
|
|
||||||
|
|
||||||
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(
|
|
||||||
"https://api.manifold.markets/v0/search-markets",
|
|
||||||
params={
|
|
||||||
"term": term,
|
|
||||||
"contractType": "BINARY" if binary else "ALL",
|
|
||||||
},
|
|
||||||
) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
markets = await resp.json()
|
|
||||||
|
|
||||||
return await asyncio.gather(
|
|
||||||
*[
|
|
||||||
format_market(session, market)
|
|
||||||
for market in markets
|
|
||||||
if market.get("volume", 0) >= min_volume
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_forecasts(
|
|
||||||
term: str, min_volume: int = 1000, binary: bool = False
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Get prediction market forecasts for a given term.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
term: The term to search for.
|
|
||||||
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
|
|
||||||
binary: Whether to only return binary markets.
|
|
||||||
"""
|
|
||||||
return await search_markets(term, min_volume, binary)
|
|
||||||
@ -1,119 +0,0 @@
|
|||||||
import logging
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Annotated, TypedDict, get_args, get_type_hints
|
|
||||||
|
|
||||||
from memory.common import qdrant
|
|
||||||
from sqlalchemy import func
|
|
||||||
|
|
||||||
from memory.api.MCP.tools import mcp
|
|
||||||
from memory.common.db.connection import make_session
|
|
||||||
from memory.common.db.models import SourceItem
|
|
||||||
from memory.common.db.models.source_items import AgentObservation
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SchemaArg(TypedDict):
|
|
||||||
type: str | None
|
|
||||||
description: str | None
|
|
||||||
|
|
||||||
|
|
||||||
class CollectionMetadata(TypedDict):
|
|
||||||
schema: dict[str, SchemaArg]
|
|
||||||
size: int
|
|
||||||
|
|
||||||
|
|
||||||
def from_annotation(annotation: Annotated) -> SchemaArg | None:
|
|
||||||
try:
|
|
||||||
type_, description = get_args(annotation)
|
|
||||||
type_str = str(type_)
|
|
||||||
if type_str.startswith("typing."):
|
|
||||||
type_str = type_str[7:]
|
|
||||||
elif len((parts := type_str.split("'"))) > 1:
|
|
||||||
type_str = parts[1]
|
|
||||||
return SchemaArg(type=type_str, description=description)
|
|
||||||
except IndexError:
|
|
||||||
logger.error(f"Error from annotation: {annotation}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
|
|
||||||
if not hasattr(klass, "as_payload"):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return {
|
|
||||||
name: schema
|
|
||||||
for name, arg in payload_type.__annotations__.items()
|
|
||||||
if (schema := from_annotation(arg))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
|
|
||||||
"""Get the metadata schema for each collection used in the knowledge base.
|
|
||||||
|
|
||||||
These schemas can be used to filter the knowledge base.
|
|
||||||
|
|
||||||
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
|
|
||||||
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
client = qdrant.get_qdrant_client()
|
|
||||||
sizes = qdrant.get_collection_sizes(client)
|
|
||||||
schemas = defaultdict(dict)
|
|
||||||
for klass in SourceItem.__subclasses__():
|
|
||||||
for collection in klass.get_collections():
|
|
||||||
schemas[collection].update(get_schema(klass))
|
|
||||||
|
|
||||||
return {
|
|
||||||
collection: CollectionMetadata(schema=schema, size=size)
|
|
||||||
for collection, schema in schemas.items()
|
|
||||||
if (size := sizes.get(collection))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_all_tags() -> list[str]:
|
|
||||||
"""Get all unique tags used across the entire knowledge base.
|
|
||||||
|
|
||||||
Returns sorted list of tags from both observations and content.
|
|
||||||
"""
|
|
||||||
with make_session() as session:
|
|
||||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
|
||||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_all_subjects() -> list[str]:
|
|
||||||
"""Get all unique subjects from observations about the user.
|
|
||||||
|
|
||||||
Returns sorted list of subject identifiers used in observations.
|
|
||||||
"""
|
|
||||||
with make_session() as session:
|
|
||||||
return sorted(
|
|
||||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_all_observation_types() -> list[str]:
|
|
||||||
"""Get all observation types that have been used.
|
|
||||||
|
|
||||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
|
||||||
"""
|
|
||||||
with make_session() as session:
|
|
||||||
return sorted(
|
|
||||||
{
|
|
||||||
r.observation_type
|
|
||||||
for r in session.query(AgentObservation.observation_type).distinct()
|
|
||||||
if r.observation_type is not None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@ -5,11 +5,12 @@ from datetime import datetime, timezone
|
|||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||||
|
|
||||||
|
from fastmcp.server.auth import OAuthProvider
|
||||||
|
from fastmcp.server.auth.auth import AccessToken as FastMCPAccessToken
|
||||||
from mcp.server.auth.provider import (
|
from mcp.server.auth.provider import (
|
||||||
AccessToken,
|
AccessToken,
|
||||||
AuthorizationCode,
|
AuthorizationCode,
|
||||||
AuthorizationParams,
|
AuthorizationParams,
|
||||||
OAuthAuthorizationServerProvider,
|
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
)
|
)
|
||||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||||
@ -133,8 +134,45 @@ def make_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
class SimpleOAuthProvider(OAuthProvider):
|
||||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
"""OAuth provider that extends fastmcp's OAuthProvider with custom login flow."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
base_url=settings.SERVER_URL,
|
||||||
|
issuer_url=settings.SERVER_URL,
|
||||||
|
client_registration_options=None, # We handle registration ourselves
|
||||||
|
required_scopes=BASE_SCOPES,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def verify_token(self, token: str) -> FastMCPAccessToken | None:
|
||||||
|
"""Verify an access token and return token info if valid."""
|
||||||
|
with make_session() as session:
|
||||||
|
# Try as OAuth access token first
|
||||||
|
user_session = session.query(UserSession).get(token)
|
||||||
|
if user_session:
|
||||||
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
if user_session.expires_at < now:
|
||||||
|
return None
|
||||||
|
return FastMCPAccessToken(
|
||||||
|
token=token,
|
||||||
|
client_id=user_session.oauth_state.client_id,
|
||||||
|
scopes=user_session.oauth_state.scopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try as bot API key
|
||||||
|
bot = session.query(User).filter(User.api_key == token).first()
|
||||||
|
if bot:
|
||||||
|
logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key")
|
||||||
|
return FastMCPAccessToken(
|
||||||
|
token=token,
|
||||||
|
client_id=cast(str, bot.name or bot.email),
|
||||||
|
scopes=["read", "write"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||||
"""Get OAuth client information."""
|
"""Get OAuth client information."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
client = session.get(OAuthClientInformation, client_id)
|
client = session.get(OAuthClientInformation, client_id)
|
||||||
|
|||||||
17
src/memory/api/MCP/servers/__init__.py
Normal file
17
src/memory/api/MCP/servers/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""MCP subservers for composable tool organization."""
|
||||||
|
|
||||||
|
from memory.api.MCP.servers.core import core_mcp
|
||||||
|
from memory.api.MCP.servers.github import github_mcp
|
||||||
|
from memory.api.MCP.servers.people import people_mcp
|
||||||
|
from memory.api.MCP.servers.schedule import schedule_mcp
|
||||||
|
from memory.api.MCP.servers.books import books_mcp
|
||||||
|
from memory.api.MCP.servers.meta import meta_mcp
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"core_mcp",
|
||||||
|
"github_mcp",
|
||||||
|
"people_mcp",
|
||||||
|
"schedule_mcp",
|
||||||
|
"books_mcp",
|
||||||
|
"meta_mcp",
|
||||||
|
]
|
||||||
@ -1,15 +1,19 @@
|
|||||||
|
"""MCP subserver for ebook access."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from fastmcp import FastMCP
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from memory.api.MCP.tools import mcp
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Book, BookSection, BookSectionPayload
|
from memory.common.db.models import Book, BookSection, BookSectionPayload
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
books_mcp = FastMCP("memory-books")
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
|
@books_mcp.tool()
|
||||||
async def all_books(sections: bool = False) -> list[dict]:
|
async def all_books(sections: bool = False) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Get all books in the database.
|
Get all books in the database.
|
||||||
@ -31,7 +35,7 @@ async def all_books(sections: bool = False) -> list[dict]:
|
|||||||
return [book.as_payload(sections=sections) for book in books]
|
return [book.as_payload(sections=sections) for book in books]
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@books_mcp.tool()
|
||||||
def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]:
|
def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]:
|
||||||
"""
|
"""
|
||||||
Read a book from the database.
|
Read a book from the database.
|
||||||
@ -1,53 +1,45 @@
|
|||||||
"""
|
"""
|
||||||
MCP tools for the epistemic sparring partner system.
|
Core MCP subserver for knowledge base search, observations, and notes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
from fastmcp import FastMCP
|
||||||
|
from PIL import Image
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
from memory.api.MCP.base import mcp
|
|
||||||
from memory.api.search.search import search
|
from memory.api.search.search import search
|
||||||
from memory.api.search.types import SearchFilters, SearchConfig
|
from memory.api.search.types import SearchConfig, SearchFilters
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||||
from memory.common.celery_app import app as celery_app
|
from memory.common.celery_app import app as celery_app
|
||||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import SourceItem, AgentObservation
|
from memory.common.db.models import AgentObservation, SourceItem
|
||||||
from memory.common.formatters import observation
|
from memory.common.formatters import observation
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
core_mcp = FastMCP("memory-core")
|
||||||
|
|
||||||
|
|
||||||
def validate_path_within_directory(
|
def validate_path_within_directory(
|
||||||
base_dir: pathlib.Path, requested_path: str
|
base_dir: pathlib.Path, requested_path: str
|
||||||
) -> pathlib.Path:
|
) -> pathlib.Path:
|
||||||
"""Validate that a requested path resolves within the base directory.
|
"""Validate that a requested path resolves within the base directory."""
|
||||||
|
|
||||||
Prevents path traversal attacks using ../ or similar techniques.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_dir: The allowed base directory
|
|
||||||
requested_path: The user-provided path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The resolved absolute path if valid
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the path would escape the base directory
|
|
||||||
"""
|
|
||||||
resolved = (base_dir / requested_path.lstrip("/")).resolve()
|
resolved = (base_dir / requested_path.lstrip("/")).resolve()
|
||||||
base_resolved = base_dir.resolve()
|
base_resolved = base_dir.resolve()
|
||||||
|
|
||||||
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
if (
|
||||||
|
not str(resolved).startswith(str(base_resolved) + "/")
|
||||||
|
and resolved != base_resolved
|
||||||
|
):
|
||||||
raise ValueError(f"Path escapes allowed directory: {requested_path}")
|
raise ValueError(f"Path escapes allowed directory: {requested_path}")
|
||||||
|
|
||||||
return resolved
|
return resolved
|
||||||
@ -63,7 +55,6 @@ def filter_observation_source_ids(
|
|||||||
items_query = session.query(AgentObservation.id)
|
items_query = session.query(AgentObservation.id)
|
||||||
|
|
||||||
if tags:
|
if tags:
|
||||||
# Use PostgreSQL array overlap operator with proper array casting
|
|
||||||
items_query = items_query.filter(
|
items_query = items_query.filter(
|
||||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||||
)
|
)
|
||||||
@ -89,7 +80,6 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
|||||||
items_query = session.query(SourceItem.id)
|
items_query = session.query(SourceItem.id)
|
||||||
|
|
||||||
if tags:
|
if tags:
|
||||||
# Use PostgreSQL array overlap operator with proper array casting
|
|
||||||
items_query = items_query.filter(
|
items_query = items_query.filter(
|
||||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||||
)
|
)
|
||||||
@ -102,7 +92,7 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
|||||||
return source_ids
|
return source_ids
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@core_mcp.tool()
|
||||||
async def search_knowledge_base(
|
async def search_knowledge_base(
|
||||||
query: str,
|
query: str,
|
||||||
filters: SearchFilters = {},
|
filters: SearchFilters = {},
|
||||||
@ -141,7 +131,6 @@ async def search_knowledge_base(
|
|||||||
|
|
||||||
if not modalities:
|
if not modalities:
|
||||||
modalities = set(ALL_COLLECTIONS.keys())
|
modalities = set(ALL_COLLECTIONS.keys())
|
||||||
# Filter to valid collections, excluding observation collections
|
|
||||||
modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS
|
modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS
|
||||||
|
|
||||||
search_filters = SearchFilters(**filters)
|
search_filters = SearchFilters(**filters)
|
||||||
@ -167,7 +156,7 @@ class RawObservation(BaseModel):
|
|||||||
tags: list[str] = []
|
tags: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@core_mcp.tool()
|
||||||
async def observe(
|
async def observe(
|
||||||
observations: list[RawObservation],
|
observations: list[RawObservation],
|
||||||
session_id: str | None = None,
|
session_id: str | None = None,
|
||||||
@ -212,23 +201,23 @@ async def observe(
|
|||||||
logger.info("MCP: Observing")
|
logger.info("MCP: Observing")
|
||||||
tasks = [
|
tasks = [
|
||||||
(
|
(
|
||||||
observation,
|
obs,
|
||||||
celery_app.send_task(
|
celery_app.send_task(
|
||||||
SYNC_OBSERVATION,
|
SYNC_OBSERVATION,
|
||||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||||
kwargs={
|
kwargs={
|
||||||
"subject": observation.subject,
|
"subject": obs.subject,
|
||||||
"content": observation.content,
|
"content": obs.content,
|
||||||
"observation_type": observation.observation_type,
|
"observation_type": obs.observation_type,
|
||||||
"confidences": observation.confidences,
|
"confidences": obs.confidences,
|
||||||
"evidence": observation.evidence,
|
"evidence": obs.evidence,
|
||||||
"tags": observation.tags,
|
"tags": obs.tags,
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"agent_model": agent_model,
|
"agent_model": agent_model,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for observation in observations
|
for obs in observations
|
||||||
]
|
]
|
||||||
|
|
||||||
def short_content(obs: RawObservation) -> str:
|
def short_content(obs: RawObservation) -> str:
|
||||||
@ -242,7 +231,7 @@ async def observe(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@core_mcp.tool()
|
||||||
async def search_observations(
|
async def search_observations(
|
||||||
query: str,
|
query: str,
|
||||||
subject: str = "",
|
subject: str = "",
|
||||||
@ -308,7 +297,7 @@ async def search_observations(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@core_mcp.tool()
|
||||||
async def create_note(
|
async def create_note(
|
||||||
subject: str,
|
subject: str,
|
||||||
content: str,
|
content: str,
|
||||||
@ -362,7 +351,7 @@ async def create_note(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@core_mcp.tool()
|
||||||
async def note_files(path: str = "/"):
|
async def note_files(path: str = "/"):
|
||||||
"""
|
"""
|
||||||
List note files in the user's note storage.
|
List note files in the user's note storage.
|
||||||
@ -386,7 +375,7 @@ async def note_files(path: str = "/"):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@core_mcp.tool()
|
||||||
def fetch_file(filename: str) -> dict:
|
def fetch_file(filename: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Read file content with automatic type detection.
|
Read file content with automatic type detection.
|
||||||
@ -1,14 +1,14 @@
|
|||||||
"""MCP tools for GitHub issue tracking and management."""
|
"""MCP subserver for GitHub issue tracking and management."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import Text, case, desc, func, asc
|
from fastmcp import FastMCP
|
||||||
|
from sqlalchemy import Text, case, desc, func
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
from memory.api.MCP.base import mcp
|
|
||||||
from memory.api.search.search import search
|
from memory.api.search.search import search
|
||||||
from memory.api.search.types import SearchConfig, SearchFilters
|
from memory.api.search.types import SearchConfig, SearchFilters
|
||||||
from memory.common import extract
|
from memory.common import extract
|
||||||
@ -17,6 +17,8 @@ from memory.common.db.models import GithubItem
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
github_mcp = FastMCP("memory-github")
|
||||||
|
|
||||||
|
|
||||||
def _build_github_url(repo_path: str, number: int | None, kind: str) -> str:
|
def _build_github_url(repo_path: str, number: int | None, kind: str) -> str:
|
||||||
"""Build GitHub URL from repo path and issue/PR number."""
|
"""Build GitHub URL from repo path and issue/PR number."""
|
||||||
@ -53,7 +55,6 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
|
|||||||
}
|
}
|
||||||
if include_content:
|
if include_content:
|
||||||
result["content"] = item.content
|
result["content"] = item.content
|
||||||
# Include PR-specific data if available
|
|
||||||
if item.kind == "pr" and item.pr_data:
|
if item.kind == "pr" and item.pr_data:
|
||||||
result["pr_data"] = {
|
result["pr_data"] = {
|
||||||
"additions": item.pr_data.additions,
|
"additions": item.pr_data.additions,
|
||||||
@ -62,12 +63,12 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st
|
|||||||
"files": item.pr_data.files,
|
"files": item.pr_data.files,
|
||||||
"reviews": item.pr_data.reviews,
|
"reviews": item.pr_data.reviews,
|
||||||
"review_comments": item.pr_data.review_comments,
|
"review_comments": item.pr_data.review_comments,
|
||||||
"diff": item.pr_data.diff, # Decompressed via property
|
"diff": item.pr_data.diff,
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@github_mcp.tool()
|
||||||
async def list_github_issues(
|
async def list_github_issues(
|
||||||
repo: str | None = None,
|
repo: str | None = None,
|
||||||
assignee: str | None = None,
|
assignee: str | None = None,
|
||||||
@ -109,64 +110,48 @@ async def list_github_issues(
|
|||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
query = session.query(GithubItem)
|
query = session.query(GithubItem)
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if repo:
|
if repo:
|
||||||
query = query.filter(GithubItem.repo_path == repo)
|
query = query.filter(GithubItem.repo_path == repo)
|
||||||
|
|
||||||
if assignee:
|
if assignee:
|
||||||
query = query.filter(GithubItem.assignees.any(assignee))
|
query = query.filter(GithubItem.assignees.any(assignee))
|
||||||
|
|
||||||
if author:
|
if author:
|
||||||
query = query.filter(GithubItem.author == author)
|
query = query.filter(GithubItem.author == author)
|
||||||
|
|
||||||
if state:
|
if state:
|
||||||
query = query.filter(GithubItem.state == state)
|
query = query.filter(GithubItem.state == state)
|
||||||
|
|
||||||
if kind:
|
if kind:
|
||||||
query = query.filter(GithubItem.kind == kind)
|
query = query.filter(GithubItem.kind == kind)
|
||||||
else:
|
else:
|
||||||
# Exclude comments by default, only show issues and PRs
|
|
||||||
query = query.filter(GithubItem.kind.in_(["issue", "pr"]))
|
query = query.filter(GithubItem.kind.in_(["issue", "pr"]))
|
||||||
|
|
||||||
if labels:
|
if labels:
|
||||||
# Match any label in the list using PostgreSQL array overlap
|
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text)))
|
GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text)))
|
||||||
)
|
)
|
||||||
|
|
||||||
if project_status:
|
if project_status:
|
||||||
query = query.filter(GithubItem.project_status == project_status)
|
query = query.filter(GithubItem.project_status == project_status)
|
||||||
|
|
||||||
if project_field:
|
if project_field:
|
||||||
for key, value in project_field.items():
|
for key, value in project_field.items():
|
||||||
query = query.filter(
|
query = query.filter(GithubItem.project_fields[key].astext == value)
|
||||||
GithubItem.project_fields[key].astext == value
|
|
||||||
)
|
|
||||||
|
|
||||||
if updated_since:
|
if updated_since:
|
||||||
since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00"))
|
since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00"))
|
||||||
query = query.filter(GithubItem.github_updated_at >= since_dt)
|
query = query.filter(GithubItem.github_updated_at >= since_dt)
|
||||||
|
|
||||||
if updated_before:
|
if updated_before:
|
||||||
before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00"))
|
before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00"))
|
||||||
query = query.filter(GithubItem.github_updated_at <= before_dt)
|
query = query.filter(GithubItem.github_updated_at <= before_dt)
|
||||||
|
|
||||||
# Apply ordering
|
|
||||||
if order_by == "created":
|
if order_by == "created":
|
||||||
query = query.order_by(desc(GithubItem.created_at))
|
query = query.order_by(desc(GithubItem.created_at))
|
||||||
elif order_by == "number":
|
elif order_by == "number":
|
||||||
query = query.order_by(desc(GithubItem.number))
|
query = query.order_by(desc(GithubItem.number))
|
||||||
else: # default: updated
|
else:
|
||||||
query = query.order_by(desc(GithubItem.github_updated_at))
|
query = query.order_by(desc(GithubItem.github_updated_at))
|
||||||
|
|
||||||
query = query.limit(limit)
|
query = query.limit(limit)
|
||||||
|
|
||||||
items = query.all()
|
items = query.all()
|
||||||
|
|
||||||
return [_serialize_issue(item) for item in items]
|
return [_serialize_issue(item) for item in items]
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@github_mcp.tool()
|
||||||
async def search_github_issues(
|
async def search_github_issues(
|
||||||
query: str,
|
query: str,
|
||||||
repo: str | None = None,
|
repo: str | None = None,
|
||||||
@ -191,7 +176,6 @@ async def search_github_issues(
|
|||||||
|
|
||||||
limit = min(limit, 100)
|
limit = min(limit, 100)
|
||||||
|
|
||||||
# Pre-filter source_ids if repo/state/kind filters are specified
|
|
||||||
source_ids = None
|
source_ids = None
|
||||||
if repo or state or kind:
|
if repo or state or kind:
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
@ -206,7 +190,6 @@ async def search_github_issues(
|
|||||||
q = q.filter(GithubItem.kind.in_(["issue", "pr"]))
|
q = q.filter(GithubItem.kind.in_(["issue", "pr"]))
|
||||||
source_ids = [item.id for item in q.all()]
|
source_ids = [item.id for item in q.all()]
|
||||||
|
|
||||||
# Use the existing search infrastructure
|
|
||||||
data = extract.extract_text(query, skip_summary=True)
|
data = extract.extract_text(query, skip_summary=True)
|
||||||
config = SearchConfig(limit=limit, previews=True)
|
config = SearchConfig(limit=limit, previews=True)
|
||||||
filters = SearchFilters()
|
filters = SearchFilters()
|
||||||
@ -220,7 +203,6 @@ async def search_github_issues(
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fetch full issue details for the results
|
|
||||||
output = []
|
output = []
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
for result in results:
|
for result in results:
|
||||||
@ -233,7 +215,7 @@ async def search_github_issues(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@github_mcp.tool()
|
||||||
async def github_issue_details(
|
async def github_issue_details(
|
||||||
repo: str,
|
repo: str,
|
||||||
number: int,
|
number: int,
|
||||||
@ -267,7 +249,7 @@ async def github_issue_details(
|
|||||||
return _serialize_issue(item, include_content=True)
|
return _serialize_issue(item, include_content=True)
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@github_mcp.tool()
|
||||||
async def github_work_summary(
|
async def github_work_summary(
|
||||||
since: str,
|
since: str,
|
||||||
until: str | None = None,
|
until: str | None = None,
|
||||||
@ -294,7 +276,6 @@ async def github_work_summary(
|
|||||||
else:
|
else:
|
||||||
until_dt = datetime.now(timezone.utc)
|
until_dt = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# Map group_by to SQL expression
|
|
||||||
group_mappings = {
|
group_mappings = {
|
||||||
"client": GithubItem.project_fields["EquiStamp.Client"].astext,
|
"client": GithubItem.project_fields["EquiStamp.Client"].astext,
|
||||||
"status": GithubItem.project_status,
|
"status": GithubItem.project_status,
|
||||||
@ -311,7 +292,6 @@ async def github_work_summary(
|
|||||||
group_col = group_mappings[group_by]
|
group_col = group_mappings[group_by]
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
# Build base query for the period
|
|
||||||
base_query = session.query(GithubItem).filter(
|
base_query = session.query(GithubItem).filter(
|
||||||
GithubItem.github_updated_at >= since_dt,
|
GithubItem.github_updated_at >= since_dt,
|
||||||
GithubItem.github_updated_at <= until_dt,
|
GithubItem.github_updated_at <= until_dt,
|
||||||
@ -321,7 +301,6 @@ async def github_work_summary(
|
|||||||
if repo:
|
if repo:
|
||||||
base_query = base_query.filter(GithubItem.repo_path == repo)
|
base_query = base_query.filter(GithubItem.repo_path == repo)
|
||||||
|
|
||||||
# Get aggregated counts by group
|
|
||||||
agg_query = (
|
agg_query = (
|
||||||
session.query(
|
session.query(
|
||||||
group_col.label("group_name"),
|
group_col.label("group_name"),
|
||||||
@ -346,7 +325,6 @@ async def github_work_summary(
|
|||||||
|
|
||||||
groups = agg_query.all()
|
groups = agg_query.all()
|
||||||
|
|
||||||
# Build summary with sample issues for each group
|
|
||||||
summary = []
|
summary = []
|
||||||
total_issues = 0
|
total_issues = 0
|
||||||
total_prs = 0
|
total_prs = 0
|
||||||
@ -358,7 +336,6 @@ async def github_work_summary(
|
|||||||
total_issues += issue_count
|
total_issues += issue_count
|
||||||
total_prs += pr_count
|
total_prs += pr_count
|
||||||
|
|
||||||
# Get sample issues for this group
|
|
||||||
sample_query = base_query.filter(group_col == group_name).limit(5)
|
sample_query = base_query.filter(group_col == group_name).limit(5)
|
||||||
samples = [
|
samples = [
|
||||||
{
|
{
|
||||||
@ -394,7 +371,7 @@ async def github_work_summary(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@github_mcp.tool()
|
||||||
async def github_repo_overview(
|
async def github_repo_overview(
|
||||||
repo: str,
|
repo: str,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@ -410,13 +387,6 @@ async def github_repo_overview(
|
|||||||
logger.info(f"github_repo_overview called: repo={repo}")
|
logger.info(f"github_repo_overview called: repo={repo}")
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
# Base query for this repo
|
|
||||||
base_query = session.query(GithubItem).filter(
|
|
||||||
GithubItem.repo_path == repo,
|
|
||||||
GithubItem.kind.in_(["issue", "pr"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get total counts
|
|
||||||
counts_query = session.query(
|
counts_query = session.query(
|
||||||
func.count(GithubItem.id).label("total"),
|
func.count(GithubItem.id).label("total"),
|
||||||
func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"),
|
func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"),
|
||||||
@ -441,7 +411,6 @@ async def github_repo_overview(
|
|||||||
|
|
||||||
counts = counts_query.first()
|
counts = counts_query.first()
|
||||||
|
|
||||||
# Status breakdown (for project_status)
|
|
||||||
status_query = (
|
status_query = (
|
||||||
session.query(
|
session.query(
|
||||||
GithubItem.project_status.label("status"),
|
GithubItem.project_status.label("status"),
|
||||||
@ -458,7 +427,6 @@ async def github_repo_overview(
|
|||||||
|
|
||||||
status_breakdown = {row.status: row.count for row in status_query.all()}
|
status_breakdown = {row.status: row.count for row in status_query.all()}
|
||||||
|
|
||||||
# Top assignees (open issues only)
|
|
||||||
assignee_query = (
|
assignee_query = (
|
||||||
session.query(
|
session.query(
|
||||||
func.unnest(GithubItem.assignees).label("assignee"),
|
func.unnest(GithubItem.assignees).label("assignee"),
|
||||||
@ -479,7 +447,6 @@ async def github_repo_overview(
|
|||||||
for row in assignee_query.all()
|
for row in assignee_query.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
# Label counts
|
|
||||||
label_query = (
|
label_query = (
|
||||||
session.query(
|
session.query(
|
||||||
func.unnest(GithubItem.labels).label("label"),
|
func.unnest(GithubItem.labels).label("label"),
|
||||||
277
src/memory/api/MCP/servers/meta.py
Normal file
277
src/memory/api/MCP/servers/meta.py
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
"""MCP subserver for metadata, utilities, and forecasting."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Annotated, Literal, NotRequired, TypedDict, get_args, get_type_hints
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from fastmcp import FastMCP
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from memory.common import qdrant
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import SourceItem
|
||||||
|
from memory.common.db.models.source_items import AgentObservation
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
meta_mcp = FastMCP("memory-meta")
|
||||||
|
|
||||||
|
# Auth provider will be injected at mount time
|
||||||
|
_get_current_user = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_auth_provider(get_current_user_func):
|
||||||
|
"""Set the authentication provider function."""
|
||||||
|
global _get_current_user
|
||||||
|
_get_current_user = get_current_user_func
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user() -> dict:
|
||||||
|
"""Get the current authenticated user."""
|
||||||
|
if _get_current_user is None:
|
||||||
|
return {"authenticated": False, "error": "Auth provider not configured"}
|
||||||
|
return _get_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Metadata tools ---
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaArg(TypedDict):
|
||||||
|
type: str | None
|
||||||
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionMetadata(TypedDict):
|
||||||
|
schema: dict[str, SchemaArg]
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
def from_annotation(annotation: Annotated) -> SchemaArg | None:
|
||||||
|
try:
|
||||||
|
type_, description = get_args(annotation)
|
||||||
|
type_str = str(type_)
|
||||||
|
if type_str.startswith("typing."):
|
||||||
|
type_str = type_str[7:]
|
||||||
|
elif len((parts := type_str.split("'"))) > 1:
|
||||||
|
type_str = parts[1]
|
||||||
|
return SchemaArg(type=type_str, description=description)
|
||||||
|
except IndexError:
|
||||||
|
logger.error(f"Error from annotation: {annotation}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
|
||||||
|
if not hasattr(klass, "as_payload"):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
name: schema
|
||||||
|
for name, arg in payload_type.__annotations__.items()
|
||||||
|
if (schema := from_annotation(arg))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@meta_mcp.tool()
|
||||||
|
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
|
||||||
|
"""Get the metadata schema for each collection used in the knowledge base.
|
||||||
|
|
||||||
|
These schemas can be used to filter the knowledge base.
|
||||||
|
|
||||||
|
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
|
||||||
|
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
sizes = qdrant.get_collection_sizes(client)
|
||||||
|
schemas = defaultdict(dict)
|
||||||
|
for klass in SourceItem.__subclasses__():
|
||||||
|
for collection in klass.get_collections():
|
||||||
|
schemas[collection].update(get_schema(klass))
|
||||||
|
|
||||||
|
return {
|
||||||
|
collection: CollectionMetadata(schema=schema, size=size)
|
||||||
|
for collection, schema in schemas.items()
|
||||||
|
if (size := sizes.get(collection))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@meta_mcp.tool()
|
||||||
|
async def get_all_tags() -> list[str]:
|
||||||
|
"""Get all unique tags used across the entire knowledge base.
|
||||||
|
|
||||||
|
Returns sorted list of tags from both observations and content.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||||
|
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||||
|
|
||||||
|
|
||||||
|
@meta_mcp.tool()
|
||||||
|
async def get_all_subjects() -> list[str]:
|
||||||
|
"""Get all unique subjects from observations about the user.
|
||||||
|
|
||||||
|
Returns sorted list of subject identifiers used in observations.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@meta_mcp.tool()
|
||||||
|
async def get_all_observation_types() -> list[str]:
|
||||||
|
"""Get all observation types that have been used.
|
||||||
|
|
||||||
|
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
{
|
||||||
|
r.observation_type
|
||||||
|
for r in session.query(AgentObservation.observation_type).distinct()
|
||||||
|
if r.observation_type is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Utility tools ---
|
||||||
|
|
||||||
|
|
||||||
|
@meta_mcp.tool()
|
||||||
|
async def get_current_time() -> dict:
|
||||||
|
"""Get the current time in UTC."""
|
||||||
|
logger.info("get_current_time tool called")
|
||||||
|
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||||
|
|
||||||
|
|
||||||
|
@meta_mcp.tool()
|
||||||
|
async def get_authenticated_user() -> dict:
|
||||||
|
"""Get information about the authenticated user."""
|
||||||
|
return get_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Forecasting tools ---
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryProbs(TypedDict):
|
||||||
|
prob: float
|
||||||
|
|
||||||
|
|
||||||
|
class MultiProbs(TypedDict):
|
||||||
|
answerProbs: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
Probs = dict[str, BinaryProbs | MultiProbs]
|
||||||
|
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
|
||||||
|
|
||||||
|
|
||||||
|
class MarketAnswer(TypedDict):
|
||||||
|
id: str
|
||||||
|
text: str
|
||||||
|
resolutionProbability: float
|
||||||
|
|
||||||
|
|
||||||
|
class MarketDetails(TypedDict):
|
||||||
|
id: str
|
||||||
|
createdTime: int
|
||||||
|
question: str
|
||||||
|
outcomeType: OutcomeType
|
||||||
|
textDescription: str
|
||||||
|
groupSlugs: list[str]
|
||||||
|
volume: float
|
||||||
|
isResolved: bool
|
||||||
|
answers: list[MarketAnswer]
|
||||||
|
|
||||||
|
|
||||||
|
class Market(TypedDict):
|
||||||
|
id: str
|
||||||
|
url: str
|
||||||
|
question: str
|
||||||
|
volume: int
|
||||||
|
createdTime: int
|
||||||
|
outcomeType: OutcomeType
|
||||||
|
createdAt: NotRequired[str]
|
||||||
|
description: NotRequired[str]
|
||||||
|
answers: NotRequired[dict[str, float]]
|
||||||
|
probability: NotRequired[float]
|
||||||
|
details: NotRequired[MarketDetails]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_details(session: aiohttp.ClientSession, market_id: str):
|
||||||
|
async with session.get(
|
||||||
|
f"https://api.manifold.markets/v0/market/{market_id}"
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def format_market(session: aiohttp.ClientSession, market: Market):
|
||||||
|
if market.get("outcomeType") != "BINARY":
|
||||||
|
details = await get_details(session, market["id"])
|
||||||
|
market["answers"] = {
|
||||||
|
answer["text"]: round(
|
||||||
|
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
|
||||||
|
)
|
||||||
|
for answer in details["answers"]
|
||||||
|
}
|
||||||
|
if creationTime := market.get("createdTime"):
|
||||||
|
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"url",
|
||||||
|
"question",
|
||||||
|
"volume",
|
||||||
|
"createdAt",
|
||||||
|
"details",
|
||||||
|
"probability",
|
||||||
|
"answers",
|
||||||
|
]
|
||||||
|
return {k: v for k, v in market.items() if k in fields}
|
||||||
|
|
||||||
|
|
||||||
|
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
"https://api.manifold.markets/v0/search-markets",
|
||||||
|
params={
|
||||||
|
"term": term,
|
||||||
|
"contractType": "BINARY" if binary else "ALL",
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
markets = await resp.json()
|
||||||
|
|
||||||
|
return await asyncio.gather(
|
||||||
|
*[
|
||||||
|
format_market(session, market)
|
||||||
|
for market in markets
|
||||||
|
if market.get("volume", 0) >= min_volume
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@meta_mcp.tool()
|
||||||
|
async def get_forecasts(
|
||||||
|
term: str, min_volume: int = 1000, binary: bool = False
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Get prediction market forecasts for a given term.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
term: The term to search for.
|
||||||
|
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
|
||||||
|
binary: Whether to only return binary markets.
|
||||||
|
"""
|
||||||
|
return await search_markets(term, min_volume, binary)
|
||||||
@ -1,23 +1,23 @@
|
|||||||
"""
|
"""MCP subserver for tracking people."""
|
||||||
MCP tools for tracking people.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from fastmcp import FastMCP
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
from memory.api.MCP.base import mcp
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
|
||||||
from memory.common.db.models import Person
|
|
||||||
from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON
|
from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON
|
||||||
from memory.common.celery_app import app as celery_app
|
from memory.common.celery_app import app as celery_app
|
||||||
from memory.common import settings
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import Person
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
people_mcp = FastMCP("memory-people")
|
||||||
|
|
||||||
|
|
||||||
def _person_to_dict(person: Person) -> dict[str, Any]:
|
def _person_to_dict(person: Person) -> dict[str, Any]:
|
||||||
"""Convert a Person model to a dictionary for API responses."""
|
"""Convert a Person model to a dictionary for API responses."""
|
||||||
@ -32,7 +32,7 @@ def _person_to_dict(person: Person) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@people_mcp.tool()
|
||||||
async def add_person(
|
async def add_person(
|
||||||
identifier: str,
|
identifier: str,
|
||||||
display_name: str,
|
display_name: str,
|
||||||
@ -67,7 +67,6 @@ async def add_person(
|
|||||||
"""
|
"""
|
||||||
logger.info(f"MCP: Adding person: {identifier}")
|
logger.info(f"MCP: Adding person: {identifier}")
|
||||||
|
|
||||||
# Check if person already exists
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
existing = session.query(Person).filter(Person.identifier == identifier).first()
|
existing = session.query(Person).filter(Person.identifier == identifier).first()
|
||||||
if existing:
|
if existing:
|
||||||
@ -93,7 +92,7 @@ async def add_person(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@people_mcp.tool()
|
||||||
async def update_person_info(
|
async def update_person_info(
|
||||||
identifier: str,
|
identifier: str,
|
||||||
display_name: str | None = None,
|
display_name: str | None = None,
|
||||||
@ -135,7 +134,6 @@ async def update_person_info(
|
|||||||
"""
|
"""
|
||||||
logger.info(f"MCP: Updating person: {identifier}")
|
logger.info(f"MCP: Updating person: {identifier}")
|
||||||
|
|
||||||
# Verify person exists
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
person = session.query(Person).filter(Person.identifier == identifier).first()
|
person = session.query(Person).filter(Person.identifier == identifier).first()
|
||||||
if not person:
|
if not person:
|
||||||
@ -162,7 +160,7 @@ async def update_person_info(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@people_mcp.tool()
|
||||||
async def get_person(identifier: str) -> dict | None:
|
async def get_person(identifier: str) -> dict | None:
|
||||||
"""
|
"""
|
||||||
Get a person by their identifier.
|
Get a person by their identifier.
|
||||||
@ -182,7 +180,7 @@ async def get_person(identifier: str) -> dict | None:
|
|||||||
return _person_to_dict(person)
|
return _person_to_dict(person)
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@people_mcp.tool()
|
||||||
async def list_people(
|
async def list_people(
|
||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
@ -207,9 +205,7 @@ async def list_people(
|
|||||||
query = session.query(Person)
|
query = session.query(Person)
|
||||||
|
|
||||||
if tags:
|
if tags:
|
||||||
query = query.filter(
|
query = query.filter(Person.tags.op("&&")(sql_cast(tags, ARRAY(Text))))
|
||||||
Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
|
|
||||||
)
|
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
search_term = f"%{search.lower()}%"
|
search_term = f"%{search.lower()}%"
|
||||||
@ -225,7 +221,7 @@ async def list_people(
|
|||||||
return [_person_to_dict(p) for p in people]
|
return [_person_to_dict(p) for p in people]
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@people_mcp.tool()
|
||||||
async def delete_person(identifier: str) -> dict:
|
async def delete_person(identifier: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Delete a person by their identifier.
|
Delete a person by their identifier.
|
||||||
@ -1,20 +1,37 @@
|
|||||||
"""
|
"""MCP subserver for scheduling messages and LLM calls."""
|
||||||
MCP tools for the epistemic sparring partner system.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from memory.api.MCP.base import get_current_user, mcp
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import ScheduledLLMCall, DiscordBotUser
|
from memory.common.db.models import DiscordBotUser, ScheduledLLMCall
|
||||||
from memory.discord.messages import schedule_discord_message
|
from memory.discord.messages import schedule_discord_message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
schedule_mcp = FastMCP("memory-schedule")
|
||||||
|
|
||||||
@mcp.tool()
|
# We need access to get_current_user from base - this will be injected at mount time
|
||||||
|
_get_current_user = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_auth_provider(get_current_user_func):
|
||||||
|
"""Set the authentication provider function."""
|
||||||
|
global _get_current_user
|
||||||
|
_get_current_user = get_current_user_func
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user() -> dict:
|
||||||
|
"""Get the current authenticated user."""
|
||||||
|
if _get_current_user is None:
|
||||||
|
return {"authenticated": False, "error": "Auth provider not configured"}
|
||||||
|
return _get_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
@schedule_mcp.tool()
|
||||||
async def schedule_message(
|
async def schedule_message(
|
||||||
scheduled_time: str,
|
scheduled_time: str,
|
||||||
message: str,
|
message: str,
|
||||||
@ -61,10 +78,8 @@ async def schedule_message(
|
|||||||
if not discord_user and not discord_channel:
|
if not discord_user and not discord_channel:
|
||||||
raise ValueError("Either discord_user or discord_channel must be provided")
|
raise ValueError("Either discord_user or discord_channel must be provided")
|
||||||
|
|
||||||
# Parse scheduled time
|
|
||||||
try:
|
try:
|
||||||
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
|
scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00"))
|
||||||
# Ensure we store as naive datetime (remove timezone info for database storage)
|
|
||||||
if scheduled_dt.tzinfo is not None:
|
if scheduled_dt.tzinfo is not None:
|
||||||
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
|
scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -98,7 +113,7 @@ async def schedule_message(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@schedule_mcp.tool()
|
||||||
async def list_scheduled_llm_calls(
|
async def list_scheduled_llm_calls(
|
||||||
status: str | None = None, limit: int | None = 50
|
status: str | None = None, limit: int | None = 50
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@ -143,7 +158,7 @@ async def list_scheduled_llm_calls(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@schedule_mcp.tool()
|
||||||
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Cancel a scheduled LLM call.
|
Cancel a scheduled LLM call.
|
||||||
@ -164,7 +179,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
|||||||
return {"error": "User not found", "user": current_user}
|
return {"error": "User not found", "user": current_user}
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
# Find the scheduled call
|
|
||||||
scheduled_call = (
|
scheduled_call = (
|
||||||
session.query(ScheduledLLMCall)
|
session.query(ScheduledLLMCall)
|
||||||
.filter(
|
.filter(
|
||||||
@ -180,7 +194,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]:
|
|||||||
if not scheduled_call.can_be_cancelled():
|
if not scheduled_call.can_be_cancelled():
|
||||||
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
|
return {"error": f"Cannot cancel call with status: {scheduled_call.status}"}
|
||||||
|
|
||||||
# Update the status
|
|
||||||
scheduled_call.status = "cancelled"
|
scheduled_call.status = "cancelled"
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
@ -1,74 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP tools for the epistemic sparring partner system.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from sqlalchemy import Text
|
|
||||||
from sqlalchemy import cast as sql_cast
|
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
|
||||||
from memory.common.db.models import AgentObservation, SourceItem
|
|
||||||
from memory.api.MCP.base import mcp, get_current_user
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_observation_source_ids(
|
|
||||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
|
||||||
):
|
|
||||||
if not tags and not observation_types:
|
|
||||||
return None
|
|
||||||
|
|
||||||
with make_session() as session:
|
|
||||||
items_query = session.query(AgentObservation.id)
|
|
||||||
|
|
||||||
if tags:
|
|
||||||
# Use PostgreSQL array overlap operator with proper array casting
|
|
||||||
items_query = items_query.filter(
|
|
||||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
|
||||||
)
|
|
||||||
if observation_types:
|
|
||||||
items_query = items_query.filter(
|
|
||||||
AgentObservation.observation_type.in_(observation_types)
|
|
||||||
)
|
|
||||||
source_ids = [item.id for item in items_query.all()]
|
|
||||||
|
|
||||||
return source_ids
|
|
||||||
|
|
||||||
|
|
||||||
def filter_source_ids(
|
|
||||||
modalities: set[str],
|
|
||||||
tags: list[str] | None = None,
|
|
||||||
):
|
|
||||||
if not tags:
|
|
||||||
return None
|
|
||||||
|
|
||||||
with make_session() as session:
|
|
||||||
items_query = session.query(SourceItem.id)
|
|
||||||
|
|
||||||
if tags:
|
|
||||||
# Use PostgreSQL array overlap operator with proper array casting
|
|
||||||
items_query = items_query.filter(
|
|
||||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
|
||||||
)
|
|
||||||
if modalities:
|
|
||||||
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
|
||||||
source_ids = [item.id for item in items_query.all()]
|
|
||||||
|
|
||||||
return source_ids
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_current_time() -> dict:
|
|
||||||
"""Get the current time in UTC."""
|
|
||||||
logger.info("get_current_time tool called")
|
|
||||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_authenticated_user() -> dict:
|
|
||||||
"""Get information about the authenticated user."""
|
|
||||||
return get_current_user()
|
|
||||||
@ -2,7 +2,6 @@
|
|||||||
FastAPI application for the knowledge base.
|
FastAPI application for the knowledge base.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
@ -35,15 +34,10 @@ limiter = Limiter(
|
|||||||
enabled=settings.API_RATE_LIMIT_ENABLED,
|
enabled=settings.API_RATE_LIMIT_ENABLED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create the MCP http app to get its lifespan
|
||||||
|
mcp_http_app = mcp.http_app(stateless_http=True)
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
app = FastAPI(title="Knowledge Base API", lifespan=mcp_http_app.lifespan)
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
async with contextlib.AsyncExitStack() as stack:
|
|
||||||
await stack.enter_async_context(mcp.session_manager.run())
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
|
||||||
app.state.limiter = limiter
|
app.state.limiter = limiter
|
||||||
|
|
||||||
# Rate limit exception handler
|
# Rate limit exception handler
|
||||||
@ -158,46 +152,9 @@ app.include_router(auth_router)
|
|||||||
|
|
||||||
|
|
||||||
# Add health check to MCP server instead of main app
|
# Add health check to MCP server instead of main app
|
||||||
@mcp.custom_route("/health", methods=["GET"])
|
|
||||||
async def health_check(request: Request):
|
|
||||||
"""Health check endpoint that verifies all dependencies are accessible."""
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from sqlalchemy import text
|
|
||||||
|
|
||||||
checks = {"mcp_oauth": "enabled"}
|
|
||||||
all_healthy = True
|
|
||||||
|
|
||||||
# Check database connection
|
|
||||||
try:
|
|
||||||
with engine.connect() as conn:
|
|
||||||
conn.execute(text("SELECT 1"))
|
|
||||||
checks["database"] = "healthy"
|
|
||||||
except Exception as e:
|
|
||||||
# Log error details but don't expose in response
|
|
||||||
logger.error(f"Database health check failed: {e}")
|
|
||||||
checks["database"] = "unhealthy"
|
|
||||||
all_healthy = False
|
|
||||||
|
|
||||||
# Check Qdrant connection
|
|
||||||
try:
|
|
||||||
from memory.common.qdrant import get_qdrant_client
|
|
||||||
|
|
||||||
client = get_qdrant_client()
|
|
||||||
client.get_collections()
|
|
||||||
checks["qdrant"] = "healthy"
|
|
||||||
except Exception as e:
|
|
||||||
# Log error details but don't expose in response
|
|
||||||
logger.error(f"Qdrant health check failed: {e}")
|
|
||||||
checks["qdrant"] = "unhealthy"
|
|
||||||
all_healthy = False
|
|
||||||
|
|
||||||
checks["status"] = "healthy" if all_healthy else "degraded"
|
|
||||||
status_code = 200 if all_healthy else 503
|
|
||||||
return JSONResponse(checks, status_code=status_code)
|
|
||||||
|
|
||||||
|
|
||||||
# Mount MCP server at root - OAuth endpoints need to be at root level
|
# Mount MCP server at root - OAuth endpoints need to be at root level
|
||||||
app.mount("/", mcp.streamable_http_app())
|
# Health check is defined in MCP/base.py
|
||||||
|
app.mount("/", mcp_http_app)
|
||||||
|
|
||||||
|
|
||||||
def main(reload: bool = False):
|
def main(reload: bool = False):
|
||||||
|
|||||||
@ -58,7 +58,6 @@ sync_code() {
|
|||||||
"$PROJECT_DIR/frontend" \
|
"$PROJECT_DIR/frontend" \
|
||||||
"$PROJECT_DIR/requirements" \
|
"$PROJECT_DIR/requirements" \
|
||||||
"$PROJECT_DIR/setup.py" \
|
"$PROJECT_DIR/setup.py" \
|
||||||
"$PROJECT_DIR/pyproject.toml" \
|
|
||||||
"$PROJECT_DIR/docker-compose.yaml" \
|
"$PROJECT_DIR/docker-compose.yaml" \
|
||||||
"$PROJECT_DIR/pytest.ini" \
|
"$PROJECT_DIR/pytest.ini" \
|
||||||
"$REMOTE_HOST:$REMOTE_DIR/"
|
"$REMOTE_HOST:$REMOTE_DIR/"
|
||||||
|
|||||||
195
tools/diagnose.sh
Executable file
195
tools/diagnose.sh
Executable file
@ -0,0 +1,195 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
REMOTE_HOST="memory"
|
||||||
|
REMOTE_DIR="/home/ec2-user/memory"
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m'
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
echo "Usage: $0 <command> [options]"
|
||||||
|
echo ""
|
||||||
|
echo "Safe diagnostic commands for the memory server."
|
||||||
|
echo ""
|
||||||
|
echo "Commands:"
|
||||||
|
echo " logs [service] [lines] View docker logs (default: all services, 100 lines)"
|
||||||
|
echo " ps Show docker container status"
|
||||||
|
echo " disk Show disk usage"
|
||||||
|
echo " mem Show memory usage"
|
||||||
|
echo " top Show running processes"
|
||||||
|
echo " ls <path> List directory contents"
|
||||||
|
echo " cat <file> View file contents"
|
||||||
|
echo " tail <file> [lines] Tail a file (default: 50 lines)"
|
||||||
|
echo " grep <pattern> <path> Search for pattern in files"
|
||||||
|
echo " db <query> Run read-only SQL query"
|
||||||
|
echo " get <path> [port] GET request to localhost (default port: 8000)"
|
||||||
|
echo " status Overall system status"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
remote() {
|
||||||
|
ssh "$REMOTE_HOST" "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
docker_logs() {
|
||||||
|
local service="${1:-}"
|
||||||
|
local lines="${2:-100}"
|
||||||
|
if [ -n "$service" ]; then
|
||||||
|
echo -e "${GREEN}Logs for $service (last $lines lines):${NC}"
|
||||||
|
remote "cd $REMOTE_DIR && docker compose logs --tail=$lines $service"
|
||||||
|
else
|
||||||
|
echo -e "${GREEN}All logs (last $lines lines):${NC}"
|
||||||
|
remote "cd $REMOTE_DIR && docker compose logs --tail=$lines"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
docker_ps() {
|
||||||
|
echo -e "${GREEN}Container status:${NC}"
|
||||||
|
remote "cd $REMOTE_DIR && docker compose ps"
|
||||||
|
}
|
||||||
|
|
||||||
|
disk_usage() {
|
||||||
|
echo -e "${GREEN}Disk usage:${NC}"
|
||||||
|
remote "df -h && echo '' && du -sh $REMOTE_DIR/* 2>/dev/null | sort -h"
|
||||||
|
}
|
||||||
|
|
||||||
|
mem_usage() {
|
||||||
|
echo -e "${GREEN}Memory usage:${NC}"
|
||||||
|
remote "free -h"
|
||||||
|
}
|
||||||
|
|
||||||
|
show_top() {
|
||||||
|
echo -e "${GREEN}Top processes:${NC}"
|
||||||
|
remote "ps aux --sort=-%mem | head -20"
|
||||||
|
}
|
||||||
|
|
||||||
|
list_dir() {
|
||||||
|
local path="${1:-.}"
|
||||||
|
# Ensure path is within project directory for safety
|
||||||
|
echo -e "${GREEN}Contents of $path:${NC}"
|
||||||
|
remote "cd $REMOTE_DIR && ls -la $path"
|
||||||
|
}
|
||||||
|
|
||||||
|
cat_file() {
|
||||||
|
local file="$1"
|
||||||
|
if [ -z "$file" ]; then
|
||||||
|
echo -e "${RED}Error: No file specified${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo -e "${GREEN}Contents of $file:${NC}"
|
||||||
|
remote "cd $REMOTE_DIR && cat $file"
|
||||||
|
}
|
||||||
|
|
||||||
|
tail_file() {
|
||||||
|
local file="$1"
|
||||||
|
local lines="${2:-50}"
|
||||||
|
if [ -z "$file" ]; then
|
||||||
|
echo -e "${RED}Error: No file specified${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo -e "${GREEN}Last $lines lines of $file:${NC}"
|
||||||
|
remote "cd $REMOTE_DIR && tail -n $lines $file"
|
||||||
|
}
|
||||||
|
|
||||||
|
grep_files() {
|
||||||
|
local pattern="$1"
|
||||||
|
local path="${2:-.}"
|
||||||
|
if [ -z "$pattern" ]; then
|
||||||
|
echo -e "${RED}Error: No pattern specified${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo -e "${GREEN}Searching for '$pattern' in $path:${NC}"
|
||||||
|
remote "cd $REMOTE_DIR && grep -r --color=always '$pattern' $path || true"
|
||||||
|
}
|
||||||
|
|
||||||
|
db_query() {
|
||||||
|
local query="$1"
|
||||||
|
if [ -z "$query" ]; then
|
||||||
|
echo -e "${RED}Error: No query specified${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
# Only allow SELECT queries for safety
|
||||||
|
if ! echo "$query" | grep -qi "^select"; then
|
||||||
|
echo -e "${RED}Error: Only SELECT queries are allowed${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo -e "${GREEN}Running query:${NC}"
|
||||||
|
remote "cd $REMOTE_DIR && docker compose exec -T postgres psql -U kb -d kb -c \"$query\""
|
||||||
|
}
|
||||||
|
|
||||||
|
http_get() {
|
||||||
|
local path="$1"
|
||||||
|
local port="${2:-8000}"
|
||||||
|
if [ -z "$path" ]; then
|
||||||
|
echo -e "${RED}Error: No path specified${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
# Ensure path starts with /
|
||||||
|
if [[ "$path" != /* ]]; then
|
||||||
|
path="/$path"
|
||||||
|
fi
|
||||||
|
echo -e "${GREEN}GET http://localhost:${port}${path}${NC}"
|
||||||
|
remote "curl -s -w '\n\nHTTP Status: %{http_code}\n' 'http://localhost:${port}${path}'"
|
||||||
|
}
|
||||||
|
|
||||||
|
system_status() {
|
||||||
|
echo -e "${GREEN}=== System Status ===${NC}"
|
||||||
|
echo ""
|
||||||
|
docker_ps
|
||||||
|
echo ""
|
||||||
|
echo -e "${GREEN}=== Memory ===${NC}"
|
||||||
|
mem_usage
|
||||||
|
echo ""
|
||||||
|
echo -e "${GREEN}=== Disk ===${NC}"
|
||||||
|
remote "df -h /"
|
||||||
|
echo ""
|
||||||
|
echo -e "${GREEN}=== Recent Errors (last 20 lines) ===${NC}"
|
||||||
|
remote "cd $REMOTE_DIR && docker compose logs --tail=100 2>&1 | grep -i -E '(error|exception|failed|fatal)' | tail -20 || echo 'No recent errors found'"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main
|
||||||
|
case "${1:-}" in
|
||||||
|
logs)
|
||||||
|
docker_logs "${2:-}" "${3:-}"
|
||||||
|
;;
|
||||||
|
ps)
|
||||||
|
docker_ps
|
||||||
|
;;
|
||||||
|
disk)
|
||||||
|
disk_usage
|
||||||
|
;;
|
||||||
|
mem)
|
||||||
|
mem_usage
|
||||||
|
;;
|
||||||
|
top)
|
||||||
|
show_top
|
||||||
|
;;
|
||||||
|
ls)
|
||||||
|
list_dir "${2:-}"
|
||||||
|
;;
|
||||||
|
cat)
|
||||||
|
cat_file "${2:-}"
|
||||||
|
;;
|
||||||
|
tail)
|
||||||
|
tail_file "${2:-}" "${3:-}"
|
||||||
|
;;
|
||||||
|
grep)
|
||||||
|
grep_files "${2:-}" "${3:-}"
|
||||||
|
;;
|
||||||
|
db)
|
||||||
|
db_query "${2:-}"
|
||||||
|
;;
|
||||||
|
get)
|
||||||
|
http_get "${2:-}" "${3:-}"
|
||||||
|
;;
|
||||||
|
status)
|
||||||
|
system_status
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
usage
|
||||||
|
;;
|
||||||
|
esac
|
||||||
Loading…
x
Reference in New Issue
Block a user