diff --git a/README.md b/README.md index 56cc8a8..fbb0202 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,27 @@ curl -X POST http://localhost:8000/auth/login \ This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header. +## Discord integration + +If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications). +Once you have the bot's token, run + +```bash +python tools/discord_setup.py generate-invite --bot-token +``` + +to get an url that can be used to connect your Discord bot. + +Next you'll have to set at least the following in your `.env` file: + +``` +DISCORD_BOT_TOKEN= +DISCORD_NOTIFICATIONS_ENABLED=True +``` + +When the worker starts it will automatically attempt to create the appropriate channels. You +can change what they will be called by setting the various `DISCORD_*_CHANNEL` settings. + ## MCP Proxy Setup Since MCP doesn't support basic authentication, use the included proxy for AI assistants that need to connect: diff --git a/src/memory/api/MCP/__init__.py b/src/memory/api/MCP/__init__.py new file mode 100644 index 0000000..5ee8067 --- /dev/null +++ b/src/memory/api/MCP/__init__.py @@ -0,0 +1,2 @@ +import memory.api.MCP.manifest +import memory.api.MCP.memory diff --git a/src/memory/api/MCP/manifest.py b/src/memory/api/MCP/manifest.py new file mode 100644 index 0000000..6dcd76a --- /dev/null +++ b/src/memory/api/MCP/manifest.py @@ -0,0 +1,119 @@ +import asyncio +import aiohttp +from datetime import datetime + +from typing import TypedDict, NotRequired, Literal +from memory.api.MCP.tools import mcp + + +class BinaryProbs(TypedDict): + prob: float + + +class MultiProbs(TypedDict): + answerProbs: dict[str, float] + + +Probs = dict[str, BinaryProbs | MultiProbs] +OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"] + + +class MarketAnswer(TypedDict): + id: str + text: str + resolutionProbability: float + + +class MarketDetails(TypedDict): + id: str + createdTime: int + question: str + outcomeType: OutcomeType + textDescription: str + groupSlugs: list[str] + volume: float + isResolved: bool + answers: list[MarketAnswer] + + +class Market(TypedDict): + id: str + url: str + question: str + volume: int + createdTime: int + outcomeType: OutcomeType + createdAt: NotRequired[str] + description: NotRequired[str] + answers: NotRequired[dict[str, float]] + probability: NotRequired[float] + details: NotRequired[MarketDetails] + + +async def get_details(session: aiohttp.ClientSession, market_id: str): + async with session.get( + f"https://api.manifold.markets/v0/market/{market_id}" + ) as resp: + resp.raise_for_status() + return await resp.json() + + +async def format_market(session: aiohttp.ClientSession, market: Market): + if market.get("outcomeType") != "BINARY": + details = await get_details(session, market["id"]) + market["answers"] = { + answer["text"]: round( + answer.get("resolutionProbability") or answer.get("probability") or 0, 3 + ) + for answer in details["answers"] + } + if creationTime := market.get("createdTime"): + market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat() + + fields = [ + "id", + "name", + "url", + "question", + "volume", + "createdAt", + "details", + "probability", + "answers", + ] + return {k: v for k, v in market.items() if k in fields} + + +async def search_markets(term: str, min_volume: int = 1000, binary: bool = False): + async with aiohttp.ClientSession() as session: + async with session.get( + "https://api.manifold.markets/v0/search-markets", + params={ + "term": term, + "contractType": "BINARY" if binary else "ALL", + }, + ) as resp: + resp.raise_for_status() + markets = await resp.json() + + return await asyncio.gather( + *[ + format_market(session, market) + for market in markets + if market.get("volume", 0) >= min_volume + ] + ) + + +@mcp.tool() +async def get_forecasts( + term: str, min_volume: int = 1000, binary: bool = False +) -> list[dict]: + """Get prediction market forecasts for a given term. + + Args: + term: The term to search for. + min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold. + binary: Whether to only return binary markets. + """ + return await search_markets(term, min_volume, binary) diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/memory.py new file mode 100644 index 0000000..89aaf30 --- /dev/null +++ b/src/memory/api/MCP/memory.py @@ -0,0 +1,443 @@ +""" +MCP tools for the epistemic sparring partner system. +""" + +import logging +import pathlib +from datetime import datetime, timezone +import mimetypes + +from pydantic import BaseModel +from sqlalchemy import Text, func +from sqlalchemy import cast as sql_cast +from sqlalchemy.dialects.postgresql import ARRAY + +from memory.api.search.search import SearchFilters, search +from memory.common import extract, settings +from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS +from memory.common.db.connection import make_session +from memory.common.db.models import AgentObservation, SourceItem +from memory.common.formatters import observation +from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE +from memory.api.MCP.tools import mcp + +logger = logging.getLogger(__name__) + + +def filter_observation_source_ids( + tags: list[str] | None = None, observation_types: list[str] | None = None +): + if not tags and not observation_types: + return None + + with make_session() as session: + items_query = session.query(AgentObservation.id) + + if tags: + # Use PostgreSQL array overlap operator with proper array casting + items_query = items_query.filter( + AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))), + ) + if observation_types: + items_query = items_query.filter( + AgentObservation.observation_type.in_(observation_types) + ) + source_ids = [item.id for item in items_query.all()] + + return source_ids + + +def filter_source_ids( + modalities: set[str], + tags: list[str] | None = None, +): + if not tags: + return None + + with make_session() as session: + items_query = session.query(SourceItem.id) + + if tags: + # Use PostgreSQL array overlap operator with proper array casting + items_query = items_query.filter( + SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))), + ) + if modalities: + items_query = items_query.filter(SourceItem.modality.in_(modalities)) + source_ids = [item.id for item in items_query.all()] + + return source_ids + + +@mcp.tool() +async def get_all_tags() -> list[str]: + """ + Get all unique tags used across the entire knowledge base. + Returns sorted list of tags from both observations and content. + """ + with make_session() as session: + tags_query = session.query(func.unnest(SourceItem.tags)).distinct() + return sorted({row[0] for row in tags_query if row[0] is not None}) + + +@mcp.tool() +async def get_all_subjects() -> list[str]: + """ + Get all unique subjects from observations about the user. + Returns sorted list of subject identifiers used in observations. + """ + with make_session() as session: + return sorted( + r.subject for r in session.query(AgentObservation.subject).distinct() + ) + + +@mcp.tool() +async def get_all_observation_types() -> list[str]: + """ + Get all observation types that have been used. + Standard types are belief, preference, behavior, contradiction, general, but there can be more. + """ + with make_session() as session: + return sorted( + { + r.observation_type + for r in session.query(AgentObservation.observation_type).distinct() + if r.observation_type is not None + } + ) + + +@mcp.tool() +async def search_knowledge_base( + query: str, + previews: bool = False, + modalities: set[str] = set(), + tags: list[str] = [], + limit: int = 10, +) -> list[dict]: + """ + Search user's stored content including emails, documents, articles, books. + Use to find specific information the user has saved or received. + Combine with search_observations for complete user context. + + Args: + query: Natural language search query - be descriptive about what you're looking for + previews: Include actual content in results - when false only a snippet is returned + modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all) + tags: Filter by tags - content must have at least one matching tag + limit: Max results (1-100) + + Returns: List of search results with id, score, chunks, content, filename + Higher scores (>0.7) indicate strong matches. + """ + logger.info(f"MCP search for: {query}") + + if not modalities: + modalities = set(ALL_COLLECTIONS.keys()) + modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS + + upload_data = extract.extract_text(query) + results = await search( + upload_data, + previews=previews, + modalities=modalities, + limit=limit, + min_text_score=0.4, + min_multimodal_score=0.25, + filters=SearchFilters( + tags=tags, + source_ids=filter_source_ids(tags=tags, modalities=modalities), + ), + ) + + return [result.model_dump() for result in results] + + +class RawObservation(BaseModel): + subject: str + content: str + observation_type: str = "general" + confidences: dict[str, float] = {} + evidence: dict | None = None + tags: list[str] = [] + + +@mcp.tool() +async def observe( + observations: list[RawObservation], + session_id: str | None = None, + agent_model: str = "unknown", +) -> dict: + """ + Record observations about the user for long-term understanding. + Use proactively when user expresses preferences, behaviors, beliefs, or contradictions. + Be specific and detailed - observations should make sense months later. + + Example call: + ``` + { + "observations": [ + { + "content": "The user is a software engineer.", + "subject": "user", + "observation_type": "belief", + "confidences": {"observation_accuracy": 0.9}, + "evidence": {"quote": "I am a software engineer.", "context": "I work at Google."}, + "tags": ["programming", "work"] + } + ], + "session_id": "123e4567-e89b-12d3-a456-426614174000", + "agent_model": "gpt-4o" + } + ``` + + RawObservation fields: + content (required): Detailed observation text explaining what you observed + subject (required): Consistent identifier like "programming_style", "work_habits" + observation_type: belief, preference, behavior, contradiction, general + confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9} + evidence: Context dict with extra context, e.g. "quote" (exact words) and "context" (situation) + tags: List of categorization tags for organization + + Args: + observations: List of RawObservation objects + session_id: UUID to group observations from same conversation + agent_model: AI model making observations (for quality tracking) + """ + tasks = [ + ( + observation, + celery_app.send_task( + SYNC_OBSERVATION, + queue=f"{settings.CELERY_QUEUE_PREFIX}-notes", + kwargs={ + "subject": observation.subject, + "content": observation.content, + "observation_type": observation.observation_type, + "confidences": observation.confidences, + "evidence": observation.evidence, + "tags": observation.tags, + "session_id": session_id, + "agent_model": agent_model, + }, + ), + ) + for observation in observations + ] + + def short_content(obs: RawObservation) -> str: + if len(obs.content) > 50: + return obs.content[:47] + "..." + return obs.content + + return { + "task_ids": {short_content(obs): task.id for obs, task in tasks}, + "status": "queued", + } + + +@mcp.tool() +async def search_observations( + query: str, + subject: str = "", + tags: list[str] | None = None, + observation_types: list[str] | None = None, + min_confidences: dict[str, float] = {}, + limit: int = 10, +) -> list[dict]: + """ + Search recorded observations about the user. + Use before responding to understand user preferences, patterns, and past insights. + Search by meaning - the query matches both content and context. + + Args: + query: Natural language search query describing what you're looking for + subject: Filter by exact subject identifier (empty = search all subjects) + tags: Filter by tags (must have at least one matching tag) + observation_types: Filter by: belief, preference, behavior, contradiction, general + min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8} + limit: Max results (1-100) + + Returns: List with content, tags, created_at, metadata + Results sorted by relevance to your query. + """ + semantic_text = observation.generate_semantic_text( + subject=subject or "", + observation_type="".join(observation_types or []), + content=query, + evidence=None, + ) + temporal = observation.generate_temporal_text( + subject=subject or "", + content=query, + created_at=datetime.now(timezone.utc), + ) + results = await search( + [ + extract.DataChunk(data=[query]), + extract.DataChunk(data=[semantic_text]), + extract.DataChunk(data=[temporal]), + ], + previews=True, + modalities={"semantic", "temporal"}, + limit=limit, + filters=SearchFilters( + subject=subject, + min_confidences=min_confidences, + tags=tags, + observation_types=observation_types, + source_ids=filter_observation_source_ids(tags=tags), + ), + timeout=2, + ) + + return [ + { + "content": r.content, + "tags": r.tags, + "created_at": r.created_at.isoformat() if r.created_at else None, + "metadata": r.metadata, + } + for r in results + ] + + +@mcp.tool() +async def create_note( + subject: str, + content: str, + filename: str | None = None, + note_type: str | None = None, + confidences: dict[str, float] = {}, + tags: list[str] = [], +) -> dict: + """ + Create a note when user asks to save or record something. + Use when user explicitly requests noting information for future reference. + + Args: + subject: What the note is about (used for organization) + content: Note content as a markdown string + filename: Optional path relative to notes folder (e.g., "project/ideas.md") + note_type: Optional categorization of the note + confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9} + tags: Organization tags for filtering and discovery + """ + if filename: + path = pathlib.Path(filename) + if not path.is_absolute(): + path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path + filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix() + + try: + task = celery_app.send_task( + SYNC_NOTE, + queue=f"{settings.CELERY_QUEUE_PREFIX}-notes", + kwargs={ + "subject": subject, + "content": content, + "filename": filename, + "note_type": note_type, + "confidences": confidences, + "tags": tags, + }, + ) + except Exception as e: + import traceback + + traceback.print_exc() + logger.error(f"Error creating note: {e}") + raise + + return { + "task_id": task.id, + "status": "queued", + } + + +@mcp.tool() +async def note_files(path: str = "/"): + """ + List note files in the user's note storage. + Use to discover existing notes before reading or to help user navigate their collection. + + Args: + path: Directory path to search (e.g., "/", "/projects", "/meetings") + Use "/" for root, or subdirectories to narrow scope + + Returns: List of file paths relative to notes directory + """ + root = settings.NOTES_STORAGE_DIR / path.lstrip("/") + return [ + f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}" + for f in root.rglob("*.md") + if f.is_file() + ] + + +@mcp.tool() +def fetch_file(filename: str) -> dict: + """ + Read file content with automatic type detection. + Returns dict with content, mime_type, is_text, file_size. + Text content as string, binary as base64. + """ + path = settings.FILE_STORAGE_DIR / filename.lstrip("/") + if not path.exists(): + raise FileNotFoundError(f"File not found: {filename}") + + mime_type, _ = mimetypes.guess_type(str(path)) + mime_type = mime_type or "application/octet-stream" + + text_extensions = { + ".md", + ".txt", + ".py", + ".js", + ".html", + ".css", + ".json", + ".xml", + ".yaml", + ".yml", + ".toml", + ".ini", + ".cfg", + ".conf", + } + text_mimes = { + "application/json", + "application/xml", + "application/javascript", + "application/x-yaml", + "application/yaml", + } + is_text = ( + mime_type.startswith("text/") + or mime_type in text_mimes + or path.suffix.lower() in text_extensions + ) + + try: + content = ( + path.read_text(encoding="utf-8") + if is_text + else __import__("base64").b64encode(path.read_bytes()).decode("ascii") + ) + except UnicodeDecodeError: + import base64 + + content = base64.b64encode(path.read_bytes()).decode("ascii") + is_text = False + mime_type = ( + "application/octet-stream" if mime_type.startswith("text/") else mime_type + ) + + return { + "content": content, + "mime_type": mime_type, + "is_text": is_text, + "file_size": path.stat().st_size, + "filename": filename, + } diff --git a/src/memory/api/MCP/tools.py b/src/memory/api/MCP/tools.py index 58d1abd..a0409b8 100644 --- a/src/memory/api/MCP/tools.py +++ b/src/memory/api/MCP/tools.py @@ -3,23 +3,15 @@ MCP tools for the epistemic sparring partner system. """ import logging -import pathlib from datetime import datetime, timezone -import mimetypes from mcp.server.fastmcp import FastMCP -from pydantic import BaseModel -from sqlalchemy import Text, func +from sqlalchemy import Text from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY -from memory.api.search.search import SearchFilters, search -from memory.common import extract, settings -from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS from memory.common.db.connection import make_session from memory.common.db.models import AgentObservation, SourceItem -from memory.common.formatters import observation -from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE logger = logging.getLogger(__name__) @@ -76,377 +68,3 @@ def filter_source_ids( async def get_current_time() -> dict: """Get the current time in UTC.""" return {"current_time": datetime.now(timezone.utc).isoformat()} - - -@mcp.tool() -async def get_all_tags() -> list[str]: - """ - Get all unique tags used across the entire knowledge base. - Returns sorted list of tags from both observations and content. - """ - with make_session() as session: - tags_query = session.query(func.unnest(SourceItem.tags)).distinct() - return sorted({row[0] for row in tags_query if row[0] is not None}) - - -@mcp.tool() -async def get_all_subjects() -> list[str]: - """ - Get all unique subjects from observations about the user. - Returns sorted list of subject identifiers used in observations. - """ - with make_session() as session: - return sorted( - r.subject for r in session.query(AgentObservation.subject).distinct() - ) - - -@mcp.tool() -async def get_all_observation_types() -> list[str]: - """ - Get all observation types that have been used. - Standard types are belief, preference, behavior, contradiction, general, but there can be more. - """ - with make_session() as session: - return sorted( - { - r.observation_type - for r in session.query(AgentObservation.observation_type).distinct() - if r.observation_type is not None - } - ) - - -@mcp.tool() -async def search_knowledge_base( - query: str, - previews: bool = False, - modalities: set[str] = set(), - tags: list[str] = [], - limit: int = 10, -) -> list[dict]: - """ - Search user's stored content including emails, documents, articles, books. - Use to find specific information the user has saved or received. - Combine with search_observations for complete user context. - - Args: - query: Natural language search query - be descriptive about what you're looking for - previews: Include actual content in results - when false only a snippet is returned - modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all) - tags: Filter by tags - content must have at least one matching tag - limit: Max results (1-100) - - Returns: List of search results with id, score, chunks, content, filename - Higher scores (>0.7) indicate strong matches. - """ - logger.info(f"MCP search for: {query}") - - if not modalities: - modalities = set(ALL_COLLECTIONS.keys()) - modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS - - upload_data = extract.extract_text(query) - results = await search( - upload_data, - previews=previews, - modalities=modalities, - limit=limit, - min_text_score=0.4, - min_multimodal_score=0.25, - filters=SearchFilters( - tags=tags, - source_ids=filter_source_ids(tags=tags, modalities=modalities), - ), - ) - - return [result.model_dump() for result in results] - - -class RawObservation(BaseModel): - subject: str - content: str - observation_type: str = "general" - confidences: dict[str, float] = {} - evidence: dict | None = None - tags: list[str] = [] - - -@mcp.tool() -async def observe( - observations: list[RawObservation], - session_id: str | None = None, - agent_model: str = "unknown", -) -> dict: - """ - Record observations about the user for long-term understanding. - Use proactively when user expresses preferences, behaviors, beliefs, or contradictions. - Be specific and detailed - observations should make sense months later. - - Example call: - ``` - { - "observations": [ - { - "content": "The user is a software engineer.", - "subject": "user", - "observation_type": "belief", - "confidences": {"observation_accuracy": 0.9}, - "evidence": {"quote": "I am a software engineer.", "context": "I work at Google."}, - "tags": ["programming", "work"] - } - ], - "session_id": "123e4567-e89b-12d3-a456-426614174000", - "agent_model": "gpt-4o" - } - ``` - - RawObservation fields: - content (required): Detailed observation text explaining what you observed - subject (required): Consistent identifier like "programming_style", "work_habits" - observation_type: belief, preference, behavior, contradiction, general - confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9} - evidence: Context dict with extra context, e.g. "quote" (exact words) and "context" (situation) - tags: List of categorization tags for organization - - Args: - observations: List of RawObservation objects - session_id: UUID to group observations from same conversation - agent_model: AI model making observations (for quality tracking) - """ - tasks = [ - ( - observation, - celery_app.send_task( - SYNC_OBSERVATION, - queue=f"{settings.CELERY_QUEUE_PREFIX}-notes", - kwargs={ - "subject": observation.subject, - "content": observation.content, - "observation_type": observation.observation_type, - "confidences": observation.confidences, - "evidence": observation.evidence, - "tags": observation.tags, - "session_id": session_id, - "agent_model": agent_model, - }, - ), - ) - for observation in observations - ] - - def short_content(obs: RawObservation) -> str: - if len(obs.content) > 50: - return obs.content[:47] + "..." - return obs.content - - return { - "task_ids": {short_content(obs): task.id for obs, task in tasks}, - "status": "queued", - } - - -@mcp.tool() -async def search_observations( - query: str, - subject: str = "", - tags: list[str] | None = None, - observation_types: list[str] | None = None, - min_confidences: dict[str, float] = {}, - limit: int = 10, -) -> list[dict]: - """ - Search recorded observations about the user. - Use before responding to understand user preferences, patterns, and past insights. - Search by meaning - the query matches both content and context. - - Args: - query: Natural language search query describing what you're looking for - subject: Filter by exact subject identifier (empty = search all subjects) - tags: Filter by tags (must have at least one matching tag) - observation_types: Filter by: belief, preference, behavior, contradiction, general - min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8} - limit: Max results (1-100) - - Returns: List with content, tags, created_at, metadata - Results sorted by relevance to your query. - """ - semantic_text = observation.generate_semantic_text( - subject=subject or "", - observation_type="".join(observation_types or []), - content=query, - evidence=None, - ) - temporal = observation.generate_temporal_text( - subject=subject or "", - content=query, - created_at=datetime.now(timezone.utc), - ) - results = await search( - [ - extract.DataChunk(data=[query]), - extract.DataChunk(data=[semantic_text]), - extract.DataChunk(data=[temporal]), - ], - previews=True, - modalities={"semantic", "temporal"}, - limit=limit, - filters=SearchFilters( - subject=subject, - min_confidences=min_confidences, - tags=tags, - observation_types=observation_types, - source_ids=filter_observation_source_ids(tags=tags), - ), - timeout=2, - ) - - return [ - { - "content": r.content, - "tags": r.tags, - "created_at": r.created_at.isoformat() if r.created_at else None, - "metadata": r.metadata, - } - for r in results - ] - - -@mcp.tool() -async def create_note( - subject: str, - content: str, - filename: str | None = None, - note_type: str | None = None, - confidences: dict[str, float] = {}, - tags: list[str] = [], -) -> dict: - """ - Create a note when user asks to save or record something. - Use when user explicitly requests noting information for future reference. - - Args: - subject: What the note is about (used for organization) - content: Note content as a markdown string - filename: Optional path relative to notes folder (e.g., "project/ideas.md") - note_type: Optional categorization of the note - confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9} - tags: Organization tags for filtering and discovery - """ - if filename: - path = pathlib.Path(filename) - if not path.is_absolute(): - path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path - filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix() - - try: - task = celery_app.send_task( - SYNC_NOTE, - queue=f"{settings.CELERY_QUEUE_PREFIX}-notes", - kwargs={ - "subject": subject, - "content": content, - "filename": filename, - "note_type": note_type, - "confidences": confidences, - "tags": tags, - }, - ) - except Exception as e: - import traceback - - traceback.print_exc() - logger.error(f"Error creating note: {e}") - raise - - return { - "task_id": task.id, - "status": "queued", - } - - -@mcp.tool() -async def note_files(path: str = "/"): - """ - List note files in the user's note storage. - Use to discover existing notes before reading or to help user navigate their collection. - - Args: - path: Directory path to search (e.g., "/", "/projects", "/meetings") - Use "/" for root, or subdirectories to narrow scope - - Returns: List of file paths relative to notes directory - """ - root = settings.NOTES_STORAGE_DIR / path.lstrip("/") - return [ - f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}" - for f in root.rglob("*.md") - if f.is_file() - ] - - -@mcp.tool() -def fetch_file(filename: str) -> dict: - """ - Read file content with automatic type detection. - Returns dict with content, mime_type, is_text, file_size. - Text content as string, binary as base64. - """ - path = settings.FILE_STORAGE_DIR / filename.lstrip("/") - if not path.exists(): - raise FileNotFoundError(f"File not found: {filename}") - - mime_type, _ = mimetypes.guess_type(str(path)) - mime_type = mime_type or "application/octet-stream" - - text_extensions = { - ".md", - ".txt", - ".py", - ".js", - ".html", - ".css", - ".json", - ".xml", - ".yaml", - ".yml", - ".toml", - ".ini", - ".cfg", - ".conf", - } - text_mimes = { - "application/json", - "application/xml", - "application/javascript", - "application/x-yaml", - "application/yaml", - } - is_text = ( - mime_type.startswith("text/") - or mime_type in text_mimes - or path.suffix.lower() in text_extensions - ) - - try: - content = ( - path.read_text(encoding="utf-8") - if is_text - else __import__("base64").b64encode(path.read_bytes()).decode("ascii") - ) - except UnicodeDecodeError: - import base64 - - content = base64.b64encode(path.read_bytes()).decode("ascii") - is_text = False - mime_type = ( - "application/octet-stream" if mime_type.startswith("text/") else mime_type - ) - - return { - "content": content, - "mime_type": mime_type, - "is_text": is_text, - "file_size": path.stat().st_size, - "filename": filename, - } diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py index 4ba2dbb..2b92e17 100644 --- a/src/memory/api/auth.py +++ b/src/memory/api/auth.py @@ -195,6 +195,9 @@ class AuthenticationMiddleware(BaseHTTPMiddleware): } async def dispatch(self, request: Request, call_next): + if settings.DISABLE_AUTH: + return await call_next(request) + path = request.url.path # Skip authentication for whitelisted endpoints diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index 9a36a8a..275523a 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -83,5 +83,7 @@ app.conf.update( @app.on_after_configure.connect # type: ignore[attr-defined] def ensure_qdrant_initialised(sender, **_): from memory.common import qdrant + from memory.common.discord import load_servers qdrant.setup_qdrant() + load_servers() diff --git a/src/memory/common/discord.py b/src/memory/common/discord.py new file mode 100644 index 0000000..53b323a --- /dev/null +++ b/src/memory/common/discord.py @@ -0,0 +1,184 @@ +import logging +import requests +from typing import Any, Dict, List + +from memory.common import settings + +logger = logging.getLogger(__name__) + +ERROR_CHANNEL = "memory-errors" +ACTIVITY_CHANNEL = "memory-activity" +DISCOVERY_CHANNEL = "memory-discoveries" +CHAT_CHANNEL = "memory-chat" + + +class DiscordServer(requests.Session): + def __init__(self, server_id: str, server_name: str, *args, **kwargs): + self.server_id = server_id + self.server_name = server_name + self.channels = {} + super().__init__(*args, **kwargs) + self.setup_channels() + + def setup_channels(self): + resp = self.get(self.channels_url) + resp.raise_for_status() + channels = {channel["name"]: channel["id"] for channel in resp.json()} + + if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)): + error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL) + self.channels[ERROR_CHANNEL] = error_channel + + if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)): + activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL) + self.channels[ACTIVITY_CHANNEL] = activity_channel + + if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)): + discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL) + self.channels[DISCOVERY_CHANNEL] = discovery_channel + + if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)): + chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL) + self.channels[CHAT_CHANNEL] = chat_channel + + @property + def error_channel(self) -> str: + return self.channels[ERROR_CHANNEL] + + @property + def activity_channel(self) -> str: + return self.channels[ACTIVITY_CHANNEL] + + @property + def discovery_channel(self) -> str: + return self.channels[DISCOVERY_CHANNEL] + + @property + def chat_channel(self) -> str: + return self.channels[CHAT_CHANNEL] + + def channel_id(self, channel_name: str) -> str: + if not (channel_id := self.channels.get(channel_name)): + raise ValueError(f"Channel {channel_name} not found") + return channel_id + + def send_message(self, channel_id: str, content: str): + self.post( + f"https://discord.com/api/v10/channels/{channel_id}/messages", + json={"content": content}, + ) + + def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None: + resp = self.post( + self.channels_url, json={"name": channel_name, "type": channel_type} + ) + resp.raise_for_status() + return resp.json()["id"] + + def __str__(self): + return ( + f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})" + ) + + def request(self, method: str, url: str, **kwargs): + headers = kwargs.get("headers", {}) + headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}" + headers["Content-Type"] = "application/json" + kwargs["headers"] = headers + return super().request(method, url, **kwargs) + + @property + def channels_url(self) -> str: + return f"https://discord.com/api/v10/guilds/{self.server_id}/channels" + + +def get_bot_servers() -> List[Dict]: + """Get list of servers the bot is in.""" + if not settings.DISCORD_BOT_TOKEN: + return [] + + try: + headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"} + response = requests.get( + "https://discord.com/api/v10/users/@me/guilds", headers=headers + ) + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Failed to get bot servers: {e}") + return [] + + +servers: dict[str, DiscordServer] = {} + + +def load_servers(): + for server in get_bot_servers(): + servers[server["id"]] = DiscordServer(server["id"], server["name"]) + + +def broadcast_message(channel: str, message: str): + if not settings.DISCORD_NOTIFICATIONS_ENABLED: + return + + for server in servers.values(): + server.send_message(server.channel_id(channel), message) + + +def send_error_message(message: str): + broadcast_message(ERROR_CHANNEL, message) + + +def send_activity_message(message: str): + broadcast_message(ACTIVITY_CHANNEL, message) + + +def send_discovery_message(message: str): + broadcast_message(DISCOVERY_CHANNEL, message) + + +def send_chat_message(message: str): + broadcast_message(CHAT_CHANNEL, message) + + +def notify_task_failure( + task_name: str, + error_message: str, + task_args: tuple = (), + task_kwargs: dict[str, Any] | None = None, + traceback_str: str | None = None, +) -> None: + """ + Send a task failure notification to Discord. + + Args: + task_name: Name of the failed task + error_message: Error message + task_args: Task arguments + task_kwargs: Task keyword arguments + traceback_str: Full traceback string + + Returns: + True if notification sent successfully + """ + if not settings.DISCORD_NOTIFICATIONS_ENABLED: + logger.debug("Discord notifications disabled") + return + + message = f"🚨 **Task Failed: {task_name}**\n\n" + message += f"**Error:** {error_message[:500]}\n" + + if task_args: + message += f"**Args:** `{str(task_args)[:200]}`\n" + + if task_kwargs: + message += f"**Kwargs:** `{str(task_kwargs)[:200]}`\n" + + if traceback_str: + message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```" + + try: + send_error_message(message) + logger.info(f"Discord error notification sent for task: {task_name}") + except Exception as e: + logger.error(f"Failed to send Discord notification: {e}") diff --git a/src/memory/common/formatters/observation.py b/src/memory/common/formatters/observation.py index a644df5..3037344 100644 --- a/src/memory/common/formatters/observation.py +++ b/src/memory/common/formatters/observation.py @@ -55,7 +55,7 @@ def generate_temporal_text( f"Observation: {content}", ] - return " | ".join(parts) + return " | ".join(parts).strip() # TODO: Add more embedding dimensions here: diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 8a84a9d..633aa88 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -138,3 +138,17 @@ SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 * SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30)) REGISTER_ENABLED = boolean_env("REGISTER_ENABLED", False) or True +DISABLE_AUTH = boolean_env("DISABLE_AUTH", False) + +# Discord notification settings +DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN", "") +DISCORD_ERROR_CHANNEL = os.getenv("DISCORD_ERROR_CHANNEL", "memory-errors") +DISCORD_ACTIVITY_CHANNEL = os.getenv("DISCORD_ACTIVITY_CHANNEL", "memory-activity") +DISCORD_DISCOVERY_CHANNEL = os.getenv("DISCORD_DISCOVERY_CHANNEL", "memory-discoveries") +DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat") + + +# Enable Discord notifications if bot token is set +DISCORD_NOTIFICATIONS_ENABLED = ( + boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN +) diff --git a/src/memory/common/summarizer.py b/src/memory/common/summarizer.py index 85d5454..fe0b51a 100644 --- a/src/memory/common/summarizer.py +++ b/src/memory/common/summarizer.py @@ -47,9 +47,20 @@ Text to summarize: def parse_response(response: str) -> dict[str, Any]: """Parse the response from the summarizer.""" - soup = BeautifulSoup(response, "xml") - summary = soup.find("summary").text - tags = [tag.text for tag in soup.find_all("tag")] + if not response or not response.strip(): + return {"summary": "", "tags": []} + + # Use html.parser instead of xml parser for better compatibility + soup = BeautifulSoup(response, "html.parser") + + # Safely extract summary + summary_element = soup.find("summary") + summary = summary_element.text if summary_element else "" + + # Safely extract tags + tag_elements = soup.find_all("tag") + tags = [tag.text for tag in tag_elements if tag.text is not None] + return {"summary": summary, "tags": tags} @@ -68,7 +79,6 @@ def _call_openai(prompt: str) -> dict[str, Any]: }, {"role": "user", "content": prompt}, ], - response_format={"type": "json_object"}, temperature=0.3, max_tokens=2048, ) diff --git a/src/memory/workers/tasks/content_processing.py b/src/memory/workers/tasks/content_processing.py index 6378d0e..e781983 100644 --- a/src/memory/workers/tasks/content_processing.py +++ b/src/memory/workers/tasks/content_processing.py @@ -12,8 +12,9 @@ import traceback import logging from typing import Any, Callable, Iterable, Sequence, cast -from memory.common import embedding, qdrant +from memory.common import embedding, qdrant, settings from memory.common.db.models import SourceItem, Chunk +from memory.common.discord import notify_task_failure logger = logging.getLogger(__name__) @@ -274,7 +275,17 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]: return func(*args, **kwargs) except Exception as e: logger.error(f"Task {func.__name__} failed: {e}") - logger.error(traceback.format_exc()) + traceback_str = traceback.format_exc() + logger.error(traceback_str) + + notify_task_failure( + task_name=func.__name__, + error_message=str(e), + task_args=args, + task_kwargs=kwargs, + traceback_str=traceback_str, + ) + return {"status": "error", "error": str(e)} return wrapper diff --git a/tests/conftest.py b/tests/conftest.py index d8a0059..92166cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -213,12 +213,15 @@ def mock_file_storage(tmp_path: Path): email_storage_dir.mkdir(parents=True, exist_ok=True) notes_storage_dir = tmp_path / "notes" notes_storage_dir.mkdir(parents=True, exist_ok=True) + comic_storage_dir = tmp_path / "comics" + comic_storage_dir.mkdir(parents=True, exist_ok=True) with ( patch.object(settings, "FILE_STORAGE_DIR", tmp_path), patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir), patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir), patch.object(settings, "EMAIL_STORAGE_DIR", email_storage_dir), patch.object(settings, "NOTES_STORAGE_DIR", notes_storage_dir), + patch.object(settings, "COMIC_STORAGE_DIR", comic_storage_dir), ): yield @@ -256,7 +259,7 @@ def mock_openai_client(): choices=[ Mock( message=Mock( - content='{"summary": "test summary", "tags": ["tag1", "tag2"]}' + content="test summarytag1tag2" ) ) ] @@ -273,8 +276,16 @@ def mock_anthropic_client(): client.messages.create = Mock( return_value=Mock( content=[ - Mock(text='{"summary": "test summary", "tags": ["tag1", "tag2"]}') + Mock( + text="test summarytag1tag2" + ) ] ) ) yield client + + +@pytest.fixture(autouse=True) +def mock_discord_client(): + with patch.object(settings, "DISCORD_NOTIFICATIONS_ENABLED", False): + yield diff --git a/tests/memory/common/test_discord.py b/tests/memory/common/test_discord.py new file mode 100644 index 0000000..af2c9c9 --- /dev/null +++ b/tests/memory/common/test_discord.py @@ -0,0 +1,393 @@ +import logging +import pytest +from unittest.mock import Mock, patch +import requests +import json + +from memory.common import discord, settings + + +@pytest.fixture +def mock_session_request(): + with patch("requests.Session.request") as mock: + yield mock + + +@pytest.fixture +def mock_get_channels_response(): + return [ + {"name": "memory-errors", "id": "error_channel_id"}, + {"name": "memory-activity", "id": "activity_channel_id"}, + {"name": "memory-discoveries", "id": "discovery_channel_id"}, + {"name": "memory-chat", "id": "chat_channel_id"}, + ] + + +def test_discord_server_init(mock_session_request, mock_get_channels_response): + # Mock the channels API call + mock_response = Mock() + mock_response.json.return_value = mock_get_channels_response + mock_response.raise_for_status.return_value = None + mock_session_request.return_value = mock_response + + server = discord.DiscordServer("server123", "Test Server") + + assert server.server_id == "server123" + assert server.server_name == "Test Server" + assert hasattr(server, "channels") + + +@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors") +@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity") +@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries") +@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat") +def test_setup_channels_existing(mock_session_request, mock_get_channels_response): + # Mock the channels API call + mock_response = Mock() + mock_response.json.return_value = mock_get_channels_response + mock_response.raise_for_status.return_value = None + mock_session_request.return_value = mock_response + + server = discord.DiscordServer("server123", "Test Server") + + assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id" + assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id" + assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id" + assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id" + + +@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel") +def test_setup_channels_create_missing(mock_session_request): + # Mock get channels (empty) and create channel calls + get_response = Mock() + get_response.json.return_value = [] + get_response.raise_for_status.return_value = None + + create_response = Mock() + create_response.json.return_value = {"id": "new_channel_id"} + create_response.raise_for_status.return_value = None + + mock_session_request.side_effect = [ + get_response, + create_response, + create_response, + create_response, + create_response, + ] + + server = discord.DiscordServer("server123", "Test Server") + + assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id" + + +def test_channel_properties(): + server = discord.DiscordServer.__new__(discord.DiscordServer) + server.channels = { + discord.ERROR_CHANNEL: "error_id", + discord.ACTIVITY_CHANNEL: "activity_id", + discord.DISCOVERY_CHANNEL: "discovery_id", + discord.CHAT_CHANNEL: "chat_id", + } + + assert server.error_channel == "error_id" + assert server.activity_channel == "activity_id" + assert server.discovery_channel == "discovery_id" + assert server.chat_channel == "chat_id" + + +def test_channel_id_exists(): + server = discord.DiscordServer.__new__(discord.DiscordServer) + server.channels = {"test-channel": "channel123"} + + assert server.channel_id("test-channel") == "channel123" + + +def test_channel_id_not_found(): + server = discord.DiscordServer.__new__(discord.DiscordServer) + server.channels = {} + + with pytest.raises(ValueError, match="Channel nonexistent not found"): + server.channel_id("nonexistent") + + +def test_send_message(mock_session_request): + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_session_request.return_value = mock_response + + server = discord.DiscordServer.__new__(discord.DiscordServer) + + server.send_message("channel123", "Hello World") + + mock_session_request.assert_called_with( + "POST", + "https://discord.com/api/v10/channels/channel123/messages", + data=None, + json={"content": "Hello World"}, + headers={ + "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}", + "Content-Type": "application/json", + }, + ) + + +def test_create_channel(mock_session_request): + mock_response = Mock() + mock_response.json.return_value = {"id": "new_channel_id"} + mock_response.raise_for_status.return_value = None + mock_session_request.return_value = mock_response + + server = discord.DiscordServer.__new__(discord.DiscordServer) + server.server_id = "server123" + + channel_id = server.create_channel("new-channel") + + assert channel_id == "new_channel_id" + mock_session_request.assert_called_with( + "POST", + "https://discord.com/api/v10/guilds/server123/channels", + data=None, + json={"name": "new-channel", "type": 0}, + headers={ + "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}", + "Content-Type": "application/json", + }, + ) + + +def test_create_channel_custom_type(mock_session_request): + mock_response = Mock() + mock_response.json.return_value = {"id": "voice_channel_id"} + mock_response.raise_for_status.return_value = None + mock_session_request.return_value = mock_response + + server = discord.DiscordServer.__new__(discord.DiscordServer) + server.server_id = "server123" + + channel_id = server.create_channel("voice-channel", channel_type=2) + + assert channel_id == "voice_channel_id" + mock_session_request.assert_called_with( + "POST", + "https://discord.com/api/v10/guilds/server123/channels", + data=None, + json={"name": "voice-channel", "type": 2}, + headers={ + "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}", + "Content-Type": "application/json", + }, + ) + + +def test_str_representation(): + server = discord.DiscordServer.__new__(discord.DiscordServer) + server.server_id = "server123" + server.server_name = "Test Server" + + assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)" + + +@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123") +def test_request_adds_headers(mock_session_request): + server = discord.DiscordServer.__new__(discord.DiscordServer) + + server.request("GET", "https://example.com", headers={"Custom": "header"}) + + expected_headers = { + "Custom": "header", + "Authorization": "Bot test_token_123", + "Content-Type": "application/json", + } + mock_session_request.assert_called_once_with( + "GET", "https://example.com", headers=expected_headers + ) + + +def test_channels_url(): + server = discord.DiscordServer.__new__(discord.DiscordServer) + server.server_id = "server123" + + assert ( + server.channels_url == "https://discord.com/api/v10/guilds/server123/channels" + ) + + +@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token") +@patch("requests.get") +def test_get_bot_servers_success(mock_get): + mock_response = Mock() + mock_response.json.return_value = [ + {"id": "server1", "name": "Server 1"}, + {"id": "server2", "name": "Server 2"}, + ] + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + servers = discord.get_bot_servers() + + assert len(servers) == 2 + assert servers[0] == {"id": "server1", "name": "Server 1"} + mock_get.assert_called_once_with( + "https://discord.com/api/v10/users/@me/guilds", + headers={"Authorization": "Bot test_token"}, + ) + + +@patch("memory.common.settings.DISCORD_BOT_TOKEN", None) +def test_get_bot_servers_no_token(): + assert discord.get_bot_servers() == [] + + +@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token") +@patch("requests.get") +def test_get_bot_servers_exception(mock_get): + mock_get.side_effect = requests.RequestException("API Error") + + servers = discord.get_bot_servers() + + assert servers == [] + + +@patch("memory.common.discord.get_bot_servers") +@patch("memory.common.discord.DiscordServer") +def test_load_servers(mock_discord_server_class, mock_get_servers): + mock_get_servers.return_value = [ + {"id": "server1", "name": "Server 1"}, + {"id": "server2", "name": "Server 2"}, + ] + + discord.load_servers() + + assert mock_discord_server_class.call_count == 2 + mock_discord_server_class.assert_any_call("server1", "Server 1") + mock_discord_server_class.assert_any_call("server2", "Server 2") + + +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +def test_broadcast_message(): + mock_server1 = Mock() + mock_server2 = Mock() + discord.servers = {"1": mock_server1, "2": mock_server2} + + discord.broadcast_message("test-channel", "Hello") + + mock_server1.send_message.assert_called_once_with( + mock_server1.channel_id.return_value, "Hello" + ) + mock_server2.send_message.assert_called_once_with( + mock_server2.channel_id.return_value, "Hello" + ) + + +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) +def test_broadcast_message_disabled(): + mock_server = Mock() + discord.servers = {"1": mock_server} + + discord.broadcast_message("test-channel", "Hello") + + mock_server.send_message.assert_not_called() + + +@patch("memory.common.discord.broadcast_message") +def test_send_error_message(mock_broadcast): + discord.send_error_message("Error occurred") + mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred") + + +@patch("memory.common.discord.broadcast_message") +def test_send_activity_message(mock_broadcast): + discord.send_activity_message("Activity update") + mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update") + + +@patch("memory.common.discord.broadcast_message") +def test_send_discovery_message(mock_broadcast): + discord.send_discovery_message("Discovery made") + mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made") + + +@patch("memory.common.discord.broadcast_message") +def test_send_chat_message(mock_broadcast): + discord.send_chat_message("Chat message") + mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message") + + +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +@patch("memory.common.discord.send_error_message") +def test_notify_task_failure_basic(mock_send_error): + discord.notify_task_failure("test_task", "Something went wrong") + + mock_send_error.assert_called_once() + message = mock_send_error.call_args[0][0] + + assert "🚨 **Task Failed: test_task**" in message + assert "**Error:** Something went wrong" in message + + +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +@patch("memory.common.discord.send_error_message") +def test_notify_task_failure_with_args(mock_send_error): + discord.notify_task_failure( + "test_task", + "Error message", + task_args=("arg1", "arg2"), + task_kwargs={"key": "value"}, + ) + + message = mock_send_error.call_args[0][0] + + assert "**Args:** `('arg1', 'arg2')`" in message + assert "**Kwargs:** `{'key': 'value'}`" in message + + +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +@patch("memory.common.discord.send_error_message") +def test_notify_task_failure_with_traceback(mock_send_error): + traceback = "Traceback (most recent call last):\n File ...\nError: Something" + + discord.notify_task_failure("test_task", "Error message", traceback_str=traceback) + + message = mock_send_error.call_args[0][0] + assert "**Traceback:**" in message + assert traceback in message + + +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +@patch("memory.common.discord.send_error_message") +def test_notify_task_failure_truncates_long_error(mock_send_error): + long_error = "x" * 600 # Longer than 500 char limit + + discord.notify_task_failure("test_task", long_error) + + message = mock_send_error.call_args[0][0] + assert long_error[:500] in message + + +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +@patch("memory.common.discord.send_error_message") +def test_notify_task_failure_truncates_long_traceback(mock_send_error): + long_traceback = "x" * 1000 # Longer than 800 char limit + + discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback) + + message = mock_send_error.call_args[0][0] + assert long_traceback[-800:] in message + + +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) +@patch("memory.common.discord.send_error_message") +def test_notify_task_failure_disabled(mock_send_error): + discord.notify_task_failure("test_task", "Error message") + mock_send_error.assert_not_called() + + +@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) +@patch("memory.common.discord.send_error_message") +def test_notify_task_failure_send_fails(mock_send_error): + mock_send_error.side_effect = Exception("Discord API error") + + # Should not raise, just log the error + discord.notify_task_failure("test_task", "Error message") + + mock_send_error.assert_called_once() diff --git a/tests/memory/common/test_summarizer.py b/tests/memory/common/test_summarizer.py new file mode 100644 index 0000000..2f4629d --- /dev/null +++ b/tests/memory/common/test_summarizer.py @@ -0,0 +1,151 @@ +from memory.common import summarizer +import pytest + + +@pytest.mark.parametrize( + "response, expected", + ( + # Basic valid cases + ("", {"summary": "", "tags": []}), + ( + "testtag1tag2", + {"summary": "test", "tags": ["tag1", "tag2"]}, + ), + ( + "testtag1tag2", + {"summary": "test", "tags": ["tag1", "tag2"]}, + ), + # Missing summary tag + ( + "tag1tag2", + {"summary": "", "tags": ["tag1", "tag2"]}, + ), + # Missing tags section + ( + "test summary", + {"summary": "test summary", "tags": []}, + ), + # Empty summary tag + ( + "tag1", + {"summary": "", "tags": ["tag1"]}, + ), + # Empty tags section + ( + "test", + {"summary": "test", "tags": []}, + ), + # Single tag + ( + "testsingle-tag", + {"summary": "test", "tags": ["single-tag"]}, + ), + # Multiple tags + ( + "testtag1tag2tag3tag4tag5", + {"summary": "test", "tags": ["tag1", "tag2", "tag3", "tag4", "tag5"]}, + ), + # Tags with special characters and hyphens + ( + "testmachine-learningai/mldata_science", + {"summary": "test", "tags": ["machine-learning", "ai/ml", "data_science"]}, + ), + # Summary with special characters + ( + "Test with & special characters <>test", + {"summary": "Test with & special characters <>", "tags": ["test"]}, + ), + # Whitespace handling + ( + " test with spaces tag1 tag2", + {"summary": " test with spaces ", "tags": [" tag1 ", "tag2"]}, + ), + # Mixed case XML tags (should still work with BeautifulSoup) + ( + "testtag1", + {"summary": "test", "tags": ["tag1"]}, + ), + # Multiple summary tags (should take first one) + ( + "firstsecondtag1", + {"summary": "first", "tags": ["tag1"]}, + ), + # Multiple tags sections (should collect all tags) + ( + "testtag1tag2", + {"summary": "test", "tags": ["tag1", "tag2"]}, + ), + # XML with extra elements (should ignore them) + ( + "testignoredtag1", + {"summary": "test", "tags": ["tag1"]}, + ), + # XML with attributes (should still work) + ( + 'testtag1', + {"summary": "test", "tags": ["tag1"]}, + ), + # Empty tag elements + ( + "testvalid-tag", + {"summary": "test", "tags": ["", "valid-tag", ""]}, + ), + # Self-closing tags + ( + "testvalid-tag", + {"summary": "test", "tags": ["", "valid-tag"]}, + ), + # Long content + ( + f"{'a' * 1000}long-content", + {"summary": "a" * 1000, "tags": ["long-content"]}, + ), + # XML with newlines and formatting + ( + """ + Multi-line + summary content + + + formatted + xml + """, + { + "summary": "\n Multi-line\n summary content\n ", + "tags": ["formatted", "xml"], + }, + ), + # Malformed XML (missing closing tags) - BeautifulSoup parses as best it can + ( + "testtag1", + {"summary": "testtag1", "tags": ["tag1"]}, + ), + # Invalid XML characters should be handled by BeautifulSoup + ( + "test & unescapedtag1", + {"summary": "test & unescaped", "tags": ["tag1"]}, + ), + # Only whitespace + ( + " \n\t ", + {"summary": "", "tags": []}, + ), + # Non-XML content + ( + "This is just plain text without XML", + {"summary": "", "tags": []}, + ), + # XML comments (should be ignored) + ( + "testtag1", + {"summary": "test", "tags": ["tag1"]}, + ), + # CDATA sections + ( + " characters]]>cdata", + {"summary": "test with characters", "tags": ["cdata"]}, + ), + ), +) +def test_parse_response(response, expected): + assert summarizer.parse_response(response) == expected diff --git a/tools/discord_setup.py b/tools/discord_setup.py new file mode 100644 index 0000000..d7ab5b6 --- /dev/null +++ b/tools/discord_setup.py @@ -0,0 +1,57 @@ +import argparse +import click +import requests + + +@click.command() +@click.option("--bot-token", type=str, required=True) +def generate_bot_invite_url(bot_token: str): + """ + Generate the Discord bot invitation URL. + + Returns: + URL that user can click to add bot to their server + """ + # Get bot's client ID from the token (it's the first part before the first dot) + # But safer to get it from the API + try: + headers = {"Authorization": f"Bot {bot_token}"} + response = requests.get( + "https://discord.com/api/v10/users/@me", headers=headers + ) + response.raise_for_status() + bot_info = response.json() + client_id = bot_info["id"] + except Exception as e: + raise ValueError(f"Could not get bot info: {e}") + + # Permissions needed: Send Messages (2048) + Manage Channels (16) + View Channels (1024) + permissions = 2048 + 16 + 1024 # = 3088 + + invite_url = f"https://discord.com/oauth2/authorize?client_id={client_id}&scope=bot&permissions={permissions}" + click.echo(f"Bot invite URL: {invite_url}") + return invite_url + + +@click.command() +def create_channels(): + """Create Discord channels using the configured servers.""" + from memory.common.discord import load_servers + + click.echo("Loading Discord servers and creating channels...") + load_servers() + click.echo("Discord channels setup completed.") + + +@click.group() +def cli(): + """Discord setup utilities.""" + pass + + +cli.add_command(generate_bot_invite_url, name="generate-invite") +cli.add_command(create_channels, name="create-channels") + + +if __name__ == "__main__": + cli()