discord notification on error

This commit is contained in:
Daniel O'Connell 2025-06-05 02:21:52 +02:00
parent 489265fe31
commit 4d057d1ec6
16 changed files with 1431 additions and 392 deletions

View File

@ -92,6 +92,27 @@ curl -X POST http://localhost:8000/auth/login \
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header. This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
## Discord integration
If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications).
Once you have the bot's token, run
```bash
python tools/discord_setup.py generate-invite --bot-token <your bot token>
```
to get an url that can be used to connect your Discord bot.
Next you'll have to set at least the following in your `.env` file:
```
DISCORD_BOT_TOKEN=<your bot token>
DISCORD_NOTIFICATIONS_ENABLED=True
```
When the worker starts it will automatically attempt to create the appropriate channels. You
can change what they will be called by setting the various `DISCORD_*_CHANNEL` settings.
## MCP Proxy Setup ## MCP Proxy Setup
Since MCP doesn't support basic authentication, use the included proxy for AI assistants that need to connect: Since MCP doesn't support basic authentication, use the included proxy for AI assistants that need to connect:

View File

@ -0,0 +1,2 @@
import memory.api.MCP.manifest
import memory.api.MCP.memory

View File

@ -0,0 +1,119 @@
import asyncio
import aiohttp
from datetime import datetime
from typing import TypedDict, NotRequired, Literal
from memory.api.MCP.tools import mcp
class BinaryProbs(TypedDict):
prob: float
class MultiProbs(TypedDict):
answerProbs: dict[str, float]
Probs = dict[str, BinaryProbs | MultiProbs]
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
class MarketAnswer(TypedDict):
id: str
text: str
resolutionProbability: float
class MarketDetails(TypedDict):
id: str
createdTime: int
question: str
outcomeType: OutcomeType
textDescription: str
groupSlugs: list[str]
volume: float
isResolved: bool
answers: list[MarketAnswer]
class Market(TypedDict):
id: str
url: str
question: str
volume: int
createdTime: int
outcomeType: OutcomeType
createdAt: NotRequired[str]
description: NotRequired[str]
answers: NotRequired[dict[str, float]]
probability: NotRequired[float]
details: NotRequired[MarketDetails]
async def get_details(session: aiohttp.ClientSession, market_id: str):
async with session.get(
f"https://api.manifold.markets/v0/market/{market_id}"
) as resp:
resp.raise_for_status()
return await resp.json()
async def format_market(session: aiohttp.ClientSession, market: Market):
if market.get("outcomeType") != "BINARY":
details = await get_details(session, market["id"])
market["answers"] = {
answer["text"]: round(
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
)
for answer in details["answers"]
}
if creationTime := market.get("createdTime"):
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
fields = [
"id",
"name",
"url",
"question",
"volume",
"createdAt",
"details",
"probability",
"answers",
]
return {k: v for k, v in market.items() if k in fields}
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.manifold.markets/v0/search-markets",
params={
"term": term,
"contractType": "BINARY" if binary else "ALL",
},
) as resp:
resp.raise_for_status()
markets = await resp.json()
return await asyncio.gather(
*[
format_market(session, market)
for market in markets
if market.get("volume", 0) >= min_volume
]
)
@mcp.tool()
async def get_forecasts(
term: str, min_volume: int = 1000, binary: bool = False
) -> list[dict]:
"""Get prediction market forecasts for a given term.
Args:
term: The term to search for.
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
binary: Whether to only return binary markets.
"""
return await search_markets(term, min_volume, binary)

View File

@ -0,0 +1,443 @@
"""
MCP tools for the epistemic sparring partner system.
"""
import logging
import pathlib
from datetime import datetime, timezone
import mimetypes
from pydantic import BaseModel
from sqlalchemy import Text, func
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.search.search import SearchFilters, search
from memory.common import extract, settings
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
from memory.common.db.connection import make_session
from memory.common.db.models import AgentObservation, SourceItem
from memory.common.formatters import observation
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
from memory.api.MCP.tools import mcp
logger = logging.getLogger(__name__)
def filter_observation_source_ids(
tags: list[str] | None = None, observation_types: list[str] | None = None
):
if not tags and not observation_types:
return None
with make_session() as session:
items_query = session.query(AgentObservation.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if observation_types:
items_query = items_query.filter(
AgentObservation.observation_type.in_(observation_types)
)
source_ids = [item.id for item in items_query.all()]
return source_ids
def filter_source_ids(
modalities: set[str],
tags: list[str] | None = None,
):
if not tags:
return None
with make_session() as session:
items_query = session.query(SourceItem.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if modalities:
items_query = items_query.filter(SourceItem.modality.in_(modalities))
source_ids = [item.id for item in items_query.all()]
return source_ids
@mcp.tool()
async def get_all_tags() -> list[str]:
"""
Get all unique tags used across the entire knowledge base.
Returns sorted list of tags from both observations and content.
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@mcp.tool()
async def get_all_subjects() -> list[str]:
"""
Get all unique subjects from observations about the user.
Returns sorted list of subject identifiers used in observations.
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@mcp.tool()
async def get_all_observation_types() -> list[str]:
"""
Get all observation types that have been used.
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)
@mcp.tool()
async def search_knowledge_base(
query: str,
previews: bool = False,
modalities: set[str] = set(),
tags: list[str] = [],
limit: int = 10,
) -> list[dict]:
"""
Search user's stored content including emails, documents, articles, books.
Use to find specific information the user has saved or received.
Combine with search_observations for complete user context.
Args:
query: Natural language search query - be descriptive about what you're looking for
previews: Include actual content in results - when false only a snippet is returned
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
tags: Filter by tags - content must have at least one matching tag
limit: Max results (1-100)
Returns: List of search results with id, score, chunks, content, filename
Higher scores (>0.7) indicate strong matches.
"""
logger.info(f"MCP search for: {query}")
if not modalities:
modalities = set(ALL_COLLECTIONS.keys())
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
upload_data = extract.extract_text(query)
results = await search(
upload_data,
previews=previews,
modalities=modalities,
limit=limit,
min_text_score=0.4,
min_multimodal_score=0.25,
filters=SearchFilters(
tags=tags,
source_ids=filter_source_ids(tags=tags, modalities=modalities),
),
)
return [result.model_dump() for result in results]
class RawObservation(BaseModel):
subject: str
content: str
observation_type: str = "general"
confidences: dict[str, float] = {}
evidence: dict | None = None
tags: list[str] = []
@mcp.tool()
async def observe(
observations: list[RawObservation],
session_id: str | None = None,
agent_model: str = "unknown",
) -> dict:
"""
Record observations about the user for long-term understanding.
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
Be specific and detailed - observations should make sense months later.
Example call:
```
{
"observations": [
{
"content": "The user is a software engineer.",
"subject": "user",
"observation_type": "belief",
"confidences": {"observation_accuracy": 0.9},
"evidence": {"quote": "I am a software engineer.", "context": "I work at Google."},
"tags": ["programming", "work"]
}
],
"session_id": "123e4567-e89b-12d3-a456-426614174000",
"agent_model": "gpt-4o"
}
```
RawObservation fields:
content (required): Detailed observation text explaining what you observed
subject (required): Consistent identifier like "programming_style", "work_habits"
observation_type: belief, preference, behavior, contradiction, general
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
evidence: Context dict with extra context, e.g. "quote" (exact words) and "context" (situation)
tags: List of categorization tags for organization
Args:
observations: List of RawObservation objects
session_id: UUID to group observations from same conversation
agent_model: AI model making observations (for quality tracking)
"""
tasks = [
(
observation,
celery_app.send_task(
SYNC_OBSERVATION,
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
kwargs={
"subject": observation.subject,
"content": observation.content,
"observation_type": observation.observation_type,
"confidences": observation.confidences,
"evidence": observation.evidence,
"tags": observation.tags,
"session_id": session_id,
"agent_model": agent_model,
},
),
)
for observation in observations
]
def short_content(obs: RawObservation) -> str:
if len(obs.content) > 50:
return obs.content[:47] + "..."
return obs.content
return {
"task_ids": {short_content(obs): task.id for obs, task in tasks},
"status": "queued",
}
@mcp.tool()
async def search_observations(
query: str,
subject: str = "",
tags: list[str] | None = None,
observation_types: list[str] | None = None,
min_confidences: dict[str, float] = {},
limit: int = 10,
) -> list[dict]:
"""
Search recorded observations about the user.
Use before responding to understand user preferences, patterns, and past insights.
Search by meaning - the query matches both content and context.
Args:
query: Natural language search query describing what you're looking for
subject: Filter by exact subject identifier (empty = search all subjects)
tags: Filter by tags (must have at least one matching tag)
observation_types: Filter by: belief, preference, behavior, contradiction, general
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
limit: Max results (1-100)
Returns: List with content, tags, created_at, metadata
Results sorted by relevance to your query.
"""
semantic_text = observation.generate_semantic_text(
subject=subject or "",
observation_type="".join(observation_types or []),
content=query,
evidence=None,
)
temporal = observation.generate_temporal_text(
subject=subject or "",
content=query,
created_at=datetime.now(timezone.utc),
)
results = await search(
[
extract.DataChunk(data=[query]),
extract.DataChunk(data=[semantic_text]),
extract.DataChunk(data=[temporal]),
],
previews=True,
modalities={"semantic", "temporal"},
limit=limit,
filters=SearchFilters(
subject=subject,
min_confidences=min_confidences,
tags=tags,
observation_types=observation_types,
source_ids=filter_observation_source_ids(tags=tags),
),
timeout=2,
)
return [
{
"content": r.content,
"tags": r.tags,
"created_at": r.created_at.isoformat() if r.created_at else None,
"metadata": r.metadata,
}
for r in results
]
@mcp.tool()
async def create_note(
subject: str,
content: str,
filename: str | None = None,
note_type: str | None = None,
confidences: dict[str, float] = {},
tags: list[str] = [],
) -> dict:
"""
Create a note when user asks to save or record something.
Use when user explicitly requests noting information for future reference.
Args:
subject: What the note is about (used for organization)
content: Note content as a markdown string
filename: Optional path relative to notes folder (e.g., "project/ideas.md")
note_type: Optional categorization of the note
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
tags: Organization tags for filtering and discovery
"""
if filename:
path = pathlib.Path(filename)
if not path.is_absolute():
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix()
try:
task = celery_app.send_task(
SYNC_NOTE,
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
kwargs={
"subject": subject,
"content": content,
"filename": filename,
"note_type": note_type,
"confidences": confidences,
"tags": tags,
},
)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Error creating note: {e}")
raise
return {
"task_id": task.id,
"status": "queued",
}
@mcp.tool()
async def note_files(path: str = "/"):
"""
List note files in the user's note storage.
Use to discover existing notes before reading or to help user navigate their collection.
Args:
path: Directory path to search (e.g., "/", "/projects", "/meetings")
Use "/" for root, or subdirectories to narrow scope
Returns: List of file paths relative to notes directory
"""
root = settings.NOTES_STORAGE_DIR / path.lstrip("/")
return [
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
for f in root.rglob("*.md")
if f.is_file()
]
@mcp.tool()
def fetch_file(filename: str) -> dict:
"""
Read file content with automatic type detection.
Returns dict with content, mime_type, is_text, file_size.
Text content as string, binary as base64.
"""
path = settings.FILE_STORAGE_DIR / filename.lstrip("/")
if not path.exists():
raise FileNotFoundError(f"File not found: {filename}")
mime_type, _ = mimetypes.guess_type(str(path))
mime_type = mime_type or "application/octet-stream"
text_extensions = {
".md",
".txt",
".py",
".js",
".html",
".css",
".json",
".xml",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
}
text_mimes = {
"application/json",
"application/xml",
"application/javascript",
"application/x-yaml",
"application/yaml",
}
is_text = (
mime_type.startswith("text/")
or mime_type in text_mimes
or path.suffix.lower() in text_extensions
)
try:
content = (
path.read_text(encoding="utf-8")
if is_text
else __import__("base64").b64encode(path.read_bytes()).decode("ascii")
)
except UnicodeDecodeError:
import base64
content = base64.b64encode(path.read_bytes()).decode("ascii")
is_text = False
mime_type = (
"application/octet-stream" if mime_type.startswith("text/") else mime_type
)
return {
"content": content,
"mime_type": mime_type,
"is_text": is_text,
"file_size": path.stat().st_size,
"filename": filename,
}

View File

@ -3,23 +3,15 @@ MCP tools for the epistemic sparring partner system.
""" """
import logging import logging
import pathlib
from datetime import datetime, timezone from datetime import datetime, timezone
import mimetypes
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel from sqlalchemy import Text
from sqlalchemy import Text, func
from sqlalchemy import cast as sql_cast from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.search.search import SearchFilters, search
from memory.common import extract, settings
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import AgentObservation, SourceItem from memory.common.db.models import AgentObservation, SourceItem
from memory.common.formatters import observation
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -76,377 +68,3 @@ def filter_source_ids(
async def get_current_time() -> dict: async def get_current_time() -> dict:
"""Get the current time in UTC.""" """Get the current time in UTC."""
return {"current_time": datetime.now(timezone.utc).isoformat()} return {"current_time": datetime.now(timezone.utc).isoformat()}
@mcp.tool()
async def get_all_tags() -> list[str]:
"""
Get all unique tags used across the entire knowledge base.
Returns sorted list of tags from both observations and content.
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@mcp.tool()
async def get_all_subjects() -> list[str]:
"""
Get all unique subjects from observations about the user.
Returns sorted list of subject identifiers used in observations.
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@mcp.tool()
async def get_all_observation_types() -> list[str]:
"""
Get all observation types that have been used.
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)
@mcp.tool()
async def search_knowledge_base(
query: str,
previews: bool = False,
modalities: set[str] = set(),
tags: list[str] = [],
limit: int = 10,
) -> list[dict]:
"""
Search user's stored content including emails, documents, articles, books.
Use to find specific information the user has saved or received.
Combine with search_observations for complete user context.
Args:
query: Natural language search query - be descriptive about what you're looking for
previews: Include actual content in results - when false only a snippet is returned
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
tags: Filter by tags - content must have at least one matching tag
limit: Max results (1-100)
Returns: List of search results with id, score, chunks, content, filename
Higher scores (>0.7) indicate strong matches.
"""
logger.info(f"MCP search for: {query}")
if not modalities:
modalities = set(ALL_COLLECTIONS.keys())
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
upload_data = extract.extract_text(query)
results = await search(
upload_data,
previews=previews,
modalities=modalities,
limit=limit,
min_text_score=0.4,
min_multimodal_score=0.25,
filters=SearchFilters(
tags=tags,
source_ids=filter_source_ids(tags=tags, modalities=modalities),
),
)
return [result.model_dump() for result in results]
class RawObservation(BaseModel):
subject: str
content: str
observation_type: str = "general"
confidences: dict[str, float] = {}
evidence: dict | None = None
tags: list[str] = []
@mcp.tool()
async def observe(
observations: list[RawObservation],
session_id: str | None = None,
agent_model: str = "unknown",
) -> dict:
"""
Record observations about the user for long-term understanding.
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
Be specific and detailed - observations should make sense months later.
Example call:
```
{
"observations": [
{
"content": "The user is a software engineer.",
"subject": "user",
"observation_type": "belief",
"confidences": {"observation_accuracy": 0.9},
"evidence": {"quote": "I am a software engineer.", "context": "I work at Google."},
"tags": ["programming", "work"]
}
],
"session_id": "123e4567-e89b-12d3-a456-426614174000",
"agent_model": "gpt-4o"
}
```
RawObservation fields:
content (required): Detailed observation text explaining what you observed
subject (required): Consistent identifier like "programming_style", "work_habits"
observation_type: belief, preference, behavior, contradiction, general
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
evidence: Context dict with extra context, e.g. "quote" (exact words) and "context" (situation)
tags: List of categorization tags for organization
Args:
observations: List of RawObservation objects
session_id: UUID to group observations from same conversation
agent_model: AI model making observations (for quality tracking)
"""
tasks = [
(
observation,
celery_app.send_task(
SYNC_OBSERVATION,
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
kwargs={
"subject": observation.subject,
"content": observation.content,
"observation_type": observation.observation_type,
"confidences": observation.confidences,
"evidence": observation.evidence,
"tags": observation.tags,
"session_id": session_id,
"agent_model": agent_model,
},
),
)
for observation in observations
]
def short_content(obs: RawObservation) -> str:
if len(obs.content) > 50:
return obs.content[:47] + "..."
return obs.content
return {
"task_ids": {short_content(obs): task.id for obs, task in tasks},
"status": "queued",
}
@mcp.tool()
async def search_observations(
query: str,
subject: str = "",
tags: list[str] | None = None,
observation_types: list[str] | None = None,
min_confidences: dict[str, float] = {},
limit: int = 10,
) -> list[dict]:
"""
Search recorded observations about the user.
Use before responding to understand user preferences, patterns, and past insights.
Search by meaning - the query matches both content and context.
Args:
query: Natural language search query describing what you're looking for
subject: Filter by exact subject identifier (empty = search all subjects)
tags: Filter by tags (must have at least one matching tag)
observation_types: Filter by: belief, preference, behavior, contradiction, general
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
limit: Max results (1-100)
Returns: List with content, tags, created_at, metadata
Results sorted by relevance to your query.
"""
semantic_text = observation.generate_semantic_text(
subject=subject or "",
observation_type="".join(observation_types or []),
content=query,
evidence=None,
)
temporal = observation.generate_temporal_text(
subject=subject or "",
content=query,
created_at=datetime.now(timezone.utc),
)
results = await search(
[
extract.DataChunk(data=[query]),
extract.DataChunk(data=[semantic_text]),
extract.DataChunk(data=[temporal]),
],
previews=True,
modalities={"semantic", "temporal"},
limit=limit,
filters=SearchFilters(
subject=subject,
min_confidences=min_confidences,
tags=tags,
observation_types=observation_types,
source_ids=filter_observation_source_ids(tags=tags),
),
timeout=2,
)
return [
{
"content": r.content,
"tags": r.tags,
"created_at": r.created_at.isoformat() if r.created_at else None,
"metadata": r.metadata,
}
for r in results
]
@mcp.tool()
async def create_note(
subject: str,
content: str,
filename: str | None = None,
note_type: str | None = None,
confidences: dict[str, float] = {},
tags: list[str] = [],
) -> dict:
"""
Create a note when user asks to save or record something.
Use when user explicitly requests noting information for future reference.
Args:
subject: What the note is about (used for organization)
content: Note content as a markdown string
filename: Optional path relative to notes folder (e.g., "project/ideas.md")
note_type: Optional categorization of the note
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
tags: Organization tags for filtering and discovery
"""
if filename:
path = pathlib.Path(filename)
if not path.is_absolute():
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix()
try:
task = celery_app.send_task(
SYNC_NOTE,
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
kwargs={
"subject": subject,
"content": content,
"filename": filename,
"note_type": note_type,
"confidences": confidences,
"tags": tags,
},
)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Error creating note: {e}")
raise
return {
"task_id": task.id,
"status": "queued",
}
@mcp.tool()
async def note_files(path: str = "/"):
"""
List note files in the user's note storage.
Use to discover existing notes before reading or to help user navigate their collection.
Args:
path: Directory path to search (e.g., "/", "/projects", "/meetings")
Use "/" for root, or subdirectories to narrow scope
Returns: List of file paths relative to notes directory
"""
root = settings.NOTES_STORAGE_DIR / path.lstrip("/")
return [
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
for f in root.rglob("*.md")
if f.is_file()
]
@mcp.tool()
def fetch_file(filename: str) -> dict:
"""
Read file content with automatic type detection.
Returns dict with content, mime_type, is_text, file_size.
Text content as string, binary as base64.
"""
path = settings.FILE_STORAGE_DIR / filename.lstrip("/")
if not path.exists():
raise FileNotFoundError(f"File not found: {filename}")
mime_type, _ = mimetypes.guess_type(str(path))
mime_type = mime_type or "application/octet-stream"
text_extensions = {
".md",
".txt",
".py",
".js",
".html",
".css",
".json",
".xml",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
}
text_mimes = {
"application/json",
"application/xml",
"application/javascript",
"application/x-yaml",
"application/yaml",
}
is_text = (
mime_type.startswith("text/")
or mime_type in text_mimes
or path.suffix.lower() in text_extensions
)
try:
content = (
path.read_text(encoding="utf-8")
if is_text
else __import__("base64").b64encode(path.read_bytes()).decode("ascii")
)
except UnicodeDecodeError:
import base64
content = base64.b64encode(path.read_bytes()).decode("ascii")
is_text = False
mime_type = (
"application/octet-stream" if mime_type.startswith("text/") else mime_type
)
return {
"content": content,
"mime_type": mime_type,
"is_text": is_text,
"file_size": path.stat().st_size,
"filename": filename,
}

View File

@ -195,6 +195,9 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
} }
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if settings.DISABLE_AUTH:
return await call_next(request)
path = request.url.path path = request.url.path
# Skip authentication for whitelisted endpoints # Skip authentication for whitelisted endpoints

View File

@ -83,5 +83,7 @@ app.conf.update(
@app.on_after_configure.connect # type: ignore[attr-defined] @app.on_after_configure.connect # type: ignore[attr-defined]
def ensure_qdrant_initialised(sender, **_): def ensure_qdrant_initialised(sender, **_):
from memory.common import qdrant from memory.common import qdrant
from memory.common.discord import load_servers
qdrant.setup_qdrant() qdrant.setup_qdrant()
load_servers()

View File

@ -0,0 +1,184 @@
import logging
import requests
from typing import Any, Dict, List
from memory.common import settings
logger = logging.getLogger(__name__)
ERROR_CHANNEL = "memory-errors"
ACTIVITY_CHANNEL = "memory-activity"
DISCOVERY_CHANNEL = "memory-discoveries"
CHAT_CHANNEL = "memory-chat"
class DiscordServer(requests.Session):
def __init__(self, server_id: str, server_name: str, *args, **kwargs):
self.server_id = server_id
self.server_name = server_name
self.channels = {}
super().__init__(*args, **kwargs)
self.setup_channels()
def setup_channels(self):
resp = self.get(self.channels_url)
resp.raise_for_status()
channels = {channel["name"]: channel["id"] for channel in resp.json()}
if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)):
error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL)
self.channels[ERROR_CHANNEL] = error_channel
if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)):
activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL)
self.channels[ACTIVITY_CHANNEL] = activity_channel
if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)):
discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL)
self.channels[DISCOVERY_CHANNEL] = discovery_channel
if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)):
chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL)
self.channels[CHAT_CHANNEL] = chat_channel
@property
def error_channel(self) -> str:
return self.channels[ERROR_CHANNEL]
@property
def activity_channel(self) -> str:
return self.channels[ACTIVITY_CHANNEL]
@property
def discovery_channel(self) -> str:
return self.channels[DISCOVERY_CHANNEL]
@property
def chat_channel(self) -> str:
return self.channels[CHAT_CHANNEL]
def channel_id(self, channel_name: str) -> str:
if not (channel_id := self.channels.get(channel_name)):
raise ValueError(f"Channel {channel_name} not found")
return channel_id
def send_message(self, channel_id: str, content: str):
self.post(
f"https://discord.com/api/v10/channels/{channel_id}/messages",
json={"content": content},
)
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
resp = self.post(
self.channels_url, json={"name": channel_name, "type": channel_type}
)
resp.raise_for_status()
return resp.json()["id"]
def __str__(self):
return (
f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})"
)
def request(self, method: str, url: str, **kwargs):
headers = kwargs.get("headers", {})
headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}"
headers["Content-Type"] = "application/json"
kwargs["headers"] = headers
return super().request(method, url, **kwargs)
@property
def channels_url(self) -> str:
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
def get_bot_servers() -> List[Dict]:
"""Get list of servers the bot is in."""
if not settings.DISCORD_BOT_TOKEN:
return []
try:
headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"}
response = requests.get(
"https://discord.com/api/v10/users/@me/guilds", headers=headers
)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Failed to get bot servers: {e}")
return []
servers: dict[str, DiscordServer] = {}
def load_servers():
for server in get_bot_servers():
servers[server["id"]] = DiscordServer(server["id"], server["name"])
def broadcast_message(channel: str, message: str):
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return
for server in servers.values():
server.send_message(server.channel_id(channel), message)
def send_error_message(message: str):
broadcast_message(ERROR_CHANNEL, message)
def send_activity_message(message: str):
broadcast_message(ACTIVITY_CHANNEL, message)
def send_discovery_message(message: str):
broadcast_message(DISCOVERY_CHANNEL, message)
def send_chat_message(message: str):
broadcast_message(CHAT_CHANNEL, message)
def notify_task_failure(
task_name: str,
error_message: str,
task_args: tuple = (),
task_kwargs: dict[str, Any] | None = None,
traceback_str: str | None = None,
) -> None:
"""
Send a task failure notification to Discord.
Args:
task_name: Name of the failed task
error_message: Error message
task_args: Task arguments
task_kwargs: Task keyword arguments
traceback_str: Full traceback string
Returns:
True if notification sent successfully
"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
logger.debug("Discord notifications disabled")
return
message = f"🚨 **Task Failed: {task_name}**\n\n"
message += f"**Error:** {error_message[:500]}\n"
if task_args:
message += f"**Args:** `{str(task_args)[:200]}`\n"
if task_kwargs:
message += f"**Kwargs:** `{str(task_kwargs)[:200]}`\n"
if traceback_str:
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
try:
send_error_message(message)
logger.info(f"Discord error notification sent for task: {task_name}")
except Exception as e:
logger.error(f"Failed to send Discord notification: {e}")

View File

@ -55,7 +55,7 @@ def generate_temporal_text(
f"Observation: {content}", f"Observation: {content}",
] ]
return " | ".join(parts) return " | ".join(parts).strip()
# TODO: Add more embedding dimensions here: # TODO: Add more embedding dimensions here:

View File

@ -138,3 +138,17 @@ SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 *
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30)) SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
REGISTER_ENABLED = boolean_env("REGISTER_ENABLED", False) or True REGISTER_ENABLED = boolean_env("REGISTER_ENABLED", False) or True
DISABLE_AUTH = boolean_env("DISABLE_AUTH", False)
# Discord notification settings
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
DISCORD_ERROR_CHANNEL = os.getenv("DISCORD_ERROR_CHANNEL", "memory-errors")
DISCORD_ACTIVITY_CHANNEL = os.getenv("DISCORD_ACTIVITY_CHANNEL", "memory-activity")
DISCORD_DISCOVERY_CHANNEL = os.getenv("DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
# Enable Discord notifications if bot token is set
DISCORD_NOTIFICATIONS_ENABLED = (
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
)

View File

@ -47,9 +47,20 @@ Text to summarize:
def parse_response(response: str) -> dict[str, Any]: def parse_response(response: str) -> dict[str, Any]:
"""Parse the response from the summarizer.""" """Parse the response from the summarizer."""
soup = BeautifulSoup(response, "xml") if not response or not response.strip():
summary = soup.find("summary").text return {"summary": "", "tags": []}
tags = [tag.text for tag in soup.find_all("tag")]
# Use html.parser instead of xml parser for better compatibility
soup = BeautifulSoup(response, "html.parser")
# Safely extract summary
summary_element = soup.find("summary")
summary = summary_element.text if summary_element else ""
# Safely extract tags
tag_elements = soup.find_all("tag")
tags = [tag.text for tag in tag_elements if tag.text is not None]
return {"summary": summary, "tags": tags} return {"summary": summary, "tags": tags}
@ -68,7 +79,6 @@ def _call_openai(prompt: str) -> dict[str, Any]:
}, },
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
], ],
response_format={"type": "json_object"},
temperature=0.3, temperature=0.3,
max_tokens=2048, max_tokens=2048,
) )

View File

@ -12,8 +12,9 @@ import traceback
import logging import logging
from typing import Any, Callable, Iterable, Sequence, cast from typing import Any, Callable, Iterable, Sequence, cast
from memory.common import embedding, qdrant from memory.common import embedding, qdrant, settings
from memory.common.db.models import SourceItem, Chunk from memory.common.db.models import SourceItem, Chunk
from memory.common.discord import notify_task_failure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -274,7 +275,17 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
return func(*args, **kwargs) return func(*args, **kwargs)
except Exception as e: except Exception as e:
logger.error(f"Task {func.__name__} failed: {e}") logger.error(f"Task {func.__name__} failed: {e}")
logger.error(traceback.format_exc()) traceback_str = traceback.format_exc()
logger.error(traceback_str)
notify_task_failure(
task_name=func.__name__,
error_message=str(e),
task_args=args,
task_kwargs=kwargs,
traceback_str=traceback_str,
)
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
return wrapper return wrapper

View File

@ -213,12 +213,15 @@ def mock_file_storage(tmp_path: Path):
email_storage_dir.mkdir(parents=True, exist_ok=True) email_storage_dir.mkdir(parents=True, exist_ok=True)
notes_storage_dir = tmp_path / "notes" notes_storage_dir = tmp_path / "notes"
notes_storage_dir.mkdir(parents=True, exist_ok=True) notes_storage_dir.mkdir(parents=True, exist_ok=True)
comic_storage_dir = tmp_path / "comics"
comic_storage_dir.mkdir(parents=True, exist_ok=True)
with ( with (
patch.object(settings, "FILE_STORAGE_DIR", tmp_path), patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir), patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir),
patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir), patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir),
patch.object(settings, "EMAIL_STORAGE_DIR", email_storage_dir), patch.object(settings, "EMAIL_STORAGE_DIR", email_storage_dir),
patch.object(settings, "NOTES_STORAGE_DIR", notes_storage_dir), patch.object(settings, "NOTES_STORAGE_DIR", notes_storage_dir),
patch.object(settings, "COMIC_STORAGE_DIR", comic_storage_dir),
): ):
yield yield
@ -256,7 +259,7 @@ def mock_openai_client():
choices=[ choices=[
Mock( Mock(
message=Mock( message=Mock(
content='{"summary": "test summary", "tags": ["tag1", "tag2"]}' content="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
) )
) )
] ]
@ -273,8 +276,16 @@ def mock_anthropic_client():
client.messages.create = Mock( client.messages.create = Mock(
return_value=Mock( return_value=Mock(
content=[ content=[
Mock(text='{"summary": "test summary", "tags": ["tag1", "tag2"]}') Mock(
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
)
] ]
) )
) )
yield client yield client
@pytest.fixture(autouse=True)
def mock_discord_client():
with patch.object(settings, "DISCORD_NOTIFICATIONS_ENABLED", False):
yield

View File

@ -0,0 +1,393 @@
import logging
import pytest
from unittest.mock import Mock, patch
import requests
import json
from memory.common import discord, settings
@pytest.fixture
def mock_session_request():
with patch("requests.Session.request") as mock:
yield mock
@pytest.fixture
def mock_get_channels_response():
return [
{"name": "memory-errors", "id": "error_channel_id"},
{"name": "memory-activity", "id": "activity_channel_id"},
{"name": "memory-discoveries", "id": "discovery_channel_id"},
{"name": "memory-chat", "id": "chat_channel_id"},
]
def test_discord_server_init(mock_session_request, mock_get_channels_response):
# Mock the channels API call
mock_response = Mock()
mock_response.json.return_value = mock_get_channels_response
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
server = discord.DiscordServer("server123", "Test Server")
assert server.server_id == "server123"
assert server.server_name == "Test Server"
assert hasattr(server, "channels")
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors")
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity")
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat")
def test_setup_channels_existing(mock_session_request, mock_get_channels_response):
# Mock the channels API call
mock_response = Mock()
mock_response.json.return_value = mock_get_channels_response
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
server = discord.DiscordServer("server123", "Test Server")
assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id"
assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id"
assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id"
assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id"
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel")
def test_setup_channels_create_missing(mock_session_request):
# Mock get channels (empty) and create channel calls
get_response = Mock()
get_response.json.return_value = []
get_response.raise_for_status.return_value = None
create_response = Mock()
create_response.json.return_value = {"id": "new_channel_id"}
create_response.raise_for_status.return_value = None
mock_session_request.side_effect = [
get_response,
create_response,
create_response,
create_response,
create_response,
]
server = discord.DiscordServer("server123", "Test Server")
assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id"
def test_channel_properties():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.channels = {
discord.ERROR_CHANNEL: "error_id",
discord.ACTIVITY_CHANNEL: "activity_id",
discord.DISCOVERY_CHANNEL: "discovery_id",
discord.CHAT_CHANNEL: "chat_id",
}
assert server.error_channel == "error_id"
assert server.activity_channel == "activity_id"
assert server.discovery_channel == "discovery_id"
assert server.chat_channel == "chat_id"
def test_channel_id_exists():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.channels = {"test-channel": "channel123"}
assert server.channel_id("test-channel") == "channel123"
def test_channel_id_not_found():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.channels = {}
with pytest.raises(ValueError, match="Channel nonexistent not found"):
server.channel_id("nonexistent")
def test_send_message(mock_session_request):
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.send_message("channel123", "Hello World")
mock_session_request.assert_called_with(
"POST",
"https://discord.com/api/v10/channels/channel123/messages",
data=None,
json={"content": "Hello World"},
headers={
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
},
)
def test_create_channel(mock_session_request):
mock_response = Mock()
mock_response.json.return_value = {"id": "new_channel_id"}
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
channel_id = server.create_channel("new-channel")
assert channel_id == "new_channel_id"
mock_session_request.assert_called_with(
"POST",
"https://discord.com/api/v10/guilds/server123/channels",
data=None,
json={"name": "new-channel", "type": 0},
headers={
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
},
)
def test_create_channel_custom_type(mock_session_request):
mock_response = Mock()
mock_response.json.return_value = {"id": "voice_channel_id"}
mock_response.raise_for_status.return_value = None
mock_session_request.return_value = mock_response
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
channel_id = server.create_channel("voice-channel", channel_type=2)
assert channel_id == "voice_channel_id"
mock_session_request.assert_called_with(
"POST",
"https://discord.com/api/v10/guilds/server123/channels",
data=None,
json={"name": "voice-channel", "type": 2},
headers={
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
},
)
def test_str_representation():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
server.server_name = "Test Server"
assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)"
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123")
def test_request_adds_headers(mock_session_request):
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.request("GET", "https://example.com", headers={"Custom": "header"})
expected_headers = {
"Custom": "header",
"Authorization": "Bot test_token_123",
"Content-Type": "application/json",
}
mock_session_request.assert_called_once_with(
"GET", "https://example.com", headers=expected_headers
)
def test_channels_url():
server = discord.DiscordServer.__new__(discord.DiscordServer)
server.server_id = "server123"
assert (
server.channels_url == "https://discord.com/api/v10/guilds/server123/channels"
)
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
@patch("requests.get")
def test_get_bot_servers_success(mock_get):
mock_response = Mock()
mock_response.json.return_value = [
{"id": "server1", "name": "Server 1"},
{"id": "server2", "name": "Server 2"},
]
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
servers = discord.get_bot_servers()
assert len(servers) == 2
assert servers[0] == {"id": "server1", "name": "Server 1"}
mock_get.assert_called_once_with(
"https://discord.com/api/v10/users/@me/guilds",
headers={"Authorization": "Bot test_token"},
)
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
def test_get_bot_servers_no_token():
assert discord.get_bot_servers() == []
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
@patch("requests.get")
def test_get_bot_servers_exception(mock_get):
mock_get.side_effect = requests.RequestException("API Error")
servers = discord.get_bot_servers()
assert servers == []
@patch("memory.common.discord.get_bot_servers")
@patch("memory.common.discord.DiscordServer")
def test_load_servers(mock_discord_server_class, mock_get_servers):
mock_get_servers.return_value = [
{"id": "server1", "name": "Server 1"},
{"id": "server2", "name": "Server 2"},
]
discord.load_servers()
assert mock_discord_server_class.call_count == 2
mock_discord_server_class.assert_any_call("server1", "Server 1")
mock_discord_server_class.assert_any_call("server2", "Server 2")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_broadcast_message():
mock_server1 = Mock()
mock_server2 = Mock()
discord.servers = {"1": mock_server1, "2": mock_server2}
discord.broadcast_message("test-channel", "Hello")
mock_server1.send_message.assert_called_once_with(
mock_server1.channel_id.return_value, "Hello"
)
mock_server2.send_message.assert_called_once_with(
mock_server2.channel_id.return_value, "Hello"
)
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_broadcast_message_disabled():
mock_server = Mock()
discord.servers = {"1": mock_server}
discord.broadcast_message("test-channel", "Hello")
mock_server.send_message.assert_not_called()
@patch("memory.common.discord.broadcast_message")
def test_send_error_message(mock_broadcast):
discord.send_error_message("Error occurred")
mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred")
@patch("memory.common.discord.broadcast_message")
def test_send_activity_message(mock_broadcast):
discord.send_activity_message("Activity update")
mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update")
@patch("memory.common.discord.broadcast_message")
def test_send_discovery_message(mock_broadcast):
discord.send_discovery_message("Discovery made")
mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made")
@patch("memory.common.discord.broadcast_message")
def test_send_chat_message(mock_broadcast):
discord.send_chat_message("Chat message")
mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
def test_notify_task_failure_basic(mock_send_error):
discord.notify_task_failure("test_task", "Something went wrong")
mock_send_error.assert_called_once()
message = mock_send_error.call_args[0][0]
assert "🚨 **Task Failed: test_task**" in message
assert "**Error:** Something went wrong" in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
def test_notify_task_failure_with_args(mock_send_error):
discord.notify_task_failure(
"test_task",
"Error message",
task_args=("arg1", "arg2"),
task_kwargs={"key": "value"},
)
message = mock_send_error.call_args[0][0]
assert "**Args:** `('arg1', 'arg2')`" in message
assert "**Kwargs:** `{'key': 'value'}`" in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
def test_notify_task_failure_with_traceback(mock_send_error):
traceback = "Traceback (most recent call last):\n File ...\nError: Something"
discord.notify_task_failure("test_task", "Error message", traceback_str=traceback)
message = mock_send_error.call_args[0][0]
assert "**Traceback:**" in message
assert traceback in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
def test_notify_task_failure_truncates_long_error(mock_send_error):
long_error = "x" * 600 # Longer than 500 char limit
discord.notify_task_failure("test_task", long_error)
message = mock_send_error.call_args[0][0]
assert long_error[:500] in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
long_traceback = "x" * 1000 # Longer than 800 char limit
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
message = mock_send_error.call_args[0][0]
assert long_traceback[-800:] in message
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
@patch("memory.common.discord.send_error_message")
def test_notify_task_failure_disabled(mock_send_error):
discord.notify_task_failure("test_task", "Error message")
mock_send_error.assert_not_called()
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
@patch("memory.common.discord.send_error_message")
def test_notify_task_failure_send_fails(mock_send_error):
mock_send_error.side_effect = Exception("Discord API error")
# Should not raise, just log the error
discord.notify_task_failure("test_task", "Error message")
mock_send_error.assert_called_once()

View File

@ -0,0 +1,151 @@
from memory.common import summarizer
import pytest
@pytest.mark.parametrize(
"response, expected",
(
# Basic valid cases
("", {"summary": "", "tags": []}),
(
"<summary>test</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
{"summary": "test", "tags": ["tag1", "tag2"]},
),
(
"<summary>test</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
{"summary": "test", "tags": ["tag1", "tag2"]},
),
# Missing summary tag
(
"<tags><tag>tag1</tag><tag>tag2</tag></tags>",
{"summary": "", "tags": ["tag1", "tag2"]},
),
# Missing tags section
(
"<summary>test summary</summary>",
{"summary": "test summary", "tags": []},
),
# Empty summary tag
(
"<summary></summary><tags><tag>tag1</tag></tags>",
{"summary": "", "tags": ["tag1"]},
),
# Empty tags section
(
"<summary>test</summary><tags></tags>",
{"summary": "test", "tags": []},
),
# Single tag
(
"<summary>test</summary><tags><tag>single-tag</tag></tags>",
{"summary": "test", "tags": ["single-tag"]},
),
# Multiple tags
(
"<summary>test</summary><tags><tag>tag1</tag><tag>tag2</tag><tag>tag3</tag><tag>tag4</tag><tag>tag5</tag></tags>",
{"summary": "test", "tags": ["tag1", "tag2", "tag3", "tag4", "tag5"]},
),
# Tags with special characters and hyphens
(
"<summary>test</summary><tags><tag>machine-learning</tag><tag>ai/ml</tag><tag>data_science</tag></tags>",
{"summary": "test", "tags": ["machine-learning", "ai/ml", "data_science"]},
),
# Summary with special characters
(
"<summary>Test with &amp; special characters &lt;&gt;</summary><tags><tag>test</tag></tags>",
{"summary": "Test with & special characters <>", "tags": ["test"]},
),
# Whitespace handling
(
"<summary> test with spaces </summary><tags><tag> tag1 </tag><tag>tag2</tag></tags>",
{"summary": " test with spaces ", "tags": [" tag1 ", "tag2"]},
),
# Mixed case XML tags (should still work with BeautifulSoup)
(
"<Summary>test</Summary><Tags><Tag>tag1</Tag></Tags>",
{"summary": "test", "tags": ["tag1"]},
),
# Multiple summary tags (should take first one)
(
"<summary>first</summary><summary>second</summary><tags><tag>tag1</tag></tags>",
{"summary": "first", "tags": ["tag1"]},
),
# Multiple tags sections (should collect all tags)
(
"<summary>test</summary><tags><tag>tag1</tag></tags><tags><tag>tag2</tag></tags>",
{"summary": "test", "tags": ["tag1", "tag2"]},
),
# XML with extra elements (should ignore them)
(
"<root><summary>test</summary><other>ignored</other><tags><tag>tag1</tag></tags></root>",
{"summary": "test", "tags": ["tag1"]},
),
# XML with attributes (should still work)
(
'<summary id="1">test</summary><tags type="keywords"><tag>tag1</tag></tags>',
{"summary": "test", "tags": ["tag1"]},
),
# Empty tag elements
(
"<summary>test</summary><tags><tag></tag><tag>valid-tag</tag><tag></tag></tags>",
{"summary": "test", "tags": ["", "valid-tag", ""]},
),
# Self-closing tags
(
"<summary>test</summary><tags><tag/><tag>valid-tag</tag></tags>",
{"summary": "test", "tags": ["", "valid-tag"]},
),
# Long content
(
f"<summary>{'a' * 1000}</summary><tags><tag>long-content</tag></tags>",
{"summary": "a" * 1000, "tags": ["long-content"]},
),
# XML with newlines and formatting
(
"""<summary>
Multi-line
summary content
</summary>
<tags>
<tag>formatted</tag>
<tag>xml</tag>
</tags>""",
{
"summary": "\n Multi-line\n summary content\n ",
"tags": ["formatted", "xml"],
},
),
# Malformed XML (missing closing tags) - BeautifulSoup parses as best it can
(
"<summary>test<tags><tag>tag1</tag></tags>",
{"summary": "testtag1", "tags": ["tag1"]},
),
# Invalid XML characters should be handled by BeautifulSoup
(
"<summary>test & unescaped</summary><tags><tag>tag1</tag></tags>",
{"summary": "test & unescaped", "tags": ["tag1"]},
),
# Only whitespace
(
" \n\t ",
{"summary": "", "tags": []},
),
# Non-XML content
(
"This is just plain text without XML",
{"summary": "", "tags": []},
),
# XML comments (should be ignored)
(
"<!-- comment --><summary>test</summary><!-- another comment --><tags><tag>tag1</tag></tags>",
{"summary": "test", "tags": ["tag1"]},
),
# CDATA sections
(
"<summary><![CDATA[test with <special> characters]]></summary><tags><tag>cdata</tag></tags>",
{"summary": "test with <special> characters", "tags": ["cdata"]},
),
),
)
def test_parse_response(response, expected):
assert summarizer.parse_response(response) == expected

57
tools/discord_setup.py Normal file
View File

@ -0,0 +1,57 @@
import argparse
import click
import requests
@click.command()
@click.option("--bot-token", type=str, required=True)
def generate_bot_invite_url(bot_token: str):
"""
Generate the Discord bot invitation URL.
Returns:
URL that user can click to add bot to their server
"""
# Get bot's client ID from the token (it's the first part before the first dot)
# But safer to get it from the API
try:
headers = {"Authorization": f"Bot {bot_token}"}
response = requests.get(
"https://discord.com/api/v10/users/@me", headers=headers
)
response.raise_for_status()
bot_info = response.json()
client_id = bot_info["id"]
except Exception as e:
raise ValueError(f"Could not get bot info: {e}")
# Permissions needed: Send Messages (2048) + Manage Channels (16) + View Channels (1024)
permissions = 2048 + 16 + 1024 # = 3088
invite_url = f"https://discord.com/oauth2/authorize?client_id={client_id}&scope=bot&permissions={permissions}"
click.echo(f"Bot invite URL: {invite_url}")
return invite_url
@click.command()
def create_channels():
"""Create Discord channels using the configured servers."""
from memory.common.discord import load_servers
click.echo("Loading Discord servers and creating channels...")
load_servers()
click.echo("Discord channels setup completed.")
@click.group()
def cli():
"""Discord setup utilities."""
pass
cli.add_command(generate_bot_invite_url, name="generate-invite")
cli.add_command(create_channels, name="create-channels")
if __name__ == "__main__":
cli()