mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 15:14:45 +02:00
discord notification on error
This commit is contained in:
parent
489265fe31
commit
4d057d1ec6
21
README.md
21
README.md
@ -92,6 +92,27 @@ curl -X POST http://localhost:8000/auth/login \
|
|||||||
|
|
||||||
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
|
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
|
||||||
|
|
||||||
|
## Discord integration
|
||||||
|
|
||||||
|
If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications).
|
||||||
|
Once you have the bot's token, run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tools/discord_setup.py generate-invite --bot-token <your bot token>
|
||||||
|
```
|
||||||
|
|
||||||
|
to get an url that can be used to connect your Discord bot.
|
||||||
|
|
||||||
|
Next you'll have to set at least the following in your `.env` file:
|
||||||
|
|
||||||
|
```
|
||||||
|
DISCORD_BOT_TOKEN=<your bot token>
|
||||||
|
DISCORD_NOTIFICATIONS_ENABLED=True
|
||||||
|
```
|
||||||
|
|
||||||
|
When the worker starts it will automatically attempt to create the appropriate channels. You
|
||||||
|
can change what they will be called by setting the various `DISCORD_*_CHANNEL` settings.
|
||||||
|
|
||||||
## MCP Proxy Setup
|
## MCP Proxy Setup
|
||||||
|
|
||||||
Since MCP doesn't support basic authentication, use the included proxy for AI assistants that need to connect:
|
Since MCP doesn't support basic authentication, use the included proxy for AI assistants that need to connect:
|
||||||
|
2
src/memory/api/MCP/__init__.py
Normal file
2
src/memory/api/MCP/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
import memory.api.MCP.manifest
|
||||||
|
import memory.api.MCP.memory
|
119
src/memory/api/MCP/manifest.py
Normal file
119
src/memory/api/MCP/manifest.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from typing import TypedDict, NotRequired, Literal
|
||||||
|
from memory.api.MCP.tools import mcp
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryProbs(TypedDict):
|
||||||
|
prob: float
|
||||||
|
|
||||||
|
|
||||||
|
class MultiProbs(TypedDict):
|
||||||
|
answerProbs: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
Probs = dict[str, BinaryProbs | MultiProbs]
|
||||||
|
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
|
||||||
|
|
||||||
|
|
||||||
|
class MarketAnswer(TypedDict):
|
||||||
|
id: str
|
||||||
|
text: str
|
||||||
|
resolutionProbability: float
|
||||||
|
|
||||||
|
|
||||||
|
class MarketDetails(TypedDict):
|
||||||
|
id: str
|
||||||
|
createdTime: int
|
||||||
|
question: str
|
||||||
|
outcomeType: OutcomeType
|
||||||
|
textDescription: str
|
||||||
|
groupSlugs: list[str]
|
||||||
|
volume: float
|
||||||
|
isResolved: bool
|
||||||
|
answers: list[MarketAnswer]
|
||||||
|
|
||||||
|
|
||||||
|
class Market(TypedDict):
|
||||||
|
id: str
|
||||||
|
url: str
|
||||||
|
question: str
|
||||||
|
volume: int
|
||||||
|
createdTime: int
|
||||||
|
outcomeType: OutcomeType
|
||||||
|
createdAt: NotRequired[str]
|
||||||
|
description: NotRequired[str]
|
||||||
|
answers: NotRequired[dict[str, float]]
|
||||||
|
probability: NotRequired[float]
|
||||||
|
details: NotRequired[MarketDetails]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_details(session: aiohttp.ClientSession, market_id: str):
|
||||||
|
async with session.get(
|
||||||
|
f"https://api.manifold.markets/v0/market/{market_id}"
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def format_market(session: aiohttp.ClientSession, market: Market):
|
||||||
|
if market.get("outcomeType") != "BINARY":
|
||||||
|
details = await get_details(session, market["id"])
|
||||||
|
market["answers"] = {
|
||||||
|
answer["text"]: round(
|
||||||
|
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
|
||||||
|
)
|
||||||
|
for answer in details["answers"]
|
||||||
|
}
|
||||||
|
if creationTime := market.get("createdTime"):
|
||||||
|
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"url",
|
||||||
|
"question",
|
||||||
|
"volume",
|
||||||
|
"createdAt",
|
||||||
|
"details",
|
||||||
|
"probability",
|
||||||
|
"answers",
|
||||||
|
]
|
||||||
|
return {k: v for k, v in market.items() if k in fields}
|
||||||
|
|
||||||
|
|
||||||
|
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
"https://api.manifold.markets/v0/search-markets",
|
||||||
|
params={
|
||||||
|
"term": term,
|
||||||
|
"contractType": "BINARY" if binary else "ALL",
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
markets = await resp.json()
|
||||||
|
|
||||||
|
return await asyncio.gather(
|
||||||
|
*[
|
||||||
|
format_market(session, market)
|
||||||
|
for market in markets
|
||||||
|
if market.get("volume", 0) >= min_volume
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_forecasts(
|
||||||
|
term: str, min_volume: int = 1000, binary: bool = False
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Get prediction market forecasts for a given term.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
term: The term to search for.
|
||||||
|
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
|
||||||
|
binary: Whether to only return binary markets.
|
||||||
|
"""
|
||||||
|
return await search_markets(term, min_volume, binary)
|
443
src/memory/api/MCP/memory.py
Normal file
443
src/memory/api/MCP/memory.py
Normal file
@ -0,0 +1,443 @@
|
|||||||
|
"""
|
||||||
|
MCP tools for the epistemic sparring partner system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import pathlib
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import mimetypes
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import Text, func
|
||||||
|
from sqlalchemy import cast as sql_cast
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
|
from memory.api.search.search import SearchFilters, search
|
||||||
|
from memory.common import extract, settings
|
||||||
|
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import AgentObservation, SourceItem
|
||||||
|
from memory.common.formatters import observation
|
||||||
|
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
|
||||||
|
from memory.api.MCP.tools import mcp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_observation_source_ids(
|
||||||
|
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||||
|
):
|
||||||
|
if not tags and not observation_types:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
items_query = session.query(AgentObservation.id)
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
# Use PostgreSQL array overlap operator with proper array casting
|
||||||
|
items_query = items_query.filter(
|
||||||
|
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||||
|
)
|
||||||
|
if observation_types:
|
||||||
|
items_query = items_query.filter(
|
||||||
|
AgentObservation.observation_type.in_(observation_types)
|
||||||
|
)
|
||||||
|
source_ids = [item.id for item in items_query.all()]
|
||||||
|
|
||||||
|
return source_ids
|
||||||
|
|
||||||
|
|
||||||
|
def filter_source_ids(
|
||||||
|
modalities: set[str],
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
if not tags:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
items_query = session.query(SourceItem.id)
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
# Use PostgreSQL array overlap operator with proper array casting
|
||||||
|
items_query = items_query.filter(
|
||||||
|
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||||
|
)
|
||||||
|
if modalities:
|
||||||
|
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
||||||
|
source_ids = [item.id for item in items_query.all()]
|
||||||
|
|
||||||
|
return source_ids
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_tags() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all unique tags used across the entire knowledge base.
|
||||||
|
Returns sorted list of tags from both observations and content.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||||
|
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_subjects() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all unique subjects from observations about the user.
|
||||||
|
Returns sorted list of subject identifiers used in observations.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_observation_types() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all observation types that have been used.
|
||||||
|
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
{
|
||||||
|
r.observation_type
|
||||||
|
for r in session.query(AgentObservation.observation_type).distinct()
|
||||||
|
if r.observation_type is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_knowledge_base(
|
||||||
|
query: str,
|
||||||
|
previews: bool = False,
|
||||||
|
modalities: set[str] = set(),
|
||||||
|
tags: list[str] = [],
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Search user's stored content including emails, documents, articles, books.
|
||||||
|
Use to find specific information the user has saved or received.
|
||||||
|
Combine with search_observations for complete user context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Natural language search query - be descriptive about what you're looking for
|
||||||
|
previews: Include actual content in results - when false only a snippet is returned
|
||||||
|
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||||
|
tags: Filter by tags - content must have at least one matching tag
|
||||||
|
limit: Max results (1-100)
|
||||||
|
|
||||||
|
Returns: List of search results with id, score, chunks, content, filename
|
||||||
|
Higher scores (>0.7) indicate strong matches.
|
||||||
|
"""
|
||||||
|
logger.info(f"MCP search for: {query}")
|
||||||
|
|
||||||
|
if not modalities:
|
||||||
|
modalities = set(ALL_COLLECTIONS.keys())
|
||||||
|
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
||||||
|
|
||||||
|
upload_data = extract.extract_text(query)
|
||||||
|
results = await search(
|
||||||
|
upload_data,
|
||||||
|
previews=previews,
|
||||||
|
modalities=modalities,
|
||||||
|
limit=limit,
|
||||||
|
min_text_score=0.4,
|
||||||
|
min_multimodal_score=0.25,
|
||||||
|
filters=SearchFilters(
|
||||||
|
tags=tags,
|
||||||
|
source_ids=filter_source_ids(tags=tags, modalities=modalities),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return [result.model_dump() for result in results]
|
||||||
|
|
||||||
|
|
||||||
|
class RawObservation(BaseModel):
|
||||||
|
subject: str
|
||||||
|
content: str
|
||||||
|
observation_type: str = "general"
|
||||||
|
confidences: dict[str, float] = {}
|
||||||
|
evidence: dict | None = None
|
||||||
|
tags: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def observe(
|
||||||
|
observations: list[RawObservation],
|
||||||
|
session_id: str | None = None,
|
||||||
|
agent_model: str = "unknown",
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Record observations about the user for long-term understanding.
|
||||||
|
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
|
||||||
|
Be specific and detailed - observations should make sense months later.
|
||||||
|
|
||||||
|
Example call:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"observations": [
|
||||||
|
{
|
||||||
|
"content": "The user is a software engineer.",
|
||||||
|
"subject": "user",
|
||||||
|
"observation_type": "belief",
|
||||||
|
"confidences": {"observation_accuracy": 0.9},
|
||||||
|
"evidence": {"quote": "I am a software engineer.", "context": "I work at Google."},
|
||||||
|
"tags": ["programming", "work"]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||||
|
"agent_model": "gpt-4o"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
RawObservation fields:
|
||||||
|
content (required): Detailed observation text explaining what you observed
|
||||||
|
subject (required): Consistent identifier like "programming_style", "work_habits"
|
||||||
|
observation_type: belief, preference, behavior, contradiction, general
|
||||||
|
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
|
||||||
|
evidence: Context dict with extra context, e.g. "quote" (exact words) and "context" (situation)
|
||||||
|
tags: List of categorization tags for organization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observations: List of RawObservation objects
|
||||||
|
session_id: UUID to group observations from same conversation
|
||||||
|
agent_model: AI model making observations (for quality tracking)
|
||||||
|
"""
|
||||||
|
tasks = [
|
||||||
|
(
|
||||||
|
observation,
|
||||||
|
celery_app.send_task(
|
||||||
|
SYNC_OBSERVATION,
|
||||||
|
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||||
|
kwargs={
|
||||||
|
"subject": observation.subject,
|
||||||
|
"content": observation.content,
|
||||||
|
"observation_type": observation.observation_type,
|
||||||
|
"confidences": observation.confidences,
|
||||||
|
"evidence": observation.evidence,
|
||||||
|
"tags": observation.tags,
|
||||||
|
"session_id": session_id,
|
||||||
|
"agent_model": agent_model,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for observation in observations
|
||||||
|
]
|
||||||
|
|
||||||
|
def short_content(obs: RawObservation) -> str:
|
||||||
|
if len(obs.content) > 50:
|
||||||
|
return obs.content[:47] + "..."
|
||||||
|
return obs.content
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_ids": {short_content(obs): task.id for obs, task in tasks},
|
||||||
|
"status": "queued",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_observations(
|
||||||
|
query: str,
|
||||||
|
subject: str = "",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
observation_types: list[str] | None = None,
|
||||||
|
min_confidences: dict[str, float] = {},
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Search recorded observations about the user.
|
||||||
|
Use before responding to understand user preferences, patterns, and past insights.
|
||||||
|
Search by meaning - the query matches both content and context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Natural language search query describing what you're looking for
|
||||||
|
subject: Filter by exact subject identifier (empty = search all subjects)
|
||||||
|
tags: Filter by tags (must have at least one matching tag)
|
||||||
|
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
||||||
|
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
||||||
|
limit: Max results (1-100)
|
||||||
|
|
||||||
|
Returns: List with content, tags, created_at, metadata
|
||||||
|
Results sorted by relevance to your query.
|
||||||
|
"""
|
||||||
|
semantic_text = observation.generate_semantic_text(
|
||||||
|
subject=subject or "",
|
||||||
|
observation_type="".join(observation_types or []),
|
||||||
|
content=query,
|
||||||
|
evidence=None,
|
||||||
|
)
|
||||||
|
temporal = observation.generate_temporal_text(
|
||||||
|
subject=subject or "",
|
||||||
|
content=query,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
results = await search(
|
||||||
|
[
|
||||||
|
extract.DataChunk(data=[query]),
|
||||||
|
extract.DataChunk(data=[semantic_text]),
|
||||||
|
extract.DataChunk(data=[temporal]),
|
||||||
|
],
|
||||||
|
previews=True,
|
||||||
|
modalities={"semantic", "temporal"},
|
||||||
|
limit=limit,
|
||||||
|
filters=SearchFilters(
|
||||||
|
subject=subject,
|
||||||
|
min_confidences=min_confidences,
|
||||||
|
tags=tags,
|
||||||
|
observation_types=observation_types,
|
||||||
|
source_ids=filter_observation_source_ids(tags=tags),
|
||||||
|
),
|
||||||
|
timeout=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"content": r.content,
|
||||||
|
"tags": r.tags,
|
||||||
|
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||||
|
"metadata": r.metadata,
|
||||||
|
}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def create_note(
|
||||||
|
subject: str,
|
||||||
|
content: str,
|
||||||
|
filename: str | None = None,
|
||||||
|
note_type: str | None = None,
|
||||||
|
confidences: dict[str, float] = {},
|
||||||
|
tags: list[str] = [],
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Create a note when user asks to save or record something.
|
||||||
|
Use when user explicitly requests noting information for future reference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subject: What the note is about (used for organization)
|
||||||
|
content: Note content as a markdown string
|
||||||
|
filename: Optional path relative to notes folder (e.g., "project/ideas.md")
|
||||||
|
note_type: Optional categorization of the note
|
||||||
|
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
|
||||||
|
tags: Organization tags for filtering and discovery
|
||||||
|
"""
|
||||||
|
if filename:
|
||||||
|
path = pathlib.Path(filename)
|
||||||
|
if not path.is_absolute():
|
||||||
|
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
|
||||||
|
filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix()
|
||||||
|
|
||||||
|
try:
|
||||||
|
task = celery_app.send_task(
|
||||||
|
SYNC_NOTE,
|
||||||
|
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||||
|
kwargs={
|
||||||
|
"subject": subject,
|
||||||
|
"content": content,
|
||||||
|
"filename": filename,
|
||||||
|
"note_type": note_type,
|
||||||
|
"confidences": confidences,
|
||||||
|
"tags": tags,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
logger.error(f"Error creating note: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": "queued",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def note_files(path: str = "/"):
|
||||||
|
"""
|
||||||
|
List note files in the user's note storage.
|
||||||
|
Use to discover existing notes before reading or to help user navigate their collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Directory path to search (e.g., "/", "/projects", "/meetings")
|
||||||
|
Use "/" for root, or subdirectories to narrow scope
|
||||||
|
|
||||||
|
Returns: List of file paths relative to notes directory
|
||||||
|
"""
|
||||||
|
root = settings.NOTES_STORAGE_DIR / path.lstrip("/")
|
||||||
|
return [
|
||||||
|
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
|
||||||
|
for f in root.rglob("*.md")
|
||||||
|
if f.is_file()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
def fetch_file(filename: str) -> dict:
|
||||||
|
"""
|
||||||
|
Read file content with automatic type detection.
|
||||||
|
Returns dict with content, mime_type, is_text, file_size.
|
||||||
|
Text content as string, binary as base64.
|
||||||
|
"""
|
||||||
|
path = settings.FILE_STORAGE_DIR / filename.lstrip("/")
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {filename}")
|
||||||
|
|
||||||
|
mime_type, _ = mimetypes.guess_type(str(path))
|
||||||
|
mime_type = mime_type or "application/octet-stream"
|
||||||
|
|
||||||
|
text_extensions = {
|
||||||
|
".md",
|
||||||
|
".txt",
|
||||||
|
".py",
|
||||||
|
".js",
|
||||||
|
".html",
|
||||||
|
".css",
|
||||||
|
".json",
|
||||||
|
".xml",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
".toml",
|
||||||
|
".ini",
|
||||||
|
".cfg",
|
||||||
|
".conf",
|
||||||
|
}
|
||||||
|
text_mimes = {
|
||||||
|
"application/json",
|
||||||
|
"application/xml",
|
||||||
|
"application/javascript",
|
||||||
|
"application/x-yaml",
|
||||||
|
"application/yaml",
|
||||||
|
}
|
||||||
|
is_text = (
|
||||||
|
mime_type.startswith("text/")
|
||||||
|
or mime_type in text_mimes
|
||||||
|
or path.suffix.lower() in text_extensions
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = (
|
||||||
|
path.read_text(encoding="utf-8")
|
||||||
|
if is_text
|
||||||
|
else __import__("base64").b64encode(path.read_bytes()).decode("ascii")
|
||||||
|
)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
import base64
|
||||||
|
|
||||||
|
content = base64.b64encode(path.read_bytes()).decode("ascii")
|
||||||
|
is_text = False
|
||||||
|
mime_type = (
|
||||||
|
"application/octet-stream" if mime_type.startswith("text/") else mime_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"is_text": is_text,
|
||||||
|
"file_size": path.stat().st_size,
|
||||||
|
"filename": filename,
|
||||||
|
}
|
@ -3,23 +3,15 @@ MCP tools for the epistemic sparring partner system.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import mimetypes
|
|
||||||
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
from pydantic import BaseModel
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import Text, func
|
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
from memory.api.search.search import SearchFilters, search
|
|
||||||
from memory.common import extract, settings
|
|
||||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import AgentObservation, SourceItem
|
from memory.common.db.models import AgentObservation, SourceItem
|
||||||
from memory.common.formatters import observation
|
|
||||||
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -76,377 +68,3 @@ def filter_source_ids(
|
|||||||
async def get_current_time() -> dict:
|
async def get_current_time() -> dict:
|
||||||
"""Get the current time in UTC."""
|
"""Get the current time in UTC."""
|
||||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_all_tags() -> list[str]:
|
|
||||||
"""
|
|
||||||
Get all unique tags used across the entire knowledge base.
|
|
||||||
Returns sorted list of tags from both observations and content.
|
|
||||||
"""
|
|
||||||
with make_session() as session:
|
|
||||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
|
||||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_all_subjects() -> list[str]:
|
|
||||||
"""
|
|
||||||
Get all unique subjects from observations about the user.
|
|
||||||
Returns sorted list of subject identifiers used in observations.
|
|
||||||
"""
|
|
||||||
with make_session() as session:
|
|
||||||
return sorted(
|
|
||||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_all_observation_types() -> list[str]:
|
|
||||||
"""
|
|
||||||
Get all observation types that have been used.
|
|
||||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
|
||||||
"""
|
|
||||||
with make_session() as session:
|
|
||||||
return sorted(
|
|
||||||
{
|
|
||||||
r.observation_type
|
|
||||||
for r in session.query(AgentObservation.observation_type).distinct()
|
|
||||||
if r.observation_type is not None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def search_knowledge_base(
|
|
||||||
query: str,
|
|
||||||
previews: bool = False,
|
|
||||||
modalities: set[str] = set(),
|
|
||||||
tags: list[str] = [],
|
|
||||||
limit: int = 10,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Search user's stored content including emails, documents, articles, books.
|
|
||||||
Use to find specific information the user has saved or received.
|
|
||||||
Combine with search_observations for complete user context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Natural language search query - be descriptive about what you're looking for
|
|
||||||
previews: Include actual content in results - when false only a snippet is returned
|
|
||||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
|
||||||
tags: Filter by tags - content must have at least one matching tag
|
|
||||||
limit: Max results (1-100)
|
|
||||||
|
|
||||||
Returns: List of search results with id, score, chunks, content, filename
|
|
||||||
Higher scores (>0.7) indicate strong matches.
|
|
||||||
"""
|
|
||||||
logger.info(f"MCP search for: {query}")
|
|
||||||
|
|
||||||
if not modalities:
|
|
||||||
modalities = set(ALL_COLLECTIONS.keys())
|
|
||||||
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
|
||||||
|
|
||||||
upload_data = extract.extract_text(query)
|
|
||||||
results = await search(
|
|
||||||
upload_data,
|
|
||||||
previews=previews,
|
|
||||||
modalities=modalities,
|
|
||||||
limit=limit,
|
|
||||||
min_text_score=0.4,
|
|
||||||
min_multimodal_score=0.25,
|
|
||||||
filters=SearchFilters(
|
|
||||||
tags=tags,
|
|
||||||
source_ids=filter_source_ids(tags=tags, modalities=modalities),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return [result.model_dump() for result in results]
|
|
||||||
|
|
||||||
|
|
||||||
class RawObservation(BaseModel):
|
|
||||||
subject: str
|
|
||||||
content: str
|
|
||||||
observation_type: str = "general"
|
|
||||||
confidences: dict[str, float] = {}
|
|
||||||
evidence: dict | None = None
|
|
||||||
tags: list[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def observe(
|
|
||||||
observations: list[RawObservation],
|
|
||||||
session_id: str | None = None,
|
|
||||||
agent_model: str = "unknown",
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Record observations about the user for long-term understanding.
|
|
||||||
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
|
|
||||||
Be specific and detailed - observations should make sense months later.
|
|
||||||
|
|
||||||
Example call:
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"observations": [
|
|
||||||
{
|
|
||||||
"content": "The user is a software engineer.",
|
|
||||||
"subject": "user",
|
|
||||||
"observation_type": "belief",
|
|
||||||
"confidences": {"observation_accuracy": 0.9},
|
|
||||||
"evidence": {"quote": "I am a software engineer.", "context": "I work at Google."},
|
|
||||||
"tags": ["programming", "work"]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
|
||||||
"agent_model": "gpt-4o"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
RawObservation fields:
|
|
||||||
content (required): Detailed observation text explaining what you observed
|
|
||||||
subject (required): Consistent identifier like "programming_style", "work_habits"
|
|
||||||
observation_type: belief, preference, behavior, contradiction, general
|
|
||||||
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
|
|
||||||
evidence: Context dict with extra context, e.g. "quote" (exact words) and "context" (situation)
|
|
||||||
tags: List of categorization tags for organization
|
|
||||||
|
|
||||||
Args:
|
|
||||||
observations: List of RawObservation objects
|
|
||||||
session_id: UUID to group observations from same conversation
|
|
||||||
agent_model: AI model making observations (for quality tracking)
|
|
||||||
"""
|
|
||||||
tasks = [
|
|
||||||
(
|
|
||||||
observation,
|
|
||||||
celery_app.send_task(
|
|
||||||
SYNC_OBSERVATION,
|
|
||||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
|
||||||
kwargs={
|
|
||||||
"subject": observation.subject,
|
|
||||||
"content": observation.content,
|
|
||||||
"observation_type": observation.observation_type,
|
|
||||||
"confidences": observation.confidences,
|
|
||||||
"evidence": observation.evidence,
|
|
||||||
"tags": observation.tags,
|
|
||||||
"session_id": session_id,
|
|
||||||
"agent_model": agent_model,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for observation in observations
|
|
||||||
]
|
|
||||||
|
|
||||||
def short_content(obs: RawObservation) -> str:
|
|
||||||
if len(obs.content) > 50:
|
|
||||||
return obs.content[:47] + "..."
|
|
||||||
return obs.content
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_ids": {short_content(obs): task.id for obs, task in tasks},
|
|
||||||
"status": "queued",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def search_observations(
|
|
||||||
query: str,
|
|
||||||
subject: str = "",
|
|
||||||
tags: list[str] | None = None,
|
|
||||||
observation_types: list[str] | None = None,
|
|
||||||
min_confidences: dict[str, float] = {},
|
|
||||||
limit: int = 10,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Search recorded observations about the user.
|
|
||||||
Use before responding to understand user preferences, patterns, and past insights.
|
|
||||||
Search by meaning - the query matches both content and context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Natural language search query describing what you're looking for
|
|
||||||
subject: Filter by exact subject identifier (empty = search all subjects)
|
|
||||||
tags: Filter by tags (must have at least one matching tag)
|
|
||||||
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
|
||||||
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
|
||||||
limit: Max results (1-100)
|
|
||||||
|
|
||||||
Returns: List with content, tags, created_at, metadata
|
|
||||||
Results sorted by relevance to your query.
|
|
||||||
"""
|
|
||||||
semantic_text = observation.generate_semantic_text(
|
|
||||||
subject=subject or "",
|
|
||||||
observation_type="".join(observation_types or []),
|
|
||||||
content=query,
|
|
||||||
evidence=None,
|
|
||||||
)
|
|
||||||
temporal = observation.generate_temporal_text(
|
|
||||||
subject=subject or "",
|
|
||||||
content=query,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
results = await search(
|
|
||||||
[
|
|
||||||
extract.DataChunk(data=[query]),
|
|
||||||
extract.DataChunk(data=[semantic_text]),
|
|
||||||
extract.DataChunk(data=[temporal]),
|
|
||||||
],
|
|
||||||
previews=True,
|
|
||||||
modalities={"semantic", "temporal"},
|
|
||||||
limit=limit,
|
|
||||||
filters=SearchFilters(
|
|
||||||
subject=subject,
|
|
||||||
min_confidences=min_confidences,
|
|
||||||
tags=tags,
|
|
||||||
observation_types=observation_types,
|
|
||||||
source_ids=filter_observation_source_ids(tags=tags),
|
|
||||||
),
|
|
||||||
timeout=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"content": r.content,
|
|
||||||
"tags": r.tags,
|
|
||||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
|
||||||
"metadata": r.metadata,
|
|
||||||
}
|
|
||||||
for r in results
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def create_note(
|
|
||||||
subject: str,
|
|
||||||
content: str,
|
|
||||||
filename: str | None = None,
|
|
||||||
note_type: str | None = None,
|
|
||||||
confidences: dict[str, float] = {},
|
|
||||||
tags: list[str] = [],
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Create a note when user asks to save or record something.
|
|
||||||
Use when user explicitly requests noting information for future reference.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subject: What the note is about (used for organization)
|
|
||||||
content: Note content as a markdown string
|
|
||||||
filename: Optional path relative to notes folder (e.g., "project/ideas.md")
|
|
||||||
note_type: Optional categorization of the note
|
|
||||||
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
|
|
||||||
tags: Organization tags for filtering and discovery
|
|
||||||
"""
|
|
||||||
if filename:
|
|
||||||
path = pathlib.Path(filename)
|
|
||||||
if not path.is_absolute():
|
|
||||||
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
|
|
||||||
filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix()
|
|
||||||
|
|
||||||
try:
|
|
||||||
task = celery_app.send_task(
|
|
||||||
SYNC_NOTE,
|
|
||||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
|
||||||
kwargs={
|
|
||||||
"subject": subject,
|
|
||||||
"content": content,
|
|
||||||
"filename": filename,
|
|
||||||
"note_type": note_type,
|
|
||||||
"confidences": confidences,
|
|
||||||
"tags": tags,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
logger.error(f"Error creating note: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_id": task.id,
|
|
||||||
"status": "queued",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def note_files(path: str = "/"):
|
|
||||||
"""
|
|
||||||
List note files in the user's note storage.
|
|
||||||
Use to discover existing notes before reading or to help user navigate their collection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Directory path to search (e.g., "/", "/projects", "/meetings")
|
|
||||||
Use "/" for root, or subdirectories to narrow scope
|
|
||||||
|
|
||||||
Returns: List of file paths relative to notes directory
|
|
||||||
"""
|
|
||||||
root = settings.NOTES_STORAGE_DIR / path.lstrip("/")
|
|
||||||
return [
|
|
||||||
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
|
|
||||||
for f in root.rglob("*.md")
|
|
||||||
if f.is_file()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
def fetch_file(filename: str) -> dict:
|
|
||||||
"""
|
|
||||||
Read file content with automatic type detection.
|
|
||||||
Returns dict with content, mime_type, is_text, file_size.
|
|
||||||
Text content as string, binary as base64.
|
|
||||||
"""
|
|
||||||
path = settings.FILE_STORAGE_DIR / filename.lstrip("/")
|
|
||||||
if not path.exists():
|
|
||||||
raise FileNotFoundError(f"File not found: {filename}")
|
|
||||||
|
|
||||||
mime_type, _ = mimetypes.guess_type(str(path))
|
|
||||||
mime_type = mime_type or "application/octet-stream"
|
|
||||||
|
|
||||||
text_extensions = {
|
|
||||||
".md",
|
|
||||||
".txt",
|
|
||||||
".py",
|
|
||||||
".js",
|
|
||||||
".html",
|
|
||||||
".css",
|
|
||||||
".json",
|
|
||||||
".xml",
|
|
||||||
".yaml",
|
|
||||||
".yml",
|
|
||||||
".toml",
|
|
||||||
".ini",
|
|
||||||
".cfg",
|
|
||||||
".conf",
|
|
||||||
}
|
|
||||||
text_mimes = {
|
|
||||||
"application/json",
|
|
||||||
"application/xml",
|
|
||||||
"application/javascript",
|
|
||||||
"application/x-yaml",
|
|
||||||
"application/yaml",
|
|
||||||
}
|
|
||||||
is_text = (
|
|
||||||
mime_type.startswith("text/")
|
|
||||||
or mime_type in text_mimes
|
|
||||||
or path.suffix.lower() in text_extensions
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
content = (
|
|
||||||
path.read_text(encoding="utf-8")
|
|
||||||
if is_text
|
|
||||||
else __import__("base64").b64encode(path.read_bytes()).decode("ascii")
|
|
||||||
)
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
import base64
|
|
||||||
|
|
||||||
content = base64.b64encode(path.read_bytes()).decode("ascii")
|
|
||||||
is_text = False
|
|
||||||
mime_type = (
|
|
||||||
"application/octet-stream" if mime_type.startswith("text/") else mime_type
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": content,
|
|
||||||
"mime_type": mime_type,
|
|
||||||
"is_text": is_text,
|
|
||||||
"file_size": path.stat().st_size,
|
|
||||||
"filename": filename,
|
|
||||||
}
|
|
||||||
|
@ -195,6 +195,9 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
if settings.DISABLE_AUTH:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
|
|
||||||
# Skip authentication for whitelisted endpoints
|
# Skip authentication for whitelisted endpoints
|
||||||
|
@ -83,5 +83,7 @@ app.conf.update(
|
|||||||
@app.on_after_configure.connect # type: ignore[attr-defined]
|
@app.on_after_configure.connect # type: ignore[attr-defined]
|
||||||
def ensure_qdrant_initialised(sender, **_):
|
def ensure_qdrant_initialised(sender, **_):
|
||||||
from memory.common import qdrant
|
from memory.common import qdrant
|
||||||
|
from memory.common.discord import load_servers
|
||||||
|
|
||||||
qdrant.setup_qdrant()
|
qdrant.setup_qdrant()
|
||||||
|
load_servers()
|
||||||
|
184
src/memory/common/discord.py
Normal file
184
src/memory/common/discord.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ERROR_CHANNEL = "memory-errors"
|
||||||
|
ACTIVITY_CHANNEL = "memory-activity"
|
||||||
|
DISCOVERY_CHANNEL = "memory-discoveries"
|
||||||
|
CHAT_CHANNEL = "memory-chat"
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordServer(requests.Session):
|
||||||
|
def __init__(self, server_id: str, server_name: str, *args, **kwargs):
|
||||||
|
self.server_id = server_id
|
||||||
|
self.server_name = server_name
|
||||||
|
self.channels = {}
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.setup_channels()
|
||||||
|
|
||||||
|
def setup_channels(self):
|
||||||
|
resp = self.get(self.channels_url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
channels = {channel["name"]: channel["id"] for channel in resp.json()}
|
||||||
|
|
||||||
|
if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)):
|
||||||
|
error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL)
|
||||||
|
self.channels[ERROR_CHANNEL] = error_channel
|
||||||
|
|
||||||
|
if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)):
|
||||||
|
activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL)
|
||||||
|
self.channels[ACTIVITY_CHANNEL] = activity_channel
|
||||||
|
|
||||||
|
if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)):
|
||||||
|
discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL)
|
||||||
|
self.channels[DISCOVERY_CHANNEL] = discovery_channel
|
||||||
|
|
||||||
|
if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)):
|
||||||
|
chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL)
|
||||||
|
self.channels[CHAT_CHANNEL] = chat_channel
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_channel(self) -> str:
|
||||||
|
return self.channels[ERROR_CHANNEL]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activity_channel(self) -> str:
|
||||||
|
return self.channels[ACTIVITY_CHANNEL]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def discovery_channel(self) -> str:
|
||||||
|
return self.channels[DISCOVERY_CHANNEL]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_channel(self) -> str:
|
||||||
|
return self.channels[CHAT_CHANNEL]
|
||||||
|
|
||||||
|
def channel_id(self, channel_name: str) -> str:
|
||||||
|
if not (channel_id := self.channels.get(channel_name)):
|
||||||
|
raise ValueError(f"Channel {channel_name} not found")
|
||||||
|
return channel_id
|
||||||
|
|
||||||
|
def send_message(self, channel_id: str, content: str):
|
||||||
|
self.post(
|
||||||
|
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
||||||
|
json={"content": content},
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
|
||||||
|
resp = self.post(
|
||||||
|
self.channels_url, json={"name": channel_name, "type": channel_type}
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()["id"]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (
|
||||||
|
f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def request(self, method: str, url: str, **kwargs):
|
||||||
|
headers = kwargs.get("headers", {})
|
||||||
|
headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}"
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
kwargs["headers"] = headers
|
||||||
|
return super().request(method, url, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels_url(self) -> str:
|
||||||
|
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
|
||||||
|
|
||||||
|
|
||||||
|
def get_bot_servers() -> List[Dict]:
|
||||||
|
"""Get list of servers the bot is in."""
|
||||||
|
if not settings.DISCORD_BOT_TOKEN:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"}
|
||||||
|
response = requests.get(
|
||||||
|
"https://discord.com/api/v10/users/@me/guilds", headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get bot servers: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
servers: dict[str, DiscordServer] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def load_servers():
|
||||||
|
for server in get_bot_servers():
|
||||||
|
servers[server["id"]] = DiscordServer(server["id"], server["name"])
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_message(channel: str, message: str):
|
||||||
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
|
return
|
||||||
|
|
||||||
|
for server in servers.values():
|
||||||
|
server.send_message(server.channel_id(channel), message)
|
||||||
|
|
||||||
|
|
||||||
|
def send_error_message(message: str):
|
||||||
|
broadcast_message(ERROR_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
|
def send_activity_message(message: str):
|
||||||
|
broadcast_message(ACTIVITY_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
|
def send_discovery_message(message: str):
|
||||||
|
broadcast_message(DISCOVERY_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
|
def send_chat_message(message: str):
|
||||||
|
broadcast_message(CHAT_CHANNEL, message)
|
||||||
|
|
||||||
|
|
||||||
|
def notify_task_failure(
|
||||||
|
task_name: str,
|
||||||
|
error_message: str,
|
||||||
|
task_args: tuple = (),
|
||||||
|
task_kwargs: dict[str, Any] | None = None,
|
||||||
|
traceback_str: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Send a task failure notification to Discord.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_name: Name of the failed task
|
||||||
|
error_message: Error message
|
||||||
|
task_args: Task arguments
|
||||||
|
task_kwargs: Task keyword arguments
|
||||||
|
traceback_str: Full traceback string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if notification sent successfully
|
||||||
|
"""
|
||||||
|
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||||
|
logger.debug("Discord notifications disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
message = f"🚨 **Task Failed: {task_name}**\n\n"
|
||||||
|
message += f"**Error:** {error_message[:500]}\n"
|
||||||
|
|
||||||
|
if task_args:
|
||||||
|
message += f"**Args:** `{str(task_args)[:200]}`\n"
|
||||||
|
|
||||||
|
if task_kwargs:
|
||||||
|
message += f"**Kwargs:** `{str(task_kwargs)[:200]}`\n"
|
||||||
|
|
||||||
|
if traceback_str:
|
||||||
|
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
|
||||||
|
|
||||||
|
try:
|
||||||
|
send_error_message(message)
|
||||||
|
logger.info(f"Discord error notification sent for task: {task_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send Discord notification: {e}")
|
@ -55,7 +55,7 @@ def generate_temporal_text(
|
|||||||
f"Observation: {content}",
|
f"Observation: {content}",
|
||||||
]
|
]
|
||||||
|
|
||||||
return " | ".join(parts)
|
return " | ".join(parts).strip()
|
||||||
|
|
||||||
|
|
||||||
# TODO: Add more embedding dimensions here:
|
# TODO: Add more embedding dimensions here:
|
||||||
|
@ -138,3 +138,17 @@ SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 *
|
|||||||
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
||||||
|
|
||||||
REGISTER_ENABLED = boolean_env("REGISTER_ENABLED", False) or True
|
REGISTER_ENABLED = boolean_env("REGISTER_ENABLED", False) or True
|
||||||
|
DISABLE_AUTH = boolean_env("DISABLE_AUTH", False)
|
||||||
|
|
||||||
|
# Discord notification settings
|
||||||
|
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||||
|
DISCORD_ERROR_CHANNEL = os.getenv("DISCORD_ERROR_CHANNEL", "memory-errors")
|
||||||
|
DISCORD_ACTIVITY_CHANNEL = os.getenv("DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
||||||
|
DISCORD_DISCOVERY_CHANNEL = os.getenv("DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
||||||
|
DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||||
|
|
||||||
|
|
||||||
|
# Enable Discord notifications if bot token is set
|
||||||
|
DISCORD_NOTIFICATIONS_ENABLED = (
|
||||||
|
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||||
|
)
|
||||||
|
@ -47,9 +47,20 @@ Text to summarize:
|
|||||||
|
|
||||||
def parse_response(response: str) -> dict[str, Any]:
|
def parse_response(response: str) -> dict[str, Any]:
|
||||||
"""Parse the response from the summarizer."""
|
"""Parse the response from the summarizer."""
|
||||||
soup = BeautifulSoup(response, "xml")
|
if not response or not response.strip():
|
||||||
summary = soup.find("summary").text
|
return {"summary": "", "tags": []}
|
||||||
tags = [tag.text for tag in soup.find_all("tag")]
|
|
||||||
|
# Use html.parser instead of xml parser for better compatibility
|
||||||
|
soup = BeautifulSoup(response, "html.parser")
|
||||||
|
|
||||||
|
# Safely extract summary
|
||||||
|
summary_element = soup.find("summary")
|
||||||
|
summary = summary_element.text if summary_element else ""
|
||||||
|
|
||||||
|
# Safely extract tags
|
||||||
|
tag_elements = soup.find_all("tag")
|
||||||
|
tags = [tag.text for tag in tag_elements if tag.text is not None]
|
||||||
|
|
||||||
return {"summary": summary, "tags": tags}
|
return {"summary": summary, "tags": tags}
|
||||||
|
|
||||||
|
|
||||||
@ -68,7 +79,6 @@ def _call_openai(prompt: str) -> dict[str, Any]:
|
|||||||
},
|
},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
],
|
],
|
||||||
response_format={"type": "json_object"},
|
|
||||||
temperature=0.3,
|
temperature=0.3,
|
||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
)
|
)
|
||||||
|
@ -12,8 +12,9 @@ import traceback
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Iterable, Sequence, cast
|
from typing import Any, Callable, Iterable, Sequence, cast
|
||||||
|
|
||||||
from memory.common import embedding, qdrant
|
from memory.common import embedding, qdrant, settings
|
||||||
from memory.common.db.models import SourceItem, Chunk
|
from memory.common.db.models import SourceItem, Chunk
|
||||||
|
from memory.common.discord import notify_task_failure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -274,7 +275,17 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Task {func.__name__} failed: {e}")
|
logger.error(f"Task {func.__name__} failed: {e}")
|
||||||
logger.error(traceback.format_exc())
|
traceback_str = traceback.format_exc()
|
||||||
|
logger.error(traceback_str)
|
||||||
|
|
||||||
|
notify_task_failure(
|
||||||
|
task_name=func.__name__,
|
||||||
|
error_message=str(e),
|
||||||
|
task_args=args,
|
||||||
|
task_kwargs=kwargs,
|
||||||
|
traceback_str=traceback_str,
|
||||||
|
)
|
||||||
|
|
||||||
return {"status": "error", "error": str(e)}
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -213,12 +213,15 @@ def mock_file_storage(tmp_path: Path):
|
|||||||
email_storage_dir.mkdir(parents=True, exist_ok=True)
|
email_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||||
notes_storage_dir = tmp_path / "notes"
|
notes_storage_dir = tmp_path / "notes"
|
||||||
notes_storage_dir.mkdir(parents=True, exist_ok=True)
|
notes_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
comic_storage_dir = tmp_path / "comics"
|
||||||
|
comic_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||||
with (
|
with (
|
||||||
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
|
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
|
||||||
patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir),
|
patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir),
|
||||||
patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir),
|
patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir),
|
||||||
patch.object(settings, "EMAIL_STORAGE_DIR", email_storage_dir),
|
patch.object(settings, "EMAIL_STORAGE_DIR", email_storage_dir),
|
||||||
patch.object(settings, "NOTES_STORAGE_DIR", notes_storage_dir),
|
patch.object(settings, "NOTES_STORAGE_DIR", notes_storage_dir),
|
||||||
|
patch.object(settings, "COMIC_STORAGE_DIR", comic_storage_dir),
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -256,7 +259,7 @@ def mock_openai_client():
|
|||||||
choices=[
|
choices=[
|
||||||
Mock(
|
Mock(
|
||||||
message=Mock(
|
message=Mock(
|
||||||
content='{"summary": "test summary", "tags": ["tag1", "tag2"]}'
|
content="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -273,8 +276,16 @@ def mock_anthropic_client():
|
|||||||
client.messages.create = Mock(
|
client.messages.create = Mock(
|
||||||
return_value=Mock(
|
return_value=Mock(
|
||||||
content=[
|
content=[
|
||||||
Mock(text='{"summary": "test summary", "tags": ["tag1", "tag2"]}')
|
Mock(
|
||||||
|
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
|
||||||
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_discord_client():
|
||||||
|
with patch.object(settings, "DISCORD_NOTIFICATIONS_ENABLED", False):
|
||||||
|
yield
|
||||||
|
393
tests/memory/common/test_discord.py
Normal file
393
tests/memory/common/test_discord.py
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
import logging
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
from memory.common import discord, settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session_request():
|
||||||
|
with patch("requests.Session.request") as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_channels_response():
|
||||||
|
return [
|
||||||
|
{"name": "memory-errors", "id": "error_channel_id"},
|
||||||
|
{"name": "memory-activity", "id": "activity_channel_id"},
|
||||||
|
{"name": "memory-discoveries", "id": "discovery_channel_id"},
|
||||||
|
{"name": "memory-chat", "id": "chat_channel_id"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_discord_server_init(mock_session_request, mock_get_channels_response):
|
||||||
|
# Mock the channels API call
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = mock_get_channels_response
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_session_request.return_value = mock_response
|
||||||
|
|
||||||
|
server = discord.DiscordServer("server123", "Test Server")
|
||||||
|
|
||||||
|
assert server.server_id == "server123"
|
||||||
|
assert server.server_name == "Test Server"
|
||||||
|
assert hasattr(server, "channels")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors")
|
||||||
|
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
||||||
|
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
||||||
|
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||||
|
def test_setup_channels_existing(mock_session_request, mock_get_channels_response):
|
||||||
|
# Mock the channels API call
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = mock_get_channels_response
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_session_request.return_value = mock_response
|
||||||
|
|
||||||
|
server = discord.DiscordServer("server123", "Test Server")
|
||||||
|
|
||||||
|
assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id"
|
||||||
|
assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id"
|
||||||
|
assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id"
|
||||||
|
assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel")
|
||||||
|
def test_setup_channels_create_missing(mock_session_request):
|
||||||
|
# Mock get channels (empty) and create channel calls
|
||||||
|
get_response = Mock()
|
||||||
|
get_response.json.return_value = []
|
||||||
|
get_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
create_response = Mock()
|
||||||
|
create_response.json.return_value = {"id": "new_channel_id"}
|
||||||
|
create_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
mock_session_request.side_effect = [
|
||||||
|
get_response,
|
||||||
|
create_response,
|
||||||
|
create_response,
|
||||||
|
create_response,
|
||||||
|
create_response,
|
||||||
|
]
|
||||||
|
|
||||||
|
server = discord.DiscordServer("server123", "Test Server")
|
||||||
|
|
||||||
|
assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_properties():
|
||||||
|
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||||
|
server.channels = {
|
||||||
|
discord.ERROR_CHANNEL: "error_id",
|
||||||
|
discord.ACTIVITY_CHANNEL: "activity_id",
|
||||||
|
discord.DISCOVERY_CHANNEL: "discovery_id",
|
||||||
|
discord.CHAT_CHANNEL: "chat_id",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert server.error_channel == "error_id"
|
||||||
|
assert server.activity_channel == "activity_id"
|
||||||
|
assert server.discovery_channel == "discovery_id"
|
||||||
|
assert server.chat_channel == "chat_id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_id_exists():
|
||||||
|
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||||
|
server.channels = {"test-channel": "channel123"}
|
||||||
|
|
||||||
|
assert server.channel_id("test-channel") == "channel123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_id_not_found():
|
||||||
|
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||||
|
server.channels = {}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Channel nonexistent not found"):
|
||||||
|
server.channel_id("nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_message(mock_session_request):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_session_request.return_value = mock_response
|
||||||
|
|
||||||
|
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||||
|
|
||||||
|
server.send_message("channel123", "Hello World")
|
||||||
|
|
||||||
|
mock_session_request.assert_called_with(
|
||||||
|
"POST",
|
||||||
|
"https://discord.com/api/v10/channels/channel123/messages",
|
||||||
|
data=None,
|
||||||
|
json={"content": "Hello World"},
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_channel(mock_session_request):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"id": "new_channel_id"}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_session_request.return_value = mock_response
|
||||||
|
|
||||||
|
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||||
|
server.server_id = "server123"
|
||||||
|
|
||||||
|
channel_id = server.create_channel("new-channel")
|
||||||
|
|
||||||
|
assert channel_id == "new_channel_id"
|
||||||
|
mock_session_request.assert_called_with(
|
||||||
|
"POST",
|
||||||
|
"https://discord.com/api/v10/guilds/server123/channels",
|
||||||
|
data=None,
|
||||||
|
json={"name": "new-channel", "type": 0},
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_channel_custom_type(mock_session_request):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"id": "voice_channel_id"}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_session_request.return_value = mock_response
|
||||||
|
|
||||||
|
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||||
|
server.server_id = "server123"
|
||||||
|
|
||||||
|
channel_id = server.create_channel("voice-channel", channel_type=2)
|
||||||
|
|
||||||
|
assert channel_id == "voice_channel_id"
|
||||||
|
mock_session_request.assert_called_with(
|
||||||
|
"POST",
|
||||||
|
"https://discord.com/api/v10/guilds/server123/channels",
|
||||||
|
data=None,
|
||||||
|
json={"name": "voice-channel", "type": 2},
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_representation():
|
||||||
|
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||||
|
server.server_id = "server123"
|
||||||
|
server.server_name = "Test Server"
|
||||||
|
|
||||||
|
assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123")
|
||||||
|
def test_request_adds_headers(mock_session_request):
|
||||||
|
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||||
|
|
||||||
|
server.request("GET", "https://example.com", headers={"Custom": "header"})
|
||||||
|
|
||||||
|
expected_headers = {
|
||||||
|
"Custom": "header",
|
||||||
|
"Authorization": "Bot test_token_123",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
mock_session_request.assert_called_once_with(
|
||||||
|
"GET", "https://example.com", headers=expected_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_channels_url():
|
||||||
|
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||||
|
server.server_id = "server123"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
server.channels_url == "https://discord.com/api/v10/guilds/server123/channels"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_get_bot_servers_success(mock_get):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{"id": "server1", "name": "Server 1"},
|
||||||
|
{"id": "server2", "name": "Server 2"},
|
||||||
|
]
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
servers = discord.get_bot_servers()
|
||||||
|
|
||||||
|
assert len(servers) == 2
|
||||||
|
assert servers[0] == {"id": "server1", "name": "Server 1"}
|
||||||
|
mock_get.assert_called_once_with(
|
||||||
|
"https://discord.com/api/v10/users/@me/guilds",
|
||||||
|
headers={"Authorization": "Bot test_token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
||||||
|
def test_get_bot_servers_no_token():
|
||||||
|
assert discord.get_bot_servers() == []
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_get_bot_servers_exception(mock_get):
|
||||||
|
mock_get.side_effect = requests.RequestException("API Error")
|
||||||
|
|
||||||
|
servers = discord.get_bot_servers()
|
||||||
|
|
||||||
|
assert servers == []
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.get_bot_servers")
|
||||||
|
@patch("memory.common.discord.DiscordServer")
|
||||||
|
def test_load_servers(mock_discord_server_class, mock_get_servers):
|
||||||
|
mock_get_servers.return_value = [
|
||||||
|
{"id": "server1", "name": "Server 1"},
|
||||||
|
{"id": "server2", "name": "Server 2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
discord.load_servers()
|
||||||
|
|
||||||
|
assert mock_discord_server_class.call_count == 2
|
||||||
|
mock_discord_server_class.assert_any_call("server1", "Server 1")
|
||||||
|
mock_discord_server_class.assert_any_call("server2", "Server 2")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
def test_broadcast_message():
|
||||||
|
mock_server1 = Mock()
|
||||||
|
mock_server2 = Mock()
|
||||||
|
discord.servers = {"1": mock_server1, "2": mock_server2}
|
||||||
|
|
||||||
|
discord.broadcast_message("test-channel", "Hello")
|
||||||
|
|
||||||
|
mock_server1.send_message.assert_called_once_with(
|
||||||
|
mock_server1.channel_id.return_value, "Hello"
|
||||||
|
)
|
||||||
|
mock_server2.send_message.assert_called_once_with(
|
||||||
|
mock_server2.channel_id.return_value, "Hello"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
|
def test_broadcast_message_disabled():
|
||||||
|
mock_server = Mock()
|
||||||
|
discord.servers = {"1": mock_server}
|
||||||
|
|
||||||
|
discord.broadcast_message("test-channel", "Hello")
|
||||||
|
|
||||||
|
mock_server.send_message.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
def test_send_error_message(mock_broadcast):
|
||||||
|
discord.send_error_message("Error occurred")
|
||||||
|
mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
def test_send_activity_message(mock_broadcast):
|
||||||
|
discord.send_activity_message("Activity update")
|
||||||
|
mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
def test_send_discovery_message(mock_broadcast):
|
||||||
|
discord.send_discovery_message("Discovery made")
|
||||||
|
mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.discord.broadcast_message")
|
||||||
|
def test_send_chat_message(mock_broadcast):
|
||||||
|
discord.send_chat_message("Chat message")
|
||||||
|
mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
def test_notify_task_failure_basic(mock_send_error):
|
||||||
|
discord.notify_task_failure("test_task", "Something went wrong")
|
||||||
|
|
||||||
|
mock_send_error.assert_called_once()
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
assert "🚨 **Task Failed: test_task**" in message
|
||||||
|
assert "**Error:** Something went wrong" in message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
def test_notify_task_failure_with_args(mock_send_error):
|
||||||
|
discord.notify_task_failure(
|
||||||
|
"test_task",
|
||||||
|
"Error message",
|
||||||
|
task_args=("arg1", "arg2"),
|
||||||
|
task_kwargs={"key": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
|
||||||
|
assert "**Args:** `('arg1', 'arg2')`" in message
|
||||||
|
assert "**Kwargs:** `{'key': 'value'}`" in message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
def test_notify_task_failure_with_traceback(mock_send_error):
|
||||||
|
traceback = "Traceback (most recent call last):\n File ...\nError: Something"
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", "Error message", traceback_str=traceback)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
assert "**Traceback:**" in message
|
||||||
|
assert traceback in message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||||
|
long_error = "x" * 600 # Longer than 500 char limit
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", long_error)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
assert long_error[:500] in message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||||
|
long_traceback = "x" * 1000 # Longer than 800 char limit
|
||||||
|
|
||||||
|
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||||
|
|
||||||
|
message = mock_send_error.call_args[0][0]
|
||||||
|
assert long_traceback[-800:] in message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
def test_notify_task_failure_disabled(mock_send_error):
|
||||||
|
discord.notify_task_failure("test_task", "Error message")
|
||||||
|
mock_send_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||||
|
@patch("memory.common.discord.send_error_message")
|
||||||
|
def test_notify_task_failure_send_fails(mock_send_error):
|
||||||
|
mock_send_error.side_effect = Exception("Discord API error")
|
||||||
|
|
||||||
|
# Should not raise, just log the error
|
||||||
|
discord.notify_task_failure("test_task", "Error message")
|
||||||
|
|
||||||
|
mock_send_error.assert_called_once()
|
151
tests/memory/common/test_summarizer.py
Normal file
151
tests/memory/common/test_summarizer.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
from memory.common import summarizer
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"response, expected",
|
||||||
|
(
|
||||||
|
# Basic valid cases
|
||||||
|
("", {"summary": "", "tags": []}),
|
||||||
|
(
|
||||||
|
"<summary>test</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||||
|
{"summary": "test", "tags": ["tag1", "tag2"]},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"<summary>test</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||||
|
{"summary": "test", "tags": ["tag1", "tag2"]},
|
||||||
|
),
|
||||||
|
# Missing summary tag
|
||||||
|
(
|
||||||
|
"<tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||||
|
{"summary": "", "tags": ["tag1", "tag2"]},
|
||||||
|
),
|
||||||
|
# Missing tags section
|
||||||
|
(
|
||||||
|
"<summary>test summary</summary>",
|
||||||
|
{"summary": "test summary", "tags": []},
|
||||||
|
),
|
||||||
|
# Empty summary tag
|
||||||
|
(
|
||||||
|
"<summary></summary><tags><tag>tag1</tag></tags>",
|
||||||
|
{"summary": "", "tags": ["tag1"]},
|
||||||
|
),
|
||||||
|
# Empty tags section
|
||||||
|
(
|
||||||
|
"<summary>test</summary><tags></tags>",
|
||||||
|
{"summary": "test", "tags": []},
|
||||||
|
),
|
||||||
|
# Single tag
|
||||||
|
(
|
||||||
|
"<summary>test</summary><tags><tag>single-tag</tag></tags>",
|
||||||
|
{"summary": "test", "tags": ["single-tag"]},
|
||||||
|
),
|
||||||
|
# Multiple tags
|
||||||
|
(
|
||||||
|
"<summary>test</summary><tags><tag>tag1</tag><tag>tag2</tag><tag>tag3</tag><tag>tag4</tag><tag>tag5</tag></tags>",
|
||||||
|
{"summary": "test", "tags": ["tag1", "tag2", "tag3", "tag4", "tag5"]},
|
||||||
|
),
|
||||||
|
# Tags with special characters and hyphens
|
||||||
|
(
|
||||||
|
"<summary>test</summary><tags><tag>machine-learning</tag><tag>ai/ml</tag><tag>data_science</tag></tags>",
|
||||||
|
{"summary": "test", "tags": ["machine-learning", "ai/ml", "data_science"]},
|
||||||
|
),
|
||||||
|
# Summary with special characters
|
||||||
|
(
|
||||||
|
"<summary>Test with & special characters <></summary><tags><tag>test</tag></tags>",
|
||||||
|
{"summary": "Test with & special characters <>", "tags": ["test"]},
|
||||||
|
),
|
||||||
|
# Whitespace handling
|
||||||
|
(
|
||||||
|
"<summary> test with spaces </summary><tags><tag> tag1 </tag><tag>tag2</tag></tags>",
|
||||||
|
{"summary": " test with spaces ", "tags": [" tag1 ", "tag2"]},
|
||||||
|
),
|
||||||
|
# Mixed case XML tags (should still work with BeautifulSoup)
|
||||||
|
(
|
||||||
|
"<Summary>test</Summary><Tags><Tag>tag1</Tag></Tags>",
|
||||||
|
{"summary": "test", "tags": ["tag1"]},
|
||||||
|
),
|
||||||
|
# Multiple summary tags (should take first one)
|
||||||
|
(
|
||||||
|
"<summary>first</summary><summary>second</summary><tags><tag>tag1</tag></tags>",
|
||||||
|
{"summary": "first", "tags": ["tag1"]},
|
||||||
|
),
|
||||||
|
# Multiple tags sections (should collect all tags)
|
||||||
|
(
|
||||||
|
"<summary>test</summary><tags><tag>tag1</tag></tags><tags><tag>tag2</tag></tags>",
|
||||||
|
{"summary": "test", "tags": ["tag1", "tag2"]},
|
||||||
|
),
|
||||||
|
# XML with extra elements (should ignore them)
|
||||||
|
(
|
||||||
|
"<root><summary>test</summary><other>ignored</other><tags><tag>tag1</tag></tags></root>",
|
||||||
|
{"summary": "test", "tags": ["tag1"]},
|
||||||
|
),
|
||||||
|
# XML with attributes (should still work)
|
||||||
|
(
|
||||||
|
'<summary id="1">test</summary><tags type="keywords"><tag>tag1</tag></tags>',
|
||||||
|
{"summary": "test", "tags": ["tag1"]},
|
||||||
|
),
|
||||||
|
# Empty tag elements
|
||||||
|
(
|
||||||
|
"<summary>test</summary><tags><tag></tag><tag>valid-tag</tag><tag></tag></tags>",
|
||||||
|
{"summary": "test", "tags": ["", "valid-tag", ""]},
|
||||||
|
),
|
||||||
|
# Self-closing tags
|
||||||
|
(
|
||||||
|
"<summary>test</summary><tags><tag/><tag>valid-tag</tag></tags>",
|
||||||
|
{"summary": "test", "tags": ["", "valid-tag"]},
|
||||||
|
),
|
||||||
|
# Long content
|
||||||
|
(
|
||||||
|
f"<summary>{'a' * 1000}</summary><tags><tag>long-content</tag></tags>",
|
||||||
|
{"summary": "a" * 1000, "tags": ["long-content"]},
|
||||||
|
),
|
||||||
|
# XML with newlines and formatting
|
||||||
|
(
|
||||||
|
"""<summary>
|
||||||
|
Multi-line
|
||||||
|
summary content
|
||||||
|
</summary>
|
||||||
|
<tags>
|
||||||
|
<tag>formatted</tag>
|
||||||
|
<tag>xml</tag>
|
||||||
|
</tags>""",
|
||||||
|
{
|
||||||
|
"summary": "\n Multi-line\n summary content\n ",
|
||||||
|
"tags": ["formatted", "xml"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Malformed XML (missing closing tags) - BeautifulSoup parses as best it can
|
||||||
|
(
|
||||||
|
"<summary>test<tags><tag>tag1</tag></tags>",
|
||||||
|
{"summary": "testtag1", "tags": ["tag1"]},
|
||||||
|
),
|
||||||
|
# Invalid XML characters should be handled by BeautifulSoup
|
||||||
|
(
|
||||||
|
"<summary>test & unescaped</summary><tags><tag>tag1</tag></tags>",
|
||||||
|
{"summary": "test & unescaped", "tags": ["tag1"]},
|
||||||
|
),
|
||||||
|
# Only whitespace
|
||||||
|
(
|
||||||
|
" \n\t ",
|
||||||
|
{"summary": "", "tags": []},
|
||||||
|
),
|
||||||
|
# Non-XML content
|
||||||
|
(
|
||||||
|
"This is just plain text without XML",
|
||||||
|
{"summary": "", "tags": []},
|
||||||
|
),
|
||||||
|
# XML comments (should be ignored)
|
||||||
|
(
|
||||||
|
"<!-- comment --><summary>test</summary><!-- another comment --><tags><tag>tag1</tag></tags>",
|
||||||
|
{"summary": "test", "tags": ["tag1"]},
|
||||||
|
),
|
||||||
|
# CDATA sections
|
||||||
|
(
|
||||||
|
"<summary><![CDATA[test with <special> characters]]></summary><tags><tag>cdata</tag></tags>",
|
||||||
|
{"summary": "test with <special> characters", "tags": ["cdata"]},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_parse_response(response, expected):
|
||||||
|
assert summarizer.parse_response(response) == expected
|
57
tools/discord_setup.py
Normal file
57
tools/discord_setup.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import argparse
|
||||||
|
import click
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option("--bot-token", type=str, required=True)
|
||||||
|
def generate_bot_invite_url(bot_token: str):
|
||||||
|
"""
|
||||||
|
Generate the Discord bot invitation URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL that user can click to add bot to their server
|
||||||
|
"""
|
||||||
|
# Get bot's client ID from the token (it's the first part before the first dot)
|
||||||
|
# But safer to get it from the API
|
||||||
|
try:
|
||||||
|
headers = {"Authorization": f"Bot {bot_token}"}
|
||||||
|
response = requests.get(
|
||||||
|
"https://discord.com/api/v10/users/@me", headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
bot_info = response.json()
|
||||||
|
client_id = bot_info["id"]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Could not get bot info: {e}")
|
||||||
|
|
||||||
|
# Permissions needed: Send Messages (2048) + Manage Channels (16) + View Channels (1024)
|
||||||
|
permissions = 2048 + 16 + 1024 # = 3088
|
||||||
|
|
||||||
|
invite_url = f"https://discord.com/oauth2/authorize?client_id={client_id}&scope=bot&permissions={permissions}"
|
||||||
|
click.echo(f"Bot invite URL: {invite_url}")
|
||||||
|
return invite_url
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
def create_channels():
|
||||||
|
"""Create Discord channels using the configured servers."""
|
||||||
|
from memory.common.discord import load_servers
|
||||||
|
|
||||||
|
click.echo("Loading Discord servers and creating channels...")
|
||||||
|
load_servers()
|
||||||
|
click.echo("Discord channels setup completed.")
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def cli():
|
||||||
|
"""Discord setup utilities."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
cli.add_command(generate_bot_invite_url, name="generate-invite")
|
||||||
|
cli.add_command(create_channels, name="create-channels")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
Loading…
x
Reference in New Issue
Block a user