Compare commits

...

10 Commits

Author SHA1 Message Date
e95a082147 allow discord tools 2025-11-02 00:50:12 +00:00
c42513100b db backups 2025-11-02 00:24:35 +00:00
a5bc53326d backups 2025-11-02 00:01:35 +00:00
131427255a fix typing indicator 2025-11-01 20:27:57 +00:00
Daniel O'Connell
ff3ca4f109 show typing 2025-11-01 21:13:39 +01:00
3b216953ab better docker compise 2025-11-01 19:51:41 +00:00
Daniel O'Connell
d7e403fb83 optional chattiness 2025-11-01 20:39:15 +01:00
57145ac7b4 fix bugs 2025-11-01 19:35:20 +00:00
Daniel O'Connell
814090dccb use db bots 2025-11-01 18:52:37 +01:00
Daniel O'Connell
9639fa3dd7 use usage tracker 2025-11-01 18:49:06 +01:00
35 changed files with 1716 additions and 295 deletions

View File

@ -0,0 +1,114 @@
"""allow no chattiness
Revision ID: 2024235e37e7
Revises: 7dc03dbf184c
Create Date: 2025-11-01 20:38:10.849651
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "2024235e37e7"
down_revision: Union[str, None] = "7dc03dbf184c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.alter_column(
"discord_channels",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=True,
existing_server_default=sa.text("50"),
)
op.drop_column("discord_channels", "track_messages")
op.alter_column(
"discord_servers",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=True,
existing_server_default=sa.text("50"),
)
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
op.create_index(
"discord_servers_active_idx",
"discord_servers",
["ignore_messages", "last_sync_at"],
unique=False,
)
op.drop_column("discord_servers", "track_messages")
op.alter_column(
"discord_users",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=True,
existing_server_default=sa.text("50"),
)
op.drop_column("discord_users", "track_messages")
def downgrade() -> None:
op.add_column(
"discord_users",
sa.Column(
"track_messages",
sa.BOOLEAN(),
server_default=sa.text("true"),
autoincrement=False,
nullable=False,
),
)
op.alter_column(
"discord_users",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=False,
existing_server_default=sa.text("50"),
)
op.add_column(
"discord_servers",
sa.Column(
"track_messages",
sa.BOOLEAN(),
server_default=sa.text("true"),
autoincrement=False,
nullable=False,
),
)
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
op.create_index(
"discord_servers_active_idx",
"discord_servers",
["track_messages", "last_sync_at"],
unique=False,
)
op.alter_column(
"discord_servers",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=False,
existing_server_default=sa.text("50"),
)
op.add_column(
"discord_channels",
sa.Column(
"track_messages",
sa.BOOLEAN(),
server_default=sa.text("true"),
autoincrement=False,
nullable=False,
),
)
op.alter_column(
"discord_channels",
"chattiness_threshold",
existing_type=sa.INTEGER(),
nullable=False,
existing_server_default=sa.text("50"),
)

View File

@ -1,5 +1,3 @@
version: "3.9"
# --------------------------------------------------------------------- networks # --------------------------------------------------------------------- networks
networks: networks:
kbnet: kbnet:
@ -26,7 +24,7 @@ x-common-env: &env
REDIS_HOST: redis REDIS_HOST: redis
REDIS_PORT: 6379 REDIS_PORT: 6379
REDIS_DB: 0 REDIS_DB: 0
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD} REDIS_PASSWORD: ${REDIS_PASSWORD}
QDRANT_HOST: qdrant QDRANT_HOST: qdrant
DB_HOST: postgres DB_HOST: postgres
DB_PORT: 5432 DB_PORT: 5432
@ -107,11 +105,11 @@ services:
image: redis:7.2-alpine image: redis:7.2-alpine
restart: unless-stopped restart: unless-stopped
networks: [ kbnet ] networks: [ kbnet ]
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${CELERY_BROKER_PASSWORD}"] command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${REDIS_PASSWORD}"]
volumes: volumes:
- redis_data:/data:rw - redis_data:/data:rw
healthcheck: healthcheck:
test: [ "CMD", "redis-cli", "--pass", "${CELERY_BROKER_PASSWORD}", "ping" ] test: [ "CMD", "redis-cli", "--pass", "${REDIS_PASSWORD}", "ping" ]
interval: 15s interval: 15s
timeout: 5s timeout: 5s
retries: 5 retries: 5
@ -175,7 +173,7 @@ services:
<<: *worker-base <<: *worker-base
environment: environment:
<<: *worker-env <<: *worker-env
QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler" QUEUES: "backup,email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
ingest-hub: ingest-hub:
<<: *worker-base <<: *worker-base
@ -196,6 +194,22 @@ services:
- /var/run/supervisor - /var/run/supervisor
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } } deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
# ------------------------------------------------------------ database backups
backup:
image: postgres:15 # Has pg_dump, wget, curl
networks: [kbnet]
depends_on: [postgres, qdrant]
env_file: [ .env ]
environment:
<<: *worker-env
secrets: [postgres_password]
volumes:
- ./tools/backup_databases.sh:/backup.sh:ro
entrypoint: ["/bin/bash"]
command: ["/backup.sh"]
profiles: [backup] # Only start when explicitly called
security_opt: ["no-new-privileges=true"]
# ------------------------------------------------------------ watchtower (auto-update) # ------------------------------------------------------------ watchtower (auto-update)
# watchtower: # watchtower:
# image: containrrr/watchtower # image: containrrr/watchtower

View File

@ -16,7 +16,7 @@ RUN apt-get update && apt-get install -y \
COPY requirements ./requirements/ COPY requirements ./requirements/
COPY setup.py ./ COPY setup.py ./
RUN mkdir src RUN mkdir src
RUN pip install -e ".[common]" RUN pip install -e ".[workers]"
# Install Python dependencies # Install Python dependencies
COPY src/ ./src/ COPY src/ ./src/
@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
git config --global user.name "${GIT_USER_NAME}" git config --global user.name "${GIT_USER_NAME}"
# Default queues to process # Default queues to process
ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance" ENV QUEUES="backup,ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
ENV PYTHONPATH="/app" ENV PYTHONPATH="/app"
ENTRYPOINT ["./entry.sh"] ENTRYPOINT ["./entry.sh"]

View File

@ -10,3 +10,4 @@ openai==2.3.0
# Pin the httpx version, as newer versions break the anthropic client # Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0 httpx==0.27.0
celery[redis,sqs]==5.3.6 celery[redis,sqs]==5.3.6
cryptography==43.0.0

View File

@ -0,0 +1,2 @@
boto3
awscli==1.42.64

View File

@ -18,6 +18,7 @@ parsers_requires = read_requirements("requirements-parsers.txt")
api_requires = read_requirements("requirements-api.txt") api_requires = read_requirements("requirements-api.txt")
dev_requires = read_requirements("requirements-dev.txt") dev_requires = read_requirements("requirements-dev.txt")
ingesters_requires = read_requirements("requirements-ingesters.txt") ingesters_requires = read_requirements("requirements-ingesters.txt")
workers_requires = read_requirements("requirements-workers.txt")
setup( setup(
name="memory", name="memory",
@ -30,10 +31,12 @@ setup(
"common": common_requires + parsers_requires, "common": common_requires + parsers_requires,
"dev": dev_requires, "dev": dev_requires,
"ingesters": common_requires + parsers_requires + ingesters_requires, "ingesters": common_requires + parsers_requires + ingesters_requires,
"workers": common_requires + parsers_requires + workers_requires,
"all": api_requires "all": api_requires
+ common_requires + common_requires
+ dev_requires + dev_requires
+ parsers_requires + parsers_requires
+ ingesters_requires, + ingesters_requires
+ workers_requires,
}, },
) )

View File

@ -290,6 +290,7 @@ class ScheduledLLMCallAdmin(ModelView, model=ScheduledLLMCall):
"created_at", "created_at",
"updated_at", "updated_at",
] ]
column_sortable_list = ["executed_at", "scheduled_time", "created_at", "updated_at"]
def setup_admin(admin: Admin): def setup_admin(admin: Admin):

View File

@ -13,6 +13,7 @@ NOTES_ROOT = "memory.workers.tasks.notes"
OBSERVATIONS_ROOT = "memory.workers.tasks.observations" OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls" SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
DISCORD_ROOT = "memory.workers.tasks.discord" DISCORD_ROOT = "memory.workers.tasks.discord"
BACKUP_ROOT = "memory.workers.tasks.backup"
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message" ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message" EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message" PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
@ -53,13 +54,25 @@ SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call" EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call"
RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls" RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
# Backup tasks
BACKUP_PATH = f"{BACKUP_ROOT}.backup_path"
BACKUP_ALL = f"{BACKUP_ROOT}.backup_all"
def get_broker_url() -> str: def get_broker_url() -> str:
protocol = settings.CELERY_BROKER_TYPE protocol = settings.CELERY_BROKER_TYPE
user = safequote(settings.CELERY_BROKER_USER) user = safequote(settings.CELERY_BROKER_USER)
password = safequote(settings.CELERY_BROKER_PASSWORD) password = safequote(settings.CELERY_BROKER_PASSWORD or "")
host = settings.CELERY_BROKER_HOST host = settings.CELERY_BROKER_HOST
return f"{protocol}://{user}:{password}@{host}"
if password:
url = f"{protocol}://{user}:{password}@{host}"
else:
url = f"{protocol}://{host}"
if protocol == "redis":
url += f"/{settings.REDIS_DB}"
return url
app = Celery( app = Celery(
@ -91,6 +104,7 @@ app.conf.update(
f"{SCHEDULED_CALLS_ROOT}.*": { f"{SCHEDULED_CALLS_ROOT}.*": {
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler" "queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
}, },
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
}, },
) )

View File

@ -22,7 +22,6 @@ from memory.common.db.models.base import Base
class MessageProcessor: class MessageProcessor:
track_messages = Column(Boolean, nullable=False, server_default="true")
ignore_messages = Column(Boolean, nullable=True, default=False) ignore_messages = Column(Boolean, nullable=True, default=False)
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
@ -35,8 +34,7 @@ class MessageProcessor:
) )
chattiness_threshold = Column( chattiness_threshold = Column(
Integer, Integer,
nullable=False, nullable=True,
default=50,
doc="The threshold for the bot to continue the conversation, between 0 and 100.", doc="The threshold for the bot to continue the conversation, between 0 and 100.",
) )
@ -90,7 +88,7 @@ class DiscordServer(Base, MessageProcessor):
) )
__table_args__ = ( __table_args__ = (
Index("discord_servers_active_idx", "track_messages", "last_sync_at"), Index("discord_servers_active_idx", "ignore_messages", "last_sync_at"),
) )

View File

@ -138,6 +138,8 @@ class DiscordBotUser(BotUser):
email: str, email: str,
api_key: str | None = None, api_key: str | None = None,
) -> "DiscordBotUser": ) -> "DiscordBotUser":
if not discord_users:
raise ValueError("discord_users must be provided")
bot = super().create_with_api_key(name, email, api_key) bot = super().create_with_api_key(name, email, api_key)
bot.discord_users = discord_users bot.discord_users = discord_users
return bot return bot

View File

@ -5,9 +5,10 @@ Simple HTTP client that communicates with the Discord collector's API server.
""" """
import logging import logging
import requests
from typing import Any from typing import Any
import requests
from memory.common import settings from memory.common import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,12 +21,12 @@ def get_api_url() -> str:
return f"http://{host}:{port}" return f"http://{host}:{port}"
def send_dm(user_identifier: str, message: str) -> bool: def send_dm(bot_id: int, user_identifier: str, message: str) -> bool:
"""Send a DM via the Discord collector API""" """Send a DM via the Discord collector API"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_dm", f"{get_api_url()}/send_dm",
json={"user": user_identifier, "message": message}, json={"bot_id": bot_id, "user": user_identifier, "message": message},
timeout=10, timeout=10,
) )
response.raise_for_status() response.raise_for_status()
@ -37,12 +38,33 @@ def send_dm(user_identifier: str, message: str) -> bool:
return False return False
def send_to_channel(channel_name: str, message: str) -> bool: def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
"""Trigger typing indicator for a DM via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/typing/dm",
json={"bot_id": bot_id, "user": user_identifier},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
"""Send a DM via the Discord collector API""" """Send a DM via the Discord collector API"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_channel", f"{get_api_url()}/send_channel",
json={"channel_name": channel_name, "message": message}, json={
"bot_id": bot_id,
"channel_name": channel_name,
"message": message,
},
timeout=10, timeout=10,
) )
response.raise_for_status() response.raise_for_status()
@ -55,12 +77,33 @@ def send_to_channel(channel_name: str, message: str) -> bool:
return False return False
def broadcast_message(channel_name: str, message: str) -> bool: def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
"""Trigger typing indicator for a channel via the Discord collector API"""
try:
response = requests.post(
f"{get_api_url()}/typing/channel",
json={"bot_id": bot_id, "channel_name": channel_name},
timeout=10,
)
response.raise_for_status()
result = response.json()
return result.get("success", False)
except requests.RequestException as e:
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
return False
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
"""Send a message to a channel via the Discord collector API""" """Send a message to a channel via the Discord collector API"""
try: try:
response = requests.post( response = requests.post(
f"{get_api_url()}/send_channel", f"{get_api_url()}/send_channel",
json={"channel_name": channel_name, "message": message}, json={
"bot_id": bot_id,
"channel_name": channel_name,
"message": message,
},
timeout=10, timeout=10,
) )
response.raise_for_status() response.raise_for_status()
@ -72,19 +115,22 @@ def broadcast_message(channel_name: str, message: str) -> bool:
return False return False
def is_collector_healthy() -> bool: def is_collector_healthy(bot_id: int) -> bool:
"""Check if the Discord collector is running and healthy""" """Check if the Discord collector is running and healthy"""
try: try:
response = requests.get(f"{get_api_url()}/health", timeout=5) response = requests.get(f"{get_api_url()}/health", timeout=5)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
return result.get("status") == "healthy" bot_status = result.get(str(bot_id))
if not isinstance(bot_status, dict):
return False
return bool(bot_status.get("connected"))
except requests.RequestException: except requests.RequestException:
return False return False
def refresh_discord_metadata() -> dict[str, int] | None: def refresh_discord_metadata() -> dict[str, Any] | None:
"""Refresh Discord server/channel/user metadata from Discord API""" """Refresh Discord server/channel/user metadata from Discord API"""
try: try:
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30) response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
@ -96,24 +142,24 @@ def refresh_discord_metadata() -> dict[str, int] | None:
# Convenience functions # Convenience functions
def send_error_message(message: str) -> bool: def send_error_message(bot_id: int, message: str) -> bool:
"""Send an error message to the error channel""" """Send an error message to the error channel"""
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message) return broadcast_message(bot_id, settings.DISCORD_ERROR_CHANNEL, message)
def send_activity_message(message: str) -> bool: def send_activity_message(bot_id: int, message: str) -> bool:
"""Send an activity message to the activity channel""" """Send an activity message to the activity channel"""
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message) return broadcast_message(bot_id, settings.DISCORD_ACTIVITY_CHANNEL, message)
def send_discovery_message(message: str) -> bool: def send_discovery_message(bot_id: int, message: str) -> bool:
"""Send a discovery message to the discovery channel""" """Send a discovery message to the discovery channel"""
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message) return broadcast_message(bot_id, settings.DISCORD_DISCOVERY_CHANNEL, message)
def send_chat_message(message: str) -> bool: def send_chat_message(bot_id: int, message: str) -> bool:
"""Send a chat message to the chat channel""" """Send a chat message to the chat channel"""
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message) return broadcast_message(bot_id, settings.DISCORD_CHAT_CHANNEL, message)
def notify_task_failure( def notify_task_failure(
@ -122,6 +168,7 @@ def notify_task_failure(
task_args: tuple = (), task_args: tuple = (),
task_kwargs: dict[str, Any] | None = None, task_kwargs: dict[str, Any] | None = None,
traceback_str: str | None = None, traceback_str: str | None = None,
bot_id: int | None = None,
) -> None: ) -> None:
""" """
Send a task failure notification to Discord. Send a task failure notification to Discord.
@ -137,6 +184,15 @@ def notify_task_failure(
logger.debug("Discord notifications disabled") logger.debug("Discord notifications disabled")
return return
if bot_id is None:
bot_id = settings.DISCORD_BOT_ID
if not bot_id:
logger.debug(
"No Discord bot ID provided for task failure notification; skipping"
)
return
message = f"🚨 **Task Failed: {task_name}**\n\n" message = f"🚨 **Task Failed: {task_name}**\n\n"
message += f"**Error:** {error_message[:500]}\n" message += f"**Error:** {error_message[:500]}\n"
@ -150,7 +206,7 @@ def notify_task_failure(
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```" message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
try: try:
send_error_message(message) send_error_message(bot_id, message)
logger.info(f"Discord error notification sent for task: {task_name}") logger.info(f"Discord error notification sent for task: {task_name}")
except Exception as e: except Exception as e:
logger.error(f"Failed to send Discord notification: {e}") logger.error(f"Failed to send Discord notification: {e}")

View File

@ -24,13 +24,6 @@ from memory.common.llms.base import (
) )
from memory.common.llms.anthropic_provider import AnthropicProvider from memory.common.llms.anthropic_provider import AnthropicProvider
from memory.common.llms.openai_provider import OpenAIProvider from memory.common.llms.openai_provider import OpenAIProvider
from memory.common.llms.usage_tracker import (
InMemoryUsageTracker,
RateLimitConfig,
TokenAllowance,
UsageBreakdown,
UsageTracker,
)
from memory.common import tokens from memory.common import tokens
__all__ = [ __all__ = [
@ -49,11 +42,6 @@ __all__ = [
"StreamEvent", "StreamEvent",
"LLMSettings", "LLMSettings",
"create_provider", "create_provider",
"InMemoryUsageTracker",
"RateLimitConfig",
"TokenAllowance",
"UsageBreakdown",
"UsageTracker",
] ]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -93,28 +81,3 @@ def truncate(content: str, target_tokens: int) -> str:
if len(content) > target_chars: if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..." return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content return content
# bla = 1
# from memory.common.llms import *
# from memory.common.llms.tools.discord import make_discord_tools
# from memory.common.db.connection import make_session
# from memory.common.db.models import *
# model = "anthropic/claude-sonnet-4-5"
# provider = create_provider(model=model)
# with make_session() as session:
# bot = session.query(DiscordBotUser).first()
# server = session.query(DiscordServer).first()
# channel = server.channels[0]
# tools = make_discord_tools(bot, None, channel, model)
# def demo(msg: str):
# messages = [
# Message(
# role=MessageRole.USER,
# content=msg,
# )
# ]
# for m in provider.stream_with_tools(messages, tools):
# print(m)

View File

@ -23,6 +23,8 @@ logger = logging.getLogger(__name__)
class AnthropicProvider(BaseLLMProvider): class AnthropicProvider(BaseLLMProvider):
"""Anthropic LLM provider with streaming, tool support, and extended thinking.""" """Anthropic LLM provider with streaming, tool support, and extended thinking."""
provider = "anthropic"
# Models that support extended thinking # Models that support extended thinking
THINKING_MODELS = { THINKING_MODELS = {
"claude-opus-4", "claude-opus-4",
@ -262,7 +264,7 @@ class AnthropicProvider(BaseLLMProvider):
Usage( Usage(
input_tokens=usage.input_tokens, input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens, output_tokens=usage.output_tokens,
total_tokens=usage.total_tokens, total_tokens=usage.input_tokens + usage.output_tokens,
) )
) )

View File

@ -12,6 +12,7 @@ from PIL import Image
from memory.common import settings from memory.common import settings
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -204,7 +205,11 @@ class LLMSettings:
class BaseLLMProvider(ABC): class BaseLLMProvider(ABC):
"""Base class for LLM providers.""" """Base class for LLM providers."""
def __init__(self, api_key: str, model: str): provider: str = ""
def __init__(
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
):
""" """
Initialize the LLM provider. Initialize the LLM provider.
@ -215,6 +220,7 @@ class BaseLLMProvider(ABC):
self.api_key = api_key self.api_key = api_key
self.model = model self.model = model
self._client: Any = None self._client: Any = None
self.usage_tracker: UsageTracker = usage_tracker or RedisUsageTracker()
@abstractmethod @abstractmethod
def _initialize_client(self) -> Any: def _initialize_client(self) -> Any:
@ -230,8 +236,14 @@ class BaseLLMProvider(ABC):
def log_usage(self, usage: Usage): def log_usage(self, usage: Usage):
"""Log usage data.""" """Log usage data."""
logger.debug(f"Token usage: {usage.to_dict()}") logger.debug(
print(f"Token usage: {usage.to_dict()}") f"Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.total_tokens} total"
)
self.usage_tracker.record_usage(
model=f"{self.provider}/{self.model}",
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
)
def execute_tool( def execute_tool(
self, self,

View File

@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
class OpenAIProvider(BaseLLMProvider): class OpenAIProvider(BaseLLMProvider):
"""OpenAI LLM provider with streaming and tool support.""" """OpenAI LLM provider with streaming and tool support."""
provider = "openai"
# Models that use max_completion_tokens instead of max_tokens # Models that use max_completion_tokens instead of max_tokens
# These are reasoning models with different parameter requirements # These are reasoning models with different parameter requirements
NON_REASONING_MODELS = {"gpt-4o"} NON_REASONING_MODELS = {"gpt-4o"}

View File

@ -0,0 +1,17 @@
from memory.common.llms.usage.redis_usage_tracker import RedisUsageTracker
from memory.common.llms.usage.usage_tracker import (
InMemoryUsageTracker,
RateLimitConfig,
TokenAllowance,
UsageBreakdown,
UsageTracker,
)
__all__ = [
"InMemoryUsageTracker",
"RateLimitConfig",
"RedisUsageTracker",
"TokenAllowance",
"UsageBreakdown",
"UsageTracker",
]

View File

@ -0,0 +1,83 @@
"""Redis-backed usage tracker implementation."""
import json
from typing import Any, Iterable, Protocol
import redis
from memory.common import settings
from memory.common.llms.usage.usage_tracker import (
RateLimitConfig,
UsageState,
UsageTracker,
)
class RedisClientProtocol(Protocol):
def get(self, key: str) -> Any: # pragma: no cover - Protocol definition
...
def set(
self, key: str, value: Any
) -> Any: # pragma: no cover - Protocol definition
...
def scan_iter(
self, match: str
) -> Iterable[Any]: # pragma: no cover - Protocol definition
...
class RedisUsageTracker(UsageTracker):
"""Tracks LLM usage for providers and models using Redis for persistence."""
def __init__(
self,
configs: dict[str, RateLimitConfig] | None = None,
default_config: RateLimitConfig | None = None,
*,
redis_client: RedisClientProtocol | None = None,
key_prefix: str | None = None,
) -> None:
super().__init__(configs=configs, default_config=default_config)
if redis_client is None:
redis_client = redis.Redis.from_url(settings.REDIS_URL)
self._redis = redis_client
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
self._key_prefix = prefix.rstrip(":")
def get_state(self, model: str) -> UsageState:
redis_key = self._format_key(model)
payload = self._redis.get(redis_key)
if not payload:
return UsageState()
if isinstance(payload, bytes):
payload = payload.decode()
return UsageState.from_payload(json.loads(payload))
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
pattern = f"{self._key_prefix}:*"
for redis_key in self._redis.scan_iter(match=pattern):
key = self._ensure_str(redis_key)
payload = self._redis.get(key)
if not payload:
continue
if isinstance(payload, bytes):
payload = payload.decode()
state = UsageState.from_payload(json.loads(payload))
yield key[len(self._key_prefix) + 1 :], state
def save_state(self, model: str, state: UsageState) -> None:
redis_key = self._format_key(model)
self._redis.set(
redis_key, json.dumps(state.to_payload(), separators=(",", ":"))
)
def _format_key(self, model: str) -> str:
return f"{self._key_prefix}:{model}"
@staticmethod
def _ensure_str(value: Any) -> str:
if isinstance(value, bytes):
return value.decode()
return str(value)

View File

@ -5,6 +5,9 @@ from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from threading import Lock from threading import Lock
from typing import Any
from memory.common import settings
@dataclass(frozen=True) @dataclass(frozen=True)
@ -45,6 +48,40 @@ class UsageState:
lifetime_input_tokens: int = 0 lifetime_input_tokens: int = 0
lifetime_output_tokens: int = 0 lifetime_output_tokens: int = 0
def to_payload(self) -> dict[str, Any]:
return {
"events": [
{
"timestamp": event.timestamp.isoformat(),
"input_tokens": event.input_tokens,
"output_tokens": event.output_tokens,
}
for event in self.events
],
"window_input_tokens": self.window_input_tokens,
"window_output_tokens": self.window_output_tokens,
"lifetime_input_tokens": self.lifetime_input_tokens,
"lifetime_output_tokens": self.lifetime_output_tokens,
}
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> "UsageState":
events = deque(
UsageEvent(
timestamp=datetime.fromisoformat(event["timestamp"]),
input_tokens=event["input_tokens"],
output_tokens=event["output_tokens"],
)
for event in payload.get("events", [])
)
return cls(
events=events,
window_input_tokens=payload.get("window_input_tokens", 0),
window_output_tokens=payload.get("window_output_tokens", 0),
lifetime_input_tokens=payload.get("lifetime_input_tokens", 0),
lifetime_output_tokens=payload.get("lifetime_output_tokens", 0),
)
@dataclass @dataclass
class TokenAllowance: class TokenAllowance:
@ -77,13 +114,13 @@ class UsageBreakdown:
def split_model_key(model: str) -> tuple[str, str]: def split_model_key(model: str) -> tuple[str, str]:
if "/" not in model: if "/" not in model:
raise ValueError( raise ValueError(
"model must be formatted as '<provider>/<model_name>'" f"model must be formatted as '<provider>/<model_name>': got '{model}'"
) )
provider, model_name = model.split("/", maxsplit=1) provider, model_name = model.split("/", maxsplit=1)
if not provider or not model_name: if not provider or not model_name:
raise ValueError( raise ValueError(
"model must include both provider and model name separated by '/'" f"model must include both provider and model name separated by '/': got '{model}'"
) )
return provider, model_name return provider, model_name
@ -93,11 +130,15 @@ class UsageTracker:
def __init__( def __init__(
self, self,
configs: dict[str, RateLimitConfig], configs: dict[str, RateLimitConfig] | None = None,
default_config: RateLimitConfig | None = None, default_config: RateLimitConfig | None = None,
) -> None: ) -> None:
self._configs = configs self._configs = configs or {}
self._default_config = default_config self._default_config = default_config or RateLimitConfig(
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
)
self._lock = Lock() self._lock = Lock()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -180,15 +221,14 @@ class UsageTracker:
""" """
split_model_key(model) split_model_key(model)
key = model
with self._lock: with self._lock:
config = self._get_config(key) config = self._get_config(model)
if config is None: if config is None:
return None return None
state = self.get_state(key) state = self.get_state(model)
self._prune_expired_events(state, config, now=timestamp) self._prune_expired_events(state, config, now=timestamp)
self.save_state(key, state) self.save_state(model, state)
if config.max_total_tokens is None: if config.max_total_tokens is None:
total_remaining = None total_remaining = None
@ -205,9 +245,7 @@ class UsageTracker:
if config.max_output_tokens is None: if config.max_output_tokens is None:
output_remaining = None output_remaining = None
else: else:
output_remaining = ( output_remaining = config.max_output_tokens - state.window_output_tokens
config.max_output_tokens - state.window_output_tokens
)
return TokenAllowance( return TokenAllowance(
input_tokens=clamp_non_negative(input_remaining), input_tokens=clamp_non_negative(input_remaining),
@ -222,8 +260,8 @@ class UsageTracker:
with self._lock: with self._lock:
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict) providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
for key, state in self.iter_state_items(): for model, state in self.iter_state_items():
prov, model_name = split_model_key(key) prov, model_name = split_model_key(model)
if provider and provider != prov: if provider and provider != prov:
continue continue
if model and model != model_name: if model and model != model_name:
@ -265,8 +303,8 @@ class UsageTracker:
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Internal helpers # Internal helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _get_config(self, key: str) -> RateLimitConfig | None: def _get_config(self, model: str) -> RateLimitConfig | None:
return self._configs.get(key) or self._default_config return self._configs.get(model) or self._default_config
def _prune_expired_events( def _prune_expired_events(
self, self,
@ -313,4 +351,3 @@ def clamp_non_negative(value: int | None) -> int | None:
if value is None: if value is None:
return None return None
return 0 if value < 0 else value return 0 if value < 0 else value

View File

@ -31,33 +31,25 @@ def make_db_url(
DB_URL = os.getenv("DATABASE_URL", make_db_url()) DB_URL = os.getenv("DATABASE_URL", make_db_url())
# Redis settings
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
REDIS_DB = os.getenv("REDIS_DB", "0")
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
if REDIS_PASSWORD:
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
else:
REDIS_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
# Broker settings # Broker settings
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory") CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower() CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
REDIS_HOST = os.getenv("REDIS_HOST", "redis") CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "")
REDIS_PORT = os.getenv("REDIS_PORT", "6379") CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", REDIS_PASSWORD)
REDIS_DB = os.getenv("REDIS_DB", "0")
CELERY_BROKER_USER = os.getenv(
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
)
CELERY_BROKER_PASSWORD = os.getenv(
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
)
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
if not CELERY_BROKER_HOST:
if CELERY_BROKER_TYPE == "amqp":
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
elif CELERY_BROKER_TYPE == "redis":
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "") or f"{REDIS_HOST}:{REDIS_PORT}"
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}") CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
# File storage settings # File storage settings
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files")) FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
EBOOK_STORAGE_DIR = pathlib.Path( EBOOK_STORAGE_DIR = pathlib.Path(
@ -81,9 +73,14 @@ WEBPAGE_STORAGE_DIR = pathlib.Path(
NOTES_STORAGE_DIR = pathlib.Path( NOTES_STORAGE_DIR = pathlib.Path(
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes") os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
) )
PRIVATE_DIRS = [
EMAIL_STORAGE_DIR,
NOTES_STORAGE_DIR,
PHOTO_STORAGE_DIR,
CHUNK_STORAGE_DIR,
]
storage_dirs = [ storage_dirs = [
FILE_STORAGE_DIR,
EBOOK_STORAGE_DIR, EBOOK_STORAGE_DIR,
EMAIL_STORAGE_DIR, EMAIL_STORAGE_DIR,
CHUNK_STORAGE_DIR, CHUNK_STORAGE_DIR,
@ -148,6 +145,18 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307") RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000)) MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES = int(
os.getenv("DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES", 30)
)
DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS = int(
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS", 1_000_000)
)
DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS = int(
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS", 1_000_000)
)
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
# Search settings # Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True) ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True) ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
@ -193,3 +202,14 @@ DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003)) DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003))
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0") DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0")
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10)) DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
# S3 Backup settings
S3_BACKUP_BUCKET = os.getenv("S3_BACKUP_BUCKET", "equistamp-memory-backup")
S3_BACKUP_PREFIX = os.getenv("S3_BACKUP_PREFIX", "Daniel")
S3_BACKUP_REGION = os.getenv("S3_BACKUP_REGION", "eu-central-1")
BACKUP_ENCRYPTION_KEY = os.getenv("BACKUP_ENCRYPTION_KEY", "")
S3_BACKUP_ENABLED = boolean_env("S3_BACKUP_ENABLED", bool(BACKUP_ENCRYPTION_KEY))
S3_BACKUP_INTERVAL = int(
os.getenv("S3_BACKUP_INTERVAL", 60 * 60 * 24)
) # Daily by default

View File

@ -7,17 +7,18 @@ providing HTTP endpoints for sending Discord messages.
import asyncio import asyncio
import logging import logging
from contextlib import asynccontextmanager
import traceback import traceback
from contextlib import asynccontextmanager
from typing import cast
import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import uvicorn
from memory.common import settings from memory.common import settings
from memory.discord.collector import MessageCollector
from memory.common.db.models.users import BotUser
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models.users import DiscordBotUser
from memory.discord.collector import MessageCollector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,6 +35,16 @@ class SendChannelRequest(BaseModel):
message: str message: str
class TypingDMRequest(BaseModel):
bot_id: int
user: int | str
class TypingChannelRequest(BaseModel):
bot_id: int
channel_name: str
class Collector: class Collector:
collector: MessageCollector collector: MessageCollector
collector_task: asyncio.Task collector_task: asyncio.Task
@ -41,37 +52,25 @@ class Collector:
bot_token: str bot_token: str
bot_name: str bot_name: str
def __init__(self, collector: MessageCollector, bot: BotUser): def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
self.collector = collector self.collector = collector
self.collector_task = asyncio.create_task(collector.start(bot.api_key)) self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
self.bot_id = bot.id self.bot_id = cast(int, bot.id)
self.bot_token = bot.api_key self.bot_token = str(bot.api_key)
self.bot_name = bot.name self.bot_name = str(bot.name)
# Application state
class AppState:
def __init__(self):
self.collector: MessageCollector | None = None
self.collector_task: asyncio.Task | None = None
app_state = AppState()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Manage Discord collector lifecycle""" """Manage Discord collector lifecycle"""
if not settings.DISCORD_BOT_TOKEN:
logger.error("DISCORD_BOT_TOKEN not configured")
return
def make_collector(bot: BotUser): def make_collector(bot: DiscordBotUser):
collector = MessageCollector() collector = MessageCollector()
return Collector(collector=collector, bot=bot) return Collector(collector=collector, bot=bot)
with make_session() as session: with make_session() as session:
app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()} bots = session.query(DiscordBotUser).all()
app.bots = {bot.id: make_collector(bot) for bot in bots}
logger.info(f"Discord collectors started for {len(app.bots)} bots") logger.info(f"Discord collectors started for {len(app.bots)} bots")
@ -120,6 +119,32 @@ async def send_dm_endpoint(request: SendDMRequest):
} }
@app.post("/typing/dm")
async def trigger_dm_typing(request: TypingDMRequest):
"""Trigger a typing indicator for a DM via the collector"""
collector = app.bots.get(request.bot_id)
if not collector:
raise HTTPException(status_code=404, detail="Bot not found")
try:
success = await collector.collector.trigger_typing_dm(request.user)
except Exception as e:
logger.error(f"Failed to trigger DM typing: {e}")
raise HTTPException(status_code=500, detail=str(e))
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to trigger typing for {request.user}",
)
return {
"success": True,
"user": request.user,
"message": f"Typing triggered for {request.user}",
}
@app.post("/send_channel") @app.post("/send_channel")
async def send_channel_endpoint(request: SendChannelRequest): async def send_channel_endpoint(request: SendChannelRequest):
"""Send a message to a channel via the collector's Discord client""" """Send a message to a channel via the collector's Discord client"""
@ -131,6 +156,9 @@ async def send_channel_endpoint(request: SendChannelRequest):
success = await collector.collector.send_to_channel( success = await collector.collector.send_to_channel(
request.channel_name, request.message request.channel_name, request.message
) )
except Exception as e:
logger.error(f"Failed to send channel message: {e}")
raise HTTPException(status_code=500, detail=str(e))
if success: if success:
return { return {
@ -138,16 +166,38 @@ async def send_channel_endpoint(request: SendChannelRequest):
"message": f"Message sent to channel {request.channel_name}", "message": f"Message sent to channel {request.channel_name}",
"channel": request.channel_name, "channel": request.channel_name,
} }
else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Failed to send message to channel {request.channel_name}", detail=f"Failed to send message to channel {request.channel_name}",
) )
@app.post("/typing/channel")
async def trigger_channel_typing(request: TypingChannelRequest):
"""Trigger a typing indicator for a channel via the collector"""
collector = app.bots.get(request.bot_id)
if not collector:
raise HTTPException(status_code=404, detail="Bot not found")
try:
success = await collector.collector.trigger_typing_channel(request.channel_name)
except Exception as e: except Exception as e:
logger.error(f"Failed to send channel message: {e}") logger.error(f"Failed to trigger channel typing: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
if not success:
raise HTTPException(
status_code=400,
detail=f"Failed to trigger typing for channel {request.channel_name}",
)
return {
"success": True,
"channel": request.channel_name,
"message": f"Typing triggered for channel {request.channel_name}",
}
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
@ -155,9 +205,8 @@ async def health_check():
if not app.bots: if not app.bots:
raise HTTPException(status_code=503, detail="Discord collector not running") raise HTTPException(status_code=503, detail="Discord collector not running")
collector = app_state.collector
return { return {
collector.bot_name: { bot.bot_name: {
"status": "healthy", "status": "healthy",
"connected": not bot.collector.is_closed(), "connected": not bot.collector.is_closed(),
"user": str(bot.collector.user) if bot.collector.user else None, "user": str(bot.collector.user) if bot.collector.user else None,

View File

@ -203,7 +203,7 @@ class MessageCollector(commands.Bot):
async def setup_hook(self): async def setup_hook(self):
"""Register slash commands when the bot is ready.""" """Register slash commands when the bot is ready."""
register_slash_commands(self) register_slash_commands(self, name=self.user.name)
async def on_ready(self): async def on_ready(self):
"""Called when bot connects to Discord""" """Called when bot connects to Discord"""
@ -381,6 +381,27 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to send DM to {user_identifier}: {e}") logger.error(f"Failed to send DM to {user_identifier}: {e}")
return False return False
async def trigger_typing_dm(self, user_identifier: int | str) -> bool:
"""Trigger typing indicator in a DM"""
try:
user = await self.get_user(user_identifier)
if not user:
logger.error(f"User {user_identifier} not found")
return False
channel = user.dm_channel or await user.create_dm()
if not channel:
logger.error(f"DM channel not available for {user_identifier}")
return False
async with channel.typing():
pass
return True
except Exception as e:
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
return False
async def send_to_channel(self, channel_name: str, message: str) -> bool: async def send_to_channel(self, channel_name: str, message: str) -> bool:
"""Send a message to a channel by name across all guilds""" """Send a message to a channel by name across all guilds"""
if not settings.DISCORD_NOTIFICATIONS_ENABLED: if not settings.DISCORD_NOTIFICATIONS_ENABLED:
@ -400,23 +421,21 @@ class MessageCollector(commands.Bot):
logger.error(f"Failed to send message to channel {channel_name}: {e}") logger.error(f"Failed to send message to channel {channel_name}: {e}")
return False return False
async def trigger_typing_channel(self, channel_name: str) -> bool:
async def run_collector(): """Trigger typing indicator in a channel"""
"""Run the Discord message collector""" if not settings.DISCORD_NOTIFICATIONS_ENABLED:
if not settings.DISCORD_BOT_TOKEN: return False
logger.error("DISCORD_BOT_TOKEN not configured")
return
collector = MessageCollector()
try: try:
await collector.start(settings.DISCORD_BOT_TOKEN) channel = await self.get_channel_by_name(channel_name)
if not channel:
logger.error(f"Channel {channel_name} not found")
return False
async with channel.typing():
pass
return True
except Exception as e: except Exception as e:
logger.error(f"Discord collector failed: {e}") logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
raise return False
if __name__ == "__main__":
import asyncio
asyncio.run(run_collector())

View File

@ -41,8 +41,13 @@ class CommandContext:
CommandHandler = Callable[..., CommandResponse] CommandHandler = Callable[..., CommandResponse]
def register_slash_commands(bot: discord.Client) -> None: def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
"""Register the collector slash commands on the provided bot.""" """Register the collector slash commands on the provided bot.
Args:
bot: Discord bot client
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
"""
if getattr(bot, "_memory_commands_registered", False): if getattr(bot, "_memory_commands_registered", False):
return return
@ -54,12 +59,14 @@ def register_slash_commands(bot: discord.Client) -> None:
tree = bot.tree tree = bot.tree
@tree.command(name="memory_prompt", description="Show the current system prompt") @tree.command(
name=f"{name}_show_prompt", description="Show the current system prompt"
)
@discord.app_commands.describe( @discord.app_commands.describe(
scope="Which configuration to inspect", scope="Which configuration to inspect",
user="Target user when the scope is 'user'", user="Target user when the scope is 'user'",
) )
async def prompt_command( async def show_prompt_command(
interaction: discord.Interaction, interaction: discord.Interaction,
scope: ScopeLiteral, scope: ScopeLiteral,
user: discord.User | None = None, user: discord.User | None = None,
@ -72,12 +79,35 @@ def register_slash_commands(bot: discord.Client) -> None:
) )
@tree.command( @tree.command(
name="memory_chattiness", name=f"{name}_set_prompt",
description="Show or update the chattiness threshold for the target", description="Set the system prompt for the target",
)
@discord.app_commands.describe(
scope="Which configuration to modify",
prompt="The system prompt to set",
user="Target user when the scope is 'user'",
)
async def set_prompt_command(
interaction: discord.Interaction,
scope: ScopeLiteral,
prompt: str,
user: discord.User | None = None,
) -> None:
await _run_interaction_command(
interaction,
scope=scope,
handler=handle_set_prompt,
target_user=user,
prompt=prompt,
)
@tree.command(
name=f"{name}_chattiness",
description="Show or update the chattiness for the target",
) )
@discord.app_commands.describe( @discord.app_commands.describe(
scope="Which configuration to inspect", scope="Which configuration to inspect",
value="Optional new threshold value between 0 and 100", value="Optional new chattiness value between 0 and 100",
user="Target user when the scope is 'user'", user="Target user when the scope is 'user'",
) )
async def chattiness_command( async def chattiness_command(
@ -95,7 +125,7 @@ def register_slash_commands(bot: discord.Client) -> None:
) )
@tree.command( @tree.command(
name="memory_ignore", name=f"{name}_ignore",
description="Toggle whether the bot should ignore messages for the target", description="Toggle whether the bot should ignore messages for the target",
) )
@discord.app_commands.describe( @discord.app_commands.describe(
@ -117,7 +147,10 @@ def register_slash_commands(bot: discord.Client) -> None:
ignore_enabled=enabled, ignore_enabled=enabled,
) )
@tree.command(name="memory_summary", description="Show the stored summary for the target") @tree.command(
name=f"{name}_show_summary",
description="Show the stored summary for the target",
)
@discord.app_commands.describe( @discord.app_commands.describe(
scope="Which configuration to inspect", scope="Which configuration to inspect",
user="Target user when the scope is 'user'", user="Target user when the scope is 'user'",
@ -337,6 +370,18 @@ def handle_prompt(context: CommandContext) -> CommandResponse:
) )
def handle_set_prompt(
context: CommandContext,
*,
prompt: str,
) -> CommandResponse:
setattr(context.target, "system_prompt", prompt)
return CommandResponse(
content=f"Updated system prompt for {context.display_name}.",
)
def handle_chattiness( def handle_chattiness(
context: CommandContext, context: CommandContext,
*, *,
@ -347,20 +392,22 @@ def handle_chattiness(
if value is None: if value is None:
return CommandResponse( return CommandResponse(
content=( content=(
f"Chattiness threshold for {context.display_name}: " f"Chattiness for {context.display_name}: "
f"{getattr(model, 'chattiness_threshold', 'not set')}" f"{getattr(model, 'chattiness_threshold', 'not set')}"
) )
) )
if not 0 <= value <= 100: if not 0 <= value <= 100:
raise CommandError("Chattiness threshold must be between 0 and 100.") raise CommandError("Chattiness must be between 0 and 100.")
setattr(model, "chattiness_threshold", value) setattr(model, "chattiness_threshold", value)
return CommandResponse( return CommandResponse(
content=( content=(
f"Updated chattiness threshold for {context.display_name} " f"Updated chattiness for {context.display_name} to {value}."
f"to {value}." "\n"
"This can be treated as how much you want the bot to pipe up by itself, as a percentage, "
"where 0 is never and 100 is always."
) )
) )

View File

@ -10,6 +10,7 @@ from memory.common.celery_app import (
TRACK_GIT_CHANGES, TRACK_GIT_CHANGES,
SYNC_LESSWRONG, SYNC_LESSWRONG,
RUN_SCHEDULED_CALLS, RUN_SCHEDULED_CALLS,
BACKUP_ALL,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,4 +49,8 @@ app.conf.beat_schedule = {
"task": RUN_SCHEDULED_CALLS, "task": RUN_SCHEDULED_CALLS,
"schedule": settings.SCHEDULED_CALL_RUN_INTERVAL, "schedule": settings.SCHEDULED_CALL_RUN_INTERVAL,
}, },
"backup-all": {
"task": BACKUP_ALL,
"schedule": settings.S3_BACKUP_INTERVAL,
},
} }

View File

@ -3,11 +3,12 @@ Import sub-modules so Celery can register their @app.task decorators.
""" """
from memory.workers.tasks import ( from memory.workers.tasks import (
email, backup,
comic,
blogs, blogs,
comic,
discord, discord,
ebook, ebook,
email,
forums, forums,
maintenance, maintenance,
notes, notes,
@ -15,8 +16,8 @@ from memory.workers.tasks import (
scheduled_calls, scheduled_calls,
) # noqa ) # noqa
__all__ = [ __all__ = [
"backup",
"email", "email",
"comic", "comic",
"blogs", "blogs",

View File

@ -0,0 +1,152 @@
"""S3 backup tasks for memory files."""
import base64
import hashlib
import io
import logging
import subprocess
import tarfile
from pathlib import Path
import boto3
from cryptography.fernet import Fernet
from memory.common import settings
from memory.common.celery_app import app, BACKUP_PATH, BACKUP_ALL
logger = logging.getLogger(__name__)
def get_cipher() -> Fernet:
"""Create Fernet cipher from password in settings."""
if not settings.BACKUP_ENCRYPTION_KEY:
raise ValueError("BACKUP_ENCRYPTION_KEY not set in environment")
# Derive key from password using SHA256
key_bytes = hashlib.sha256(settings.BACKUP_ENCRYPTION_KEY.encode()).digest()
key = base64.urlsafe_b64encode(key_bytes)
return Fernet(key)
def create_tarball(directory: Path) -> bytes:
"""Create a gzipped tarball of a directory in memory."""
if not directory.exists():
logger.warning(f"Directory does not exist: {directory}")
return b""
tar_buffer = io.BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar:
tar.add(directory, arcname=directory.name)
tar_buffer.seek(0)
return tar_buffer.read()
def sync_unencrypted_directory(path: Path) -> dict:
"""Sync an unencrypted directory to S3 using aws s3 sync."""
if not path.exists():
logger.warning(f"Directory does not exist: {path}")
return {"synced": False, "reason": "directory_not_found"}
s3_uri = f"s3://{settings.S3_BACKUP_BUCKET}/{settings.S3_BACKUP_PREFIX}/{path.name}"
cmd = [
"aws",
"s3",
"sync",
str(path),
s3_uri,
"--delete",
"--region",
settings.S3_BACKUP_REGION,
]
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True,
)
logger.info(f"Synced {path} to {s3_uri}")
logger.debug(f"Output: {result.stdout}")
return {"synced": True, "directory": path, "s3_uri": s3_uri}
except subprocess.CalledProcessError as e:
logger.error(f"Failed to sync {path}: {e.stderr}")
return {"synced": False, "directory": path, "error": str(e)}
def backup_encrypted_directory(path: Path) -> dict:
"""Create encrypted tarball of directory and upload to S3."""
if not path.exists():
logger.warning(f"Directory does not exist: {path}")
return {"uploaded": False, "reason": "directory_not_found"}
# Create tarball
logger.info(f"Creating tarball of {path}...")
tarball_bytes = create_tarball(path)
if not tarball_bytes:
logger.warning(f"Empty tarball for {path}, skipping")
return {"uploaded": False, "reason": "empty_directory"}
# Encrypt
logger.info(f"Encrypting {path} ({len(tarball_bytes)} bytes)...")
cipher = get_cipher()
encrypted_bytes = cipher.encrypt(tarball_bytes)
# Upload to S3
s3_client = boto3.client("s3", region_name=settings.S3_BACKUP_REGION)
s3_key = f"{settings.S3_BACKUP_PREFIX}/{path.name}.tar.gz.enc"
try:
logger.info(
f"Uploading encrypted {path} to s3://{settings.S3_BACKUP_BUCKET}/{s3_key}"
)
s3_client.put_object(
Bucket=settings.S3_BACKUP_BUCKET,
Key=s3_key,
Body=encrypted_bytes,
ServerSideEncryption="AES256",
)
return {
"uploaded": True,
"directory": path,
"size_bytes": len(encrypted_bytes),
"s3_key": s3_key,
}
except Exception as e:
logger.error(f"Failed to upload {path}: {e}")
return {"uploaded": False, "directory": path, "error": str(e)}
@app.task(name=BACKUP_PATH)
def backup_to_s3(path: Path | str):
"""Backup a specific directory to S3."""
path = Path(path)
if not path.exists():
logger.warning(f"Directory does not exist: {path}")
return {"uploaded": False, "reason": "directory_not_found"}
if path in settings.PRIVATE_DIRS:
return backup_encrypted_directory(path)
return sync_unencrypted_directory(path)
@app.task(name=BACKUP_ALL)
def backup_all_to_s3():
"""Main backup task that syncs unencrypted dirs and uploads encrypted dirs."""
if not settings.S3_BACKUP_ENABLED:
logger.info("S3 backup is disabled")
return {"status": "disabled"}
logger.info("Starting S3 backup...")
for dir_name in settings.storage_dirs:
backup_to_s3.delay((settings.FILE_STORAGE_DIR / dir_name).as_posix())
return {
"status": "success",
"message": f"Started backup for {len(settings.storage_dirs)} directories",
}

View File

@ -7,7 +7,7 @@ import logging
import re import re
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any, cast
from sqlalchemy import exc as sqlalchemy_exc from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy.orm import Session, scoped_session from sqlalchemy.orm import Session, scoped_session
@ -56,8 +56,15 @@ def call_llm(
message: DiscordMessage, message: DiscordMessage,
model: str, model: str,
msgs: list[str] = [], msgs: list[str] = [],
allowed_tools: list[str] = [], allowed_tools: list[str] | None = None,
) -> str | None: ) -> str | None:
provider = create_provider(model=model)
if provider.usage_tracker.is_rate_limited(model):
logger.error(
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
)
return None
tools = make_discord_tools( tools = make_discord_tools(
message.recipient_user.system_user, message.recipient_user.system_user,
message.from_user, message.from_user,
@ -67,13 +74,13 @@ def call_llm(
tools = { tools = {
name: tool name: tool
for name, tool in tools.items() for name, tool in tools.items()
if message.tool_allowed(name) and name in allowed_tools if message.tool_allowed(name)
and (allowed_tools is None or name in allowed_tools)
} }
system_prompt = message.system_prompt or "" system_prompt = message.system_prompt or ""
system_prompt += comm_channel_prompt( system_prompt += comm_channel_prompt(
session, message.recipient_user, message.channel session, message.recipient_user, message.channel
) )
provider = create_provider(model=model)
messages = previous_messages( messages = previous_messages(
session, session,
message.recipient_user and message.recipient_user.id, message.recipient_user and message.recipient_user.id,
@ -127,10 +134,32 @@ def should_process(message: DiscordMessage) -> bool:
if not (res := re.search(r"<number>(.*)</number>", response)): if not (res := re.search(r"<number>(.*)</number>", response)):
return False return False
try: try:
return int(res.group(1)) > message.chattiness_threshold if int(res.group(1)) < 100 - message.chattiness_threshold:
return False
except ValueError: except ValueError:
return False return False
if not (bot_id := _resolve_bot_id(message)):
return False
if message.channel and message.channel.server:
discord.trigger_typing_channel(bot_id, message.channel.name)
else:
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
return True
def _resolve_bot_id(discord_message: DiscordMessage) -> int | None:
recipient = discord_message.recipient_user
if not recipient:
return None
system_user = recipient.system_user
if not system_user:
return None
return getattr(system_user, "id", None)
@app.task(name=PROCESS_DISCORD_MESSAGE) @app.task(name=PROCESS_DISCORD_MESSAGE)
@safe_task_execution @safe_task_execution
@ -152,14 +181,33 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id, "message_id": message_id,
} }
response = call_llm(session, discord_message, settings.DISCORD_MODEL) bot_id = _resolve_bot_id(discord_message)
if not bot_id:
logger.warning(
"No associated Discord bot user for message %s; skipping send",
message_id,
)
return {
"status": "processed",
"message_id": message_id,
}
try:
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
except Exception:
logger.exception("Failed to generate Discord response")
print("response:", response)
if not response: if not response:
pass return {
elif discord_message.channel.server: "status": "processed",
discord.send_to_channel(discord_message.channel.name, response) "message_id": message_id,
}
if discord_message.channel.server:
discord.send_to_channel(bot_id, discord_message.channel.name, response)
else: else:
discord.send_dm(discord_message.from_user.username, response) discord.send_dm(bot_id, discord_message.from_user.username, response)
return { return {
"status": "processed", "status": "processed",

View File

@ -37,12 +37,22 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
if len(message) > 1900: # Leave some buffer if len(message) > 1900: # Leave some buffer
message = message[:1900] + "\n\n... (response truncated)" message = message[:1900] + "\n\n... (response truncated)"
bot_id_value = scheduled_call.user_id
if bot_id_value is None:
logger.warning(
"Scheduled call %s has no associated bot user; skipping Discord send",
scheduled_call.id,
)
return
bot_id = cast(int, bot_id_value)
if discord_user := scheduled_call.discord_user: if discord_user := scheduled_call.discord_user:
logger.info(f"Sending DM to {discord_user.username}: {message}") logger.info(f"Sending DM to {discord_user.username}: {message}")
discord.send_dm(discord_user.username, message) discord.send_dm(bot_id, discord_user.username, message)
elif discord_channel := scheduled_call.discord_channel: elif discord_channel := scheduled_call.discord_channel:
logger.info(f"Broadcasting message to {discord_channel.name}: {message}") logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
discord.broadcast_message(discord_channel.name, message) discord.broadcast_message(bot_id, discord_channel.name, message)
else: else:
logger.warning( logger.warning(
f"No Discord user or channel found for scheduled call {scheduled_call.id}" f"No Discord user or channel found for scheduled call {scheduled_call.id}"

View File

@ -1,7 +1,24 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Iterable
import pytest import pytest
try:
import redis # noqa: F401 # pragma: no cover - optional test dependency
except ModuleNotFoundError: # pragma: no cover - import guard for test envs
import sys
from types import SimpleNamespace
class _RedisStub(SimpleNamespace):
class Redis: # type: ignore[no-redef]
def __init__(self, *args: object, **kwargs: object) -> None:
raise ModuleNotFoundError(
"The 'redis' package is required to use RedisUsageTracker"
)
sys.modules.setdefault("redis", _RedisStub())
from memory.common.llms.redis_usage_tracker import RedisUsageTracker
from memory.common.llms.usage_tracker import ( from memory.common.llms.usage_tracker import (
InMemoryUsageTracker, InMemoryUsageTracker,
RateLimitConfig, RateLimitConfig,
@ -9,6 +26,24 @@ from memory.common.llms.usage_tracker import (
) )
class FakeRedis:
def __init__(self) -> None:
self._store: dict[str, str] = {}
def get(self, key: str) -> str | None:
return self._store.get(key)
def set(self, key: str, value: str) -> None:
self._store[key] = value
def scan_iter(self, match: str) -> Iterable[str]:
from fnmatch import fnmatch
for key in list(self._store.keys()):
if fnmatch(key, match):
yield key
@pytest.fixture @pytest.fixture
def tracker() -> InMemoryUsageTracker: def tracker() -> InMemoryUsageTracker:
config = RateLimitConfig( config = RateLimitConfig(
@ -25,6 +60,23 @@ def tracker() -> InMemoryUsageTracker:
) )
@pytest.fixture
def redis_tracker() -> RedisUsageTracker:
config = RateLimitConfig(
window=timedelta(minutes=1),
max_input_tokens=1_000,
max_output_tokens=2_000,
max_total_tokens=2_500,
)
return RedisUsageTracker(
{
"anthropic/claude-3": config,
"anthropic/haiku": config,
},
redis_client=FakeRedis(),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"window, kwargs", "window, kwargs",
[ [
@ -139,6 +191,22 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
assert tracker.is_rate_limited("openai/gpt-4o") assert tracker.is_rate_limited("openai/gpt-4o")
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
allowance = redis_tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
assert allowance is not None
assert allowance.input_tokens == 900
breakdown = redis_tracker.get_usage_breakdown()
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
items = dict(redis_tracker.iter_state_items())
assert set(items.keys()) == {"anthropic/claude-3", "anthropic/haiku"}
def test_usage_tracker_base_not_instantiable() -> None: def test_usage_tracker_base_not_instantiable() -> None:
class DummyTracker(UsageTracker): class DummyTracker(UsageTracker):
pass pass

View File

@ -4,6 +4,8 @@ import requests
from memory.common import discord from memory.common import discord
BOT_ID = 42
@pytest.fixture @pytest.fixture
def mock_api_url(): def mock_api_url():
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is True assert result is True
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"http://localhost:8000/send_dm", "http://localhost:8000/send_dm",
json={"user": "user123", "message": "Hello!"}, json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
timeout=10, timeout=10,
) )
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception""" """Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error") mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is True assert result is True
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"http://localhost:8000/send_channel", "http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"}, json={
"bot_id": BOT_ID,
"channel_name": "general",
"message": "Announcement!",
},
timeout=10, timeout=10,
) )
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False assert result is False
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception""" """Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout") mock_post.side_effect = requests.Timeout("Request timeout")
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False assert result is False
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
def test_is_collector_healthy_true(mock_get, mock_api_url): def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy""" """Test health check when collector is healthy"""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"status": "healthy"} mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is True assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5) mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
def test_is_collector_healthy_false_status(mock_get, mock_api_url): def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status""" """Test health check when collector returns unhealthy status"""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"} mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails""" """Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused") mock_get.side_effect = requests.ConnectionError("Connection refused")
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
"""Test sending error message to error channel""" """Test sending error message to error channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_error_message("Something broke") result = discord.send_error_message(BOT_ID, "Something broke")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke") mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
"""Test sending activity message to activity channel""" """Test sending activity message to activity channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in") result = discord.send_activity_message(BOT_ID, "User logged in")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in") mock_broadcast.assert_called_once_with(
BOT_ID, "activity", "User logged in"
)
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
"""Test sending discovery message to discovery channel""" """Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern") result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern") mock_broadcast.assert_called_once_with(
BOT_ID, "discoveries", "Found interesting pattern"
)
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
"""Test sending chat message to chat channel""" """Test sending chat message to chat channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot") result = discord.send_chat_message(BOT_ID, "Hello from bot")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot") mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
@patch("memory.common.discord.send_error_message") @patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error): def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification""" """Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong") discord.notify_task_failure(
"test_task", "Something went wrong", bot_id=BOT_ID
)
mock_send_error.assert_called_once() mock_send_error.assert_called_once()
message = mock_send_error.call_args[0][0] assert mock_send_error.call_args[0][0] == BOT_ID
message = mock_send_error.call_args[0][1]
assert "🚨 **Task Failed: test_task**" in message assert "🚨 **Task Failed: test_task**" in message
assert "**Error:** Something went wrong" in message assert "**Error:** Something went wrong" in message
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
"Error occurred", "Error occurred",
task_args=("arg1", 42), task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123}, task_kwargs={"key": "value", "number": 123},
bot_id=BOT_ID,
) )
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
assert "**Args:** `('arg1', 42)" in message assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
"""Test task failure notification with traceback""" """Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test" traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback) discord.notify_task_failure(
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
assert "**Traceback:**" in message assert "**Traceback:**" in message
assert "Exception: test" in message assert "Exception: test" in message
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
"""Test that long error messages are truncated""" """Test that long error messages are truncated"""
long_error = "x" * 600 long_error = "x" * 600
discord.notify_task_failure("test_task", long_error) discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Error should be truncated to 500 chars - check that the full 600 char string is not there # Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message assert "**Error:** " + long_error[:500] in message
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
"""Test that long tracebacks are truncated""" """Test that long tracebacks are truncated"""
long_traceback = "x" * 1000 long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback) discord.notify_task_failure(
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Traceback should show last 800 chars # Traceback should show last 800 chars
assert long_traceback[-800:] in message assert long_traceback[-800:] in message
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated""" """Test that long task arguments are truncated"""
long_args = ("x" * 300,) long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args) discord.notify_task_failure(
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Args should be truncated to 200 chars # Args should be truncated to 200 chars
assert ( assert (
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated""" """Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300} long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs) discord.notify_task_failure(
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Kwargs should be truncated to 200 chars # Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210 assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error): def test_notify_task_failure_disabled(mock_send_error):
"""Test that notifications are not sent when disabled""" """Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred") discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_not_called() mock_send_error.assert_not_called()
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
mock_send_error.side_effect = Exception("Failed to send") mock_send_error.side_effect = Exception("Failed to send")
# Should not raise # Should not raise
discord.notify_task_failure("test_task", "Error occurred") discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_called_once() mock_send_error.assert_called_once()
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
): ):
"""Test that convenience functions use the correct channel settings""" """Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"): with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message) function(BOT_ID, message)
mock_broadcast.assert_called_once_with("test-channel", message) mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
@patch("requests.post") @patch("requests.post")
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
mock_post.return_value = mock_response mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general" message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars) result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
assert result is True assert result is True
call_args = mock_post.call_args call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars json_payload = call_args[1]["json"]
assert json_payload["message"] == message_with_special_chars
assert json_payload["bot_id"] == BOT_ID
@patch("requests.post") @patch("requests.post")
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
mock_post.return_value = mock_response mock_post.return_value = mock_response
long_message = "A" * 2000 long_message = "A" * 2000
result = discord.broadcast_message("general", long_message) result = discord.broadcast_message(BOT_ID, "general", long_message)
assert result is True assert result is True
call_args = mock_post.call_args call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message json_payload = call_args[1]["json"]
assert json_payload["message"] == long_message
assert json_payload["bot_id"] == BOT_ID
@patch("requests.get") @patch("requests.get")
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False

View File

@ -4,6 +4,8 @@ import requests
from memory.common import discord from memory.common import discord
BOT_ID = 42
@pytest.fixture @pytest.fixture
def mock_api_url(): def mock_api_url():
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is True assert result is True
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"http://localhost:8000/send_dm", "http://localhost:8000/send_dm",
json={"user": "user123", "message": "Hello!"}, json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
timeout=10, timeout=10,
) )
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
"""Test DM sending when request raises exception""" """Test DM sending when request raises exception"""
mock_post.side_effect = requests.RequestException("Network error") mock_post.side_effect = requests.RequestException("Network error")
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found") mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.send_dm("user123", "Hello!") result = discord.send_dm(BOT_ID, "user123", "Hello!")
assert result is False assert result is False
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is True assert result is True
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"http://localhost:8000/send_channel", "http://localhost:8000/send_channel",
json={"channel_name": "general", "message": "Announcement!"}, json={
"bot_id": BOT_ID,
"channel_name": "general",
"message": "Announcement!",
},
timeout=10, timeout=10,
) )
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False assert result is False
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
"""Test channel message broadcast with exception""" """Test channel message broadcast with exception"""
mock_post.side_effect = requests.Timeout("Request timeout") mock_post.side_effect = requests.Timeout("Request timeout")
result = discord.broadcast_message("general", "Announcement!") result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
assert result is False assert result is False
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
def test_is_collector_healthy_true(mock_get, mock_api_url): def test_is_collector_healthy_true(mock_get, mock_api_url):
"""Test health check when collector is healthy""" """Test health check when collector is healthy"""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"status": "healthy"} mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is True assert result is True
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5) mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
def test_is_collector_healthy_false_status(mock_get, mock_api_url): def test_is_collector_healthy_false_status(mock_get, mock_api_url):
"""Test health check when collector returns unhealthy status""" """Test health check when collector returns unhealthy status"""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"status": "unhealthy"} mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
"""Test health check when request fails""" """Test health check when request fails"""
mock_get.side_effect = requests.ConnectionError("Connection refused") mock_get.side_effect = requests.ConnectionError("Connection refused")
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
"""Test sending error message to error channel""" """Test sending error message to error channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_error_message("Something broke") result = discord.send_error_message(BOT_ID, "Something broke")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("errors", "Something broke") mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
"""Test sending activity message to activity channel""" """Test sending activity message to activity channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_activity_message("User logged in") result = discord.send_activity_message(BOT_ID, "User logged in")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("activity", "User logged in") mock_broadcast.assert_called_once_with(
BOT_ID, "activity", "User logged in"
)
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
"""Test sending discovery message to discovery channel""" """Test sending discovery message to discovery channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_discovery_message("Found interesting pattern") result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern") mock_broadcast.assert_called_once_with(
BOT_ID, "discoveries", "Found interesting pattern"
)
@patch("memory.common.discord.broadcast_message") @patch("memory.common.discord.broadcast_message")
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
"""Test sending chat message to chat channel""" """Test sending chat message to chat channel"""
mock_broadcast.return_value = True mock_broadcast.return_value = True
result = discord.send_chat_message("Hello from bot") result = discord.send_chat_message(BOT_ID, "Hello from bot")
assert result is True assert result is True
mock_broadcast.assert_called_once_with("chat", "Hello from bot") mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
@patch("memory.common.discord.send_error_message") @patch("memory.common.discord.send_error_message")
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
def test_notify_task_failure_basic(mock_send_error): def test_notify_task_failure_basic(mock_send_error):
"""Test basic task failure notification""" """Test basic task failure notification"""
discord.notify_task_failure("test_task", "Something went wrong") discord.notify_task_failure(
"test_task", "Something went wrong", bot_id=BOT_ID
)
mock_send_error.assert_called_once() mock_send_error.assert_called_once()
message = mock_send_error.call_args[0][0] assert mock_send_error.call_args[0][0] == BOT_ID
message = mock_send_error.call_args[0][1]
assert "🚨 **Task Failed: test_task**" in message assert "🚨 **Task Failed: test_task**" in message
assert "**Error:** Something went wrong" in message assert "**Error:** Something went wrong" in message
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
"Error occurred", "Error occurred",
task_args=("arg1", 42), task_args=("arg1", 42),
task_kwargs={"key": "value", "number": 123}, task_kwargs={"key": "value", "number": 123},
bot_id=BOT_ID,
) )
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
assert "**Args:** `('arg1', 42)" in message assert "**Args:** `('arg1', 42)" in message
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
"""Test task failure notification with traceback""" """Test task failure notification with traceback"""
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test" traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback) discord.notify_task_failure(
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
assert "**Traceback:**" in message assert "**Traceback:**" in message
assert "Exception: test" in message assert "Exception: test" in message
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
"""Test that long error messages are truncated""" """Test that long error messages are truncated"""
long_error = "x" * 600 long_error = "x" * 600
discord.notify_task_failure("test_task", long_error) discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Error should be truncated to 500 chars - check that the full 600 char string is not there # Error should be truncated to 500 chars - check that the full 600 char string is not there
assert "**Error:** " + long_error[:500] in message assert "**Error:** " + long_error[:500] in message
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
"""Test that long tracebacks are truncated""" """Test that long tracebacks are truncated"""
long_traceback = "x" * 1000 long_traceback = "x" * 1000
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback) discord.notify_task_failure(
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Traceback should show last 800 chars # Traceback should show last 800 chars
assert long_traceback[-800:] in message assert long_traceback[-800:] in message
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
"""Test that long task arguments are truncated""" """Test that long task arguments are truncated"""
long_args = ("x" * 300,) long_args = ("x" * 300,)
discord.notify_task_failure("test_task", "Error", task_args=long_args) discord.notify_task_failure(
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Args should be truncated to 200 chars # Args should be truncated to 200 chars
assert ( assert (
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
"""Test that long task kwargs are truncated""" """Test that long task kwargs are truncated"""
long_kwargs = {"key": "x" * 300} long_kwargs = {"key": "x" * 300}
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs) discord.notify_task_failure(
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
)
message = mock_send_error.call_args[0][0] message = mock_send_error.call_args[0][1]
# Kwargs should be truncated to 200 chars # Kwargs should be truncated to 200 chars
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210 assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False) @patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
def test_notify_task_failure_disabled(mock_send_error): def test_notify_task_failure_disabled(mock_send_error):
"""Test that notifications are not sent when disabled""" """Test that notifications are not sent when disabled"""
discord.notify_task_failure("test_task", "Error occurred") discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_not_called() mock_send_error.assert_not_called()
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
mock_send_error.side_effect = Exception("Failed to send") mock_send_error.side_effect = Exception("Failed to send")
# Should not raise # Should not raise
discord.notify_task_failure("test_task", "Error occurred") discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
mock_send_error.assert_called_once() mock_send_error.assert_called_once()
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
): ):
"""Test that convenience functions use the correct channel settings""" """Test that convenience functions use the correct channel settings"""
with patch(f"memory.common.settings.{channel_setting}", "test-channel"): with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
function(message) function(BOT_ID, message)
mock_broadcast.assert_called_once_with("test-channel", message) mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
@patch("requests.post") @patch("requests.post")
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
mock_post.return_value = mock_response mock_post.return_value = mock_response
message_with_special_chars = "Hello! 🎉 <@123> #general" message_with_special_chars = "Hello! 🎉 <@123> #general"
result = discord.send_dm("user123", message_with_special_chars) result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
assert result is True assert result is True
call_args = mock_post.call_args call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == message_with_special_chars json_payload = call_args[1]["json"]
assert json_payload["message"] == message_with_special_chars
assert json_payload["bot_id"] == BOT_ID
@patch("requests.post") @patch("requests.post")
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
mock_post.return_value = mock_response mock_post.return_value = mock_response
long_message = "A" * 2000 long_message = "A" * 2000
result = discord.broadcast_message("general", long_message) result = discord.broadcast_message(BOT_ID, "general", long_message)
assert result is True assert result is True
call_args = mock_post.call_args call_args = mock_post.call_args
assert call_args[1]["json"]["message"] == long_message json_payload = call_args[1]["json"]
assert json_payload["message"] == long_message
assert json_payload["bot_id"] == BOT_ID
@patch("requests.get") @patch("requests.get")
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response mock_get.return_value = mock_response
result = discord.is_collector_healthy() result = discord.is_collector_healthy(BOT_ID)
assert result is False assert result is False

View File

@ -0,0 +1,393 @@
import io
import subprocess
import tarfile
from unittest.mock import Mock, patch, MagicMock
import pytest
from botocore.exceptions import ClientError
from memory.common import settings
from memory.workers.tasks import backup
@pytest.fixture
def sample_files():
"""Create sample files in memory_files structure."""
base = settings.FILE_STORAGE_DIR
dirs_with_files = {
"emails": ["email1.txt", "email2.txt"],
"notes": ["note1.md", "note2.md"],
"photos": ["photo1.jpg"],
"comics": ["comic1.png", "comic2.png"],
"ebooks": ["book1.epub"],
"webpages": ["page1.html"],
}
for dir_name, filenames in dirs_with_files.items():
dir_path = base / dir_name
dir_path.mkdir(parents=True, exist_ok=True)
for filename in filenames:
file_path = dir_path / filename
content = f"Content of {dir_name}/{filename}\n" * 100
file_path.write_text(content)
@pytest.fixture
def mock_s3_client():
"""Mock boto3 S3 client."""
with patch("boto3.client") as mock_client:
s3_mock = MagicMock()
mock_client.return_value = s3_mock
yield s3_mock
@pytest.fixture
def backup_settings():
"""Mock backup settings."""
with (
patch.object(settings, "S3_BACKUP_ENABLED", True),
patch.object(settings, "BACKUP_ENCRYPTION_KEY", "test-password-123"),
patch.object(settings, "S3_BACKUP_BUCKET", "test-bucket"),
patch.object(settings, "S3_BACKUP_PREFIX", "test-prefix"),
patch.object(settings, "S3_BACKUP_REGION", "us-east-1"),
):
yield
@pytest.fixture
def get_test_path():
"""Helper to construct test paths."""
return lambda dir_name: settings.FILE_STORAGE_DIR / dir_name
@pytest.mark.parametrize(
"data,key",
[
(b"This is a test message", "my-secret-key"),
(b"\x00\x01\x02\xff" * 10000, "another-key"),
(b"x" * 1000000, "large-data-key"),
],
)
def test_encrypt_decrypt_roundtrip(data, key):
"""Test encryption and decryption produces original data."""
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", key):
cipher = backup.get_cipher()
encrypted = cipher.encrypt(data)
decrypted = cipher.decrypt(encrypted)
assert decrypted == data
assert encrypted != data
def test_encrypt_decrypt_tarball(sample_files):
"""Test full tarball creation, encryption, and decryption."""
emails_dir = settings.FILE_STORAGE_DIR / "emails"
# Create tarball
tarball_bytes = backup.create_tarball(emails_dir)
assert len(tarball_bytes) > 0
# Encrypt
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "tarball-key"):
cipher = backup.get_cipher()
encrypted = cipher.encrypt(tarball_bytes)
# Decrypt
decrypted = cipher.decrypt(encrypted)
assert decrypted == tarball_bytes
# Verify tarball can be extracted
tar_buffer = io.BytesIO(decrypted)
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
members = tar.getmembers()
assert len(members) >= 2 # At least 2 email files
# Extract and verify content
for member in members:
if member.isfile():
extracted = tar.extractfile(member)
assert extracted is not None
content = extracted.read().decode()
assert "Content of emails/" in content
def test_different_keys_produce_different_ciphertext():
"""Test that different encryption keys produce different ciphertext."""
data = b"Same data encrypted with different keys"
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key1"):
cipher1 = backup.get_cipher()
encrypted1 = cipher1.encrypt(data)
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key2"):
cipher2 = backup.get_cipher()
encrypted2 = cipher2.encrypt(data)
assert encrypted1 != encrypted2
def test_missing_encryption_key_raises_error():
"""Test that missing encryption key raises ValueError."""
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", ""):
with pytest.raises(ValueError, match="BACKUP_ENCRYPTION_KEY not set"):
backup.get_cipher()
def test_create_tarball_with_files(sample_files):
"""Test creating tarball from directory with files."""
notes_dir = settings.FILE_STORAGE_DIR / "notes"
tarball_bytes = backup.create_tarball(notes_dir)
assert len(tarball_bytes) > 0
# Verify it's a valid gzipped tarball
tar_buffer = io.BytesIO(tarball_bytes)
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
members = tar.getmembers()
filenames = [m.name for m in members if m.isfile()]
assert len(filenames) >= 2
assert any("note1.md" in f for f in filenames)
assert any("note2.md" in f for f in filenames)
def test_create_tarball_nonexistent_directory():
"""Test creating tarball from nonexistent directory."""
nonexistent = settings.FILE_STORAGE_DIR / "does_not_exist"
tarball_bytes = backup.create_tarball(nonexistent)
assert tarball_bytes == b""
def test_create_tarball_empty_directory():
"""Test creating tarball from empty directory."""
empty_dir = settings.FILE_STORAGE_DIR / "empty"
empty_dir.mkdir(parents=True, exist_ok=True)
tarball_bytes = backup.create_tarball(empty_dir)
# Should create tarball with just the directory entry
assert len(tarball_bytes) > 0
tar_buffer = io.BytesIO(tarball_bytes)
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
members = tar.getmembers()
assert len(members) >= 1
assert members[0].isdir()
def test_sync_unencrypted_success(sample_files, backup_settings):
"""Test successful sync of unencrypted directory."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = Mock(stdout="Synced files", returncode=0)
comics_path = settings.FILE_STORAGE_DIR / "comics"
result = backup.sync_unencrypted_directory(comics_path)
assert result["synced"] is True
assert result["directory"] == comics_path
assert "s3_uri" in result
assert "test-bucket" in result["s3_uri"]
assert "test-prefix/comics" in result["s3_uri"]
# Verify aws s3 sync was called correctly
mock_run.assert_called_once()
call_args = mock_run.call_args[0][0]
assert call_args[0] == "aws"
assert call_args[1] == "s3"
assert call_args[2] == "sync"
assert "--delete" in call_args
assert "--region" in call_args
def test_sync_unencrypted_nonexistent_directory(backup_settings):
"""Test syncing nonexistent directory."""
nonexistent_path = settings.FILE_STORAGE_DIR / "does_not_exist"
result = backup.sync_unencrypted_directory(nonexistent_path)
assert result["synced"] is False
assert result["reason"] == "directory_not_found"
def test_sync_unencrypted_aws_cli_failure(sample_files, backup_settings):
"""Test handling of AWS CLI failure."""
with patch("subprocess.run") as mock_run:
mock_run.side_effect = subprocess.CalledProcessError(
1, "aws", stderr="AWS CLI error"
)
comics_path = settings.FILE_STORAGE_DIR / "comics"
result = backup.sync_unencrypted_directory(comics_path)
assert result["synced"] is False
assert "error" in result
def test_backup_encrypted_success(
sample_files, mock_s3_client, backup_settings, get_test_path
):
"""Test successful encrypted backup."""
result = backup.backup_encrypted_directory(get_test_path("emails"))
assert result["uploaded"] is True
assert result["size_bytes"] > 0
assert result["s3_key"].endswith("emails.tar.gz.enc")
call_kwargs = mock_s3_client.put_object.call_args[1]
assert call_kwargs["Bucket"] == "test-bucket"
assert call_kwargs["ServerSideEncryption"] == "AES256"
def test_backup_encrypted_nonexistent_directory(
mock_s3_client, backup_settings, get_test_path
):
"""Test backing up nonexistent directory."""
result = backup.backup_encrypted_directory(get_test_path("does_not_exist"))
assert result["uploaded"] is False
assert result["reason"] == "directory_not_found"
mock_s3_client.put_object.assert_not_called()
def test_backup_encrypted_empty_directory(
mock_s3_client, backup_settings, get_test_path
):
"""Test backing up empty directory."""
empty_dir = get_test_path("empty_encrypted")
empty_dir.mkdir(parents=True, exist_ok=True)
result = backup.backup_encrypted_directory(empty_dir)
assert "uploaded" in result
def test_backup_encrypted_s3_failure(
sample_files, mock_s3_client, backup_settings, get_test_path
):
"""Test handling of S3 upload failure."""
mock_s3_client.put_object.side_effect = ClientError(
{"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, "PutObject"
)
result = backup.backup_encrypted_directory(get_test_path("notes"))
assert result["uploaded"] is False
assert "error" in result
def test_backup_encrypted_data_integrity(
sample_files, mock_s3_client, backup_settings, get_test_path
):
"""Test that encrypted backup maintains data integrity through full cycle."""
result = backup.backup_encrypted_directory(get_test_path("notes"))
assert result["uploaded"] is True
# Decrypt uploaded data
cipher = backup.get_cipher()
encrypted_data = mock_s3_client.put_object.call_args[1]["Body"]
decrypted_tarball = cipher.decrypt(encrypted_data)
# Verify content
tar_buffer = io.BytesIO(decrypted_tarball)
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
note1_found = False
for member in tar.getmembers():
if member.name.endswith("note1.md") and member.isfile():
content = tar.extractfile(member).read().decode()
assert "Content of notes/note1.md" in content
note1_found = True
assert note1_found, "note1.md not found in tarball"
def test_backup_disabled():
"""Test that backup returns early when disabled."""
with patch.object(settings, "S3_BACKUP_ENABLED", False):
result = backup.backup_all_to_s3()
assert result["status"] == "disabled"
def test_backup_full_execution(sample_files, mock_s3_client, backup_settings):
"""Test full backup execution dispatches tasks for all directories."""
with patch.object(backup, "backup_to_s3") as mock_task:
mock_task.delay = Mock()
result = backup.backup_all_to_s3()
assert result["status"] == "success"
assert "message" in result
# Verify task was queued for each storage directory
assert mock_task.delay.call_count == len(settings.storage_dirs)
def test_backup_handles_partial_failures(
sample_files, mock_s3_client, backup_settings, get_test_path
):
"""Test that backup continues even if some directories fail."""
with patch("subprocess.run") as mock_run:
mock_run.side_effect = subprocess.CalledProcessError(
1, "aws", stderr="Sync failed"
)
result = backup.sync_unencrypted_directory(get_test_path("comics"))
assert result["synced"] is False
assert "error" in result
def test_same_key_different_runs_different_ciphertext():
"""Test that Fernet produces different ciphertext each run (due to nonce)."""
data = b"Consistent data"
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "same-key"):
cipher = backup.get_cipher()
encrypted1 = cipher.encrypt(data)
encrypted2 = cipher.encrypt(data)
# Should be different due to random nonce, but both should decrypt to same value
assert encrypted1 != encrypted2
decrypted1 = cipher.decrypt(encrypted1)
decrypted2 = cipher.decrypt(encrypted2)
assert decrypted1 == decrypted2 == data
def test_key_derivation_consistency():
"""Test that same password produces same encryption key."""
password = "test-password"
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", password):
cipher1 = backup.get_cipher()
cipher2 = backup.get_cipher()
# Both should be able to decrypt each other's ciphertext
data = b"Test data"
encrypted = cipher1.encrypt(data)
decrypted = cipher2.decrypt(encrypted)
assert decrypted == data
@pytest.mark.parametrize(
"dir_name,is_private",
[
("emails", True),
("notes", True),
("photos", True),
("comics", False),
("ebooks", False),
("webpages", False),
("lesswrong", False),
("chunks", False),
],
)
def test_directory_encryption_classification(dir_name, is_private, backup_settings):
"""Test that directories are correctly classified as encrypted or not."""
# Create a mock PRIVATE_DIRS list
private_dirs = ["emails", "notes", "photos"]
with patch.object(
settings, "PRIVATE_DIRS", [settings.FILE_STORAGE_DIR / d for d in private_dirs]
):
test_path = settings.FILE_STORAGE_DIR / dir_name
is_in_private = test_path in settings.PRIVATE_DIRS
assert is_in_private == is_private

View File

@ -3,6 +3,7 @@ from datetime import datetime, timezone
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from memory.common.db.models import ( from memory.common.db.models import (
DiscordBotUser,
DiscordMessage, DiscordMessage,
DiscordUser, DiscordUser,
DiscordServer, DiscordServer,
@ -12,12 +13,25 @@ from memory.workers.tasks import discord
@pytest.fixture @pytest.fixture
def mock_discord_user(db_session): def discord_bot_user(db_session):
bot = DiscordBotUser.create_with_api_key(
discord_users=[],
name="Test Bot",
email="bot@example.com",
)
db_session.add(bot)
db_session.commit()
return bot
@pytest.fixture
def mock_discord_user(db_session, discord_bot_user):
"""Create a Discord user for testing.""" """Create a Discord user for testing."""
user = DiscordUser( user = DiscordUser(
id=123456789, id=123456789,
username="testuser", username="testuser",
ignore_messages=False, ignore_messages=False,
system_user_id=discord_bot_user.id,
) )
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()

View File

@ -3,17 +3,23 @@ from datetime import datetime, timezone, timedelta
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import uuid import uuid
from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer from memory.common.db.models import (
ScheduledLLMCall,
DiscordBotUser,
DiscordUser,
DiscordChannel,
DiscordServer,
)
from memory.workers.tasks import scheduled_calls from memory.workers.tasks import scheduled_calls
@pytest.fixture @pytest.fixture
def sample_user(db_session): def sample_user(db_session):
"""Create a sample user for testing.""" """Create a sample user for testing."""
user = HumanUser.create_with_password( user = DiscordBotUser.create_with_api_key(
name="testuser", discord_users=[],
email="test@example.com", name="testbot",
password="password", email="bot@example.com",
) )
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
@ -124,6 +130,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
scheduled_calls._send_to_discord(pending_scheduled_call, response) scheduled_calls._send_to_discord(pending_scheduled_call, response)
mock_send_dm.assert_called_once_with( mock_send_dm.assert_called_once_with(
pending_scheduled_call.user_id,
"testuser", # username, not ID "testuser", # username, not ID
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.", "**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
) )
@ -137,6 +144,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
scheduled_calls._send_to_discord(completed_scheduled_call, response) scheduled_calls._send_to_discord(completed_scheduled_call, response)
mock_broadcast.assert_called_once_with( mock_broadcast.assert_called_once_with(
completed_scheduled_call.user_id,
"test-channel", # channel name, not ID "test-channel", # channel name, not ID
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.", "**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
) )
@ -151,7 +159,8 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
# Verify the message was truncated # Verify the message was truncated
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
message = args[1] assert args[0] == pending_scheduled_call.user_id
message = args[2]
assert len(message) <= 1950 # Should be truncated assert len(message) <= 1950 # Should be truncated
assert message.endswith("... (response truncated)") assert message.endswith("... (response truncated)")
@ -164,7 +173,8 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response) scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
message = args[1] assert args[0] == pending_scheduled_call.user_id
message = args[2]
assert not message.endswith("... (response truncated)") assert not message.endswith("... (response truncated)")
assert "This is a normal length response." in message assert "This is a normal length response." in message
@ -574,6 +584,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
mock_discord_user.username = "testuser" mock_discord_user.username = "testuser"
mock_call = Mock() mock_call = Mock()
mock_call.user_id = 987
mock_call.topic = topic mock_call.topic = topic
mock_call.model = model mock_call.model = model
mock_call.discord_user = mock_discord_user mock_call.discord_user = mock_discord_user
@ -583,7 +594,8 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
# Get the actual message that was sent # Get the actual message that was sent
args, kwargs = mock_send_dm.call_args args, kwargs = mock_send_dm.call_args
actual_message = args[1] assert args[0] == mock_call.user_id
actual_message = args[2]
# Verify all expected parts are in the message # Verify all expected parts are in the message
for expected_part in expected_in_message: for expected_part in expected_in_message:

182
tools/backup_databases.sh Normal file
View File

@ -0,0 +1,182 @@
#!/bin/bash
# Backup Postgres and Qdrant databases to S3
set -euo pipefail
# Install AWS CLI if not present (postgres:15 image doesn't include it)
if ! command -v aws >/dev/null 2>&1; then
echo "Installing AWS CLI, wget, and jq..."
apt-get update -qq && apt-get install -y -qq awscli wget jq >/dev/null 2>&1
fi
# Configuration - read from environment or use defaults
BUCKET="${S3_BACKUP_BUCKET:-equistamp-memory-backup}"
PREFIX="${S3_BACKUP_PREFIX:-Daniel}/databases"
REGION="${S3_BACKUP_REGION:-eu-central-1}"
PASSWORD="${BACKUP_ENCRYPTION_KEY:?BACKUP_ENCRYPTION_KEY not set}"
MAX_BACKUPS="${MAX_BACKUPS:-30}" # Keep last N backups
# Service names (docker-compose network)
POSTGRES_HOST="${POSTGRES_HOST:-postgres}"
POSTGRES_USER="${POSTGRES_USER:-kb}"
POSTGRES_DB="${POSTGRES_DB:-kb}"
QDRANT_URL="${QDRANT_URL:-http://qdrant:6333}"
# Timestamp for backups
DATE=$(date +%Y%m%d-%H%M%S)
DATE_SIMPLE=$(date +%Y%m%d)
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*"
}
error() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] ERROR: $*" >&2
}
# Clean old backups - keep only last N
cleanup_old_backups() {
local prefix=$1
local pattern=$2 # e.g., "postgres-" or "qdrant-"
log "Checking for old ${pattern} backups to clean up..."
# List all backups matching pattern, sorted by date (oldest first)
local backups
backups=$(aws s3 ls "s3://${BUCKET}/${prefix}/" --region "${REGION}" | \
grep "${pattern}" | \
awk '{print $4}' | \
sort)
local count=$(echo "$backups" | wc -l)
if [ "$count" -le "$MAX_BACKUPS" ]; then
log "Found ${count} ${pattern} backups (max: ${MAX_BACKUPS}), no cleanup needed"
return 0
fi
local to_delete=$((count - MAX_BACKUPS))
log "Found ${count} ${pattern} backups, deleting ${to_delete} oldest..."
echo "$backups" | head -n "$to_delete" | while read -r file; do
if [ -n "$file" ]; then
log "Deleting old backup: ${file}"
aws s3 rm "s3://${BUCKET}/${prefix}/${file}" --region "${REGION}"
fi
done
}
# Backup Postgres
backup_postgres() {
log "Starting Postgres backup..."
local output_path="s3://${BUCKET}/${PREFIX}/postgres-${DATE_SIMPLE}.sql.gz.enc"
# Use pg_dump directly with service name (no docker exec needed)
export PGPASSWORD=$(cat "${POSTGRES_PASSWORD_FILE}")
if pg_dump -h "${POSTGRES_HOST}" -U "${POSTGRES_USER}" "${POSTGRES_DB}" 2>/dev/null | \
gzip | \
openssl enc -aes-256-cbc -salt -pbkdf2 -pass "pass:${PASSWORD}" | \
aws s3 cp - "${output_path}" --region "${REGION}"; then
log "Postgres backup completed: ${output_path}"
unset PGPASSWORD
cleanup_old_backups "${PREFIX}" "postgres-"
return 0
else
error "Postgres backup failed"
unset PGPASSWORD
return 1
fi
}
# Backup Qdrant
backup_qdrant() {
log "Starting Qdrant backup..."
# Create snapshot via HTTP API (no docker exec needed)
local snapshot_response
if ! snapshot_response=$(wget -q -O - --post-data='{}' \
--header='Content-Type: application/json' \
"${QDRANT_URL}/snapshots" 2>/dev/null); then
error "Failed to create Qdrant snapshot"
return 1
fi
local snapshot_name
# Parse snapshot name - wget/busybox may not have jq, so use grep/sed
if command -v jq >/dev/null 2>&1; then
snapshot_name=$(echo "${snapshot_response}" | jq -r '.result.name // empty')
else
# Fallback: parse JSON without jq (fragile but works for simple case)
snapshot_name=$(echo "${snapshot_response}" | grep -o '"name":"[^"]*"' | cut -d'"' -f4)
fi
if [ -z "${snapshot_name}" ]; then
error "Could not extract snapshot name from response: ${snapshot_response}"
return 1
fi
log "Created Qdrant snapshot: ${snapshot_name}"
# Download snapshot and upload to S3
local output_path="s3://${BUCKET}/${PREFIX}/qdrant-${DATE_SIMPLE}.snapshot.enc"
if wget -q -O - "${QDRANT_URL}/snapshots/${snapshot_name}" | \
openssl enc -aes-256-cbc -salt -pbkdf2 -pass "pass:${PASSWORD}" | \
aws s3 cp - "${output_path}" --region "${REGION}"; then
log "Qdrant backup completed: ${output_path}"
# Delete the snapshot from Qdrant
if wget -q -O - --method=DELETE \
"${QDRANT_URL}/snapshots/${snapshot_name}" >/dev/null 2>&1; then
log "Deleted Qdrant snapshot: ${snapshot_name}"
else
error "Failed to delete Qdrant snapshot: ${snapshot_name}"
fi
cleanup_old_backups "${PREFIX}" "qdrant-"
return 0
else
error "Qdrant backup failed"
# Try to clean up snapshot
wget -q -O - --method=DELETE \
"${QDRANT_URL}/snapshots/${snapshot_name}" >/dev/null 2>&1 || true
return 1
fi
}
# Main execution
main() {
log "Database backup started"
local postgres_result=0
local qdrant_result=0
# Backup Postgres
if ! backup_postgres; then
postgres_result=1
fi
# Backup Qdrant
if ! backup_qdrant; then
qdrant_result=1
fi
# Summary
if [ $postgres_result -eq 0 ] && [ $qdrant_result -eq 0 ]; then
log "All database backups completed successfully"
return 0
elif [ $postgres_result -ne 0 ] && [ $qdrant_result -ne 0 ]; then
error "All database backups failed"
return 1
else
error "Some database backups failed (Postgres: ${postgres_result}, Qdrant: ${qdrant_result})"
return 1
fi
}
# Run main function
main

View File

@ -51,6 +51,8 @@ from memory.common.celery_app import (
UPDATE_METADATA_FOR_SOURCE_ITEMS, UPDATE_METADATA_FOR_SOURCE_ITEMS,
SETUP_GIT_NOTES, SETUP_GIT_NOTES,
TRACK_GIT_CHANGES, TRACK_GIT_CHANGES,
BACKUP_TO_S3_DIRECTORY,
BACKUP_ALL,
app, app,
) )
@ -97,6 +99,10 @@ TASK_MAPPINGS = {
"setup_git_notes": SETUP_GIT_NOTES, "setup_git_notes": SETUP_GIT_NOTES,
"track_git_changes": TRACK_GIT_CHANGES, "track_git_changes": TRACK_GIT_CHANGES,
}, },
"backup": {
"backup_to_s3_directory": BACKUP_TO_S3_DIRECTORY,
"backup_all": BACKUP_ALL,
},
} }
QUEUE_MAPPINGS = { QUEUE_MAPPINGS = {
"email": "email", "email": "email",
@ -177,6 +183,28 @@ def execute_task(ctx, category: str, task_name: str, **kwargs):
sys.exit(1) sys.exit(1)
@cli.group()
@click.pass_context
def backup(ctx):
"""Backup-related tasks."""
pass
@backup.command("all")
@click.pass_context
def backup_all(ctx):
"""Backup all directories."""
execute_task(ctx, "backup", "backup_all")
@backup.command("path")
@click.option("--path", required=True, help="Path to backup")
@click.pass_context
def backup_to_s3_directory(ctx, path):
"""Backup a specific path."""
execute_task(ctx, "backup", "backup_to_s3_directory", path=path)
@cli.group() @cli.group()
@click.pass_context @click.pass_context
def email(ctx): def email(ctx):