mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 15:14:45 +02:00
discord notification on error
This commit is contained in:
parent
489265fe31
commit
4d057d1ec6
21
README.md
21
README.md
@ -92,6 +92,27 @@ curl -X POST http://localhost:8000/auth/login \
|
||||
|
||||
This returns a session ID that should be included in subsequent requests as the `X-Session-ID` header.
|
||||
|
||||
## Discord integration
|
||||
|
||||
If you want to have notifications sent to discord, you'll have to [create a bot for that](https://discord.com/developers/applications).
|
||||
Once you have the bot's token, run
|
||||
|
||||
```bash
|
||||
python tools/discord_setup.py generate-invite --bot-token <your bot token>
|
||||
```
|
||||
|
||||
to get an url that can be used to connect your Discord bot.
|
||||
|
||||
Next you'll have to set at least the following in your `.env` file:
|
||||
|
||||
```
|
||||
DISCORD_BOT_TOKEN=<your bot token>
|
||||
DISCORD_NOTIFICATIONS_ENABLED=True
|
||||
```
|
||||
|
||||
When the worker starts it will automatically attempt to create the appropriate channels. You
|
||||
can change what they will be called by setting the various `DISCORD_*_CHANNEL` settings.
|
||||
|
||||
## MCP Proxy Setup
|
||||
|
||||
Since MCP doesn't support basic authentication, use the included proxy for AI assistants that need to connect:
|
||||
|
2
src/memory/api/MCP/__init__.py
Normal file
2
src/memory/api/MCP/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
import memory.api.MCP.manifest
|
||||
import memory.api.MCP.memory
|
119
src/memory/api/MCP/manifest.py
Normal file
119
src/memory/api/MCP/manifest.py
Normal file
@ -0,0 +1,119 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from datetime import datetime
|
||||
|
||||
from typing import TypedDict, NotRequired, Literal
|
||||
from memory.api.MCP.tools import mcp
|
||||
|
||||
|
||||
class BinaryProbs(TypedDict):
|
||||
prob: float
|
||||
|
||||
|
||||
class MultiProbs(TypedDict):
|
||||
answerProbs: dict[str, float]
|
||||
|
||||
|
||||
Probs = dict[str, BinaryProbs | MultiProbs]
|
||||
OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"]
|
||||
|
||||
|
||||
class MarketAnswer(TypedDict):
|
||||
id: str
|
||||
text: str
|
||||
resolutionProbability: float
|
||||
|
||||
|
||||
class MarketDetails(TypedDict):
|
||||
id: str
|
||||
createdTime: int
|
||||
question: str
|
||||
outcomeType: OutcomeType
|
||||
textDescription: str
|
||||
groupSlugs: list[str]
|
||||
volume: float
|
||||
isResolved: bool
|
||||
answers: list[MarketAnswer]
|
||||
|
||||
|
||||
class Market(TypedDict):
|
||||
id: str
|
||||
url: str
|
||||
question: str
|
||||
volume: int
|
||||
createdTime: int
|
||||
outcomeType: OutcomeType
|
||||
createdAt: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
answers: NotRequired[dict[str, float]]
|
||||
probability: NotRequired[float]
|
||||
details: NotRequired[MarketDetails]
|
||||
|
||||
|
||||
async def get_details(session: aiohttp.ClientSession, market_id: str):
|
||||
async with session.get(
|
||||
f"https://api.manifold.markets/v0/market/{market_id}"
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
|
||||
async def format_market(session: aiohttp.ClientSession, market: Market):
|
||||
if market.get("outcomeType") != "BINARY":
|
||||
details = await get_details(session, market["id"])
|
||||
market["answers"] = {
|
||||
answer["text"]: round(
|
||||
answer.get("resolutionProbability") or answer.get("probability") or 0, 3
|
||||
)
|
||||
for answer in details["answers"]
|
||||
}
|
||||
if creationTime := market.get("createdTime"):
|
||||
market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat()
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"url",
|
||||
"question",
|
||||
"volume",
|
||||
"createdAt",
|
||||
"details",
|
||||
"probability",
|
||||
"answers",
|
||||
]
|
||||
return {k: v for k, v in market.items() if k in fields}
|
||||
|
||||
|
||||
async def search_markets(term: str, min_volume: int = 1000, binary: bool = False):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.manifold.markets/v0/search-markets",
|
||||
params={
|
||||
"term": term,
|
||||
"contractType": "BINARY" if binary else "ALL",
|
||||
},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
markets = await resp.json()
|
||||
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
format_market(session, market)
|
||||
for market in markets
|
||||
if market.get("volume", 0) >= min_volume
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_forecasts(
|
||||
term: str, min_volume: int = 1000, binary: bool = False
|
||||
) -> list[dict]:
|
||||
"""Get prediction market forecasts for a given term.
|
||||
|
||||
Args:
|
||||
term: The term to search for.
|
||||
min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold.
|
||||
binary: Whether to only return binary markets.
|
||||
"""
|
||||
return await search_markets(term, min_volume, binary)
|
443
src/memory/api/MCP/memory.py
Normal file
443
src/memory/api/MCP/memory.py
Normal file
@ -0,0 +1,443 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
import mimetypes
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Text, func
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.search.search import SearchFilters, search
|
||||
from memory.common import extract, settings
|
||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.common.formatters import observation
|
||||
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
|
||||
from memory.api.MCP.tools import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_observation_source_ids(
|
||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||
):
|
||||
if not tags and not observation_types:
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
items_query = session.query(AgentObservation.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if observation_types:
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.observation_type.in_(observation_types)
|
||||
)
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
|
||||
return source_ids
|
||||
|
||||
|
||||
def filter_source_ids(
|
||||
modalities: set[str],
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
if not tags:
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
items_query = session.query(SourceItem.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if modalities:
|
||||
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
|
||||
return source_ids
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""
|
||||
Get all unique tags used across the entire knowledge base.
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""
|
||||
Get all unique subjects from observations about the user.
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""
|
||||
Get all observation types that have been used.
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str,
|
||||
previews: bool = False,
|
||||
modalities: set[str] = set(),
|
||||
tags: list[str] = [],
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search user's stored content including emails, documents, articles, books.
|
||||
Use to find specific information the user has saved or received.
|
||||
Combine with search_observations for complete user context.
|
||||
|
||||
Args:
|
||||
query: Natural language search query - be descriptive about what you're looking for
|
||||
previews: Include actual content in results - when false only a snippet is returned
|
||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||
tags: Filter by tags - content must have at least one matching tag
|
||||
limit: Max results (1-100)
|
||||
|
||||
Returns: List of search results with id, score, chunks, content, filename
|
||||
Higher scores (>0.7) indicate strong matches.
|
||||
"""
|
||||
logger.info(f"MCP search for: {query}")
|
||||
|
||||
if not modalities:
|
||||
modalities = set(ALL_COLLECTIONS.keys())
|
||||
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
||||
|
||||
upload_data = extract.extract_text(query)
|
||||
results = await search(
|
||||
upload_data,
|
||||
previews=previews,
|
||||
modalities=modalities,
|
||||
limit=limit,
|
||||
min_text_score=0.4,
|
||||
min_multimodal_score=0.25,
|
||||
filters=SearchFilters(
|
||||
tags=tags,
|
||||
source_ids=filter_source_ids(tags=tags, modalities=modalities),
|
||||
),
|
||||
)
|
||||
|
||||
return [result.model_dump() for result in results]
|
||||
|
||||
|
||||
class RawObservation(BaseModel):
|
||||
subject: str
|
||||
content: str
|
||||
observation_type: str = "general"
|
||||
confidences: dict[str, float] = {}
|
||||
evidence: dict | None = None
|
||||
tags: list[str] = []
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def observe(
|
||||
observations: list[RawObservation],
|
||||
session_id: str | None = None,
|
||||
agent_model: str = "unknown",
|
||||
) -> dict:
|
||||
"""
|
||||
Record observations about the user for long-term understanding.
|
||||
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
|
||||
Be specific and detailed - observations should make sense months later.
|
||||
|
||||
Example call:
|
||||
```
|
||||
{
|
||||
"observations": [
|
||||
{
|
||||
"content": "The user is a software engineer.",
|
||||
"subject": "user",
|
||||
"observation_type": "belief",
|
||||
"confidences": {"observation_accuracy": 0.9},
|
||||
"evidence": {"quote": "I am a software engineer.", "context": "I work at Google."},
|
||||
"tags": ["programming", "work"]
|
||||
}
|
||||
],
|
||||
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"agent_model": "gpt-4o"
|
||||
}
|
||||
```
|
||||
|
||||
RawObservation fields:
|
||||
content (required): Detailed observation text explaining what you observed
|
||||
subject (required): Consistent identifier like "programming_style", "work_habits"
|
||||
observation_type: belief, preference, behavior, contradiction, general
|
||||
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
|
||||
evidence: Context dict with extra context, e.g. "quote" (exact words) and "context" (situation)
|
||||
tags: List of categorization tags for organization
|
||||
|
||||
Args:
|
||||
observations: List of RawObservation objects
|
||||
session_id: UUID to group observations from same conversation
|
||||
agent_model: AI model making observations (for quality tracking)
|
||||
"""
|
||||
tasks = [
|
||||
(
|
||||
observation,
|
||||
celery_app.send_task(
|
||||
SYNC_OBSERVATION,
|
||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||
kwargs={
|
||||
"subject": observation.subject,
|
||||
"content": observation.content,
|
||||
"observation_type": observation.observation_type,
|
||||
"confidences": observation.confidences,
|
||||
"evidence": observation.evidence,
|
||||
"tags": observation.tags,
|
||||
"session_id": session_id,
|
||||
"agent_model": agent_model,
|
||||
},
|
||||
),
|
||||
)
|
||||
for observation in observations
|
||||
]
|
||||
|
||||
def short_content(obs: RawObservation) -> str:
|
||||
if len(obs.content) > 50:
|
||||
return obs.content[:47] + "..."
|
||||
return obs.content
|
||||
|
||||
return {
|
||||
"task_ids": {short_content(obs): task.id for obs, task in tasks},
|
||||
"status": "queued",
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_observations(
|
||||
query: str,
|
||||
subject: str = "",
|
||||
tags: list[str] | None = None,
|
||||
observation_types: list[str] | None = None,
|
||||
min_confidences: dict[str, float] = {},
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search recorded observations about the user.
|
||||
Use before responding to understand user preferences, patterns, and past insights.
|
||||
Search by meaning - the query matches both content and context.
|
||||
|
||||
Args:
|
||||
query: Natural language search query describing what you're looking for
|
||||
subject: Filter by exact subject identifier (empty = search all subjects)
|
||||
tags: Filter by tags (must have at least one matching tag)
|
||||
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
||||
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
||||
limit: Max results (1-100)
|
||||
|
||||
Returns: List with content, tags, created_at, metadata
|
||||
Results sorted by relevance to your query.
|
||||
"""
|
||||
semantic_text = observation.generate_semantic_text(
|
||||
subject=subject or "",
|
||||
observation_type="".join(observation_types or []),
|
||||
content=query,
|
||||
evidence=None,
|
||||
)
|
||||
temporal = observation.generate_temporal_text(
|
||||
subject=subject or "",
|
||||
content=query,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
results = await search(
|
||||
[
|
||||
extract.DataChunk(data=[query]),
|
||||
extract.DataChunk(data=[semantic_text]),
|
||||
extract.DataChunk(data=[temporal]),
|
||||
],
|
||||
previews=True,
|
||||
modalities={"semantic", "temporal"},
|
||||
limit=limit,
|
||||
filters=SearchFilters(
|
||||
subject=subject,
|
||||
min_confidences=min_confidences,
|
||||
tags=tags,
|
||||
observation_types=observation_types,
|
||||
source_ids=filter_observation_source_ids(tags=tags),
|
||||
),
|
||||
timeout=2,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"content": r.content,
|
||||
"tags": r.tags,
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
"metadata": r.metadata,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def create_note(
|
||||
subject: str,
|
||||
content: str,
|
||||
filename: str | None = None,
|
||||
note_type: str | None = None,
|
||||
confidences: dict[str, float] = {},
|
||||
tags: list[str] = [],
|
||||
) -> dict:
|
||||
"""
|
||||
Create a note when user asks to save or record something.
|
||||
Use when user explicitly requests noting information for future reference.
|
||||
|
||||
Args:
|
||||
subject: What the note is about (used for organization)
|
||||
content: Note content as a markdown string
|
||||
filename: Optional path relative to notes folder (e.g., "project/ideas.md")
|
||||
note_type: Optional categorization of the note
|
||||
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
|
||||
tags: Organization tags for filtering and discovery
|
||||
"""
|
||||
if filename:
|
||||
path = pathlib.Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
|
||||
filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix()
|
||||
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
SYNC_NOTE,
|
||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||
kwargs={
|
||||
"subject": subject,
|
||||
"content": content,
|
||||
"filename": filename,
|
||||
"note_type": note_type,
|
||||
"confidences": confidences,
|
||||
"tags": tags,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"Error creating note: {e}")
|
||||
raise
|
||||
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"status": "queued",
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def note_files(path: str = "/"):
|
||||
"""
|
||||
List note files in the user's note storage.
|
||||
Use to discover existing notes before reading or to help user navigate their collection.
|
||||
|
||||
Args:
|
||||
path: Directory path to search (e.g., "/", "/projects", "/meetings")
|
||||
Use "/" for root, or subdirectories to narrow scope
|
||||
|
||||
Returns: List of file paths relative to notes directory
|
||||
"""
|
||||
root = settings.NOTES_STORAGE_DIR / path.lstrip("/")
|
||||
return [
|
||||
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
|
||||
for f in root.rglob("*.md")
|
||||
if f.is_file()
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def fetch_file(filename: str) -> dict:
|
||||
"""
|
||||
Read file content with automatic type detection.
|
||||
Returns dict with content, mime_type, is_text, file_size.
|
||||
Text content as string, binary as base64.
|
||||
"""
|
||||
path = settings.FILE_STORAGE_DIR / filename.lstrip("/")
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {filename}")
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(str(path))
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
text_extensions = {
|
||||
".md",
|
||||
".txt",
|
||||
".py",
|
||||
".js",
|
||||
".html",
|
||||
".css",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
}
|
||||
text_mimes = {
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
}
|
||||
is_text = (
|
||||
mime_type.startswith("text/")
|
||||
or mime_type in text_mimes
|
||||
or path.suffix.lower() in text_extensions
|
||||
)
|
||||
|
||||
try:
|
||||
content = (
|
||||
path.read_text(encoding="utf-8")
|
||||
if is_text
|
||||
else __import__("base64").b64encode(path.read_bytes()).decode("ascii")
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
import base64
|
||||
|
||||
content = base64.b64encode(path.read_bytes()).decode("ascii")
|
||||
is_text = False
|
||||
mime_type = (
|
||||
"application/octet-stream" if mime_type.startswith("text/") else mime_type
|
||||
)
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"mime_type": mime_type,
|
||||
"is_text": is_text,
|
||||
"file_size": path.stat().st_size,
|
||||
"filename": filename,
|
||||
}
|
@ -3,23 +3,15 @@ MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
import mimetypes
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Text, func
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.search.search import SearchFilters, search
|
||||
from memory.common import extract, settings
|
||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.common.formatters import observation
|
||||
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -76,377 +68,3 @@ def filter_source_ids(
|
||||
async def get_current_time() -> dict:
|
||||
"""Get the current time in UTC."""
|
||||
return {"current_time": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""
|
||||
Get all unique tags used across the entire knowledge base.
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""
|
||||
Get all unique subjects from observations about the user.
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""
|
||||
Get all observation types that have been used.
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str,
|
||||
previews: bool = False,
|
||||
modalities: set[str] = set(),
|
||||
tags: list[str] = [],
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search user's stored content including emails, documents, articles, books.
|
||||
Use to find specific information the user has saved or received.
|
||||
Combine with search_observations for complete user context.
|
||||
|
||||
Args:
|
||||
query: Natural language search query - be descriptive about what you're looking for
|
||||
previews: Include actual content in results - when false only a snippet is returned
|
||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||
tags: Filter by tags - content must have at least one matching tag
|
||||
limit: Max results (1-100)
|
||||
|
||||
Returns: List of search results with id, score, chunks, content, filename
|
||||
Higher scores (>0.7) indicate strong matches.
|
||||
"""
|
||||
logger.info(f"MCP search for: {query}")
|
||||
|
||||
if not modalities:
|
||||
modalities = set(ALL_COLLECTIONS.keys())
|
||||
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
||||
|
||||
upload_data = extract.extract_text(query)
|
||||
results = await search(
|
||||
upload_data,
|
||||
previews=previews,
|
||||
modalities=modalities,
|
||||
limit=limit,
|
||||
min_text_score=0.4,
|
||||
min_multimodal_score=0.25,
|
||||
filters=SearchFilters(
|
||||
tags=tags,
|
||||
source_ids=filter_source_ids(tags=tags, modalities=modalities),
|
||||
),
|
||||
)
|
||||
|
||||
return [result.model_dump() for result in results]
|
||||
|
||||
|
||||
class RawObservation(BaseModel):
|
||||
subject: str
|
||||
content: str
|
||||
observation_type: str = "general"
|
||||
confidences: dict[str, float] = {}
|
||||
evidence: dict | None = None
|
||||
tags: list[str] = []
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def observe(
|
||||
observations: list[RawObservation],
|
||||
session_id: str | None = None,
|
||||
agent_model: str = "unknown",
|
||||
) -> dict:
|
||||
"""
|
||||
Record observations about the user for long-term understanding.
|
||||
Use proactively when user expresses preferences, behaviors, beliefs, or contradictions.
|
||||
Be specific and detailed - observations should make sense months later.
|
||||
|
||||
Example call:
|
||||
```
|
||||
{
|
||||
"observations": [
|
||||
{
|
||||
"content": "The user is a software engineer.",
|
||||
"subject": "user",
|
||||
"observation_type": "belief",
|
||||
"confidences": {"observation_accuracy": 0.9},
|
||||
"evidence": {"quote": "I am a software engineer.", "context": "I work at Google."},
|
||||
"tags": ["programming", "work"]
|
||||
}
|
||||
],
|
||||
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"agent_model": "gpt-4o"
|
||||
}
|
||||
```
|
||||
|
||||
RawObservation fields:
|
||||
content (required): Detailed observation text explaining what you observed
|
||||
subject (required): Consistent identifier like "programming_style", "work_habits"
|
||||
observation_type: belief, preference, behavior, contradiction, general
|
||||
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
|
||||
evidence: Context dict with extra context, e.g. "quote" (exact words) and "context" (situation)
|
||||
tags: List of categorization tags for organization
|
||||
|
||||
Args:
|
||||
observations: List of RawObservation objects
|
||||
session_id: UUID to group observations from same conversation
|
||||
agent_model: AI model making observations (for quality tracking)
|
||||
"""
|
||||
tasks = [
|
||||
(
|
||||
observation,
|
||||
celery_app.send_task(
|
||||
SYNC_OBSERVATION,
|
||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||
kwargs={
|
||||
"subject": observation.subject,
|
||||
"content": observation.content,
|
||||
"observation_type": observation.observation_type,
|
||||
"confidences": observation.confidences,
|
||||
"evidence": observation.evidence,
|
||||
"tags": observation.tags,
|
||||
"session_id": session_id,
|
||||
"agent_model": agent_model,
|
||||
},
|
||||
),
|
||||
)
|
||||
for observation in observations
|
||||
]
|
||||
|
||||
def short_content(obs: RawObservation) -> str:
|
||||
if len(obs.content) > 50:
|
||||
return obs.content[:47] + "..."
|
||||
return obs.content
|
||||
|
||||
return {
|
||||
"task_ids": {short_content(obs): task.id for obs, task in tasks},
|
||||
"status": "queued",
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_observations(
|
||||
query: str,
|
||||
subject: str = "",
|
||||
tags: list[str] | None = None,
|
||||
observation_types: list[str] | None = None,
|
||||
min_confidences: dict[str, float] = {},
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search recorded observations about the user.
|
||||
Use before responding to understand user preferences, patterns, and past insights.
|
||||
Search by meaning - the query matches both content and context.
|
||||
|
||||
Args:
|
||||
query: Natural language search query describing what you're looking for
|
||||
subject: Filter by exact subject identifier (empty = search all subjects)
|
||||
tags: Filter by tags (must have at least one matching tag)
|
||||
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
||||
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
||||
limit: Max results (1-100)
|
||||
|
||||
Returns: List with content, tags, created_at, metadata
|
||||
Results sorted by relevance to your query.
|
||||
"""
|
||||
semantic_text = observation.generate_semantic_text(
|
||||
subject=subject or "",
|
||||
observation_type="".join(observation_types or []),
|
||||
content=query,
|
||||
evidence=None,
|
||||
)
|
||||
temporal = observation.generate_temporal_text(
|
||||
subject=subject or "",
|
||||
content=query,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
results = await search(
|
||||
[
|
||||
extract.DataChunk(data=[query]),
|
||||
extract.DataChunk(data=[semantic_text]),
|
||||
extract.DataChunk(data=[temporal]),
|
||||
],
|
||||
previews=True,
|
||||
modalities={"semantic", "temporal"},
|
||||
limit=limit,
|
||||
filters=SearchFilters(
|
||||
subject=subject,
|
||||
min_confidences=min_confidences,
|
||||
tags=tags,
|
||||
observation_types=observation_types,
|
||||
source_ids=filter_observation_source_ids(tags=tags),
|
||||
),
|
||||
timeout=2,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"content": r.content,
|
||||
"tags": r.tags,
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
"metadata": r.metadata,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def create_note(
|
||||
subject: str,
|
||||
content: str,
|
||||
filename: str | None = None,
|
||||
note_type: str | None = None,
|
||||
confidences: dict[str, float] = {},
|
||||
tags: list[str] = [],
|
||||
) -> dict:
|
||||
"""
|
||||
Create a note when user asks to save or record something.
|
||||
Use when user explicitly requests noting information for future reference.
|
||||
|
||||
Args:
|
||||
subject: What the note is about (used for organization)
|
||||
content: Note content as a markdown string
|
||||
filename: Optional path relative to notes folder (e.g., "project/ideas.md")
|
||||
note_type: Optional categorization of the note
|
||||
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
|
||||
tags: Organization tags for filtering and discovery
|
||||
"""
|
||||
if filename:
|
||||
path = pathlib.Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
|
||||
filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix()
|
||||
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
SYNC_NOTE,
|
||||
queue=f"{settings.CELERY_QUEUE_PREFIX}-notes",
|
||||
kwargs={
|
||||
"subject": subject,
|
||||
"content": content,
|
||||
"filename": filename,
|
||||
"note_type": note_type,
|
||||
"confidences": confidences,
|
||||
"tags": tags,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"Error creating note: {e}")
|
||||
raise
|
||||
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"status": "queued",
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def note_files(path: str = "/"):
|
||||
"""
|
||||
List note files in the user's note storage.
|
||||
Use to discover existing notes before reading or to help user navigate their collection.
|
||||
|
||||
Args:
|
||||
path: Directory path to search (e.g., "/", "/projects", "/meetings")
|
||||
Use "/" for root, or subdirectories to narrow scope
|
||||
|
||||
Returns: List of file paths relative to notes directory
|
||||
"""
|
||||
root = settings.NOTES_STORAGE_DIR / path.lstrip("/")
|
||||
return [
|
||||
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
|
||||
for f in root.rglob("*.md")
|
||||
if f.is_file()
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def fetch_file(filename: str) -> dict:
|
||||
"""
|
||||
Read file content with automatic type detection.
|
||||
Returns dict with content, mime_type, is_text, file_size.
|
||||
Text content as string, binary as base64.
|
||||
"""
|
||||
path = settings.FILE_STORAGE_DIR / filename.lstrip("/")
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {filename}")
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(str(path))
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
text_extensions = {
|
||||
".md",
|
||||
".txt",
|
||||
".py",
|
||||
".js",
|
||||
".html",
|
||||
".css",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
}
|
||||
text_mimes = {
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
}
|
||||
is_text = (
|
||||
mime_type.startswith("text/")
|
||||
or mime_type in text_mimes
|
||||
or path.suffix.lower() in text_extensions
|
||||
)
|
||||
|
||||
try:
|
||||
content = (
|
||||
path.read_text(encoding="utf-8")
|
||||
if is_text
|
||||
else __import__("base64").b64encode(path.read_bytes()).decode("ascii")
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
import base64
|
||||
|
||||
content = base64.b64encode(path.read_bytes()).decode("ascii")
|
||||
is_text = False
|
||||
mime_type = (
|
||||
"application/octet-stream" if mime_type.startswith("text/") else mime_type
|
||||
)
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"mime_type": mime_type,
|
||||
"is_text": is_text,
|
||||
"file_size": path.stat().st_size,
|
||||
"filename": filename,
|
||||
}
|
||||
|
@ -195,6 +195,9 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if settings.DISABLE_AUTH:
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
|
||||
# Skip authentication for whitelisted endpoints
|
||||
|
@ -83,5 +83,7 @@ app.conf.update(
|
||||
@app.on_after_configure.connect # type: ignore[attr-defined]
|
||||
def ensure_qdrant_initialised(sender, **_):
|
||||
from memory.common import qdrant
|
||||
from memory.common.discord import load_servers
|
||||
|
||||
qdrant.setup_qdrant()
|
||||
load_servers()
|
||||
|
184
src/memory/common/discord.py
Normal file
184
src/memory/common/discord.py
Normal file
@ -0,0 +1,184 @@
|
||||
import logging
|
||||
import requests
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ERROR_CHANNEL = "memory-errors"
|
||||
ACTIVITY_CHANNEL = "memory-activity"
|
||||
DISCOVERY_CHANNEL = "memory-discoveries"
|
||||
CHAT_CHANNEL = "memory-chat"
|
||||
|
||||
|
||||
class DiscordServer(requests.Session):
|
||||
def __init__(self, server_id: str, server_name: str, *args, **kwargs):
|
||||
self.server_id = server_id
|
||||
self.server_name = server_name
|
||||
self.channels = {}
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_channels()
|
||||
|
||||
def setup_channels(self):
|
||||
resp = self.get(self.channels_url)
|
||||
resp.raise_for_status()
|
||||
channels = {channel["name"]: channel["id"] for channel in resp.json()}
|
||||
|
||||
if not (error_channel := channels.get(settings.DISCORD_ERROR_CHANNEL)):
|
||||
error_channel = self.create_channel(settings.DISCORD_ERROR_CHANNEL)
|
||||
self.channels[ERROR_CHANNEL] = error_channel
|
||||
|
||||
if not (activity_channel := channels.get(settings.DISCORD_ACTIVITY_CHANNEL)):
|
||||
activity_channel = self.create_channel(settings.DISCORD_ACTIVITY_CHANNEL)
|
||||
self.channels[ACTIVITY_CHANNEL] = activity_channel
|
||||
|
||||
if not (discovery_channel := channels.get(settings.DISCORD_DISCOVERY_CHANNEL)):
|
||||
discovery_channel = self.create_channel(settings.DISCORD_DISCOVERY_CHANNEL)
|
||||
self.channels[DISCOVERY_CHANNEL] = discovery_channel
|
||||
|
||||
if not (chat_channel := channels.get(settings.DISCORD_CHAT_CHANNEL)):
|
||||
chat_channel = self.create_channel(settings.DISCORD_CHAT_CHANNEL)
|
||||
self.channels[CHAT_CHANNEL] = chat_channel
|
||||
|
||||
@property
|
||||
def error_channel(self) -> str:
|
||||
return self.channels[ERROR_CHANNEL]
|
||||
|
||||
@property
|
||||
def activity_channel(self) -> str:
|
||||
return self.channels[ACTIVITY_CHANNEL]
|
||||
|
||||
@property
|
||||
def discovery_channel(self) -> str:
|
||||
return self.channels[DISCOVERY_CHANNEL]
|
||||
|
||||
@property
|
||||
def chat_channel(self) -> str:
|
||||
return self.channels[CHAT_CHANNEL]
|
||||
|
||||
def channel_id(self, channel_name: str) -> str:
|
||||
if not (channel_id := self.channels.get(channel_name)):
|
||||
raise ValueError(f"Channel {channel_name} not found")
|
||||
return channel_id
|
||||
|
||||
def send_message(self, channel_id: str, content: str):
|
||||
self.post(
|
||||
f"https://discord.com/api/v10/channels/{channel_id}/messages",
|
||||
json={"content": content},
|
||||
)
|
||||
|
||||
def create_channel(self, channel_name: str, channel_type: int = 0) -> str | None:
|
||||
resp = self.post(
|
||||
self.channels_url, json={"name": channel_name, "type": channel_type}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["id"]
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"DiscordServer(server_id={self.server_id}, server_name={self.server_name})"
|
||||
)
|
||||
|
||||
def request(self, method: str, url: str, **kwargs):
|
||||
headers = kwargs.get("headers", {})
|
||||
headers["Authorization"] = f"Bot {settings.DISCORD_BOT_TOKEN}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
kwargs["headers"] = headers
|
||||
return super().request(method, url, **kwargs)
|
||||
|
||||
@property
|
||||
def channels_url(self) -> str:
|
||||
return f"https://discord.com/api/v10/guilds/{self.server_id}/channels"
|
||||
|
||||
|
||||
def get_bot_servers() -> List[Dict]:
|
||||
"""Get list of servers the bot is in."""
|
||||
if not settings.DISCORD_BOT_TOKEN:
|
||||
return []
|
||||
|
||||
try:
|
||||
headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"}
|
||||
response = requests.get(
|
||||
"https://discord.com/api/v10/users/@me/guilds", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get bot servers: {e}")
|
||||
return []
|
||||
|
||||
|
||||
servers: dict[str, DiscordServer] = {}
|
||||
|
||||
|
||||
def load_servers():
|
||||
for server in get_bot_servers():
|
||||
servers[server["id"]] = DiscordServer(server["id"], server["name"])
|
||||
|
||||
|
||||
def broadcast_message(channel: str, message: str):
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
return
|
||||
|
||||
for server in servers.values():
|
||||
server.send_message(server.channel_id(channel), message)
|
||||
|
||||
|
||||
def send_error_message(message: str):
|
||||
broadcast_message(ERROR_CHANNEL, message)
|
||||
|
||||
|
||||
def send_activity_message(message: str):
|
||||
broadcast_message(ACTIVITY_CHANNEL, message)
|
||||
|
||||
|
||||
def send_discovery_message(message: str):
|
||||
broadcast_message(DISCOVERY_CHANNEL, message)
|
||||
|
||||
|
||||
def send_chat_message(message: str):
|
||||
broadcast_message(CHAT_CHANNEL, message)
|
||||
|
||||
|
||||
def notify_task_failure(
|
||||
task_name: str,
|
||||
error_message: str,
|
||||
task_args: tuple = (),
|
||||
task_kwargs: dict[str, Any] | None = None,
|
||||
traceback_str: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send a task failure notification to Discord.
|
||||
|
||||
Args:
|
||||
task_name: Name of the failed task
|
||||
error_message: Error message
|
||||
task_args: Task arguments
|
||||
task_kwargs: Task keyword arguments
|
||||
traceback_str: Full traceback string
|
||||
|
||||
Returns:
|
||||
True if notification sent successfully
|
||||
"""
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
logger.debug("Discord notifications disabled")
|
||||
return
|
||||
|
||||
message = f"🚨 **Task Failed: {task_name}**\n\n"
|
||||
message += f"**Error:** {error_message[:500]}\n"
|
||||
|
||||
if task_args:
|
||||
message += f"**Args:** `{str(task_args)[:200]}`\n"
|
||||
|
||||
if task_kwargs:
|
||||
message += f"**Kwargs:** `{str(task_kwargs)[:200]}`\n"
|
||||
|
||||
if traceback_str:
|
||||
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
|
||||
|
||||
try:
|
||||
send_error_message(message)
|
||||
logger.info(f"Discord error notification sent for task: {task_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Discord notification: {e}")
|
@ -55,7 +55,7 @@ def generate_temporal_text(
|
||||
f"Observation: {content}",
|
||||
]
|
||||
|
||||
return " | ".join(parts)
|
||||
return " | ".join(parts).strip()
|
||||
|
||||
|
||||
# TODO: Add more embedding dimensions here:
|
||||
|
@ -138,3 +138,17 @@ SESSION_COOKIE_MAX_AGE = int(os.getenv("SESSION_COOKIE_MAX_AGE", 30 * 24 * 60 *
|
||||
SESSION_VALID_FOR = int(os.getenv("SESSION_VALID_FOR", 30))
|
||||
|
||||
REGISTER_ENABLED = boolean_env("REGISTER_ENABLED", False) or True
|
||||
DISABLE_AUTH = boolean_env("DISABLE_AUTH", False)
|
||||
|
||||
# Discord notification settings
|
||||
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||
DISCORD_ERROR_CHANNEL = os.getenv("DISCORD_ERROR_CHANNEL", "memory-errors")
|
||||
DISCORD_ACTIVITY_CHANNEL = os.getenv("DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
||||
DISCORD_DISCOVERY_CHANNEL = os.getenv("DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
||||
DISCORD_CHAT_CHANNEL = os.getenv("DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||
|
||||
|
||||
# Enable Discord notifications if bot token is set
|
||||
DISCORD_NOTIFICATIONS_ENABLED = (
|
||||
boolean_env("DISCORD_NOTIFICATIONS_ENABLED", True) and DISCORD_BOT_TOKEN
|
||||
)
|
||||
|
@ -47,9 +47,20 @@ Text to summarize:
|
||||
|
||||
def parse_response(response: str) -> dict[str, Any]:
|
||||
"""Parse the response from the summarizer."""
|
||||
soup = BeautifulSoup(response, "xml")
|
||||
summary = soup.find("summary").text
|
||||
tags = [tag.text for tag in soup.find_all("tag")]
|
||||
if not response or not response.strip():
|
||||
return {"summary": "", "tags": []}
|
||||
|
||||
# Use html.parser instead of xml parser for better compatibility
|
||||
soup = BeautifulSoup(response, "html.parser")
|
||||
|
||||
# Safely extract summary
|
||||
summary_element = soup.find("summary")
|
||||
summary = summary_element.text if summary_element else ""
|
||||
|
||||
# Safely extract tags
|
||||
tag_elements = soup.find_all("tag")
|
||||
tags = [tag.text for tag in tag_elements if tag.text is not None]
|
||||
|
||||
return {"summary": summary, "tags": tags}
|
||||
|
||||
|
||||
@ -68,7 +79,6 @@ def _call_openai(prompt: str) -> dict[str, Any]:
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
@ -12,8 +12,9 @@ import traceback
|
||||
import logging
|
||||
from typing import Any, Callable, Iterable, Sequence, cast
|
||||
|
||||
from memory.common import embedding, qdrant
|
||||
from memory.common import embedding, qdrant, settings
|
||||
from memory.common.db.models import SourceItem, Chunk
|
||||
from memory.common.discord import notify_task_failure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -274,7 +275,17 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Task {func.__name__} failed: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
traceback_str = traceback.format_exc()
|
||||
logger.error(traceback_str)
|
||||
|
||||
notify_task_failure(
|
||||
task_name=func.__name__,
|
||||
error_message=str(e),
|
||||
task_args=args,
|
||||
task_kwargs=kwargs,
|
||||
traceback_str=traceback_str,
|
||||
)
|
||||
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
return wrapper
|
||||
|
@ -213,12 +213,15 @@ def mock_file_storage(tmp_path: Path):
|
||||
email_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
notes_storage_dir = tmp_path / "notes"
|
||||
notes_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
comic_storage_dir = tmp_path / "comics"
|
||||
comic_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
with (
|
||||
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
|
||||
patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir),
|
||||
patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir),
|
||||
patch.object(settings, "EMAIL_STORAGE_DIR", email_storage_dir),
|
||||
patch.object(settings, "NOTES_STORAGE_DIR", notes_storage_dir),
|
||||
patch.object(settings, "COMIC_STORAGE_DIR", comic_storage_dir),
|
||||
):
|
||||
yield
|
||||
|
||||
@ -256,7 +259,7 @@ def mock_openai_client():
|
||||
choices=[
|
||||
Mock(
|
||||
message=Mock(
|
||||
content='{"summary": "test summary", "tags": ["tag1", "tag2"]}'
|
||||
content="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -273,8 +276,16 @@ def mock_anthropic_client():
|
||||
client.messages.create = Mock(
|
||||
return_value=Mock(
|
||||
content=[
|
||||
Mock(text='{"summary": "test summary", "tags": ["tag1", "tag2"]}')
|
||||
Mock(
|
||||
text="<summary>test summary</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>"
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_discord_client():
|
||||
with patch.object(settings, "DISCORD_NOTIFICATIONS_ENABLED", False):
|
||||
yield
|
||||
|
393
tests/memory/common/test_discord.py
Normal file
393
tests/memory/common/test_discord.py
Normal file
@ -0,0 +1,393 @@
|
||||
import logging
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import requests
|
||||
import json
|
||||
|
||||
from memory.common import discord, settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_request():
|
||||
with patch("requests.Session.request") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_channels_response():
|
||||
return [
|
||||
{"name": "memory-errors", "id": "error_channel_id"},
|
||||
{"name": "memory-activity", "id": "activity_channel_id"},
|
||||
{"name": "memory-discoveries", "id": "discovery_channel_id"},
|
||||
{"name": "memory-chat", "id": "chat_channel_id"},
|
||||
]
|
||||
|
||||
|
||||
def test_discord_server_init(mock_session_request, mock_get_channels_response):
|
||||
# Mock the channels API call
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_get_channels_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer("server123", "Test Server")
|
||||
|
||||
assert server.server_id == "server123"
|
||||
assert server.server_name == "Test Server"
|
||||
assert hasattr(server, "channels")
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "memory-errors")
|
||||
@patch("memory.common.settings.DISCORD_ACTIVITY_CHANNEL", "memory-activity")
|
||||
@patch("memory.common.settings.DISCORD_DISCOVERY_CHANNEL", "memory-discoveries")
|
||||
@patch("memory.common.settings.DISCORD_CHAT_CHANNEL", "memory-chat")
|
||||
def test_setup_channels_existing(mock_session_request, mock_get_channels_response):
|
||||
# Mock the channels API call
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_get_channels_response
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer("server123", "Test Server")
|
||||
|
||||
assert server.channels[discord.ERROR_CHANNEL] == "error_channel_id"
|
||||
assert server.channels[discord.ACTIVITY_CHANNEL] == "activity_channel_id"
|
||||
assert server.channels[discord.DISCOVERY_CHANNEL] == "discovery_channel_id"
|
||||
assert server.channels[discord.CHAT_CHANNEL] == "chat_channel_id"
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_ERROR_CHANNEL", "new-error-channel")
|
||||
def test_setup_channels_create_missing(mock_session_request):
|
||||
# Mock get channels (empty) and create channel calls
|
||||
get_response = Mock()
|
||||
get_response.json.return_value = []
|
||||
get_response.raise_for_status.return_value = None
|
||||
|
||||
create_response = Mock()
|
||||
create_response.json.return_value = {"id": "new_channel_id"}
|
||||
create_response.raise_for_status.return_value = None
|
||||
|
||||
mock_session_request.side_effect = [
|
||||
get_response,
|
||||
create_response,
|
||||
create_response,
|
||||
create_response,
|
||||
create_response,
|
||||
]
|
||||
|
||||
server = discord.DiscordServer("server123", "Test Server")
|
||||
|
||||
assert server.channels[discord.ERROR_CHANNEL] == "new_channel_id"
|
||||
|
||||
|
||||
def test_channel_properties():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.channels = {
|
||||
discord.ERROR_CHANNEL: "error_id",
|
||||
discord.ACTIVITY_CHANNEL: "activity_id",
|
||||
discord.DISCOVERY_CHANNEL: "discovery_id",
|
||||
discord.CHAT_CHANNEL: "chat_id",
|
||||
}
|
||||
|
||||
assert server.error_channel == "error_id"
|
||||
assert server.activity_channel == "activity_id"
|
||||
assert server.discovery_channel == "discovery_id"
|
||||
assert server.chat_channel == "chat_id"
|
||||
|
||||
|
||||
def test_channel_id_exists():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.channels = {"test-channel": "channel123"}
|
||||
|
||||
assert server.channel_id("test-channel") == "channel123"
|
||||
|
||||
|
||||
def test_channel_id_not_found():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.channels = {}
|
||||
|
||||
with pytest.raises(ValueError, match="Channel nonexistent not found"):
|
||||
server.channel_id("nonexistent")
|
||||
|
||||
|
||||
def test_send_message(mock_session_request):
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
|
||||
server.send_message("channel123", "Hello World")
|
||||
|
||||
mock_session_request.assert_called_with(
|
||||
"POST",
|
||||
"https://discord.com/api/v10/channels/channel123/messages",
|
||||
data=None,
|
||||
json={"content": "Hello World"},
|
||||
headers={
|
||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_create_channel(mock_session_request):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"id": "new_channel_id"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
|
||||
channel_id = server.create_channel("new-channel")
|
||||
|
||||
assert channel_id == "new_channel_id"
|
||||
mock_session_request.assert_called_with(
|
||||
"POST",
|
||||
"https://discord.com/api/v10/guilds/server123/channels",
|
||||
data=None,
|
||||
json={"name": "new-channel", "type": 0},
|
||||
headers={
|
||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_create_channel_custom_type(mock_session_request):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"id": "voice_channel_id"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_session_request.return_value = mock_response
|
||||
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
|
||||
channel_id = server.create_channel("voice-channel", channel_type=2)
|
||||
|
||||
assert channel_id == "voice_channel_id"
|
||||
mock_session_request.assert_called_with(
|
||||
"POST",
|
||||
"https://discord.com/api/v10/guilds/server123/channels",
|
||||
data=None,
|
||||
json={"name": "voice-channel", "type": 2},
|
||||
headers={
|
||||
"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_str_representation():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
server.server_name = "Test Server"
|
||||
|
||||
assert str(server) == "DiscordServer(server_id=server123, server_name=Test Server)"
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token_123")
|
||||
def test_request_adds_headers(mock_session_request):
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
|
||||
server.request("GET", "https://example.com", headers={"Custom": "header"})
|
||||
|
||||
expected_headers = {
|
||||
"Custom": "header",
|
||||
"Authorization": "Bot test_token_123",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
mock_session_request.assert_called_once_with(
|
||||
"GET", "https://example.com", headers=expected_headers
|
||||
)
|
||||
|
||||
|
||||
def test_channels_url():
|
||||
server = discord.DiscordServer.__new__(discord.DiscordServer)
|
||||
server.server_id = "server123"
|
||||
|
||||
assert (
|
||||
server.channels_url == "https://discord.com/api/v10/guilds/server123/channels"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
@patch("requests.get")
|
||||
def test_get_bot_servers_success(mock_get):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = [
|
||||
{"id": "server1", "name": "Server 1"},
|
||||
{"id": "server2", "name": "Server 2"},
|
||||
]
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
servers = discord.get_bot_servers()
|
||||
|
||||
assert len(servers) == 2
|
||||
assert servers[0] == {"id": "server1", "name": "Server 1"}
|
||||
mock_get.assert_called_once_with(
|
||||
"https://discord.com/api/v10/users/@me/guilds",
|
||||
headers={"Authorization": "Bot test_token"},
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", None)
|
||||
def test_get_bot_servers_no_token():
|
||||
assert discord.get_bot_servers() == []
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_BOT_TOKEN", "test_token")
|
||||
@patch("requests.get")
|
||||
def test_get_bot_servers_exception(mock_get):
|
||||
mock_get.side_effect = requests.RequestException("API Error")
|
||||
|
||||
servers = discord.get_bot_servers()
|
||||
|
||||
assert servers == []
|
||||
|
||||
|
||||
@patch("memory.common.discord.get_bot_servers")
|
||||
@patch("memory.common.discord.DiscordServer")
|
||||
def test_load_servers(mock_discord_server_class, mock_get_servers):
|
||||
mock_get_servers.return_value = [
|
||||
{"id": "server1", "name": "Server 1"},
|
||||
{"id": "server2", "name": "Server 2"},
|
||||
]
|
||||
|
||||
discord.load_servers()
|
||||
|
||||
assert mock_discord_server_class.call_count == 2
|
||||
mock_discord_server_class.assert_any_call("server1", "Server 1")
|
||||
mock_discord_server_class.assert_any_call("server2", "Server 2")
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_broadcast_message():
|
||||
mock_server1 = Mock()
|
||||
mock_server2 = Mock()
|
||||
discord.servers = {"1": mock_server1, "2": mock_server2}
|
||||
|
||||
discord.broadcast_message("test-channel", "Hello")
|
||||
|
||||
mock_server1.send_message.assert_called_once_with(
|
||||
mock_server1.channel_id.return_value, "Hello"
|
||||
)
|
||||
mock_server2.send_message.assert_called_once_with(
|
||||
mock_server2.channel_id.return_value, "Hello"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_broadcast_message_disabled():
|
||||
mock_server = Mock()
|
||||
discord.servers = {"1": mock_server}
|
||||
|
||||
discord.broadcast_message("test-channel", "Hello")
|
||||
|
||||
mock_server.send_message.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
def test_send_error_message(mock_broadcast):
|
||||
discord.send_error_message("Error occurred")
|
||||
mock_broadcast.assert_called_once_with(discord.ERROR_CHANNEL, "Error occurred")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
def test_send_activity_message(mock_broadcast):
|
||||
discord.send_activity_message("Activity update")
|
||||
mock_broadcast.assert_called_once_with(discord.ACTIVITY_CHANNEL, "Activity update")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
def test_send_discovery_message(mock_broadcast):
|
||||
discord.send_discovery_message("Discovery made")
|
||||
mock_broadcast.assert_called_once_with(discord.DISCOVERY_CHANNEL, "Discovery made")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
def test_send_chat_message(mock_broadcast):
|
||||
discord.send_chat_message("Chat message")
|
||||
mock_broadcast.assert_called_once_with(discord.CHAT_CHANNEL, "Chat message")
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
def test_notify_task_failure_basic(mock_send_error):
|
||||
discord.notify_task_failure("test_task", "Something went wrong")
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "🚨 **Task Failed: test_task**" in message
|
||||
assert "**Error:** Something went wrong" in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
def test_notify_task_failure_with_args(mock_send_error):
|
||||
discord.notify_task_failure(
|
||||
"test_task",
|
||||
"Error message",
|
||||
task_args=("arg1", "arg2"),
|
||||
task_kwargs={"key": "value"},
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
|
||||
assert "**Args:** `('arg1', 'arg2')`" in message
|
||||
assert "**Kwargs:** `{'key': 'value'}`" in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
def test_notify_task_failure_with_traceback(mock_send_error):
|
||||
traceback = "Traceback (most recent call last):\n File ...\nError: Something"
|
||||
|
||||
discord.notify_task_failure("test_task", "Error message", traceback_str=traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
assert "**Traceback:**" in message
|
||||
assert traceback in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||
long_error = "x" * 600 # Longer than 500 char limit
|
||||
|
||||
discord.notify_task_failure("test_task", long_error)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
assert long_error[:500] in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||
long_traceback = "x" * 1000 # Longer than 800 char limit
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
assert long_traceback[-800:] in message
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
def test_notify_task_failure_disabled(mock_send_error):
|
||||
discord.notify_task_failure("test_task", "Error message")
|
||||
mock_send_error.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
def test_notify_task_failure_send_fails(mock_send_error):
|
||||
mock_send_error.side_effect = Exception("Discord API error")
|
||||
|
||||
# Should not raise, just log the error
|
||||
discord.notify_task_failure("test_task", "Error message")
|
||||
|
||||
mock_send_error.assert_called_once()
|
151
tests/memory/common/test_summarizer.py
Normal file
151
tests/memory/common/test_summarizer.py
Normal file
@ -0,0 +1,151 @@
|
||||
from memory.common import summarizer
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response, expected",
|
||||
(
|
||||
# Basic valid cases
|
||||
("", {"summary": "", "tags": []}),
|
||||
(
|
||||
"<summary>test</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||
{"summary": "test", "tags": ["tag1", "tag2"]},
|
||||
),
|
||||
(
|
||||
"<summary>test</summary><tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||
{"summary": "test", "tags": ["tag1", "tag2"]},
|
||||
),
|
||||
# Missing summary tag
|
||||
(
|
||||
"<tags><tag>tag1</tag><tag>tag2</tag></tags>",
|
||||
{"summary": "", "tags": ["tag1", "tag2"]},
|
||||
),
|
||||
# Missing tags section
|
||||
(
|
||||
"<summary>test summary</summary>",
|
||||
{"summary": "test summary", "tags": []},
|
||||
),
|
||||
# Empty summary tag
|
||||
(
|
||||
"<summary></summary><tags><tag>tag1</tag></tags>",
|
||||
{"summary": "", "tags": ["tag1"]},
|
||||
),
|
||||
# Empty tags section
|
||||
(
|
||||
"<summary>test</summary><tags></tags>",
|
||||
{"summary": "test", "tags": []},
|
||||
),
|
||||
# Single tag
|
||||
(
|
||||
"<summary>test</summary><tags><tag>single-tag</tag></tags>",
|
||||
{"summary": "test", "tags": ["single-tag"]},
|
||||
),
|
||||
# Multiple tags
|
||||
(
|
||||
"<summary>test</summary><tags><tag>tag1</tag><tag>tag2</tag><tag>tag3</tag><tag>tag4</tag><tag>tag5</tag></tags>",
|
||||
{"summary": "test", "tags": ["tag1", "tag2", "tag3", "tag4", "tag5"]},
|
||||
),
|
||||
# Tags with special characters and hyphens
|
||||
(
|
||||
"<summary>test</summary><tags><tag>machine-learning</tag><tag>ai/ml</tag><tag>data_science</tag></tags>",
|
||||
{"summary": "test", "tags": ["machine-learning", "ai/ml", "data_science"]},
|
||||
),
|
||||
# Summary with special characters
|
||||
(
|
||||
"<summary>Test with & special characters <></summary><tags><tag>test</tag></tags>",
|
||||
{"summary": "Test with & special characters <>", "tags": ["test"]},
|
||||
),
|
||||
# Whitespace handling
|
||||
(
|
||||
"<summary> test with spaces </summary><tags><tag> tag1 </tag><tag>tag2</tag></tags>",
|
||||
{"summary": " test with spaces ", "tags": [" tag1 ", "tag2"]},
|
||||
),
|
||||
# Mixed case XML tags (should still work with BeautifulSoup)
|
||||
(
|
||||
"<Summary>test</Summary><Tags><Tag>tag1</Tag></Tags>",
|
||||
{"summary": "test", "tags": ["tag1"]},
|
||||
),
|
||||
# Multiple summary tags (should take first one)
|
||||
(
|
||||
"<summary>first</summary><summary>second</summary><tags><tag>tag1</tag></tags>",
|
||||
{"summary": "first", "tags": ["tag1"]},
|
||||
),
|
||||
# Multiple tags sections (should collect all tags)
|
||||
(
|
||||
"<summary>test</summary><tags><tag>tag1</tag></tags><tags><tag>tag2</tag></tags>",
|
||||
{"summary": "test", "tags": ["tag1", "tag2"]},
|
||||
),
|
||||
# XML with extra elements (should ignore them)
|
||||
(
|
||||
"<root><summary>test</summary><other>ignored</other><tags><tag>tag1</tag></tags></root>",
|
||||
{"summary": "test", "tags": ["tag1"]},
|
||||
),
|
||||
# XML with attributes (should still work)
|
||||
(
|
||||
'<summary id="1">test</summary><tags type="keywords"><tag>tag1</tag></tags>',
|
||||
{"summary": "test", "tags": ["tag1"]},
|
||||
),
|
||||
# Empty tag elements
|
||||
(
|
||||
"<summary>test</summary><tags><tag></tag><tag>valid-tag</tag><tag></tag></tags>",
|
||||
{"summary": "test", "tags": ["", "valid-tag", ""]},
|
||||
),
|
||||
# Self-closing tags
|
||||
(
|
||||
"<summary>test</summary><tags><tag/><tag>valid-tag</tag></tags>",
|
||||
{"summary": "test", "tags": ["", "valid-tag"]},
|
||||
),
|
||||
# Long content
|
||||
(
|
||||
f"<summary>{'a' * 1000}</summary><tags><tag>long-content</tag></tags>",
|
||||
{"summary": "a" * 1000, "tags": ["long-content"]},
|
||||
),
|
||||
# XML with newlines and formatting
|
||||
(
|
||||
"""<summary>
|
||||
Multi-line
|
||||
summary content
|
||||
</summary>
|
||||
<tags>
|
||||
<tag>formatted</tag>
|
||||
<tag>xml</tag>
|
||||
</tags>""",
|
||||
{
|
||||
"summary": "\n Multi-line\n summary content\n ",
|
||||
"tags": ["formatted", "xml"],
|
||||
},
|
||||
),
|
||||
# Malformed XML (missing closing tags) - BeautifulSoup parses as best it can
|
||||
(
|
||||
"<summary>test<tags><tag>tag1</tag></tags>",
|
||||
{"summary": "testtag1", "tags": ["tag1"]},
|
||||
),
|
||||
# Invalid XML characters should be handled by BeautifulSoup
|
||||
(
|
||||
"<summary>test & unescaped</summary><tags><tag>tag1</tag></tags>",
|
||||
{"summary": "test & unescaped", "tags": ["tag1"]},
|
||||
),
|
||||
# Only whitespace
|
||||
(
|
||||
" \n\t ",
|
||||
{"summary": "", "tags": []},
|
||||
),
|
||||
# Non-XML content
|
||||
(
|
||||
"This is just plain text without XML",
|
||||
{"summary": "", "tags": []},
|
||||
),
|
||||
# XML comments (should be ignored)
|
||||
(
|
||||
"<!-- comment --><summary>test</summary><!-- another comment --><tags><tag>tag1</tag></tags>",
|
||||
{"summary": "test", "tags": ["tag1"]},
|
||||
),
|
||||
# CDATA sections
|
||||
(
|
||||
"<summary><![CDATA[test with <special> characters]]></summary><tags><tag>cdata</tag></tags>",
|
||||
{"summary": "test with <special> characters", "tags": ["cdata"]},
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_parse_response(response, expected):
|
||||
assert summarizer.parse_response(response) == expected
|
57
tools/discord_setup.py
Normal file
57
tools/discord_setup.py
Normal file
@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
import click
|
||||
import requests
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--bot-token", type=str, required=True)
|
||||
def generate_bot_invite_url(bot_token: str):
|
||||
"""
|
||||
Generate the Discord bot invitation URL.
|
||||
|
||||
Returns:
|
||||
URL that user can click to add bot to their server
|
||||
"""
|
||||
# Get bot's client ID from the token (it's the first part before the first dot)
|
||||
# But safer to get it from the API
|
||||
try:
|
||||
headers = {"Authorization": f"Bot {bot_token}"}
|
||||
response = requests.get(
|
||||
"https://discord.com/api/v10/users/@me", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
bot_info = response.json()
|
||||
client_id = bot_info["id"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not get bot info: {e}")
|
||||
|
||||
# Permissions needed: Send Messages (2048) + Manage Channels (16) + View Channels (1024)
|
||||
permissions = 2048 + 16 + 1024 # = 3088
|
||||
|
||||
invite_url = f"https://discord.com/oauth2/authorize?client_id={client_id}&scope=bot&permissions={permissions}"
|
||||
click.echo(f"Bot invite URL: {invite_url}")
|
||||
return invite_url
|
||||
|
||||
|
||||
@click.command()
|
||||
def create_channels():
|
||||
"""Create Discord channels using the configured servers."""
|
||||
from memory.common.discord import load_servers
|
||||
|
||||
click.echo("Loading Discord servers and creating channels...")
|
||||
load_servers()
|
||||
click.echo("Discord channels setup completed.")
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""Discord setup utilities."""
|
||||
pass
|
||||
|
||||
|
||||
cli.add_command(generate_bot_invite_url, name="generate-invite")
|
||||
cli.add_command(create_channels, name="create-channels")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
Loading…
x
Reference in New Issue
Block a user