Compare commits

..

10 Commits

Author SHA1 Message Date
e95a082147 allow discord tools 2025-11-02 00:50:12 +00:00
c42513100b db backups 2025-11-02 00:24:35 +00:00
a5bc53326d backups 2025-11-02 00:01:35 +00:00
131427255a fix typing indicator 2025-11-01 20:27:57 +00:00
Daniel O'Connell
ff3ca4f109 show typing 2025-11-01 21:13:39 +01:00
3b216953ab better docker compise 2025-11-01 19:51:41 +00:00
Daniel O'Connell
d7e403fb83 optional chattiness 2025-11-01 20:39:15 +01:00
57145ac7b4 fix bugs 2025-11-01 19:35:20 +00:00
Daniel O'Connell
814090dccb use db bots 2025-11-01 18:52:37 +01:00
Daniel O'Connell
9639fa3dd7 use usage tracker 2025-11-01 18:49:06 +01:00
35 changed files with 1716 additions and 295 deletions

View File

@ -0,0 +1,114 @@
"""allow no chattiness
Revision ID: 2024235e37e7
Revises: 7dc03dbf184c
Create Date: 2025-11-01 20:38:10.849651
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "2024235e37e7"
down_revision: Union[str, None] = "7dc03dbf184c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.alter_column(
"discord_channels",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=True,
existing_server_default=sa.text("50"),
)
op.drop_column("discord_channels", "track_messages")
op.alter_column(
"discord_servers",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=True,
existing_server_default=sa.text("50"),
)
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
op.create_index(
"discord_servers_active_idx",
"discord_servers",
["ignore_messages", "last_sync_at"],
unique=False,
)
op.drop_column("discord_servers", "track_messages")
op.alter_column(
"discord_users",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=True,
existing_server_default=sa.text("50"),
)
op.drop_column("discord_users", "track_messages")
def downgrade() -> None:
op.add_column(
"discord_users",
sa.Column(
"track_messages",
sa.BOOLEAN(),
server_default=sa.text("true"),
autoincrement=False,
nullable=False,
),
)
op.alter_column(
"discord_users",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=False,
existing_server_default=sa.text("50"),
)
op.add_column(
"discord_servers",
sa.Column(
"track_messages",
sa.BOOLEAN(),
server_default=sa.text("true"),
autoincrement=False,
nullable=False,
),
)
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
op.create_index(
"discord_servers_active_idx",
"discord_servers",
["track_messages", "last_sync_at"],
unique=False,
)
op.alter_column(
"discord_servers",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=False,
existing_server_default=sa.text("50"),
)
op.add_column(
"discord_channels",
sa.Column(
"track_messages",
sa.BOOLEAN(),
server_default=sa.text("true"),
autoincrement=False,
nullable=False,
),
)
op.alter_column(
"discord_channels",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=False,
existing_server_default=sa.text("50"),
)

View File

@ -1,5 +1,3 @@
version: "3.9"
# --------------------------------------------------------------------- networks
networks:
kbnet:
@ -26,7 +24,7 @@ x-common-env: &env
REDIS_HOST: redis
REDIS_PORT: 6379
REDIS_DB: 0
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
REDIS_PASSWORD: ${REDIS_PASSWORD}
QDRANT_HOST: qdrant
DB_HOST: postgres
DB_PORT: 5432
@ -107,11 +105,11 @@ services:
image: redis:7.2-alpine
restart: unless-stopped
networks: [ kbnet ]
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${CELERY_BROKER_PASSWORD}"]
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${REDIS_PASSWORD}"]
volumes:
- redis_data:/data:rw
healthcheck:
test: [ "CMD", "redis-cli", "--pass", "${CELERY_BROKER_PASSWORD}", "ping" ]
test: [ "CMD", "redis-cli", "--pass", "${REDIS_PASSWORD}", "ping" ]
interval: 15s
timeout: 5s
retries: 5
@ -175,7 +173,7 @@ services:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
QUEUES: "backup,email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
ingest-hub:
<<: *worker-base
@ -196,6 +194,22 @@ services:
- /var/run/supervisor
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
# ------------------------------------------------------------ database backups
backup:
image: postgres:15 # Has pg_dump, wget, curl
networks: [kbnet]
depends_on: [postgres, qdrant]
env_file: [ .env ]
environment:
<<: *worker-env
secrets: [postgres_password]
volumes:
- ./tools/backup_databases.sh:/backup.sh:ro
entrypoint: ["/bin/bash"]
command: ["/backup.sh"]
profiles: [backup] # Only start when explicitly called
security_opt: ["no-new-privileges=true"]
# ------------------------------------------------------------ watchtower (auto-update)
# watchtower:
# image: containrrr/watchtower

View File

@ -16,7 +16,7 @@ RUN apt-get update && apt-get install -y \
COPY requirements ./requirements/
COPY setup.py ./
RUN mkdir src
RUN pip install -e ".[common]"
RUN pip install -e ".[workers]"
# Install Python dependencies
COPY src/ ./src/
@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
git config --global user.name "${GIT_USER_NAME}"
# Default queues to process
ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
ENV QUEUES="backup,ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
ENV PYTHONPATH="/app"
ENTRYPOINT ["./entry.sh"]

View File

@ -10,3 +10,4 @@ openai==2.3.0
# Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0
celery[redis,sqs]==5.3.6
cryptography==43.0.0

View File

@ -0,0 +1,2 @@
boto3
awscli==1.42.64

View File

@ -18,6 +18,7 @@ parsers_requires = read_requirements("requirements-parsers.txt")
api_requires = read_requirements("requirements-api.txt")
dev_requires = read_requirements("requirements-dev.txt")
ingesters_requires = read_requirements("requirements-ingesters.txt")
workers_requires = read_requirements("requirements-workers.txt")
setup(
name="memory",
@ -30,10 +31,12 @@ setup(
"common": common_requires + parsers_requires,
"dev": dev_requires,
"ingesters": common_requires + parsers_requires + ingesters_requires,
"workers": common_requires + parsers_requires + workers_requires,
"all": api_requires
+ common_requires
+ dev_requires
+ parsers_requires
+ ingesters_requires,
+ ingesters_requires
+ workers_requires,
},
)

View File

@ -290,6 +290,7 @@ class ScheduledLLMCallAdmin(ModelView, model=ScheduledLLMCall):
"created_at",
"updated_at",
]
column_sortable_list = ["executed_at", "scheduled_time", "created_at", "updated_at"]
def setup_admin(admin: Admin):

View File

@ -13,6 +13,7 @@ NOTES_ROOT = "memory.workers.tasks.notes"
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
DISCORD_ROOT = "memory.workers.tasks.discord"
BACKUP_ROOT = "memory.workers.tasks.backup"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
@ -53,13 +54,25 @@ SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call"
RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
# Backup tasks
BACKUP_PATH = f"{BACKUP_ROOT}.backup_path"
BACKUP_ALL = f"{BACKUP_ROOT}.backup_all"
def get_broker_url() -> str:
protocol = settings.CELERY_BROKER_TYPE
user = safequote(settings.CELERY_BROKER_USER)
password = safequote(settings.CELERY_BROKER_PASSWORD)
password = safequote(settings.CELERY_BROKER_PASSWORD or "")
host = settings.CELERY_BROKER_HOST
return f"{protocol}://{user}:{password}@{host}"
if password:
url = f"{protocol}://{user}:{password}@{host}"
else:
url = f"{protocol}://{host}"
if protocol == "redis":
url += f"/{settings.REDIS_DB}"
return url
app = Celery(
@ -91,6 +104,7 @@ app.conf.update(
f"{SCHEDULED_CALLS_ROOT}.*": {
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
},
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
},
)

View File

@ -22,7 +22,6 @@ from memory.common.db.models.base import Base
class MessageProcessor:
track_messages = Column(Boolean, nullable=False, server_default="true")
ignore_messages = Column(Boolean, nullable=True, default=False)
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
@ -35,8 +34,7 @@ class MessageProcessor:
)
chattiness_threshold = Column(
Integer,
nullable=False,
default=50,
nullable=True,
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
)
@ -90,7 +88,7 @@ class DiscordServer(Base, MessageProcessor):
)
__table_args__ = (
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
Index("discord_servers_active_idx", "ignore_messages", "last_sync_at"),
)

View File

@ -138,6 +138,8 @@ class DiscordBotUser(BotUser):
email: str,
api_key: str | None = None,
) -> "DiscordBotUser":
if not discord_users:
raise ValueError("discord_users must be provided")
bot = super().create_with_api_key(name, email, api_key)
bot.discord_users = discord_users
return bot

View File

@ -5,9 +5,10 @@ Simple HTTP client that communicates with the Discord collector's API server.
"""
import logging
import requests
from typing import Any
import requests
from memory.common import settings
logger = logging.getLogger(__name__)
@ -20,12 +21,12 @@ def get_api_url() -> str:
return f"http://{host}:{port}"
def send_dm(user_identifier: str, message: str) -> bool:
def send_dm(bot_id: int, user_identifier: str, message: str) -> bool:
"""Send a DM via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/send_dm",
json={"user": user_identifier, "message": message},
json={"bot_id": bot_id, "user": user_identifier, "message": message},
timeout=10,
)
response.raise_for_status()
@ -37,12 +38,33 @@ def send_dm(user_identifier: str, message: str) -> bool:
return False
def send_to_channel(channel_name: str, message: str) -> bool:
def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
"""Trigger typing indicator for a DM via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/typing/dm",
json={"bot_id": bot_id, "user": user_identifier},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
"""Send a DM via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/send_channel",
json={"channel_name": channel_name, "message": message},
json={
"bot_id": bot_id,
"channel_name": channel_name,
"message": message,
},
timeout=10,
)
response.raise_for_status()
@ -55,12 +77,33 @@ def send_to_channel(channel_name: str, message: str) -> bool:
return False
def broadcast_message(channel_name: str, message: str) -> bool:
def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
"""Trigger typing indicator for a channel via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/typing/channel",
json={"bot_id": bot_id, "channel_name": channel_name},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
return False
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/send_channel",
json={"channel_name": channel_name, "message": message},
json={
"bot_id": bot_id,
"channel_name": channel_name,
"message": message,
},
timeout=10,
)
response.raise_for_status()
@ -72,19 +115,22 @@ def broadcast_message(channel_name: str, message: str) -> bool:
return False
def is_collector_healthy() -> bool:
def is_collector_healthy(bot_id: int) -> bool:
"""Check if the Discord collector is running and healthy"""
try:
response = requests.get(f"{get_api_url()}/health", timeout=5)
response.raise_for_status()
result = response.json()
return result.get("status") == "healthy"
bot_status = result.get(str(bot_id))
if not isinstance(bot_status, dict):
return False
return bool(bot_status.get("connected"))
except requests.RequestException:
return False
def refresh_discord_metadata() -> dict[str, int] | None:
def refresh_discord_metadata() -> dict[str, Any] | None:
"""Refresh Discord server/channel/user metadata from Discord API"""
try:
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
@ -96,24 +142,24 @@ def refresh_discord_metadata() -> dict[str, int] | None:
# Convenience functions
def send_error_message(message: str) -> bool:
def send_error_message(bot_id: int, message: str) -> bool:
"""Send an error message to the error channel"""
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
return broadcast_message(bot_id, settings.DISCORD_ERROR_CHANNEL, message)
def send_activity_message(message: str) -> bool:
def send_activity_message(bot_id: int, message: str) -> bool:
"""Send an activity message to the activity channel"""
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
return broadcast_message(bot_id, settings.DISCORD_ACTIVITY_CHANNEL, message)
def send_discovery_message(message: str) -> bool:
def send_discovery_message(bot_id: int, message: str) -> bool:
"""Send a discovery message to the discovery channel"""
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
return broadcast_message(bot_id, settings.DISCORD_DISCOVERY_CHANNEL, message)
def send_chat_message(message: str) -> bool:
def send_chat_message(bot_id: int, message: str) -> bool:
"""Send a chat message to the chat channel"""
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
return broadcast_message(bot_id, settings.DISCORD_CHAT_CHANNEL, message)
def notify_task_failure(
@ -122,6 +168,7 @@ def notify_task_failure(
task_args: tuple = (),
task_kwargs: dict[str, Any] | None = None,
traceback_str: str | None = None,
bot_id: int | None = None,
) -> None:
"""
Send a task failure notification to Discord.
@ -137,6 +184,15 @@ def notify_task_failure(
logger.debug("Discord notifications disabled")
return
if bot_id is None:
bot_id = settings.DISCORD_BOT_ID
if not bot_id:
logger.debug(
"No Discord bot ID provided for task failure notification; skipping"
)
return
message = f"🚨 **Task Failed: {task_name}**\n\n"
message += f"**Error:** {error_message[:500]}\n"
@ -150,7 +206,7 @@ def notify_task_failure(
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
try:
send_error_message(message)
send_error_message(bot_id, message)
logger.info(f"Discord error notification sent for task: {task_name}")
except Exception as e:
logger.error(f"Failed to send Discord notification: {e}")

View File

@ -24,13 +24,6 @@ from memory.common.llms.base import (
)
from memory.common.llms.anthropic_provider import AnthropicProvider
from memory.common.llms.openai_provider import OpenAIProvider
from memory.common.llms.usage_tracker import (
InMemoryUsageTracker,
RateLimitConfig,
TokenAllowance,
UsageBreakdown,
UsageTracker,
)
from memory.common import tokens
__all__ = [
@ -49,11 +42,6 @@ __all__ = [
"StreamEvent",
"LLMSettings",
"create_provider",
"InMemoryUsageTracker",
"RateLimitConfig",
"TokenAllowance",
"UsageBreakdown",
"UsageTracker",
]
logger = logging.getLogger(__name__)
@ -93,28 +81,3 @@ def truncate(content: str, target_tokens: int) -> str:
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content
# bla = 1
# from memory.common.llms import *
# from memory.common.llms.tools.discord import make_discord_tools
# from memory.common.db.connection import make_session
# from memory.common.db.models import *
# model = "anthropic/claude-sonnet-4-5"
# provider = create_provider(model=model)
# with make_session() as session:
# bot = session.query(DiscordBotUser).first()
# server = session.query(DiscordServer).first()
# channel = server.channels[0]
# tools = make_discord_tools(bot, None, channel, model)
# def demo(msg: str):
# messages = [
# Message(
# role=MessageRole.USER,
# content=msg,
# )
# ]
# for m in provider.stream_with_tools(messages, tools):
# print(m)

View File

@ -23,6 +23,8 @@ logger = logging.getLogger(__name__)
class AnthropicProvider(BaseLLMProvider):
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
provider = "anthropic"
# Models that support extended thinking
THINKING_MODELS = {
"claude-opus-4",
@ -262,7 +264,7 @@ class AnthropicProvider(BaseLLMProvider):
Usage(
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
total_tokens=usage.total_tokens,
total_tokens=usage.input_tokens + usage.output_tokens,
)
)

View File

@ -12,6 +12,7 @@ from PIL import Image
from memory.common import settings
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
logger = logging.getLogger(__name__)
@ -204,7 +205,11 @@ class LLMSettings:
class BaseLLMProvider(ABC):
"""Base class for LLM providers."""
def __init__(self, api_key: str, model: str):
provider: str = ""
def __init__(
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
):
"""
Initialize the LLM provider.
@ -215,6 +220,7 @@ class BaseLLMProvider(ABC):
self.api_key = api_key
self.model = model
self._client: Any = None
self.usage_tracker: UsageTracker = usage_tracker or RedisUsageTracker()
@abstractmethod
def _initialize_client(self) -> Any:
@ -230,8 +236,14 @@ class BaseLLMProvider(ABC):
def log_usage(self, usage: Usage):
"""Log usage data."""
logger.debug(f"Token usage: {usage.to_dict()}")
print(f"Token usage: {usage.to_dict()}")
logger.debug(
f"Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.total_tokens} total"
)
self.usage_tracker.record_usage(
model=f"{self.provider}/{self.model}",
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
)
def execute_tool(
self,

View File

@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
class OpenAIProvider(BaseLLMProvider):
"""OpenAI LLM provider with streaming and tool support."""
provider = "openai"
# Models that use max_completion_tokens instead of max_tokens
# These are reasoning models with different parameter requirements
NON_REASONING_MODELS = {"gpt-4o"}

View File

@ -0,0 +1,17 @@
from memory.common.llms.usage.redis_usage_tracker import RedisUsageTracker
from memory.common.llms.usage.usage_tracker import (
InMemoryUsageTracker,
RateLimitConfig,
TokenAllowance,
UsageBreakdown,
UsageTracker,
)
__all__ = [
"InMemoryUsageTracker",
"RateLimitConfig",
"RedisUsageTracker",
"TokenAllowance",
"UsageBreakdown",
"UsageTracker",
]

View File

@ -0,0 +1,83 @@
"""Redis-backed usage tracker implementation."""
import json
from typing import Any, Iterable, Protocol
import redis
from memory.common import settings
from memory.common.llms.usage.usage_tracker import (
RateLimitConfig,
UsageState,
UsageTracker,
)
class RedisClientProtocol(Protocol):
def get(self, key: str) -> Any: # pragma: no cover - Protocol definition
...
def set(
self, key: str, value: Any
) -> Any: # pragma: no cover - Protocol definition
...
def scan_iter(
self, match: str
) -> Iterable[Any]: # pragma: no cover - Protocol definition
...
class RedisUsageTracker(UsageTracker):
"""Tracks LLM usage for providers and models using Redis for persistence."""
def __init__(
self,
configs: dict[str, RateLimitConfig] | None = None,
default_config: RateLimitConfig | None = None,
*,
redis_client: RedisClientProtocol | None = None,
key_prefix: str | None = None,
) -> None:
super().__init__(configs=configs, default_config=default_config)
if redis_client is None:
redis_client = redis.Redis.from_url(settings.REDIS_URL)
self._redis = redis_client
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
self._key_prefix = prefix.rstrip(":")
def get_state(self, model: str) -> UsageState:
redis_key = self._format_key(model)
payload = self._redis.get(redis_key)
if not payload:
return UsageState()
if isinstance(payload, bytes):
payload = payload.decode()
return UsageState.from_payload(json.loads(payload))
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
pattern = f"{self._key_prefix}:*"
for redis_key in self._redis.scan_iter(match=pattern):
key = self._ensure_str(redis_key)
payload = self._redis.get(key)
if not payload:
continue
if isinstance(payload, bytes):
payload = payload.decode()
state = UsageState.from_payload(json.loads(payload))
yield key[len(self._key_prefix) + 1 :], state
def save_state(self, model: str, state: UsageState) -> None:
redis_key = self._format_key(model)
self._redis.set(
redis_key, json.dumps(state.to_payload(), separators=(",", ":"))
)
def _format_key(self, model: str) -> str:
return f"{self._key_prefix}:{model}"
@staticmethod
def _ensure_str(value: Any) -> str:
if isinstance(value, bytes):
return value.decode()
return str(value)

View File

@ -5,6 +5,9 @@ from collections.abc import Iterable
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from threading import Lock
from typing import Any
from memory.common import settings
@dataclass(frozen=True)
@ -45,6 +48,40 @@ class UsageState:
lifetime_input_tokens: int = 0
lifetime_output_tokens: int = 0
def to_payload(self) -> dict[str, Any]:
return {
"events": [
{
"timestamp": event.timestamp.isoformat(),
"input_tokens": event.input_tokens,
"output_tokens": event.output_tokens,
}
for event in self.events
],
"window_input_tokens": self.window_input_tokens,
"window_output_tokens": self.window_output_tokens,
"lifetime_input_tokens": self.lifetime_input_tokens,
"lifetime_output_tokens": self.lifetime_output_tokens,
}
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> "UsageState":
events = deque(
UsageEvent(
timestamp=datetime.fromisoformat(event["timestamp"]),
input_tokens=event["input_tokens"],
output_tokens=event["output_tokens"],
)
for event in payload.get("events", [])
)
return cls(
events=events,
window_input_tokens=payload.get("window_input_tokens", 0),
window_output_tokens=payload.get("window_output_tokens", 0),
lifetime_input_tokens=payload.get("lifetime_input_tokens", 0),
lifetime_output_tokens=payload.get("lifetime_output_tokens", 0),
)
@dataclass
class TokenAllowance:
@ -77,13 +114,13 @@ class UsageBreakdown:
def split_model_key(model: str) -> tuple[str, str]:
if "/" not in model:
raise ValueError(
"model must be formatted as '<provider>/<model_name>'"
f"model must be formatted as '<provider>/<model_name>': got '{model}'"
)
provider, model_name = model.split("/", maxsplit=1)
if not provider or not model_name:
raise ValueError(
"model must include both provider and model name separated by '/'"
f"model must include both provider and model name separated by '/': got '{model}'"
)
return provider, model_name
@ -93,11 +130,15 @@ class UsageTracker:
def __init__(
self,
configs: dict[str, RateLimitConfig],
configs: dict[str, RateLimitConfig] | None = None,
default_config: RateLimitConfig | None = None,
) -> None:
self._configs = configs
self._default_config = default_config
self._configs = configs or {}
self._default_config = default_config or RateLimitConfig(
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
)
self._lock = Lock()
# ------------------------------------------------------------------
@ -180,15 +221,14 @@ class UsageTracker:
"""
split_model_key(model)
key = model
with self._lock:
config = self._get_config(key)
config = self._get_config(model)
if config is None:
return None
state = self.get_state(key)
state = self.get_state(model)
self._prune_expired_events(state, config, now=timestamp)
self.save_state(key, state)
self.save_state(model, state)
if config.max_total_tokens is None:
total_remaining = None
@ -205,9 +245,7 @@ class UsageTracker:
if config.max_output_tokens is None:
output_remaining = None
else:
output_remaining = (
config.max_output_tokens - state.window_output_tokens
)
output_remaining = config.max_output_tokens - state.window_output_tokens
return TokenAllowance(
input_tokens=clamp_non_negative(input_remaining),
@ -222,8 +260,8 @@ class UsageTracker:
with self._lock:
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
for key, state in self.iter_state_items():
prov, model_name = split_model_key(key)
for model, state in self.iter_state_items():
prov, model_name = split_model_key(model)
if provider and provider != prov:
continue
if model and model != model_name:
@ -265,8 +303,8 @@ class UsageTracker:
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _get_config(self, key: str) -> RateLimitConfig | None:
return self._configs.get(key) or self._default_config
def _get_config(self, model: str) -> RateLimitConfig | None:
return self._configs.get(model) or self._default_config
def _prune_expired_events(
self,
@ -313,4 +351,3 @@ def clamp_non_negative(value: int | None) -> int | None:
if value is None:
return None
return 0 if value < 0 else value

View File

@ -31,33 +31,25 @@ def make_db_url(
DB_URL = os.getenv("DATABASE_URL", make_db_url())
# Redis settings
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
REDIS_DB = os.getenv("REDIS_DB", "0")
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
if REDIS_PASSWORD:
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
else:
REDIS_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
# Broker settings
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
REDIS_DB = os.getenv("REDIS_DB", "0")
CELERY_BROKER_USER = os.getenv(
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
)
CELERY_BROKER_PASSWORD = os.getenv(
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
)
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
if not CELERY_BROKER_HOST:
if CELERY_BROKER_TYPE == "amqp":
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
elif CELERY_BROKER_TYPE == "redis":
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "")
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", REDIS_PASSWORD)
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "") or f"{REDIS_HOST}:{REDIS_PORT}"
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
# File storage settings
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
EBOOK_STORAGE_DIR = pathlib.Path(
@ -81,9 +73,14 @@ WEBPAGE_STORAGE_DIR = pathlib.Path(
NOTES_STORAGE_DIR = pathlib.Path(
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
)
PRIVATE_DIRS = [
EMAIL_STORAGE_DIR,
NOTES_STORAGE_DIR,
PHOTO_STORAGE_DIR,
CHUNK_STORAGE_DIR,
]
storage_dirs = [
FILE_STORAGE_DIR,
EBOOK_STORAGE_DIR,
EMAIL_STORAGE_DIR,
CHUNK_STORAGE_DIR,
@ -148,6 +145,18 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES = int(
os.getenv("DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES", 30)
)
DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS = int(
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS", 1_000_000)
)
DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS = int(
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS", 1_000_000)
)
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
# Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
@ -193,3 +202,14 @@ DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003))
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0")
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
# S3 Backup settings
S3_BACKUP_BUCKET = os.getenv("S3_BACKUP_BUCKET", "equistamp-memory-backup")
S3_BACKUP_PREFIX = os.getenv("S3_BACKUP_PREFIX", "Daniel")
S3_BACKUP_REGION = os.getenv("S3_BACKUP_REGION", "eu-central-1")
BACKUP_ENCRYPTION_KEY = os.getenv("BACKUP_ENCRYPTION_KEY", "")
S3_BACKUP_ENABLED = boolean_env("S3_BACKUP_ENABLED", bool(BACKUP_ENCRYPTION_KEY))
S3_BACKUP_INTERVAL = int(
os.getenv("S3_BACKUP_INTERVAL", 60 * 60 * 24)
) # Daily by default

View File

@ -7,17 +7,18 @@ providing HTTP endpoints for sending Discord messages.
import asyncio
import logging
from contextlib import asynccontextmanager
import traceback
from contextlib import asynccontextmanager
from typing import cast
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from memory.common import settings
from memory.discord.collector import MessageCollector
from memory.common.db.models.users import BotUser
from memory.common.db.connection import make_session
from memory.common.db.models.users import DiscordBotUser
from memory.discord.collector import MessageCollector
logger = logging.getLogger(__name__)
@ -34,6 +35,16 @@ class SendChannelRequest(BaseModel):
message: str
class TypingDMRequest(BaseModel):
bot_id: int
user: int | str
class TypingChannelRequest(BaseModel):
bot_id: int
channel_name: str
class Collector:
collector: MessageCollector
collector_task: asyncio.Task
@ -41,37 +52,25 @@ class Collector:
bot_token: str
bot_name: str
def __init__(self, collector: MessageCollector, bot: BotUser):
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
self.collector = collector
self.collector_task = asyncio.create_task(collector.start(bot.api_key))
self.bot_id = bot.id
self.bot_token = bot.api_key
self.bot_name = bot.name
# Application state
class AppState:
def __init__(self):
self.collector: MessageCollector | None = None
self.collector_task: asyncio.Task | None = None
app_state = AppState()
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
self.bot_id = cast(int, bot.id)
self.bot_token = str(bot.api_key)
self.bot_name = str(bot.name)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage Discord collector lifecycle"""
if not settings.DISCORD_BOT_TOKEN:
logger.error("DISCORD_BOT_TOKEN not configured")
return
def make_collector(bot: BotUser):
def make_collector(bot: DiscordBotUser):
collector = MessageCollector()
return Collector(collector=collector, bot=bot)
with make_session() as session:
app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()}
bots = session.query(DiscordBotUser).all()
app.bots = {bot.id: make_collector(bot) for bot in bots}
logger.info(f"Discord collectors started for {len(app.bots)} bots")
@ -120,6 +119,32 @@ async def send_dm_endpoint(request: SendDMRequest):
}
@app.post("/typing/dm")
async def trigger_dm_typing(request: TypingDMRequest):
"""Trigger a typing indicator for a DM via the collector"""
collector = app.bots.get(request.bot_id)
if not collector:
raise HTTPException(status_code=404, detail="Bot not found")
try:
success = await collector.collector.trigger_typing_dm(request.user)
except Exception as e:
logger.error(f"Failed to trigger DM typing: {e}")
raise HTTPException(status_code=500, detail=str(e))
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to trigger typing for {request.user}",
)
return {
"success": True,
"user": request.user,
"message": f"Typing triggered for {request.user}",
}
@app.post("/send_channel")
async def send_channel_endpoint(request: SendChannelRequest):
"""Send a message to a channel via the collector's Discord client"""
@ -131,6 +156,9 @@ async def send_channel_endpoint(request: SendChannelRequest):
success = await collector.collector.send_to_channel(
request.channel_name, request.message
)
except Exception as e:
logger.error(f"Failed to send channel message: {e}")
raise HTTPException(status_code=500, detail=str(e))
if success:
return {
@ -138,16 +166,38 @@ async def send_channel_endpoint(request: SendChannelRequest):
"message": f"Message sent to channel {request.channel_name}",
"channel": request.channel_name,
}
else:
raise HTTPException(
status_code=400,
detail=f"Failed to send message to channel {request.channel_name}",
)
@app.post("/typing/channel")
async def trigger_channel_typing(request: TypingChannelRequest):
"""Trigger a typing indicator for a channel via the collector"""
collector = app.bots.get(request.bot_id)
if not collector:
raise HTTPException(status_code=404, detail="Bot not found")
try:
success = await collector.collector.trigger_typing_channel(request.channel_name)
except Exception as e:
logger.error(f"Failed to send channel message: {e}")
logger.error(f"Failed to trigger channel typing: {e}")
raise HTTPException(status_code=500, detail=str(e))
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to trigger typing for channel {request.channel_name}",
)
return {
"success": True,
"channel": request.channel_name,
"message": f"Typing triggered for channel {request.channel_name}",
}
@app.get("/health")
async def health_check():
@ -155,9 +205,8 @@ async def health_check():
if not app.bots:
raise HTTPException(status_code=503, detail="Discord collector not running")
collector = app_state.collector
return {
collector.bot_name: {
bot.bot_name: {
"status": "healthy",
"connected": not bot.collector.is_closed(),
"user": str(bot.collector.user) if bot.collector.user else None,

View File

@ -203,7 +203,7 @@ class MessageCollector(commands.Bot):
async def setup_hook(self):
"""Register slash commands when the bot is ready."""
register_slash_commands(self)
register_slash_commands(self, name=self.user.name)
async def on_ready(self):
"""Called when bot connects to Discord"""
@ -381,6 +381,27 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to send DM to {user_identifier}: {e}")
return False
async def trigger_typing_dm(self, user_identifier: int | str) -> bool:
"""Trigger typing indicator in a DM"""
try:
user = await self.get_user(user_identifier)
if not user:
logger.error(f"User {user_identifier} not found")
return False
channel = user.dm_channel or await user.create_dm()
if not channel:
logger.error(f"DM channel not available for {user_identifier}")
return False
async with channel.typing():
pass
return True
except Exception as e:
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False
async def send_to_channel(self, channel_name: str, message: str) -> bool:
"""Send a message to a channel by name across all guilds"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
@ -400,23 +421,21 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to send message to channel {channel_name}: {e}")
return False
async def run_collector():
"""Run the Discord message collector"""
if not settings.DISCORD_BOT_TOKEN:
logger.error("DISCORD_BOT_TOKEN not configured")
return
collector = MessageCollector()
async def trigger_typing_channel(self, channel_name: str) -> bool:
"""Trigger typing indicator in a channel"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
return False
try:
await collector.start(settings.DISCORD_BOT_TOKEN)
channel = await self.get_channel_by_name(channel_name)
if not channel:
logger.error(f"Channel {channel_name} not found")
return False
async with channel.typing():
pass
return True
except Exception as e:
logger.error(f"Discord collector failed: {e}")
raise
if __name__ == "__main__":
import asyncio
asyncio.run(run_collector())
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
return False

View File

@ -41,8 +41,13 @@ class CommandContext:
CommandHandler = Callable[..., CommandResponse]
def register_slash_commands(bot: discord.Client) -> None:
"""Register the collector slash commands on the provided bot."""
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
"""Register the collector slash commands on the provided bot.
Args:
bot: Discord bot client
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
"""
if getattr(bot, "_memory_commands_registered", False):
return
@ -54,12 +59,14 @@ def register_slash_commands(bot: discord.Client) -> None:
tree = bot.tree
@tree.command(name="memory_prompt", description="Show the current system prompt")
@tree.command(
name=f"{name}_show_prompt", description="Show the current system prompt"
)
@discord.app_commands.describe(
scope="Which configuration to inspect",
user="Target user when the scope is 'user'",
)
async def prompt_command(
async def show_prompt_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
user: discord.User | None = None,
@ -72,12 +79,35 @@ def register_slash_commands(bot: discord.Client) -> None:
)
@tree.command(
name="memory_chattiness",
description="Show or update the chattiness threshold for the target",
name=f"{name}_set_prompt",
description="Set the system prompt for the target",
)
@discord.app_commands.describe(
scope="Which configuration to modify",
prompt="The system prompt to set",
user="Target user when the scope is 'user'",
)
async def set_prompt_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
prompt: str,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_set_prompt,
target_user=user,
prompt=prompt,
)
@tree.command(
name=f"{name}_chattiness",
description="Show or update the chattiness for the target",
)
@discord.app_commands.describe(
scope="Which configuration to inspect",
value="Optional new threshold value between 0 and 100",
value="Optional new chattiness value between 0 and 100",
user="Target user when the scope is 'user'",
)
async def chattiness_command(
@ -95,7 +125,7 @@ def register_slash_commands(bot: discord.Client) -> None:
)
@tree.command(
name="memory_ignore",
name=f"{name}_ignore",
description="Toggle whether the bot should ignore messages for the target",
)
@discord.app_commands.describe(
@ -117,7 +147,10 @@ def register_slash_commands(bot: discord.Client) -> None:
ignore_enabled=enabled,
)
@tree.command(name="memory_summary", description="Show the stored summary for the target")
@tree.command(
name=f"{name}_show_summary",
description="Show the stored summary for the target",
)
@discord.app_commands.describe(
scope="Which configuration to inspect",
user="Target user when the scope is 'user'",
@ -337,6 +370,18 @@ def handle_prompt(context: CommandContext) -> CommandResponse:
)
def handle_set_prompt(
context: CommandContext,
*,
prompt: str,
) -> CommandResponse:
setattr(context.target, "system_prompt", prompt)
return CommandResponse(
content=f"Updated system prompt for {context.display_name}.",
)
def handle_chattiness(
context: CommandContext,
*,
@ -347,20 +392,22 @@ def handle_chattiness(
if value is None:
return CommandResponse(
content=(
f"Chattiness threshold for {context.display_name}: "
f"Chattiness for {context.display_name}: "
f"{getattr(model, 'chattiness_threshold', 'not set')}"
)
)
if not 0 <= value <= 100:
raise CommandError("Chattiness threshold must be between 0 and 100.")
raise CommandError("Chattiness must be between 0 and 100.")
setattr(model, "chattiness_threshold", value)
return CommandResponse(
content=(
f"Updated chattiness threshold for {context.display_name} "
f"to {value}."
f"Updated chattiness for {context.display_name} to {value}."
"\n"
"This can be treated as how much you want the bot to pipe up by itself, as a percentage, "
"where 0 is never and 100 is always."
)
)

View File

@ -10,6 +10,7 @@ from memory.common.celery_app import (
TRACK_GIT_CHANGES,
SYNC_LESSWRONG,
RUN_SCHEDULED_CALLS,
BACKUP_ALL,
)
logger = logging.getLogger(__name__)
@ -48,4 +49,8 @@ app.conf.beat_schedule = {
"task": RUN_SCHEDULED_CALLS,
"schedule": settings.SCHEDULED_CALL_RUN_INTERVAL,
},
"backup-all": {
"task": BACKUP_ALL,
"schedule": settings.S3_BACKUP_INTERVAL,
},
}

View File

@ -3,11 +3,12 @@ Import sub-modules so Celery can register their @app.task decorators.
"""
from memory.workers.tasks import (
email,
comic,
backup,
blogs,
comic,
discord,
ebook,
email,
forums,
maintenance,
notes,
@ -15,8 +16,8 @@ from memory.workers.tasks import (
scheduled_calls,
) # noqa
__all__ = [
"backup",
"email",
"comic",
"blogs",

View File

@ -0,0 +1,152 @@
"""S3 backup tasks for memory files."""
import base64
import hashlib
import io
import logging
import subprocess
import tarfile
from pathlib import Path
import boto3
from cryptography.fernet import Fernet
from memory.common import settings
from memory.common.celery_app import app, BACKUP_PATH, BACKUP_ALL
logger = logging.getLogger(__name__)
def get_cipher() -> Fernet:
"""Create Fernet cipher from password in settings."""
if not settings.BACKUP_ENCRYPTION_KEY:
raise ValueError("BACKUP_ENCRYPTION_KEY not set in environment")
# Derive key from password using SHA256
key_bytes = hashlib.sha256(settings.BACKUP_ENCRYPTION_KEY.encode()).digest()
key = base64.urlsafe_b64encode(key_bytes)
return Fernet(key)
def create_tarball(directory: Path) -> bytes:
"""Create a gzipped tarball of a directory in memory."""
if not directory.exists():
logger.warning(f"Directory does not exist: {directory}")
return b""
tar_buffer = io.BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar:
tar.add(directory, arcname=directory.name)
tar_buffer.seek(0)
return tar_buffer.read()
def sync_unencrypted_directory(path: Path) -> dict:
"""Sync an unencrypted directory to S3 using aws s3 sync."""
if not path.exists():
logger.warning(f"Directory does not exist: {path}")
return {"synced": False, "reason": "directory_not_found"}
s3_uri = f"s3://{settings.S3_BACKUP_BUCKET}/{settings.S3_BACKUP_PREFIX}/{path.name}"
cmd = [
"aws",
"s3",
"sync",
str(path),
s3_uri,
"--delete",
"--region",
settings.S3_BACKUP_REGION,
]
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True,
)
logger.info(f"Synced {path} to {s3_uri}")
logger.debug(f"Output: {result.stdout}")
return {"synced": True, "directory": path, "s3_uri": s3_uri}
except subprocess.CalledProcessError as e:
logger.error(f"Failed to sync {path}: {e.stderr}")
return {"synced": False, "directory": path, "error": str(e)}
def backup_encrypted_directory(path: Path) -> dict:
"""Create encrypted tarball of directory and upload to S3."""
if not path.exists():
logger.warning(f"Directory does not exist: {path}")
return {"uploaded": False, "reason": "directory_not_found"}
# Create tarball
logger.info(f"Creating tarball of {path}...")
tarball_bytes = create_tarball(path)
if not tarball_bytes:
logger.warning(f"Empty tarball for {path}, skipping")
return {"uploaded": False, "reason": "empty_directory"}
# Encrypt
logger.info(f"Encrypting {path} ({len(tarball_bytes)} bytes)...")
cipher = get_cipher()
encrypted_bytes = cipher.encrypt(tarball_bytes)
# Upload to S3
s3_client = boto3.client("s3", region_name=settings.S3_BACKUP_REGION)
s3_key = f"{settings.S3_BACKUP_PREFIX}/{path.name}.tar.gz.enc"
try:
logger.info(
f"Uploading encrypted {path} to s3://{settings.S3_BACKUP_BUCKET}/{s3_key}"
)
s3_client.put_object(
Bucket=settings.S3_BACKUP_BUCKET,
Key=s3_key,
Body=encrypted_bytes,
ServerSideEncryption="AES256",
)
return {
"uploaded": True,
"directory": path,
"size_bytes": len(encrypted_bytes),
"s3_key": s3_key,
}
except Exception as e:
logger.error(f"Failed to upload {path}: {e}")
return {"uploaded": False, "directory": path, "error": str(e)}
@app.task(name=BACKUP_PATH)
def backup_to_s3(path: Path | str):
"""Backup a specific directory to S3."""
path = Path(path)
if not path.exists():
logger.warning(f"Directory does not exist: {path}")
return {"uploaded": False, "reason": "directory_not_found"}
if path in settings.PRIVATE_DIRS:
return backup_encrypted_directory(path)
return sync_unencrypted_directory(path)
@app.task(name=BACKUP_ALL)
def backup_all_to_s3():
"""Main backup task that syncs unencrypted dirs and uploads encrypted dirs."""
if not settings.S3_BACKUP_ENABLED:
logger.info("S3 backup is disabled")
return {"status": "disabled"}
logger.info("Starting S3 backup...")
for dir_name in settings.storage_dirs:
backup_to_s3.delay((settings.FILE_STORAGE_DIR / dir_name).as_posix())
return {
"status": "success",
"message": f"Started backup for {len(settings.storage_dirs)} directories",
}

View File

@ -7,7 +7,7 @@ import logging
import re
import textwrap
from datetime import datetime
from typing import Any
from typing import Any, cast
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy.orm import Session, scoped_session
@ -56,8 +56,15 @@ def call_llm(
message: DiscordMessage,
model: str,
msgs: list[str] = [],
allowed_tools: list[str] = [],
allowed_tools: list[str] | None = None,
) -> str | None:
provider = create_provider(model=model)
if provider.usage_tracker.is_rate_limited(model):
logger.error(
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
)
return None
tools = make_discord_tools(
message.recipient_user.system_user,
message.from_user,
@ -67,13 +74,13 @@ def call_llm(
tools = {
name: tool
for name, tool in tools.items()
if message.tool_allowed(name) and name in allowed_tools
if message.tool_allowed(name)
and (allowed_tools is None or name in allowed_tools)
}
system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel
)
provider = create_provider(model=model)
messages = previous_messages(
session,
message.recipient_user and message.recipient_user.id,
@ -127,10 +134,32 @@ def should_process(message: DiscordMessage) -> bool:
if not (res := re.search(r"<number>(.*)</number>", response)):
return False
try:
return int(res.group(1)) > message.chattiness_threshold
if int(res.group(1)) < 100 - message.chattiness_threshold:
return False
except ValueError:
return False
if not (bot_id := _resolve_bot_id(message)):
return False
if message.channel and message.channel.server:
discord.trigger_typing_channel(bot_id, message.channel.name)
else:
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
return True
def _resolve_bot_id(discord_message: DiscordMessage) -> int | None:
recipient = discord_message.recipient_user
if not recipient:
return None
system_user = recipient.system_user
if not system_user:
return None
return getattr(system_user, "id", None)
@app.task(name=PROCESS_DISCORD_MESSAGE)
@safe_task_execution
@ -152,14 +181,33 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id,
}
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
bot_id = _resolve_bot_id(discord_message)
if not bot_id:
logger.warning(
"No associated Discord bot user for message %s; skipping send",
message_id,
)
return {
"status": "processed",
"message_id": message_id,
}
try:
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
except Exception:
logger.exception("Failed to generate Discord response")
print("response:", response)
if not response:
pass
elif discord_message.channel.server:
discord.send_to_channel(discord_message.channel.name, response)
return {
"status": "processed",
"message_id": message_id,
}
if discord_message.channel.server:
discord.send_to_channel(bot_id, discord_message.channel.name, response)
else:
discord.send_dm(discord_message.from_user.username, response)
discord.send_dm(bot_id, discord_message.from_user.username, response)
return {
"status": "processed",

View File

@ -37,12 +37,22 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
if len(message) > 1900: # Leave some buffer
message = message[:1900] + "\n\n... (response truncated)"
bot_id_value = scheduled_call.user_id
if bot_id_value is None:
logger.warning(
"Scheduled call %s has no associated bot user; skipping Discord send",
scheduled_call.id,
)
return
bot_id = cast(int, bot_id_value)
if discord_user := scheduled_call.discord_user:
logger.info(f"Sending DM to {discord_user.username}: {message}")
discord.send_dm(discord_user.username, message)
discord.send_dm(bot_id, discord_user.username, message)
elif discord_channel := scheduled_call.discord_channel:
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
discord.broadcast_message(discord_channel.name, message)
discord.broadcast_message(bot_id, discord_channel.name, message)
else:
logger.warning(
f"No Discord user or channel found for scheduled call {scheduled_call.id}"

View File

@ -1,7 +1,24 @@
from datetime import datetime, timedelta, timezone
from typing import Iterable
import pytest
try:
import redis # noqa: F401 # pragma: no cover - optional test dependency
except ModuleNotFoundError: # pragma: no cover - import guard for test envs
import sys
from types import SimpleNamespace
class _RedisStub(SimpleNamespace):
class Redis: # type: ignore[no-redef]
def __init__(self, *args: object, **kwargs: object) -> None:
raise ModuleNotFoundError(
"The 'redis' package is required to use RedisUsageTracker"
)
sys.modules.setdefault("redis", _RedisStub())
from memory.common.llms.redis_usage_tracker import RedisUsageTracker
from memory.common.llms.usage_tracker import (
InMemoryUsageTracker,
RateLimitConfig,
@ -9,6 +26,24 @@ from memory.common.llms.usage_tracker import (
)
class FakeRedis:
def __init__(self) -> None:
self._store: dict[str, str] = {}
def get(self, key: str) -> str | None:
return self._store.get(key)
def set(self, key: str, value: str) -> None:
self._store[key] = value
def scan_iter(self, match: str) -> Iterable[str]:
from fnmatch import fnmatch
for key in list(self._store.keys()):
if fnmatch(key, match):
yield key
@pytest.fixture
def tracker() -> InMemoryUsageTracker:
config = RateLimitConfig(
@ -25,6 +60,23 @@ def tracker() -> InMemoryUsageTracker:
)
@pytest.fixture
def redis_tracker() -> RedisUsageTracker:
config = RateLimitConfig(
window=timedelta(minutes=1),
max_input_tokens=1_000,
max_output_tokens=2_000,
max_total_tokens=2_500,
)
return RedisUsageTracker(
{
"anthropic/claude-3": config,
"anthropic/haiku": config,
},
redis_client=FakeRedis(),
)
@pytest.mark.parametrize(
"window, kwargs",
[
@ -139,6 +191,22 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
assert tracker.is_rate_limited("openai/gpt-4o")
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
allowance = redis_tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
assert allowance is not None
assert allowance.input_tokens == 900
breakdown = redis_tracker.get_usage_breakdown()
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
items = dict(redis_tracker.iter_state_items())
assert set(items.keys()) == {"anthropic/claude-3", "anthropic/haiku"}
def test_usage_tracker_base_not_instantiable() -> None:
class DummyTracker(UsageTracker):
pass

View File

@ -4,6 +4,8 @@ import requests
from memory.common import discord
BOT_ID = 42
@pytest.fixture
def mock_api_url():
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_dm",
json={"user": "user123", "message": "Hello!"},
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
timeout=10,
)
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!")
result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!")
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"},
json={
"bot_id": BOT_ID,
"channel_name": "general",
"message": "Announcement!",
},
timeout=10,
)
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!")
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout")
result = discord.broadcast_message("general", "Announcement!")
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy"""
mock_response = Mock()
mock_response.json.return_value = {"status": "healthy"}
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
result = discord.is_collector_healthy(BOT_ID)
assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status"""
mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"}
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
result = discord.is_collector_healthy(BOT_ID)
assert result is False
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused")
result = discord.is_collector_healthy()
result = discord.is_collector_healthy(BOT_ID)
assert result is False
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
"""Test sending error message to error channel"""
mock_broadcast.return_value = True
result = discord.send_error_message("Something broke")
result = discord.send_error_message(BOT_ID, "Something broke")
assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke")
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
@patch("memory.common.discord.broadcast_message")
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
"""Test sending activity message to activity channel"""
mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in")
result = discord.send_activity_message(BOT_ID, "User logged in")
assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in")
mock_broadcast.assert_called_once_with(
BOT_ID, "activity", "User logged in"
)
@patch("memory.common.discord.broadcast_message")
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
"""Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern")
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
mock_broadcast.assert_called_once_with(
BOT_ID, "discoveries", "Found interesting pattern"
)
@patch("memory.common.discord.broadcast_message")
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
"""Test sending chat message to chat channel"""
mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot")
result = discord.send_chat_message(BOT_ID, "Hello from bot")
assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong")
discord.notify_task_failure(
"test_task", "Something went wrong", bot_id=BOT_ID
)
mock_send_error.assert_called_once()
message = mock_send_error.call_args[0][0]
assert mock_send_error.call_args[0][0] == BOT_ID
message = mock_send_error.call_args[0][1]
assert "🚨 **Task Failed: test_task**" in message
assert "**Error:** Something went wrong" in message
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
"Error occurred",
task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123},
bot_id=BOT_ID,
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
"""Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
discord.notify_task_failure(
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
assert "**Traceback:**" in message
assert "Exception: test" in message
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
"""Test that long error messages are truncated"""
long_error = "x" * 600
discord.notify_task_failure("test_task", long_error)
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
# Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
"""Test that long tracebacks are truncated"""
long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
discord.notify_task_failure(
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
# Traceback should show last 800 chars
assert long_traceback[-800:] in message
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated"""
long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args)
discord.notify_task_failure(
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
# Args should be truncated to 200 chars
assert (
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
discord.notify_task_failure(
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
# Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error):
"""Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred")
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_not_called()
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
mock_send_error.side_effect = Exception("Failed to send")
# Should not raise
discord.notify_task_failure("test_task", "Error occurred")
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_called_once()
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
):
"""Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message)
mock_broadcast.assert_called_once_with("test-channel", message)
function(BOT_ID, message)
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
@patch("requests.post")
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars)
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars
json_payload = call_args[1]["json"]
assert json_payload["message"] == message_with_special_chars
assert json_payload["bot_id"] == BOT_ID
@patch("requests.post")
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
mock_post.return_value = mock_response
long_message = "A" * 2000
result = discord.broadcast_message("general", long_message)
result = discord.broadcast_message(BOT_ID, "general", long_message)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message
json_payload = call_args[1]["json"]
assert json_payload["message"] == long_message
assert json_payload["bot_id"] == BOT_ID
@patch("requests.get")
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
result = discord.is_collector_healthy(BOT_ID)
assert result is False

View File

@ -4,6 +4,8 @@ import requests
from memory.common import discord
BOT_ID = 42
@pytest.fixture
def mock_api_url():
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_dm",
json={"user": "user123", "message": "Hello!"},
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
timeout=10,
)
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!")
result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!")
result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!")
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is True
mock_post.assert_called_once_with(
"http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"},
json={
"bot_id": BOT_ID,
"channel_name": "general",
"message": "Announcement!",
},
timeout=10,
)
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!")
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout")
result = discord.broadcast_message("general", "Announcement!")
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy"""
mock_response = Mock()
mock_response.json.return_value = {"status": "healthy"}
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
result = discord.is_collector_healthy(BOT_ID)
assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status"""
mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"}
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
result = discord.is_collector_healthy(BOT_ID)
assert result is False
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused")
result = discord.is_collector_healthy()
result = discord.is_collector_healthy(BOT_ID)
assert result is False
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
"""Test sending error message to error channel"""
mock_broadcast.return_value = True
result = discord.send_error_message("Something broke")
result = discord.send_error_message(BOT_ID, "Something broke")
assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke")
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
@patch("memory.common.discord.broadcast_message")
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
"""Test sending activity message to activity channel"""
mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in")
result = discord.send_activity_message(BOT_ID, "User logged in")
assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in")
mock_broadcast.assert_called_once_with(
BOT_ID, "activity", "User logged in"
)
@patch("memory.common.discord.broadcast_message")
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
"""Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern")
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
mock_broadcast.assert_called_once_with(
BOT_ID, "discoveries", "Found interesting pattern"
)
@patch("memory.common.discord.broadcast_message")
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
"""Test sending chat message to chat channel"""
mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot")
result = discord.send_chat_message(BOT_ID, "Hello from bot")
assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
@patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong")
discord.notify_task_failure(
"test_task", "Something went wrong", bot_id=BOT_ID
)
mock_send_error.assert_called_once()
message = mock_send_error.call_args[0][0]
assert mock_send_error.call_args[0][0] == BOT_ID
message = mock_send_error.call_args[0][1]
assert "🚨 **Task Failed: test_task**" in message
assert "**Error:** Something went wrong" in message
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
"Error occurred",
task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123},
bot_id=BOT_ID,
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
"""Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
discord.notify_task_failure(
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
assert "**Traceback:**" in message
assert "Exception: test" in message
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
"""Test that long error messages are truncated"""
long_error = "x" * 600
discord.notify_task_failure("test_task", long_error)
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
# Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
"""Test that long tracebacks are truncated"""
long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
discord.notify_task_failure(
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
# Traceback should show last 800 chars
assert long_traceback[-800:] in message
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated"""
long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args)
discord.notify_task_failure(
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
# Args should be truncated to 200 chars
assert (
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
discord.notify_task_failure(
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0]
message = mock_send_error.call_args[0][1]
# Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error):
"""Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred")
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_not_called()
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
mock_send_error.side_effect = Exception("Failed to send")
# Should not raise
discord.notify_task_failure("test_task", "Error occurred")
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_called_once()
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
):
"""Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message)
mock_broadcast.assert_called_once_with("test-channel", message)
function(BOT_ID, message)
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
@patch("requests.post")
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars)
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars
json_payload = call_args[1]["json"]
assert json_payload["message"] == message_with_special_chars
assert json_payload["bot_id"] == BOT_ID
@patch("requests.post")
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
mock_post.return_value = mock_response
long_message = "A" * 2000
result = discord.broadcast_message("general", long_message)
result = discord.broadcast_message(BOT_ID, "general", long_message)
assert result is True
call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message
json_payload = call_args[1]["json"]
assert json_payload["message"] == long_message
assert json_payload["bot_id"] == BOT_ID
@patch("requests.get")
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
result = discord.is_collector_healthy()
result = discord.is_collector_healthy(BOT_ID)
assert result is False

View File

@ -0,0 +1,393 @@
import io
import subprocess
import tarfile
from unittest.mock import Mock, patch, MagicMock
import pytest
from botocore.exceptions import ClientError
from memory.common import settings
from memory.workers.tasks import backup
@pytest.fixture
def sample_files():
"""Create sample files in memory_files structure."""
base = settings.FILE_STORAGE_DIR
dirs_with_files = {
"emails": ["email1.txt", "email2.txt"],
"notes": ["note1.md", "note2.md"],
"photos": ["photo1.jpg"],
"comics": ["comic1.png", "comic2.png"],
"ebooks": ["book1.epub"],
"webpages": ["page1.html"],
}
for dir_name, filenames in dirs_with_files.items():
dir_path = base / dir_name
dir_path.mkdir(parents=True, exist_ok=True)
for filename in filenames:
file_path = dir_path / filename
content = f"Content of {dir_name}/{filename}\n" * 100
file_path.write_text(content)
@pytest.fixture
def mock_s3_client():
"""Mock boto3 S3 client."""
with patch("boto3.client") as mock_client:
s3_mock = MagicMock()
mock_client.return_value = s3_mock
yield s3_mock
@pytest.fixture
def backup_settings():
"""Mock backup settings."""
with (
patch.object(settings, "S3_BACKUP_ENABLED", True),
patch.object(settings, "BACKUP_ENCRYPTION_KEY", "test-password-123"),
patch.object(settings, "S3_BACKUP_BUCKET", "test-bucket"),
patch.object(settings, "S3_BACKUP_PREFIX", "test-prefix"),
patch.object(settings, "S3_BACKUP_REGION", "us-east-1"),
):
yield
@pytest.fixture
def get_test_path():
"""Helper to construct test paths."""
return lambda dir_name: settings.FILE_STORAGE_DIR / dir_name
@pytest.mark.parametrize(
"data,key",
[
(b"This is a test message", "my-secret-key"),
(b"\x00\x01\x02\xff" * 10000, "another-key"),
(b"x" * 1000000, "large-data-key"),
],
)
def test_encrypt_decrypt_roundtrip(data, key):
"""Test encryption and decryption produces original data."""
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", key):
cipher = backup.get_cipher()
encrypted = cipher.encrypt(data)
decrypted = cipher.decrypt(encrypted)
assert decrypted == data
assert encrypted != data
def test_encrypt_decrypt_tarball(sample_files):
"""Test full tarball creation, encryption, and decryption."""
emails_dir = settings.FILE_STORAGE_DIR / "emails"
# Create tarball
tarball_bytes = backup.create_tarball(emails_dir)
assert len(tarball_bytes) > 0
# Encrypt
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "tarball-key"):
cipher = backup.get_cipher()
encrypted = cipher.encrypt(tarball_bytes)
# Decrypt
decrypted = cipher.decrypt(encrypted)
assert decrypted == tarball_bytes
# Verify tarball can be extracted
tar_buffer = io.BytesIO(decrypted)
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
members = tar.getmembers()
assert len(members) >= 2 # At least 2 email files
# Extract and verify content
for member in members:
if member.isfile():
extracted = tar.extractfile(member)
assert extracted is not None
content = extracted.read().decode()
assert "Content of emails/" in content
def test_different_keys_produce_different_ciphertext():
"""Test that different encryption keys produce different ciphertext."""
data = b"Same data encrypted with different keys"
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key1"):
cipher1 = backup.get_cipher()
encrypted1 = cipher1.encrypt(data)
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key2"):
cipher2 = backup.get_cipher()
encrypted2 = cipher2.encrypt(data)
assert encrypted1 != encrypted2
def test_missing_encryption_key_raises_error():
"""Test that missing encryption key raises ValueError."""
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", ""):
with pytest.raises(ValueError, match="BACKUP_ENCRYPTION_KEY not set"):
backup.get_cipher()
def test_create_tarball_with_files(sample_files):
"""Test creating tarball from directory with files."""
notes_dir = settings.FILE_STORAGE_DIR / "notes"
tarball_bytes = backup.create_tarball(notes_dir)
assert len(tarball_bytes) > 0
# Verify it's a valid gzipped tarball
tar_buffer = io.BytesIO(tarball_bytes)
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
members = tar.getmembers()
filenames = [m.name for m in members if m.isfile()]
assert len(filenames) >= 2
assert any("note1.md" in f for f in filenames)
assert any("note2.md" in f for f in filenames)
def test_create_tarball_nonexistent_directory():
"""Test creating tarball from nonexistent directory."""
nonexistent = settings.FILE_STORAGE_DIR / "does_not_exist"
tarball_bytes = backup.create_tarball(nonexistent)
assert tarball_bytes == b""
def test_create_tarball_empty_directory():
"""Test creating tarball from empty directory."""
empty_dir = settings.FILE_STORAGE_DIR / "empty"
empty_dir.mkdir(parents=True, exist_ok=True)
tarball_bytes = backup.create_tarball(empty_dir)
# Should create tarball with just the directory entry
assert len(tarball_bytes) > 0
tar_buffer = io.BytesIO(tarball_bytes)
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
members = tar.getmembers()
assert len(members) >= 1
assert members[0].isdir()
def test_sync_unencrypted_success(sample_files, backup_settings):
"""Test successful sync of unencrypted directory."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = Mock(stdout="Synced files", returncode=0)
comics_path = settings.FILE_STORAGE_DIR / "comics"
result = backup.sync_unencrypted_directory(comics_path)
assert result["synced"] is True
assert result["directory"] == comics_path
assert "s3_uri" in result
assert "test-bucket" in result["s3_uri"]
assert "test-prefix/comics" in result["s3_uri"]
# Verify aws s3 sync was called correctly
mock_run.assert_called_once()
call_args = mock_run.call_args[0][0]
assert call_args[0] == "aws"
assert call_args[1] == "s3"
assert call_args[2] == "sync"
assert "--delete" in call_args
assert "--region" in call_args
def test_sync_unencrypted_nonexistent_directory(backup_settings):
"""Test syncing nonexistent directory."""
nonexistent_path = settings.FILE_STORAGE_DIR / "does_not_exist"
result = backup.sync_unencrypted_directory(nonexistent_path)
assert result["synced"] is False
assert result["reason"] == "directory_not_found"
def test_sync_unencrypted_aws_cli_failure(sample_files, backup_settings):
"""Test handling of AWS CLI failure."""
with patch("subprocess.run") as mock_run:
mock_run.side_effect = subprocess.CalledProcessError(
1, "aws", stderr="AWS CLI error"
)
comics_path = settings.FILE_STORAGE_DIR / "comics"
result = backup.sync_unencrypted_directory(comics_path)
assert result["synced"] is False
assert "error" in result
def test_backup_encrypted_success(
sample_files, mock_s3_client, backup_settings, get_test_path
):
"""Test successful encrypted backup."""
result = backup.backup_encrypted_directory(get_test_path("emails"))
assert result["uploaded"] is True
assert result["size_bytes"] > 0
assert result["s3_key"].endswith("emails.tar.gz.enc")
call_kwargs = mock_s3_client.put_object.call_args[1]
assert call_kwargs["Bucket"] == "test-bucket"
assert call_kwargs["ServerSideEncryption"] == "AES256"
def test_backup_encrypted_nonexistent_directory(
mock_s3_client, backup_settings, get_test_path
):
"""Test backing up nonexistent directory."""
result = backup.backup_encrypted_directory(get_test_path("does_not_exist"))
assert result["uploaded"] is False
assert result["reason"] == "directory_not_found"
mock_s3_client.put_object.assert_not_called()
def test_backup_encrypted_empty_directory(
mock_s3_client, backup_settings, get_test_path
):
"""Test backing up empty directory."""
empty_dir = get_test_path("empty_encrypted")
empty_dir.mkdir(parents=True, exist_ok=True)
result = backup.backup_encrypted_directory(empty_dir)
assert "uploaded" in result
def test_backup_encrypted_s3_failure(
sample_files, mock_s3_client, backup_settings, get_test_path
):
"""Test handling of S3 upload failure."""
mock_s3_client.put_object.side_effect = ClientError(
{"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, "PutObject"
)
result = backup.backup_encrypted_directory(get_test_path("notes"))
assert result["uploaded"] is False
assert "error" in result
def test_backup_encrypted_data_integrity(
sample_files, mock_s3_client, backup_settings, get_test_path
):
"""Test that encrypted backup maintains data integrity through full cycle."""
result = backup.backup_encrypted_directory(get_test_path("notes"))
assert result["uploaded"] is True
# Decrypt uploaded data
cipher = backup.get_cipher()
encrypted_data = mock_s3_client.put_object.call_args[1]["Body"]
decrypted_tarball = cipher.decrypt(encrypted_data)
# Verify content
tar_buffer = io.BytesIO(decrypted_tarball)
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
note1_found = False
for member in tar.getmembers():
if member.name.endswith("note1.md") and member.isfile():
content = tar.extractfile(member).read().decode()
assert "Content of notes/note1.md" in content
note1_found = True
assert note1_found, "note1.md not found in tarball"
def test_backup_disabled():
"""Test that backup returns early when disabled."""
with patch.object(settings, "S3_BACKUP_ENABLED", False):
result = backup.backup_all_to_s3()
assert result["status"] == "disabled"
def test_backup_full_execution(sample_files, mock_s3_client, backup_settings):
"""Test full backup execution dispatches tasks for all directories."""
with patch.object(backup, "backup_to_s3") as mock_task:
mock_task.delay = Mock()
result = backup.backup_all_to_s3()
assert result["status"] == "success"
assert "message" in result
# Verify task was queued for each storage directory
assert mock_task.delay.call_count == len(settings.storage_dirs)
def test_backup_handles_partial_failures(
sample_files, mock_s3_client, backup_settings, get_test_path
):
"""Test that backup continues even if some directories fail."""
with patch("subprocess.run") as mock_run:
mock_run.side_effect = subprocess.CalledProcessError(
1, "aws", stderr="Sync failed"
)
result = backup.sync_unencrypted_directory(get_test_path("comics"))
assert result["synced"] is False
assert "error" in result
def test_same_key_different_runs_different_ciphertext():
"""Test that Fernet produces different ciphertext each run (due to nonce)."""
data = b"Consistent data"
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "same-key"):
cipher = backup.get_cipher()
encrypted1 = cipher.encrypt(data)
encrypted2 = cipher.encrypt(data)
# Should be different due to random nonce, but both should decrypt to same value
assert encrypted1 != encrypted2
decrypted1 = cipher.decrypt(encrypted1)
decrypted2 = cipher.decrypt(encrypted2)
assert decrypted1 == decrypted2 == data
def test_key_derivation_consistency():
"""Test that same password produces same encryption key."""
password = "test-password"
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", password):
cipher1 = backup.get_cipher()
cipher2 = backup.get_cipher()
# Both should be able to decrypt each other's ciphertext
data = b"Test data"
encrypted = cipher1.encrypt(data)
decrypted = cipher2.decrypt(encrypted)
assert decrypted == data
@pytest.mark.parametrize(
"dir_name,is_private",
[
("emails", True),
("notes", True),
("photos", True),
("comics", False),
("ebooks", False),
("webpages", False),
("lesswrong", False),
("chunks", False),
],
)
def test_directory_encryption_classification(dir_name, is_private, backup_settings):
"""Test that directories are correctly classified as encrypted or not."""
# Create a mock PRIVATE_DIRS list
private_dirs = ["emails", "notes", "photos"]
with patch.object(
settings, "PRIVATE_DIRS", [settings.FILE_STORAGE_DIR / d for d in private_dirs]
):
test_path = settings.FILE_STORAGE_DIR / dir_name
is_in_private = test_path in settings.PRIVATE_DIRS
assert is_in_private == is_private

View File

@ -3,6 +3,7 @@ from datetime import datetime, timezone
from unittest.mock import Mock, patch
from memory.common.db.models import (
DiscordBotUser,
DiscordMessage,
DiscordUser,
DiscordServer,
@ -12,12 +13,25 @@ from memory.workers.tasks import discord
@pytest.fixture
def mock_discord_user(db_session):
def discord_bot_user(db_session):
bot = DiscordBotUser.create_with_api_key(
discord_users=[],
name="Test Bot",
email="bot@example.com",
)
db_session.add(bot)
db_session.commit()
return bot
@pytest.fixture
def mock_discord_user(db_session, discord_bot_user):
"""Create a Discord user for testing."""
user = DiscordUser(
id=123456789,
username="testuser",
ignore_messages=False,
system_user_id=discord_bot_user.id,
)
db_session.add(user)
db_session.commit()

View File

@ -3,17 +3,23 @@ from datetime import datetime, timezone, timedelta
from unittest.mock import Mock, patch
import uuid
from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer
from memory.common.db.models import (
ScheduledLLMCall,
DiscordBotUser,
DiscordUser,
DiscordChannel,
DiscordServer,
)
from memory.workers.tasks import scheduled_calls
@pytest.fixture
def sample_user(db_session):
"""Create a sample user for testing."""
user = HumanUser.create_with_password(
name="testuser",
email="test@example.com",
password="password",
user = DiscordBotUser.create_with_api_key(
discord_users=[],
name="testbot",
email="bot@example.com",
)
db_session.add(user)
db_session.commit()
@ -124,6 +130,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
scheduled_calls._send_to_discord(pending_scheduled_call, response)
mock_send_dm.assert_called_once_with(
pending_scheduled_call.user_id,
"testuser", # username, not ID
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
)
@ -137,6 +144,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
scheduled_calls._send_to_discord(completed_scheduled_call, response)
mock_broadcast.assert_called_once_with(
completed_scheduled_call.user_id,
"test-channel", # channel name, not ID
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
)
@ -151,7 +159,8 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
# Verify the message was truncated
args, kwargs = mock_send_dm.call_args
message = args[1]
assert args[0] == pending_scheduled_call.user_id
message = args[2]
assert len(message) <= 1950 # Should be truncated
assert message.endswith("... (response truncated)")
@ -164,7 +173,8 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args
message = args[1]
assert args[0] == pending_scheduled_call.user_id
message = args[2]
assert not message.endswith("... (response truncated)")
assert "This is a normal length response." in message
@ -574,6 +584,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
mock_discord_user.username = "testuser"
mock_call = Mock()
mock_call.user_id = 987
mock_call.topic = topic
mock_call.model = model
mock_call.discord_user = mock_discord_user
@ -583,7 +594,8 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
# Get the actual message that was sent
args, kwargs = mock_send_dm.call_args
actual_message = args[1]
assert args[0] == mock_call.user_id
actual_message = args[2]
# Verify all expected parts are in the message
for expected_part in expected_in_message:

182
tools/backup_databases.sh Normal file
View File

@ -0,0 +1,182 @@
#!/bin/bash
# Backup Postgres and Qdrant databases to S3
set -euo pipefail
# Install AWS CLI if not present (postgres:15 image doesn't include it)
if ! command -v aws >/dev/null 2>&1; then
echo "Installing AWS CLI, wget, and jq..."
apt-get update -qq && apt-get install -y -qq awscli wget jq >/dev/null 2>&1
fi
# Configuration - read from environment or use defaults
BUCKET="${S3_BACKUP_BUCKET:-equistamp-memory-backup}"
PREFIX="${S3_BACKUP_PREFIX:-Daniel}/databases"
REGION="${S3_BACKUP_REGION:-eu-central-1}"
PASSWORD="${BACKUP_ENCRYPTION_KEY:?BACKUP_ENCRYPTION_KEY not set}"
MAX_BACKUPS="${MAX_BACKUPS:-30}" # Keep last N backups
# Service names (docker-compose network)
POSTGRES_HOST="${POSTGRES_HOST:-postgres}"
POSTGRES_USER="${POSTGRES_USER:-kb}"
POSTGRES_DB="${POSTGRES_DB:-kb}"
QDRANT_URL="${QDRANT_URL:-http://qdrant:6333}"
# Timestamp for backups
DATE=$(date +%Y%m%d-%H%M%S)
DATE_SIMPLE=$(date +%Y%m%d)
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*"
}
error() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] ERROR: $*" >&2
}
# Clean old backups - keep only last N
cleanup_old_backups() {
local prefix=$1
local pattern=$2 # e.g., "postgres-" or "qdrant-"
log "Checking for old ${pattern} backups to clean up..."
# List all backups matching pattern, sorted by date (oldest first)
local backups
backups=$(aws s3 ls "s3://${BUCKET}/${prefix}/" --region "${REGION}" | \
grep "${pattern}" | \
awk '{print $4}' | \
sort)
local count=$(echo "$backups" | wc -l)
if [ "$count" -le "$MAX_BACKUPS" ]; then
log "Found ${count} ${pattern} backups (max: ${MAX_BACKUPS}), no cleanup needed"
return 0
fi
local to_delete=$((count - MAX_BACKUPS))
log "Found ${count} ${pattern} backups, deleting ${to_delete} oldest..."
echo "$backups" | head -n "$to_delete" | while read -r file; do
if [ -n "$file" ]; then
log "Deleting old backup: ${file}"
aws s3 rm "s3://${BUCKET}/${prefix}/${file}" --region "${REGION}"
fi
done
}
# Backup Postgres
backup_postgres() {
log "Starting Postgres backup..."
local output_path="s3://${BUCKET}/${PREFIX}/postgres-${DATE_SIMPLE}.sql.gz.enc"
# Use pg_dump directly with service name (no docker exec needed)
export PGPASSWORD=$(cat "${POSTGRES_PASSWORD_FILE}")
if pg_dump -h "${POSTGRES_HOST}" -U "${POSTGRES_USER}" "${POSTGRES_DB}" 2>/dev/null | \
gzip | \
openssl enc -aes-256-cbc -salt -pbkdf2 -pass "pass:${PASSWORD}" | \
aws s3 cp - "${output_path}" --region "${REGION}"; then
log "Postgres backup completed: ${output_path}"
unset PGPASSWORD
cleanup_old_backups "${PREFIX}" "postgres-"
return 0
else
error "Postgres backup failed"
unset PGPASSWORD
return 1
fi
}
# Backup Qdrant
backup_qdrant() {
log "Starting Qdrant backup..."
# Create snapshot via HTTP API (no docker exec needed)
local snapshot_response
if ! snapshot_response=$(wget -q -O - --post-data='{}' \
--header='Content-Type: application/json' \
"${QDRANT_URL}/snapshots" 2>/dev/null); then
error "Failed to create Qdrant snapshot"
return 1
fi
local snapshot_name
# Parse snapshot name - wget/busybox may not have jq, so use grep/sed
if command -v jq >/dev/null 2>&1; then
snapshot_name=$(echo "${snapshot_response}" | jq -r '.result.name // empty')
else
# Fallback: parse JSON without jq (fragile but works for simple case)
snapshot_name=$(echo "${snapshot_response}" | grep -o '"name":"[^"]*"' | cut -d'"' -f4)
fi
if [ -z "${snapshot_name}" ]; then
error "Could not extract snapshot name from response: ${snapshot_response}"
return 1
fi
log "Created Qdrant snapshot: ${snapshot_name}"
# Download snapshot and upload to S3
local output_path="s3://${BUCKET}/${PREFIX}/qdrant-${DATE_SIMPLE}.snapshot.enc"
if wget -q -O - "${QDRANT_URL}/snapshots/${snapshot_name}" | \
openssl enc -aes-256-cbc -salt -pbkdf2 -pass "pass:${PASSWORD}" | \
aws s3 cp - "${output_path}" --region "${REGION}"; then
log "Qdrant backup completed: ${output_path}"
# Delete the snapshot from Qdrant
if wget -q -O - --method=DELETE \
"${QDRANT_URL}/snapshots/${snapshot_name}" >/dev/null 2>&1; then
log "Deleted Qdrant snapshot: ${snapshot_name}"
else
error "Failed to delete Qdrant snapshot: ${snapshot_name}"
fi
cleanup_old_backups "${PREFIX}" "qdrant-"
return 0
else
error "Qdrant backup failed"
# Try to clean up snapshot
wget -q -O - --method=DELETE \
"${QDRANT_URL}/snapshots/${snapshot_name}" >/dev/null 2>&1 || true
return 1
fi
}
# Main execution
main() {
log "Database backup started"
local postgres_result=0
local qdrant_result=0
# Backup Postgres
if ! backup_postgres; then
postgres_result=1
fi
# Backup Qdrant
if ! backup_qdrant; then
qdrant_result=1
fi
# Summary
if [ $postgres_result -eq 0 ] && [ $qdrant_result -eq 0 ]; then
log "All database backups completed successfully"
return 0
elif [ $postgres_result -ne 0 ] && [ $qdrant_result -ne 0 ]; then
error "All database backups failed"
return 1
else
error "Some database backups failed (Postgres: ${postgres_result}, Qdrant: ${qdrant_result})"
return 1
fi
}
# Run main function
main

View File

@ -51,6 +51,8 @@ from memory.common.celery_app import (
UPDATE_METADATA_FOR_SOURCE_ITEMS,
SETUP_GIT_NOTES,
TRACK_GIT_CHANGES,
BACKUP_TO_S3_DIRECTORY,
BACKUP_ALL,
app,
)
@ -97,6 +99,10 @@ TASK_MAPPINGS = {
"setup_git_notes": SETUP_GIT_NOTES,
"track_git_changes": TRACK_GIT_CHANGES,
},
"backup": {
"backup_to_s3_directory": BACKUP_TO_S3_DIRECTORY,
"backup_all": BACKUP_ALL,
},
}
QUEUE_MAPPINGS = {
"email": "email",
@ -177,6 +183,28 @@ def execute_task(ctx, category: str, task_name: str, **kwargs):
sys.exit(1)
@cli.group()
@click.pass_context
def backup(ctx):
"""Backup-related tasks."""
pass
@backup.command("all")
@click.pass_context
def backup_all(ctx):
"""Backup all directories."""
execute_task(ctx, "backup", "backup_all")
@backup.command("path")
@click.option("--path", required=True, help="Path to backup")
@click.pass_context
def backup_to_s3_directory(ctx, path):
"""Backup a specific path."""
execute_task(ctx, "backup", "backup_to_s3_directory", path=path)
@cli.group()
@click.pass_context
def email(ctx):