mirror of
https://github.com/mruwnik/memory.git
synced 2025-12-16 00:51:18 +01:00
Compare commits
10 Commits
8af07f0dac
...
e95a082147
| Author | SHA1 | Date | |
|---|---|---|---|
| e95a082147 | |||
| c42513100b | |||
| a5bc53326d | |||
| 131427255a | |||
|
|
ff3ca4f109 | ||
| 3b216953ab | |||
|
|
d7e403fb83 | ||
| 57145ac7b4 | |||
|
|
814090dccb | ||
|
|
9639fa3dd7 |
114
db/migrations/versions/20251101_203810_allow_no_chattiness.py
Normal file
114
db/migrations/versions/20251101_203810_allow_no_chattiness.py
Normal file
@ -0,0 +1,114 @@
|
||||
"""allow no chattiness
|
||||
|
||||
Revision ID: 2024235e37e7
|
||||
Revises: 7dc03dbf184c
|
||||
Create Date: 2025-11-01 20:38:10.849651
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2024235e37e7"
|
||||
down_revision: Union[str, None] = "7dc03dbf184c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"discord_channels",
|
||||
"chattiness_threshold",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("50"),
|
||||
)
|
||||
op.drop_column("discord_channels", "track_messages")
|
||||
op.alter_column(
|
||||
"discord_servers",
|
||||
"chattiness_threshold",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("50"),
|
||||
)
|
||||
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
|
||||
op.create_index(
|
||||
"discord_servers_active_idx",
|
||||
"discord_servers",
|
||||
["ignore_messages", "last_sync_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.drop_column("discord_servers", "track_messages")
|
||||
op.alter_column(
|
||||
"discord_users",
|
||||
"chattiness_threshold",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("50"),
|
||||
)
|
||||
op.drop_column("discord_users", "track_messages")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"discord_users",
|
||||
sa.Column(
|
||||
"track_messages",
|
||||
sa.BOOLEAN(),
|
||||
server_default=sa.text("true"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"discord_users",
|
||||
"chattiness_threshold",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("50"),
|
||||
)
|
||||
op.add_column(
|
||||
"discord_servers",
|
||||
sa.Column(
|
||||
"track_messages",
|
||||
sa.BOOLEAN(),
|
||||
server_default=sa.text("true"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.drop_index("discord_servers_active_idx", table_name="discord_servers")
|
||||
op.create_index(
|
||||
"discord_servers_active_idx",
|
||||
"discord_servers",
|
||||
["track_messages", "last_sync_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"discord_servers",
|
||||
"chattiness_threshold",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("50"),
|
||||
)
|
||||
op.add_column(
|
||||
"discord_channels",
|
||||
sa.Column(
|
||||
"track_messages",
|
||||
sa.BOOLEAN(),
|
||||
server_default=sa.text("true"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"discord_channels",
|
||||
"chattiness_threshold",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("50"),
|
||||
)
|
||||
@ -1,5 +1,3 @@
|
||||
version: "3.9"
|
||||
|
||||
# --------------------------------------------------------------------- networks
|
||||
networks:
|
||||
kbnet:
|
||||
@ -26,7 +24,7 @@ x-common-env: &env
|
||||
REDIS_HOST: redis
|
||||
REDIS_PORT: 6379
|
||||
REDIS_DB: 0
|
||||
CELERY_BROKER_PASSWORD: ${CELERY_BROKER_PASSWORD}
|
||||
REDIS_PASSWORD: ${REDIS_PASSWORD}
|
||||
QDRANT_HOST: qdrant
|
||||
DB_HOST: postgres
|
||||
DB_PORT: 5432
|
||||
@ -107,11 +105,11 @@ services:
|
||||
image: redis:7.2-alpine
|
||||
restart: unless-stopped
|
||||
networks: [ kbnet ]
|
||||
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${CELERY_BROKER_PASSWORD}"]
|
||||
command: ["redis-server", "--save", "", "--appendonly", "no", "--requirepass", "${REDIS_PASSWORD}"]
|
||||
volumes:
|
||||
- redis_data:/data:rw
|
||||
healthcheck:
|
||||
test: [ "CMD", "redis-cli", "--pass", "${CELERY_BROKER_PASSWORD}", "ping" ]
|
||||
test: [ "CMD", "redis-cli", "--pass", "${REDIS_PASSWORD}", "ping" ]
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
@ -175,7 +173,7 @@ services:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
|
||||
QUEUES: "backup,email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler"
|
||||
|
||||
ingest-hub:
|
||||
<<: *worker-base
|
||||
@ -196,6 +194,22 @@ services:
|
||||
- /var/run/supervisor
|
||||
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
||||
|
||||
# ------------------------------------------------------------ database backups
|
||||
backup:
|
||||
image: postgres:15 # Has pg_dump, wget, curl
|
||||
networks: [kbnet]
|
||||
depends_on: [postgres, qdrant]
|
||||
env_file: [ .env ]
|
||||
environment:
|
||||
<<: *worker-env
|
||||
secrets: [postgres_password]
|
||||
volumes:
|
||||
- ./tools/backup_databases.sh:/backup.sh:ro
|
||||
entrypoint: ["/bin/bash"]
|
||||
command: ["/backup.sh"]
|
||||
profiles: [backup] # Only start when explicitly called
|
||||
security_opt: ["no-new-privileges=true"]
|
||||
|
||||
# ------------------------------------------------------------ watchtower (auto-update)
|
||||
# watchtower:
|
||||
# image: containrrr/watchtower
|
||||
|
||||
@ -16,7 +16,7 @@ RUN apt-get update && apt-get install -y \
|
||||
COPY requirements ./requirements/
|
||||
COPY setup.py ./
|
||||
RUN mkdir src
|
||||
RUN pip install -e ".[common]"
|
||||
RUN pip install -e ".[workers]"
|
||||
|
||||
# Install Python dependencies
|
||||
COPY src/ ./src/
|
||||
@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \
|
||||
git config --global user.name "${GIT_USER_NAME}"
|
||||
|
||||
# Default queues to process
|
||||
ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
|
||||
ENV QUEUES="backup,ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance"
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
ENTRYPOINT ["./entry.sh"]
|
||||
@ -10,3 +10,4 @@ openai==2.3.0
|
||||
# Pin the httpx version, as newer versions break the anthropic client
|
||||
httpx==0.27.0
|
||||
celery[redis,sqs]==5.3.6
|
||||
cryptography==43.0.0
|
||||
2
requirements/requirements-workers.txt
Normal file
2
requirements/requirements-workers.txt
Normal file
@ -0,0 +1,2 @@
|
||||
boto3
|
||||
awscli==1.42.64
|
||||
5
setup.py
5
setup.py
@ -18,6 +18,7 @@ parsers_requires = read_requirements("requirements-parsers.txt")
|
||||
api_requires = read_requirements("requirements-api.txt")
|
||||
dev_requires = read_requirements("requirements-dev.txt")
|
||||
ingesters_requires = read_requirements("requirements-ingesters.txt")
|
||||
workers_requires = read_requirements("requirements-workers.txt")
|
||||
|
||||
setup(
|
||||
name="memory",
|
||||
@ -30,10 +31,12 @@ setup(
|
||||
"common": common_requires + parsers_requires,
|
||||
"dev": dev_requires,
|
||||
"ingesters": common_requires + parsers_requires + ingesters_requires,
|
||||
"workers": common_requires + parsers_requires + workers_requires,
|
||||
"all": api_requires
|
||||
+ common_requires
|
||||
+ dev_requires
|
||||
+ parsers_requires
|
||||
+ ingesters_requires,
|
||||
+ ingesters_requires
|
||||
+ workers_requires,
|
||||
},
|
||||
)
|
||||
|
||||
@ -290,6 +290,7 @@ class ScheduledLLMCallAdmin(ModelView, model=ScheduledLLMCall):
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
column_sortable_list = ["executed_at", "scheduled_time", "created_at", "updated_at"]
|
||||
|
||||
|
||||
def setup_admin(admin: Admin):
|
||||
|
||||
@ -13,6 +13,7 @@ NOTES_ROOT = "memory.workers.tasks.notes"
|
||||
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
|
||||
SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls"
|
||||
DISCORD_ROOT = "memory.workers.tasks.discord"
|
||||
BACKUP_ROOT = "memory.workers.tasks.backup"
|
||||
ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message"
|
||||
EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message"
|
||||
PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message"
|
||||
@ -53,13 +54,25 @@ SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
||||
EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call"
|
||||
RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls"
|
||||
|
||||
# Backup tasks
|
||||
BACKUP_PATH = f"{BACKUP_ROOT}.backup_path"
|
||||
BACKUP_ALL = f"{BACKUP_ROOT}.backup_all"
|
||||
|
||||
|
||||
def get_broker_url() -> str:
|
||||
protocol = settings.CELERY_BROKER_TYPE
|
||||
user = safequote(settings.CELERY_BROKER_USER)
|
||||
password = safequote(settings.CELERY_BROKER_PASSWORD)
|
||||
password = safequote(settings.CELERY_BROKER_PASSWORD or "")
|
||||
host = settings.CELERY_BROKER_HOST
|
||||
return f"{protocol}://{user}:{password}@{host}"
|
||||
|
||||
if password:
|
||||
url = f"{protocol}://{user}:{password}@{host}"
|
||||
else:
|
||||
url = f"{protocol}://{host}"
|
||||
|
||||
if protocol == "redis":
|
||||
url += f"/{settings.REDIS_DB}"
|
||||
return url
|
||||
|
||||
|
||||
app = Celery(
|
||||
@ -91,6 +104,7 @@ app.conf.update(
|
||||
f"{SCHEDULED_CALLS_ROOT}.*": {
|
||||
"queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler"
|
||||
},
|
||||
f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -22,7 +22,6 @@ from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class MessageProcessor:
|
||||
track_messages = Column(Boolean, nullable=False, server_default="true")
|
||||
ignore_messages = Column(Boolean, nullable=True, default=False)
|
||||
|
||||
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
@ -35,8 +34,7 @@ class MessageProcessor:
|
||||
)
|
||||
chattiness_threshold = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=50,
|
||||
nullable=True,
|
||||
doc="The threshold for the bot to continue the conversation, between 0 and 100.",
|
||||
)
|
||||
|
||||
@ -90,7 +88,7 @@ class DiscordServer(Base, MessageProcessor):
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("discord_servers_active_idx", "track_messages", "last_sync_at"),
|
||||
Index("discord_servers_active_idx", "ignore_messages", "last_sync_at"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -138,6 +138,8 @@ class DiscordBotUser(BotUser):
|
||||
email: str,
|
||||
api_key: str | None = None,
|
||||
) -> "DiscordBotUser":
|
||||
if not discord_users:
|
||||
raise ValueError("discord_users must be provided")
|
||||
bot = super().create_with_api_key(name, email, api_key)
|
||||
bot.discord_users = discord_users
|
||||
return bot
|
||||
|
||||
@ -5,9 +5,10 @@ Simple HTTP client that communicates with the Discord collector's API server.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -20,12 +21,12 @@ def get_api_url() -> str:
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def send_dm(user_identifier: str, message: str) -> bool:
|
||||
def send_dm(bot_id: int, user_identifier: str, message: str) -> bool:
|
||||
"""Send a DM via the Discord collector API"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_dm",
|
||||
json={"user": user_identifier, "message": message},
|
||||
json={"bot_id": bot_id, "user": user_identifier, "message": message},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@ -37,12 +38,33 @@ def send_dm(user_identifier: str, message: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def send_to_channel(channel_name: str, message: str) -> bool:
|
||||
def trigger_typing_dm(bot_id: int, user_identifier: int | str) -> bool:
|
||||
"""Trigger typing indicator for a DM via the Discord collector API"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/typing/dm",
|
||||
json={"bot_id": bot_id, "user": user_identifier},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def send_to_channel(bot_id: int, channel_name: str, message: str) -> bool:
|
||||
"""Send a DM via the Discord collector API"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_channel",
|
||||
json={"channel_name": channel_name, "message": message},
|
||||
json={
|
||||
"bot_id": bot_id,
|
||||
"channel_name": channel_name,
|
||||
"message": message,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@ -55,12 +77,33 @@ def send_to_channel(channel_name: str, message: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def broadcast_message(channel_name: str, message: str) -> bool:
|
||||
def trigger_typing_channel(bot_id: int, channel_name: str) -> bool:
|
||||
"""Trigger typing indicator for a channel via the Discord collector API"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/typing/channel",
|
||||
json={"bot_id": bot_id, "channel_name": channel_name},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("success", False)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def broadcast_message(bot_id: int, channel_name: str, message: str) -> bool:
|
||||
"""Send a message to a channel via the Discord collector API"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{get_api_url()}/send_channel",
|
||||
json={"channel_name": channel_name, "message": message},
|
||||
json={
|
||||
"bot_id": bot_id,
|
||||
"channel_name": channel_name,
|
||||
"message": message,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@ -72,19 +115,22 @@ def broadcast_message(channel_name: str, message: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_collector_healthy() -> bool:
|
||||
def is_collector_healthy(bot_id: int) -> bool:
|
||||
"""Check if the Discord collector is running and healthy"""
|
||||
try:
|
||||
response = requests.get(f"{get_api_url()}/health", timeout=5)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("status") == "healthy"
|
||||
bot_status = result.get(str(bot_id))
|
||||
if not isinstance(bot_status, dict):
|
||||
return False
|
||||
return bool(bot_status.get("connected"))
|
||||
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
def refresh_discord_metadata() -> dict[str, int] | None:
|
||||
def refresh_discord_metadata() -> dict[str, Any] | None:
|
||||
"""Refresh Discord server/channel/user metadata from Discord API"""
|
||||
try:
|
||||
response = requests.post(f"{get_api_url()}/refresh_metadata", timeout=30)
|
||||
@ -96,24 +142,24 @@ def refresh_discord_metadata() -> dict[str, int] | None:
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def send_error_message(message: str) -> bool:
|
||||
def send_error_message(bot_id: int, message: str) -> bool:
|
||||
"""Send an error message to the error channel"""
|
||||
return broadcast_message(settings.DISCORD_ERROR_CHANNEL, message)
|
||||
return broadcast_message(bot_id, settings.DISCORD_ERROR_CHANNEL, message)
|
||||
|
||||
|
||||
def send_activity_message(message: str) -> bool:
|
||||
def send_activity_message(bot_id: int, message: str) -> bool:
|
||||
"""Send an activity message to the activity channel"""
|
||||
return broadcast_message(settings.DISCORD_ACTIVITY_CHANNEL, message)
|
||||
return broadcast_message(bot_id, settings.DISCORD_ACTIVITY_CHANNEL, message)
|
||||
|
||||
|
||||
def send_discovery_message(message: str) -> bool:
|
||||
def send_discovery_message(bot_id: int, message: str) -> bool:
|
||||
"""Send a discovery message to the discovery channel"""
|
||||
return broadcast_message(settings.DISCORD_DISCOVERY_CHANNEL, message)
|
||||
return broadcast_message(bot_id, settings.DISCORD_DISCOVERY_CHANNEL, message)
|
||||
|
||||
|
||||
def send_chat_message(message: str) -> bool:
|
||||
def send_chat_message(bot_id: int, message: str) -> bool:
|
||||
"""Send a chat message to the chat channel"""
|
||||
return broadcast_message(settings.DISCORD_CHAT_CHANNEL, message)
|
||||
return broadcast_message(bot_id, settings.DISCORD_CHAT_CHANNEL, message)
|
||||
|
||||
|
||||
def notify_task_failure(
|
||||
@ -122,6 +168,7 @@ def notify_task_failure(
|
||||
task_args: tuple = (),
|
||||
task_kwargs: dict[str, Any] | None = None,
|
||||
traceback_str: str | None = None,
|
||||
bot_id: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send a task failure notification to Discord.
|
||||
@ -137,6 +184,15 @@ def notify_task_failure(
|
||||
logger.debug("Discord notifications disabled")
|
||||
return
|
||||
|
||||
if bot_id is None:
|
||||
bot_id = settings.DISCORD_BOT_ID
|
||||
|
||||
if not bot_id:
|
||||
logger.debug(
|
||||
"No Discord bot ID provided for task failure notification; skipping"
|
||||
)
|
||||
return
|
||||
|
||||
message = f"🚨 **Task Failed: {task_name}**\n\n"
|
||||
message += f"**Error:** {error_message[:500]}\n"
|
||||
|
||||
@ -150,7 +206,7 @@ def notify_task_failure(
|
||||
message += f"**Traceback:**\n```\n{traceback_str[-800:]}\n```"
|
||||
|
||||
try:
|
||||
send_error_message(message)
|
||||
send_error_message(bot_id, message)
|
||||
logger.info(f"Discord error notification sent for task: {task_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Discord notification: {e}")
|
||||
|
||||
@ -24,13 +24,6 @@ from memory.common.llms.base import (
|
||||
)
|
||||
from memory.common.llms.anthropic_provider import AnthropicProvider
|
||||
from memory.common.llms.openai_provider import OpenAIProvider
|
||||
from memory.common.llms.usage_tracker import (
|
||||
InMemoryUsageTracker,
|
||||
RateLimitConfig,
|
||||
TokenAllowance,
|
||||
UsageBreakdown,
|
||||
UsageTracker,
|
||||
)
|
||||
from memory.common import tokens
|
||||
|
||||
__all__ = [
|
||||
@ -49,11 +42,6 @@ __all__ = [
|
||||
"StreamEvent",
|
||||
"LLMSettings",
|
||||
"create_provider",
|
||||
"InMemoryUsageTracker",
|
||||
"RateLimitConfig",
|
||||
"TokenAllowance",
|
||||
"UsageBreakdown",
|
||||
"UsageTracker",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -93,28 +81,3 @@ def truncate(content: str, target_tokens: int) -> str:
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
||||
|
||||
|
||||
# bla = 1
|
||||
# from memory.common.llms import *
|
||||
# from memory.common.llms.tools.discord import make_discord_tools
|
||||
# from memory.common.db.connection import make_session
|
||||
# from memory.common.db.models import *
|
||||
|
||||
# model = "anthropic/claude-sonnet-4-5"
|
||||
# provider = create_provider(model=model)
|
||||
# with make_session() as session:
|
||||
# bot = session.query(DiscordBotUser).first()
|
||||
# server = session.query(DiscordServer).first()
|
||||
# channel = server.channels[0]
|
||||
# tools = make_discord_tools(bot, None, channel, model)
|
||||
|
||||
# def demo(msg: str):
|
||||
# messages = [
|
||||
# Message(
|
||||
# role=MessageRole.USER,
|
||||
# content=msg,
|
||||
# )
|
||||
# ]
|
||||
# for m in provider.stream_with_tools(messages, tools):
|
||||
# print(m)
|
||||
|
||||
@ -23,6 +23,8 @@ logger = logging.getLogger(__name__)
|
||||
class AnthropicProvider(BaseLLMProvider):
|
||||
"""Anthropic LLM provider with streaming, tool support, and extended thinking."""
|
||||
|
||||
provider = "anthropic"
|
||||
|
||||
# Models that support extended thinking
|
||||
THINKING_MODELS = {
|
||||
"claude-opus-4",
|
||||
@ -262,7 +264,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
Usage(
|
||||
input_tokens=usage.input_tokens,
|
||||
output_tokens=usage.output_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
total_tokens=usage.input_tokens + usage.output_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from PIL import Image
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.llms.tools import ToolCall, ToolDefinition, ToolResult
|
||||
from memory.common.llms.usage import UsageTracker, RedisUsageTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -204,7 +205,11 @@ class LLMSettings:
|
||||
class BaseLLMProvider(ABC):
|
||||
"""Base class for LLM providers."""
|
||||
|
||||
def __init__(self, api_key: str, model: str):
|
||||
provider: str = ""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model: str, usage_tracker: UsageTracker | None = None
|
||||
):
|
||||
"""
|
||||
Initialize the LLM provider.
|
||||
|
||||
@ -215,6 +220,7 @@ class BaseLLMProvider(ABC):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self._client: Any = None
|
||||
self.usage_tracker: UsageTracker = usage_tracker or RedisUsageTracker()
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_client(self) -> Any:
|
||||
@ -230,8 +236,14 @@ class BaseLLMProvider(ABC):
|
||||
|
||||
def log_usage(self, usage: Usage):
|
||||
"""Log usage data."""
|
||||
logger.debug(f"Token usage: {usage.to_dict()}")
|
||||
print(f"Token usage: {usage.to_dict()}")
|
||||
logger.debug(
|
||||
f"Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.total_tokens} total"
|
||||
)
|
||||
self.usage_tracker.record_usage(
|
||||
model=f"{self.provider}/{self.model}",
|
||||
input_tokens=usage.input_tokens,
|
||||
output_tokens=usage.output_tokens,
|
||||
)
|
||||
|
||||
def execute_tool(
|
||||
self,
|
||||
|
||||
@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
|
||||
class OpenAIProvider(BaseLLMProvider):
|
||||
"""OpenAI LLM provider with streaming and tool support."""
|
||||
|
||||
provider = "openai"
|
||||
|
||||
# Models that use max_completion_tokens instead of max_tokens
|
||||
# These are reasoning models with different parameter requirements
|
||||
NON_REASONING_MODELS = {"gpt-4o"}
|
||||
|
||||
17
src/memory/common/llms/usage/__init__.py
Normal file
17
src/memory/common/llms/usage/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from memory.common.llms.usage.redis_usage_tracker import RedisUsageTracker
|
||||
from memory.common.llms.usage.usage_tracker import (
|
||||
InMemoryUsageTracker,
|
||||
RateLimitConfig,
|
||||
TokenAllowance,
|
||||
UsageBreakdown,
|
||||
UsageTracker,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"InMemoryUsageTracker",
|
||||
"RateLimitConfig",
|
||||
"RedisUsageTracker",
|
||||
"TokenAllowance",
|
||||
"UsageBreakdown",
|
||||
"UsageTracker",
|
||||
]
|
||||
83
src/memory/common/llms/usage/redis_usage_tracker.py
Normal file
83
src/memory/common/llms/usage/redis_usage_tracker.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""Redis-backed usage tracker implementation."""
|
||||
|
||||
import json
|
||||
from typing import Any, Iterable, Protocol
|
||||
|
||||
import redis
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.llms.usage.usage_tracker import (
|
||||
RateLimitConfig,
|
||||
UsageState,
|
||||
UsageTracker,
|
||||
)
|
||||
|
||||
|
||||
class RedisClientProtocol(Protocol):
|
||||
def get(self, key: str) -> Any: # pragma: no cover - Protocol definition
|
||||
...
|
||||
|
||||
def set(
|
||||
self, key: str, value: Any
|
||||
) -> Any: # pragma: no cover - Protocol definition
|
||||
...
|
||||
|
||||
def scan_iter(
|
||||
self, match: str
|
||||
) -> Iterable[Any]: # pragma: no cover - Protocol definition
|
||||
...
|
||||
|
||||
|
||||
class RedisUsageTracker(UsageTracker):
|
||||
"""Tracks LLM usage for providers and models using Redis for persistence."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configs: dict[str, RateLimitConfig] | None = None,
|
||||
default_config: RateLimitConfig | None = None,
|
||||
*,
|
||||
redis_client: RedisClientProtocol | None = None,
|
||||
key_prefix: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(configs=configs, default_config=default_config)
|
||||
if redis_client is None:
|
||||
redis_client = redis.Redis.from_url(settings.REDIS_URL)
|
||||
self._redis = redis_client
|
||||
prefix = key_prefix or settings.LLM_USAGE_REDIS_PREFIX
|
||||
self._key_prefix = prefix.rstrip(":")
|
||||
|
||||
def get_state(self, model: str) -> UsageState:
|
||||
redis_key = self._format_key(model)
|
||||
payload = self._redis.get(redis_key)
|
||||
if not payload:
|
||||
return UsageState()
|
||||
if isinstance(payload, bytes):
|
||||
payload = payload.decode()
|
||||
return UsageState.from_payload(json.loads(payload))
|
||||
|
||||
def iter_state_items(self) -> Iterable[tuple[str, UsageState]]:
|
||||
pattern = f"{self._key_prefix}:*"
|
||||
for redis_key in self._redis.scan_iter(match=pattern):
|
||||
key = self._ensure_str(redis_key)
|
||||
payload = self._redis.get(key)
|
||||
if not payload:
|
||||
continue
|
||||
if isinstance(payload, bytes):
|
||||
payload = payload.decode()
|
||||
state = UsageState.from_payload(json.loads(payload))
|
||||
yield key[len(self._key_prefix) + 1 :], state
|
||||
|
||||
def save_state(self, model: str, state: UsageState) -> None:
|
||||
redis_key = self._format_key(model)
|
||||
self._redis.set(
|
||||
redis_key, json.dumps(state.to_payload(), separators=(",", ":"))
|
||||
)
|
||||
|
||||
def _format_key(self, model: str) -> str:
|
||||
return f"{self._key_prefix}:{model}"
|
||||
|
||||
@staticmethod
|
||||
def _ensure_str(value: Any) -> str:
|
||||
if isinstance(value, bytes):
|
||||
return value.decode()
|
||||
return str(value)
|
||||
@ -5,6 +5,9 @@ from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -45,6 +48,40 @@ class UsageState:
|
||||
lifetime_input_tokens: int = 0
|
||||
lifetime_output_tokens: int = 0
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"events": [
|
||||
{
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"input_tokens": event.input_tokens,
|
||||
"output_tokens": event.output_tokens,
|
||||
}
|
||||
for event in self.events
|
||||
],
|
||||
"window_input_tokens": self.window_input_tokens,
|
||||
"window_output_tokens": self.window_output_tokens,
|
||||
"lifetime_input_tokens": self.lifetime_input_tokens,
|
||||
"lifetime_output_tokens": self.lifetime_output_tokens,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> "UsageState":
|
||||
events = deque(
|
||||
UsageEvent(
|
||||
timestamp=datetime.fromisoformat(event["timestamp"]),
|
||||
input_tokens=event["input_tokens"],
|
||||
output_tokens=event["output_tokens"],
|
||||
)
|
||||
for event in payload.get("events", [])
|
||||
)
|
||||
return cls(
|
||||
events=events,
|
||||
window_input_tokens=payload.get("window_input_tokens", 0),
|
||||
window_output_tokens=payload.get("window_output_tokens", 0),
|
||||
lifetime_input_tokens=payload.get("lifetime_input_tokens", 0),
|
||||
lifetime_output_tokens=payload.get("lifetime_output_tokens", 0),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenAllowance:
|
||||
@ -77,13 +114,13 @@ class UsageBreakdown:
|
||||
def split_model_key(model: str) -> tuple[str, str]:
|
||||
if "/" not in model:
|
||||
raise ValueError(
|
||||
"model must be formatted as '<provider>/<model_name>'"
|
||||
f"model must be formatted as '<provider>/<model_name>': got '{model}'"
|
||||
)
|
||||
|
||||
provider, model_name = model.split("/", maxsplit=1)
|
||||
if not provider or not model_name:
|
||||
raise ValueError(
|
||||
"model must include both provider and model name separated by '/'"
|
||||
f"model must include both provider and model name separated by '/': got '{model}'"
|
||||
)
|
||||
return provider, model_name
|
||||
|
||||
@ -93,11 +130,15 @@ class UsageTracker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configs: dict[str, RateLimitConfig],
|
||||
configs: dict[str, RateLimitConfig] | None = None,
|
||||
default_config: RateLimitConfig | None = None,
|
||||
) -> None:
|
||||
self._configs = configs
|
||||
self._default_config = default_config
|
||||
self._configs = configs or {}
|
||||
self._default_config = default_config or RateLimitConfig(
|
||||
window=timedelta(minutes=settings.DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES),
|
||||
max_input_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS,
|
||||
max_output_tokens=settings.DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS,
|
||||
)
|
||||
self._lock = Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@ -180,15 +221,14 @@ class UsageTracker:
|
||||
"""
|
||||
|
||||
split_model_key(model)
|
||||
key = model
|
||||
with self._lock:
|
||||
config = self._get_config(key)
|
||||
config = self._get_config(model)
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
state = self.get_state(key)
|
||||
state = self.get_state(model)
|
||||
self._prune_expired_events(state, config, now=timestamp)
|
||||
self.save_state(key, state)
|
||||
self.save_state(model, state)
|
||||
|
||||
if config.max_total_tokens is None:
|
||||
total_remaining = None
|
||||
@ -205,9 +245,7 @@ class UsageTracker:
|
||||
if config.max_output_tokens is None:
|
||||
output_remaining = None
|
||||
else:
|
||||
output_remaining = (
|
||||
config.max_output_tokens - state.window_output_tokens
|
||||
)
|
||||
output_remaining = config.max_output_tokens - state.window_output_tokens
|
||||
|
||||
return TokenAllowance(
|
||||
input_tokens=clamp_non_negative(input_remaining),
|
||||
@ -222,8 +260,8 @@ class UsageTracker:
|
||||
|
||||
with self._lock:
|
||||
providers: dict[str, dict[str, UsageBreakdown]] = defaultdict(dict)
|
||||
for key, state in self.iter_state_items():
|
||||
prov, model_name = split_model_key(key)
|
||||
for model, state in self.iter_state_items():
|
||||
prov, model_name = split_model_key(model)
|
||||
if provider and provider != prov:
|
||||
continue
|
||||
if model and model != model_name:
|
||||
@ -265,8 +303,8 @@ class UsageTracker:
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _get_config(self, key: str) -> RateLimitConfig | None:
|
||||
return self._configs.get(key) or self._default_config
|
||||
def _get_config(self, model: str) -> RateLimitConfig | None:
|
||||
return self._configs.get(model) or self._default_config
|
||||
|
||||
def _prune_expired_events(
|
||||
self,
|
||||
@ -313,4 +351,3 @@ def clamp_non_negative(value: int | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return 0 if value < 0 else value
|
||||
|
||||
@ -31,33 +31,25 @@ def make_db_url(
|
||||
|
||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||
|
||||
# Redis settings
|
||||
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
||||
REDIS_DB = os.getenv("REDIS_DB", "0")
|
||||
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
if REDIS_PASSWORD:
|
||||
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||
else:
|
||||
REDIS_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||
|
||||
# Broker settings
|
||||
CELERY_QUEUE_PREFIX = os.getenv("CELERY_QUEUE_PREFIX", "memory")
|
||||
CELERY_BROKER_TYPE = os.getenv("CELERY_BROKER_TYPE", "redis").lower()
|
||||
REDIS_HOST = os.getenv("REDIS_HOST", "redis")
|
||||
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
|
||||
REDIS_DB = os.getenv("REDIS_DB", "0")
|
||||
CELERY_BROKER_USER = os.getenv(
|
||||
"CELERY_BROKER_USER", "kb" if CELERY_BROKER_TYPE == "amqp" else ""
|
||||
)
|
||||
CELERY_BROKER_PASSWORD = os.getenv(
|
||||
"CELERY_BROKER_PASSWORD", "" if CELERY_BROKER_TYPE == "redis" else "kb"
|
||||
)
|
||||
|
||||
|
||||
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "")
|
||||
if not CELERY_BROKER_HOST:
|
||||
if CELERY_BROKER_TYPE == "amqp":
|
||||
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
||||
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
|
||||
CELERY_BROKER_HOST = f"{RABBITMQ_HOST}:{RABBITMQ_PORT}//"
|
||||
elif CELERY_BROKER_TYPE == "redis":
|
||||
CELERY_BROKER_HOST = f"{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||
CELERY_BROKER_USER = os.getenv("CELERY_BROKER_USER", "")
|
||||
CELERY_BROKER_PASSWORD = os.getenv("CELERY_BROKER_PASSWORD", REDIS_PASSWORD)
|
||||
|
||||
CELERY_BROKER_HOST = os.getenv("CELERY_BROKER_HOST", "") or f"{REDIS_HOST}:{REDIS_PORT}"
|
||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||
|
||||
|
||||
# File storage settings
|
||||
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
|
||||
EBOOK_STORAGE_DIR = pathlib.Path(
|
||||
@ -81,9 +73,14 @@ WEBPAGE_STORAGE_DIR = pathlib.Path(
|
||||
NOTES_STORAGE_DIR = pathlib.Path(
|
||||
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
|
||||
)
|
||||
PRIVATE_DIRS = [
|
||||
EMAIL_STORAGE_DIR,
|
||||
NOTES_STORAGE_DIR,
|
||||
PHOTO_STORAGE_DIR,
|
||||
CHUNK_STORAGE_DIR,
|
||||
]
|
||||
|
||||
storage_dirs = [
|
||||
FILE_STORAGE_DIR,
|
||||
EBOOK_STORAGE_DIR,
|
||||
EMAIL_STORAGE_DIR,
|
||||
CHUNK_STORAGE_DIR,
|
||||
@ -148,6 +145,18 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-haiku-4-5")
|
||||
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||
|
||||
DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES = int(
|
||||
os.getenv("DEFAULT_LLM_RATE_LIMIT_WINDOW_MINUTES", 30)
|
||||
)
|
||||
DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS = int(
|
||||
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_INPUT_TOKENS", 1_000_000)
|
||||
)
|
||||
DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS = int(
|
||||
os.getenv("DEFAULT_LLM_RATE_LIMIT_MAX_OUTPUT_TOKENS", 1_000_000)
|
||||
)
|
||||
LLM_USAGE_REDIS_PREFIX = os.getenv("LLM_USAGE_REDIS_PREFIX", "llm_usage")
|
||||
|
||||
|
||||
# Search settings
|
||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||
@ -193,3 +202,14 @@ DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True)
|
||||
DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003))
|
||||
DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0")
|
||||
DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10))
|
||||
|
||||
|
||||
# S3 Backup settings
|
||||
S3_BACKUP_BUCKET = os.getenv("S3_BACKUP_BUCKET", "equistamp-memory-backup")
|
||||
S3_BACKUP_PREFIX = os.getenv("S3_BACKUP_PREFIX", "Daniel")
|
||||
S3_BACKUP_REGION = os.getenv("S3_BACKUP_REGION", "eu-central-1")
|
||||
BACKUP_ENCRYPTION_KEY = os.getenv("BACKUP_ENCRYPTION_KEY", "")
|
||||
S3_BACKUP_ENABLED = boolean_env("S3_BACKUP_ENABLED", bool(BACKUP_ENCRYPTION_KEY))
|
||||
S3_BACKUP_INTERVAL = int(
|
||||
os.getenv("S3_BACKUP_INTERVAL", 60 * 60 * 24)
|
||||
) # Daily by default
|
||||
|
||||
@ -7,17 +7,18 @@ providing HTTP endpoints for sending Discord messages.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
import traceback
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import cast
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
from memory.common import settings
|
||||
from memory.discord.collector import MessageCollector
|
||||
from memory.common.db.models.users import BotUser
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models.users import DiscordBotUser
|
||||
from memory.discord.collector import MessageCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -34,6 +35,16 @@ class SendChannelRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class TypingDMRequest(BaseModel):
|
||||
bot_id: int
|
||||
user: int | str
|
||||
|
||||
|
||||
class TypingChannelRequest(BaseModel):
|
||||
bot_id: int
|
||||
channel_name: str
|
||||
|
||||
|
||||
class Collector:
|
||||
collector: MessageCollector
|
||||
collector_task: asyncio.Task
|
||||
@ -41,37 +52,25 @@ class Collector:
|
||||
bot_token: str
|
||||
bot_name: str
|
||||
|
||||
def __init__(self, collector: MessageCollector, bot: BotUser):
|
||||
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
||||
self.collector = collector
|
||||
self.collector_task = asyncio.create_task(collector.start(bot.api_key))
|
||||
self.bot_id = bot.id
|
||||
self.bot_token = bot.api_key
|
||||
self.bot_name = bot.name
|
||||
|
||||
|
||||
# Application state
|
||||
class AppState:
|
||||
def __init__(self):
|
||||
self.collector: MessageCollector | None = None
|
||||
self.collector_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
app_state = AppState()
|
||||
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
||||
self.bot_id = cast(int, bot.id)
|
||||
self.bot_token = str(bot.api_key)
|
||||
self.bot_name = str(bot.name)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage Discord collector lifecycle"""
|
||||
if not settings.DISCORD_BOT_TOKEN:
|
||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||
return
|
||||
|
||||
def make_collector(bot: BotUser):
|
||||
def make_collector(bot: DiscordBotUser):
|
||||
collector = MessageCollector()
|
||||
return Collector(collector=collector, bot=bot)
|
||||
|
||||
with make_session() as session:
|
||||
app.bots = {bot.id: make_collector(bot) for bot in session.query(BotUser).all()}
|
||||
bots = session.query(DiscordBotUser).all()
|
||||
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
||||
|
||||
logger.info(f"Discord collectors started for {len(app.bots)} bots")
|
||||
|
||||
@ -120,6 +119,32 @@ async def send_dm_endpoint(request: SendDMRequest):
|
||||
}
|
||||
|
||||
|
||||
@app.post("/typing/dm")
|
||||
async def trigger_dm_typing(request: TypingDMRequest):
|
||||
"""Trigger a typing indicator for a DM via the collector"""
|
||||
collector = app.bots.get(request.bot_id)
|
||||
if not collector:
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
|
||||
try:
|
||||
success = await collector.collector.trigger_typing_dm(request.user)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger DM typing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to trigger typing for {request.user}",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"user": request.user,
|
||||
"message": f"Typing triggered for {request.user}",
|
||||
}
|
||||
|
||||
|
||||
@app.post("/send_channel")
|
||||
async def send_channel_endpoint(request: SendChannelRequest):
|
||||
"""Send a message to a channel via the collector's Discord client"""
|
||||
@ -131,23 +156,48 @@ async def send_channel_endpoint(request: SendChannelRequest):
|
||||
success = await collector.collector.send_to_channel(
|
||||
request.channel_name, request.message
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Message sent to channel {request.channel_name}",
|
||||
"channel": request.channel_name,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to send message to channel {request.channel_name}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send channel message: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Message sent to channel {request.channel_name}",
|
||||
"channel": request.channel_name,
|
||||
}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to send message to channel {request.channel_name}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/typing/channel")
|
||||
async def trigger_channel_typing(request: TypingChannelRequest):
|
||||
"""Trigger a typing indicator for a channel via the collector"""
|
||||
collector = app.bots.get(request.bot_id)
|
||||
if not collector:
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
|
||||
try:
|
||||
success = await collector.collector.trigger_typing_channel(request.channel_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger channel typing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to trigger typing for channel {request.channel_name}",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"channel": request.channel_name,
|
||||
"message": f"Typing triggered for channel {request.channel_name}",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
@ -155,9 +205,8 @@ async def health_check():
|
||||
if not app.bots:
|
||||
raise HTTPException(status_code=503, detail="Discord collector not running")
|
||||
|
||||
collector = app_state.collector
|
||||
return {
|
||||
collector.bot_name: {
|
||||
bot.bot_name: {
|
||||
"status": "healthy",
|
||||
"connected": not bot.collector.is_closed(),
|
||||
"user": str(bot.collector.user) if bot.collector.user else None,
|
||||
|
||||
@ -203,7 +203,7 @@ class MessageCollector(commands.Bot):
|
||||
async def setup_hook(self):
|
||||
"""Register slash commands when the bot is ready."""
|
||||
|
||||
register_slash_commands(self)
|
||||
register_slash_commands(self, name=self.user.name)
|
||||
|
||||
async def on_ready(self):
|
||||
"""Called when bot connects to Discord"""
|
||||
@ -381,6 +381,27 @@ class MessageCollector(commands.Bot):
|
||||
logger.error(f"Failed to send DM to {user_identifier}: {e}")
|
||||
return False
|
||||
|
||||
async def trigger_typing_dm(self, user_identifier: int | str) -> bool:
|
||||
"""Trigger typing indicator in a DM"""
|
||||
try:
|
||||
user = await self.get_user(user_identifier)
|
||||
if not user:
|
||||
logger.error(f"User {user_identifier} not found")
|
||||
return False
|
||||
|
||||
channel = user.dm_channel or await user.create_dm()
|
||||
if not channel:
|
||||
logger.error(f"DM channel not available for {user_identifier}")
|
||||
return False
|
||||
|
||||
async with channel.typing():
|
||||
pass
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger DM typing for {user_identifier}: {e}")
|
||||
return False
|
||||
|
||||
async def send_to_channel(self, channel_name: str, message: str) -> bool:
|
||||
"""Send a message to a channel by name across all guilds"""
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
@ -400,23 +421,21 @@ class MessageCollector(commands.Bot):
|
||||
logger.error(f"Failed to send message to channel {channel_name}: {e}")
|
||||
return False
|
||||
|
||||
async def trigger_typing_channel(self, channel_name: str) -> bool:
|
||||
"""Trigger typing indicator in a channel"""
|
||||
if not settings.DISCORD_NOTIFICATIONS_ENABLED:
|
||||
return False
|
||||
|
||||
async def run_collector():
|
||||
"""Run the Discord message collector"""
|
||||
if not settings.DISCORD_BOT_TOKEN:
|
||||
logger.error("DISCORD_BOT_TOKEN not configured")
|
||||
return
|
||||
try:
|
||||
channel = await self.get_channel_by_name(channel_name)
|
||||
if not channel:
|
||||
logger.error(f"Channel {channel_name} not found")
|
||||
return False
|
||||
|
||||
collector = MessageCollector()
|
||||
async with channel.typing():
|
||||
pass
|
||||
return True
|
||||
|
||||
try:
|
||||
await collector.start(settings.DISCORD_BOT_TOKEN)
|
||||
except Exception as e:
|
||||
logger.error(f"Discord collector failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run_collector())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger typing for channel {channel_name}: {e}")
|
||||
return False
|
||||
|
||||
@ -41,8 +41,13 @@ class CommandContext:
|
||||
CommandHandler = Callable[..., CommandResponse]
|
||||
|
||||
|
||||
def register_slash_commands(bot: discord.Client) -> None:
|
||||
"""Register the collector slash commands on the provided bot."""
|
||||
def register_slash_commands(bot: discord.Client, name: str = "memory") -> None:
|
||||
"""Register the collector slash commands on the provided bot.
|
||||
|
||||
Args:
|
||||
bot: Discord bot client
|
||||
name: Prefix for command names (e.g., "memory" creates "memory_prompt")
|
||||
"""
|
||||
|
||||
if getattr(bot, "_memory_commands_registered", False):
|
||||
return
|
||||
@ -54,12 +59,14 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
|
||||
tree = bot.tree
|
||||
|
||||
@tree.command(name="memory_prompt", description="Show the current system prompt")
|
||||
@tree.command(
|
||||
name=f"{name}_show_prompt", description="Show the current system prompt"
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def prompt_command(
|
||||
async def show_prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
user: discord.User | None = None,
|
||||
@ -72,12 +79,35 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name="memory_chattiness",
|
||||
description="Show or update the chattiness threshold for the target",
|
||||
name=f"{name}_set_prompt",
|
||||
description="Set the system prompt for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to modify",
|
||||
prompt="The system prompt to set",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def set_prompt_command(
|
||||
interaction: discord.Interaction,
|
||||
scope: ScopeLiteral,
|
||||
prompt: str,
|
||||
user: discord.User | None = None,
|
||||
) -> None:
|
||||
await _run_interaction_command(
|
||||
interaction,
|
||||
scope=scope,
|
||||
handler=handle_set_prompt,
|
||||
target_user=user,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name=f"{name}_chattiness",
|
||||
description="Show or update the chattiness for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
value="Optional new threshold value between 0 and 100",
|
||||
value="Optional new chattiness value between 0 and 100",
|
||||
user="Target user when the scope is 'user'",
|
||||
)
|
||||
async def chattiness_command(
|
||||
@ -95,7 +125,7 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
)
|
||||
|
||||
@tree.command(
|
||||
name="memory_ignore",
|
||||
name=f"{name}_ignore",
|
||||
description="Toggle whether the bot should ignore messages for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
@ -117,7 +147,10 @@ def register_slash_commands(bot: discord.Client) -> None:
|
||||
ignore_enabled=enabled,
|
||||
)
|
||||
|
||||
@tree.command(name="memory_summary", description="Show the stored summary for the target")
|
||||
@tree.command(
|
||||
name=f"{name}_show_summary",
|
||||
description="Show the stored summary for the target",
|
||||
)
|
||||
@discord.app_commands.describe(
|
||||
scope="Which configuration to inspect",
|
||||
user="Target user when the scope is 'user'",
|
||||
@ -337,6 +370,18 @@ def handle_prompt(context: CommandContext) -> CommandResponse:
|
||||
)
|
||||
|
||||
|
||||
def handle_set_prompt(
|
||||
context: CommandContext,
|
||||
*,
|
||||
prompt: str,
|
||||
) -> CommandResponse:
|
||||
setattr(context.target, "system_prompt", prompt)
|
||||
|
||||
return CommandResponse(
|
||||
content=f"Updated system prompt for {context.display_name}.",
|
||||
)
|
||||
|
||||
|
||||
def handle_chattiness(
|
||||
context: CommandContext,
|
||||
*,
|
||||
@ -347,20 +392,22 @@ def handle_chattiness(
|
||||
if value is None:
|
||||
return CommandResponse(
|
||||
content=(
|
||||
f"Chattiness threshold for {context.display_name}: "
|
||||
f"Chattiness for {context.display_name}: "
|
||||
f"{getattr(model, 'chattiness_threshold', 'not set')}"
|
||||
)
|
||||
)
|
||||
|
||||
if not 0 <= value <= 100:
|
||||
raise CommandError("Chattiness threshold must be between 0 and 100.")
|
||||
raise CommandError("Chattiness must be between 0 and 100.")
|
||||
|
||||
setattr(model, "chattiness_threshold", value)
|
||||
|
||||
return CommandResponse(
|
||||
content=(
|
||||
f"Updated chattiness threshold for {context.display_name} "
|
||||
f"to {value}."
|
||||
f"Updated chattiness for {context.display_name} to {value}."
|
||||
"\n"
|
||||
"This can be treated as how much you want the bot to pipe up by itself, as a percentage, "
|
||||
"where 0 is never and 100 is always."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from memory.common.celery_app import (
|
||||
TRACK_GIT_CHANGES,
|
||||
SYNC_LESSWRONG,
|
||||
RUN_SCHEDULED_CALLS,
|
||||
BACKUP_ALL,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -48,4 +49,8 @@ app.conf.beat_schedule = {
|
||||
"task": RUN_SCHEDULED_CALLS,
|
||||
"schedule": settings.SCHEDULED_CALL_RUN_INTERVAL,
|
||||
},
|
||||
"backup-all": {
|
||||
"task": BACKUP_ALL,
|
||||
"schedule": settings.S3_BACKUP_INTERVAL,
|
||||
},
|
||||
}
|
||||
|
||||
@ -3,11 +3,12 @@ Import sub-modules so Celery can register their @app.task decorators.
|
||||
"""
|
||||
|
||||
from memory.workers.tasks import (
|
||||
email,
|
||||
comic,
|
||||
backup,
|
||||
blogs,
|
||||
comic,
|
||||
discord,
|
||||
ebook,
|
||||
email,
|
||||
forums,
|
||||
maintenance,
|
||||
notes,
|
||||
@ -15,8 +16,8 @@ from memory.workers.tasks import (
|
||||
scheduled_calls,
|
||||
) # noqa
|
||||
|
||||
|
||||
__all__ = [
|
||||
"backup",
|
||||
"email",
|
||||
"comic",
|
||||
"blogs",
|
||||
|
||||
152
src/memory/workers/tasks/backup.py
Normal file
152
src/memory/workers/tasks/backup.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""S3 backup tasks for memory files."""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
import subprocess
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
|
||||
import boto3
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.celery_app import app, BACKUP_PATH, BACKUP_ALL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_cipher() -> Fernet:
|
||||
"""Create Fernet cipher from password in settings."""
|
||||
if not settings.BACKUP_ENCRYPTION_KEY:
|
||||
raise ValueError("BACKUP_ENCRYPTION_KEY not set in environment")
|
||||
|
||||
# Derive key from password using SHA256
|
||||
key_bytes = hashlib.sha256(settings.BACKUP_ENCRYPTION_KEY.encode()).digest()
|
||||
key = base64.urlsafe_b64encode(key_bytes)
|
||||
return Fernet(key)
|
||||
|
||||
|
||||
def create_tarball(directory: Path) -> bytes:
|
||||
"""Create a gzipped tarball of a directory in memory."""
|
||||
if not directory.exists():
|
||||
logger.warning(f"Directory does not exist: {directory}")
|
||||
return b""
|
||||
|
||||
tar_buffer = io.BytesIO()
|
||||
with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar:
|
||||
tar.add(directory, arcname=directory.name)
|
||||
|
||||
tar_buffer.seek(0)
|
||||
return tar_buffer.read()
|
||||
|
||||
|
||||
def sync_unencrypted_directory(path: Path) -> dict:
|
||||
"""Sync an unencrypted directory to S3 using aws s3 sync."""
|
||||
if not path.exists():
|
||||
logger.warning(f"Directory does not exist: {path}")
|
||||
return {"synced": False, "reason": "directory_not_found"}
|
||||
|
||||
s3_uri = f"s3://{settings.S3_BACKUP_BUCKET}/{settings.S3_BACKUP_PREFIX}/{path.name}"
|
||||
|
||||
cmd = [
|
||||
"aws",
|
||||
"s3",
|
||||
"sync",
|
||||
str(path),
|
||||
s3_uri,
|
||||
"--delete",
|
||||
"--region",
|
||||
settings.S3_BACKUP_REGION,
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
logger.info(f"Synced {path} to {s3_uri}")
|
||||
logger.debug(f"Output: {result.stdout}")
|
||||
return {"synced": True, "directory": path, "s3_uri": s3_uri}
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Failed to sync {path}: {e.stderr}")
|
||||
return {"synced": False, "directory": path, "error": str(e)}
|
||||
|
||||
|
||||
def backup_encrypted_directory(path: Path) -> dict:
|
||||
"""Create encrypted tarball of directory and upload to S3."""
|
||||
if not path.exists():
|
||||
logger.warning(f"Directory does not exist: {path}")
|
||||
return {"uploaded": False, "reason": "directory_not_found"}
|
||||
|
||||
# Create tarball
|
||||
logger.info(f"Creating tarball of {path}...")
|
||||
tarball_bytes = create_tarball(path)
|
||||
|
||||
if not tarball_bytes:
|
||||
logger.warning(f"Empty tarball for {path}, skipping")
|
||||
return {"uploaded": False, "reason": "empty_directory"}
|
||||
|
||||
# Encrypt
|
||||
logger.info(f"Encrypting {path} ({len(tarball_bytes)} bytes)...")
|
||||
cipher = get_cipher()
|
||||
encrypted_bytes = cipher.encrypt(tarball_bytes)
|
||||
|
||||
# Upload to S3
|
||||
s3_client = boto3.client("s3", region_name=settings.S3_BACKUP_REGION)
|
||||
s3_key = f"{settings.S3_BACKUP_PREFIX}/{path.name}.tar.gz.enc"
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Uploading encrypted {path} to s3://{settings.S3_BACKUP_BUCKET}/{s3_key}"
|
||||
)
|
||||
s3_client.put_object(
|
||||
Bucket=settings.S3_BACKUP_BUCKET,
|
||||
Key=s3_key,
|
||||
Body=encrypted_bytes,
|
||||
ServerSideEncryption="AES256",
|
||||
)
|
||||
return {
|
||||
"uploaded": True,
|
||||
"directory": path,
|
||||
"size_bytes": len(encrypted_bytes),
|
||||
"s3_key": s3_key,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload {path}: {e}")
|
||||
return {"uploaded": False, "directory": path, "error": str(e)}
|
||||
|
||||
|
||||
@app.task(name=BACKUP_PATH)
|
||||
def backup_to_s3(path: Path | str):
|
||||
"""Backup a specific directory to S3."""
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
logger.warning(f"Directory does not exist: {path}")
|
||||
return {"uploaded": False, "reason": "directory_not_found"}
|
||||
|
||||
if path in settings.PRIVATE_DIRS:
|
||||
return backup_encrypted_directory(path)
|
||||
return sync_unencrypted_directory(path)
|
||||
|
||||
|
||||
@app.task(name=BACKUP_ALL)
|
||||
def backup_all_to_s3():
|
||||
"""Main backup task that syncs unencrypted dirs and uploads encrypted dirs."""
|
||||
if not settings.S3_BACKUP_ENABLED:
|
||||
logger.info("S3 backup is disabled")
|
||||
return {"status": "disabled"}
|
||||
|
||||
logger.info("Starting S3 backup...")
|
||||
|
||||
for dir_name in settings.storage_dirs:
|
||||
backup_to_s3.delay((settings.FILE_STORAGE_DIR / dir_name).as_posix())
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Started backup for {len(settings.storage_dirs)} directories",
|
||||
}
|
||||
@ -7,7 +7,7 @@ import logging
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import exc as sqlalchemy_exc
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
@ -56,8 +56,15 @@ def call_llm(
|
||||
message: DiscordMessage,
|
||||
model: str,
|
||||
msgs: list[str] = [],
|
||||
allowed_tools: list[str] = [],
|
||||
allowed_tools: list[str] | None = None,
|
||||
) -> str | None:
|
||||
provider = create_provider(model=model)
|
||||
if provider.usage_tracker.is_rate_limited(model):
|
||||
logger.error(
|
||||
f"Rate limited for model {model}: {provider.usage_tracker.get_usage_breakdown(model=model)}"
|
||||
)
|
||||
return None
|
||||
|
||||
tools = make_discord_tools(
|
||||
message.recipient_user.system_user,
|
||||
message.from_user,
|
||||
@ -67,13 +74,13 @@ def call_llm(
|
||||
tools = {
|
||||
name: tool
|
||||
for name, tool in tools.items()
|
||||
if message.tool_allowed(name) and name in allowed_tools
|
||||
if message.tool_allowed(name)
|
||||
and (allowed_tools is None or name in allowed_tools)
|
||||
}
|
||||
system_prompt = message.system_prompt or ""
|
||||
system_prompt += comm_channel_prompt(
|
||||
session, message.recipient_user, message.channel
|
||||
)
|
||||
provider = create_provider(model=model)
|
||||
messages = previous_messages(
|
||||
session,
|
||||
message.recipient_user and message.recipient_user.id,
|
||||
@ -127,10 +134,32 @@ def should_process(message: DiscordMessage) -> bool:
|
||||
if not (res := re.search(r"<number>(.*)</number>", response)):
|
||||
return False
|
||||
try:
|
||||
return int(res.group(1)) > message.chattiness_threshold
|
||||
if int(res.group(1)) < 100 - message.chattiness_threshold:
|
||||
return False
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if not (bot_id := _resolve_bot_id(message)):
|
||||
return False
|
||||
|
||||
if message.channel and message.channel.server:
|
||||
discord.trigger_typing_channel(bot_id, message.channel.name)
|
||||
else:
|
||||
discord.trigger_typing_dm(bot_id, cast(int | str, message.from_id))
|
||||
return True
|
||||
|
||||
|
||||
def _resolve_bot_id(discord_message: DiscordMessage) -> int | None:
|
||||
recipient = discord_message.recipient_user
|
||||
if not recipient:
|
||||
return None
|
||||
|
||||
system_user = recipient.system_user
|
||||
if not system_user:
|
||||
return None
|
||||
|
||||
return getattr(system_user, "id", None)
|
||||
|
||||
|
||||
@app.task(name=PROCESS_DISCORD_MESSAGE)
|
||||
@safe_task_execution
|
||||
@ -152,14 +181,33 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
||||
bot_id = _resolve_bot_id(discord_message)
|
||||
if not bot_id:
|
||||
logger.warning(
|
||||
"No associated Discord bot user for message %s; skipping send",
|
||||
message_id,
|
||||
)
|
||||
return {
|
||||
"status": "processed",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
try:
|
||||
response = call_llm(session, discord_message, settings.DISCORD_MODEL)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate Discord response")
|
||||
|
||||
print("response:", response)
|
||||
if not response:
|
||||
pass
|
||||
elif discord_message.channel.server:
|
||||
discord.send_to_channel(discord_message.channel.name, response)
|
||||
return {
|
||||
"status": "processed",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
if discord_message.channel.server:
|
||||
discord.send_to_channel(bot_id, discord_message.channel.name, response)
|
||||
else:
|
||||
discord.send_dm(discord_message.from_user.username, response)
|
||||
discord.send_dm(bot_id, discord_message.from_user.username, response)
|
||||
|
||||
return {
|
||||
"status": "processed",
|
||||
|
||||
@ -37,12 +37,22 @@ def _send_to_discord(scheduled_call: ScheduledLLMCall, response: str):
|
||||
if len(message) > 1900: # Leave some buffer
|
||||
message = message[:1900] + "\n\n... (response truncated)"
|
||||
|
||||
bot_id_value = scheduled_call.user_id
|
||||
if bot_id_value is None:
|
||||
logger.warning(
|
||||
"Scheduled call %s has no associated bot user; skipping Discord send",
|
||||
scheduled_call.id,
|
||||
)
|
||||
return
|
||||
|
||||
bot_id = cast(int, bot_id_value)
|
||||
|
||||
if discord_user := scheduled_call.discord_user:
|
||||
logger.info(f"Sending DM to {discord_user.username}: {message}")
|
||||
discord.send_dm(discord_user.username, message)
|
||||
discord.send_dm(bot_id, discord_user.username, message)
|
||||
elif discord_channel := scheduled_call.discord_channel:
|
||||
logger.info(f"Broadcasting message to {discord_channel.name}: {message}")
|
||||
discord.broadcast_message(discord_channel.name, message)
|
||||
discord.broadcast_message(bot_id, discord_channel.name, message)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No Discord user or channel found for scheduled call {scheduled_call.id}"
|
||||
|
||||
@ -1,7 +1,24 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Iterable
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import redis # noqa: F401 # pragma: no cover - optional test dependency
|
||||
except ModuleNotFoundError: # pragma: no cover - import guard for test envs
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
|
||||
class _RedisStub(SimpleNamespace):
|
||||
class Redis: # type: ignore[no-redef]
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
raise ModuleNotFoundError(
|
||||
"The 'redis' package is required to use RedisUsageTracker"
|
||||
)
|
||||
|
||||
sys.modules.setdefault("redis", _RedisStub())
|
||||
|
||||
from memory.common.llms.redis_usage_tracker import RedisUsageTracker
|
||||
from memory.common.llms.usage_tracker import (
|
||||
InMemoryUsageTracker,
|
||||
RateLimitConfig,
|
||||
@ -9,6 +26,24 @@ from memory.common.llms.usage_tracker import (
|
||||
)
|
||||
|
||||
|
||||
class FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, str] = {}
|
||||
|
||||
def get(self, key: str) -> str | None:
|
||||
return self._store.get(key)
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
def scan_iter(self, match: str) -> Iterable[str]:
|
||||
from fnmatch import fnmatch
|
||||
|
||||
for key in list(self._store.keys()):
|
||||
if fnmatch(key, match):
|
||||
yield key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker() -> InMemoryUsageTracker:
|
||||
config = RateLimitConfig(
|
||||
@ -25,6 +60,23 @@ def tracker() -> InMemoryUsageTracker:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_tracker() -> RedisUsageTracker:
|
||||
config = RateLimitConfig(
|
||||
window=timedelta(minutes=1),
|
||||
max_input_tokens=1_000,
|
||||
max_output_tokens=2_000,
|
||||
max_total_tokens=2_500,
|
||||
)
|
||||
return RedisUsageTracker(
|
||||
{
|
||||
"anthropic/claude-3": config,
|
||||
"anthropic/haiku": config,
|
||||
},
|
||||
redis_client=FakeRedis(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"window, kwargs",
|
||||
[
|
||||
@ -139,6 +191,22 @@ def test_is_rate_limited_when_only_output_exceeds_limit() -> None:
|
||||
assert tracker.is_rate_limited("openai/gpt-4o")
|
||||
|
||||
|
||||
def test_redis_usage_tracker_persists_state(redis_tracker: RedisUsageTracker) -> None:
|
||||
now = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
redis_tracker.record_usage("anthropic/claude-3", 100, 200, timestamp=now)
|
||||
redis_tracker.record_usage("anthropic/haiku", 50, 75, timestamp=now)
|
||||
|
||||
allowance = redis_tracker.get_available_tokens("anthropic/claude-3", timestamp=now)
|
||||
assert allowance is not None
|
||||
assert allowance.input_tokens == 900
|
||||
|
||||
breakdown = redis_tracker.get_usage_breakdown()
|
||||
assert breakdown["anthropic"]["claude-3"].window_output_tokens == 200
|
||||
|
||||
items = dict(redis_tracker.iter_state_items())
|
||||
assert set(items.keys()) == {"anthropic/claude-3", "anthropic/haiku"}
|
||||
|
||||
|
||||
def test_usage_tracker_base_not_instantiable() -> None:
|
||||
class DummyTracker(UsageTracker):
|
||||
pass
|
||||
|
||||
@ -4,6 +4,8 @@ import requests
|
||||
|
||||
from memory.common import discord
|
||||
|
||||
BOT_ID = 42
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_url():
|
||||
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user": "user123", "message": "Hello!"},
|
||||
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||
"""Test DM sending when request raises exception"""
|
||||
mock_post.side_effect = requests.RequestException("Network error")
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_channel",
|
||||
json={"channel_name": "general", "message": "Announcement!"},
|
||||
json={
|
||||
"bot_id": BOT_ID,
|
||||
"channel_name": "general",
|
||||
"message": "Announcement!",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast with exception"""
|
||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
"""Test health check when collector is healthy"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "healthy"}
|
||||
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||
"""Test health check when collector returns unhealthy status"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "unhealthy"}
|
||||
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||
"""Test health check when request fails"""
|
||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
|
||||
"""Test sending error message to error channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_error_message("Something broke")
|
||||
result = discord.send_error_message(BOT_ID, "Something broke")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
|
||||
"""Test sending activity message to activity channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_activity_message("User logged in")
|
||||
result = discord.send_activity_message(BOT_ID, "User logged in")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||
mock_broadcast.assert_called_once_with(
|
||||
BOT_ID, "activity", "User logged in"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
|
||||
"""Test sending discovery message to discovery channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_discovery_message("Found interesting pattern")
|
||||
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||
mock_broadcast.assert_called_once_with(
|
||||
BOT_ID, "discoveries", "Found interesting pattern"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
|
||||
"""Test sending chat message to chat channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_chat_message("Hello from bot")
|
||||
result = discord.send_chat_message(BOT_ID, "Hello from bot")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_basic(mock_send_error):
|
||||
"""Test basic task failure notification"""
|
||||
discord.notify_task_failure("test_task", "Something went wrong")
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Something went wrong", bot_id=BOT_ID
|
||||
)
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
message = mock_send_error.call_args[0][0]
|
||||
assert mock_send_error.call_args[0][0] == BOT_ID
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "🚨 **Task Failed: test_task**" in message
|
||||
assert "**Error:** Something went wrong" in message
|
||||
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
|
||||
"Error occurred",
|
||||
task_args=("arg1", 42),
|
||||
task_kwargs={"key": "value", "number": 123},
|
||||
bot_id=BOT_ID,
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "**Args:** `('arg1', 42)" in message
|
||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
|
||||
"""Test task failure notification with traceback"""
|
||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||
|
||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "**Traceback:**" in message
|
||||
assert "Exception: test" in message
|
||||
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||
"""Test that long error messages are truncated"""
|
||||
long_error = "x" * 600
|
||||
|
||||
discord.notify_task_failure("test_task", long_error)
|
||||
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||
assert "**Error:** " + long_error[:500] in message
|
||||
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||
"""Test that long tracebacks are truncated"""
|
||||
long_traceback = "x" * 1000
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Traceback should show last 800 chars
|
||||
assert long_traceback[-800:] in message
|
||||
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||
"""Test that long task arguments are truncated"""
|
||||
long_args = ("x" * 300,)
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Args should be truncated to 200 chars
|
||||
assert (
|
||||
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
"""Test that long task kwargs are truncated"""
|
||||
long_kwargs = {"key": "x" * 300}
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Kwargs should be truncated to 200 chars
|
||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_notify_task_failure_disabled(mock_send_error):
|
||||
"""Test that notifications are not sent when disabled"""
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||
|
||||
mock_send_error.assert_not_called()
|
||||
|
||||
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||
mock_send_error.side_effect = Exception("Failed to send")
|
||||
|
||||
# Should not raise
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
|
||||
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
|
||||
):
|
||||
"""Test that convenience functions use the correct channel settings"""
|
||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||
function(message)
|
||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||
function(BOT_ID, message)
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||
result = discord.send_dm("user123", message_with_special_chars)
|
||||
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||
json_payload = call_args[1]["json"]
|
||||
assert json_payload["message"] == message_with_special_chars
|
||||
assert json_payload["bot_id"] == BOT_ID
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
long_message = "A" * 2000
|
||||
result = discord.broadcast_message("general", long_message)
|
||||
result = discord.broadcast_message(BOT_ID, "general", long_message)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == long_message
|
||||
json_payload = call_args[1]["json"]
|
||||
assert json_payload["message"] == long_message
|
||||
assert json_payload["bot_id"] == BOT_ID
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -4,6 +4,8 @@ import requests
|
||||
|
||||
from memory.common import discord
|
||||
|
||||
BOT_ID = 42
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_url():
|
||||
@ -29,12 +31,12 @@ def test_send_dm_success(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_dm",
|
||||
json={"user": "user123", "message": "Hello!"},
|
||||
json={"bot_id": BOT_ID, "user": "user123", "message": "Hello!"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
@ -47,7 +49,7 @@ def test_send_dm_api_failure(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -57,7 +59,7 @@ def test_send_dm_request_exception(mock_post, mock_api_url):
|
||||
"""Test DM sending when request raises exception"""
|
||||
mock_post.side_effect = requests.RequestException("Network error")
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -69,7 +71,7 @@ def test_send_dm_http_error(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.send_dm("user123", "Hello!")
|
||||
result = discord.send_dm(BOT_ID, "user123", "Hello!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -82,12 +84,16 @@ def test_broadcast_message_success(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"http://localhost:8000/send_channel",
|
||||
json={"channel_name": "general", "message": "Announcement!"},
|
||||
json={
|
||||
"bot_id": BOT_ID,
|
||||
"channel_name": "general",
|
||||
"message": "Announcement!",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
@ -100,7 +106,7 @@ def test_broadcast_message_failure(mock_post, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -110,7 +116,7 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
"""Test channel message broadcast with exception"""
|
||||
mock_post.side_effect = requests.Timeout("Request timeout")
|
||||
|
||||
result = discord.broadcast_message("general", "Announcement!")
|
||||
result = discord.broadcast_message(BOT_ID, "general", "Announcement!")
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -119,11 +125,11 @@ def test_broadcast_message_exception(mock_post, mock_api_url):
|
||||
def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
"""Test health check when collector is healthy"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "healthy"}
|
||||
mock_response.json.return_value = {str(BOT_ID): {"connected": True}}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||
@ -133,11 +139,11 @@ def test_is_collector_healthy_true(mock_get, mock_api_url):
|
||||
def test_is_collector_healthy_false_status(mock_get, mock_api_url):
|
||||
"""Test health check when collector returns unhealthy status"""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"status": "unhealthy"}
|
||||
mock_response.json.return_value = {str(BOT_ID): {"connected": False}}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -147,7 +153,7 @@ def test_is_collector_healthy_exception(mock_get, mock_api_url):
|
||||
"""Test health check when request fails"""
|
||||
mock_get.side_effect = requests.ConnectionError("Connection refused")
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
@ -200,10 +206,10 @@ def test_send_error_message(mock_broadcast):
|
||||
"""Test sending error message to error channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_error_message("Something broke")
|
||||
result = discord.send_error_message(BOT_ID, "Something broke")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("errors", "Something broke")
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "errors", "Something broke")
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -212,10 +218,12 @@ def test_send_activity_message(mock_broadcast):
|
||||
"""Test sending activity message to activity channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_activity_message("User logged in")
|
||||
result = discord.send_activity_message(BOT_ID, "User logged in")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("activity", "User logged in")
|
||||
mock_broadcast.assert_called_once_with(
|
||||
BOT_ID, "activity", "User logged in"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -224,10 +232,12 @@ def test_send_discovery_message(mock_broadcast):
|
||||
"""Test sending discovery message to discovery channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_discovery_message("Found interesting pattern")
|
||||
result = discord.send_discovery_message(BOT_ID, "Found interesting pattern")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("discoveries", "Found interesting pattern")
|
||||
mock_broadcast.assert_called_once_with(
|
||||
BOT_ID, "discoveries", "Found interesting pattern"
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.common.discord.broadcast_message")
|
||||
@ -236,20 +246,23 @@ def test_send_chat_message(mock_broadcast):
|
||||
"""Test sending chat message to chat channel"""
|
||||
mock_broadcast.return_value = True
|
||||
|
||||
result = discord.send_chat_message("Hello from bot")
|
||||
result = discord.send_chat_message(BOT_ID, "Hello from bot")
|
||||
|
||||
assert result is True
|
||||
mock_broadcast.assert_called_once_with("chat", "Hello from bot")
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "chat", "Hello from bot")
|
||||
|
||||
|
||||
@patch("memory.common.discord.send_error_message")
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", True)
|
||||
def test_notify_task_failure_basic(mock_send_error):
|
||||
"""Test basic task failure notification"""
|
||||
discord.notify_task_failure("test_task", "Something went wrong")
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Something went wrong", bot_id=BOT_ID
|
||||
)
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
message = mock_send_error.call_args[0][0]
|
||||
assert mock_send_error.call_args[0][0] == BOT_ID
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "🚨 **Task Failed: test_task**" in message
|
||||
assert "**Error:** Something went wrong" in message
|
||||
@ -264,9 +277,10 @@ def test_notify_task_failure_with_args(mock_send_error):
|
||||
"Error occurred",
|
||||
task_args=("arg1", 42),
|
||||
task_kwargs={"key": "value", "number": 123},
|
||||
bot_id=BOT_ID,
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "**Args:** `('arg1', 42)" in message
|
||||
assert "**Kwargs:** `{'key': 'value', 'number': 123}" in message
|
||||
@ -278,9 +292,11 @@ def test_notify_task_failure_with_traceback(mock_send_error):
|
||||
"""Test task failure notification with traceback"""
|
||||
traceback = "Traceback (most recent call last):\n File test.py, line 10\n raise Exception('test')\nException: test"
|
||||
|
||||
discord.notify_task_failure("test_task", "Error occurred", traceback_str=traceback)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error occurred", traceback_str=traceback, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
assert "**Traceback:**" in message
|
||||
assert "Exception: test" in message
|
||||
@ -292,9 +308,9 @@ def test_notify_task_failure_truncates_long_error(mock_send_error):
|
||||
"""Test that long error messages are truncated"""
|
||||
long_error = "x" * 600
|
||||
|
||||
discord.notify_task_failure("test_task", long_error)
|
||||
discord.notify_task_failure("test_task", long_error, bot_id=BOT_ID)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Error should be truncated to 500 chars - check that the full 600 char string is not there
|
||||
assert "**Error:** " + long_error[:500] in message
|
||||
@ -309,9 +325,11 @@ def test_notify_task_failure_truncates_long_traceback(mock_send_error):
|
||||
"""Test that long tracebacks are truncated"""
|
||||
long_traceback = "x" * 1000
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", traceback_str=long_traceback)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", traceback_str=long_traceback, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Traceback should show last 800 chars
|
||||
assert long_traceback[-800:] in message
|
||||
@ -326,9 +344,11 @@ def test_notify_task_failure_truncates_long_args(mock_send_error):
|
||||
"""Test that long task arguments are truncated"""
|
||||
long_args = ("x" * 300,)
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_args=long_args)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", task_args=long_args, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Args should be truncated to 200 chars
|
||||
assert (
|
||||
@ -342,9 +362,11 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
"""Test that long task kwargs are truncated"""
|
||||
long_kwargs = {"key": "x" * 300}
|
||||
|
||||
discord.notify_task_failure("test_task", "Error", task_kwargs=long_kwargs)
|
||||
discord.notify_task_failure(
|
||||
"test_task", "Error", task_kwargs=long_kwargs, bot_id=BOT_ID
|
||||
)
|
||||
|
||||
message = mock_send_error.call_args[0][0]
|
||||
message = mock_send_error.call_args[0][1]
|
||||
|
||||
# Kwargs should be truncated to 200 chars
|
||||
assert len(message.split("**Kwargs:**")[1].split("\n")[0]) <= 210
|
||||
@ -354,7 +376,7 @@ def test_notify_task_failure_truncates_long_kwargs(mock_send_error):
|
||||
@patch("memory.common.settings.DISCORD_NOTIFICATIONS_ENABLED", False)
|
||||
def test_notify_task_failure_disabled(mock_send_error):
|
||||
"""Test that notifications are not sent when disabled"""
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||
|
||||
mock_send_error.assert_not_called()
|
||||
|
||||
@ -366,7 +388,7 @@ def test_notify_task_failure_send_error_exception(mock_send_error):
|
||||
mock_send_error.side_effect = Exception("Failed to send")
|
||||
|
||||
# Should not raise
|
||||
discord.notify_task_failure("test_task", "Error occurred")
|
||||
discord.notify_task_failure("test_task", "Error occurred", bot_id=BOT_ID)
|
||||
|
||||
mock_send_error.assert_called_once()
|
||||
|
||||
@ -386,8 +408,8 @@ def test_convenience_functions_use_correct_channels(
|
||||
):
|
||||
"""Test that convenience functions use the correct channel settings"""
|
||||
with patch(f"memory.common.settings.{channel_setting}", "test-channel"):
|
||||
function(message)
|
||||
mock_broadcast.assert_called_once_with("test-channel", message)
|
||||
function(BOT_ID, message)
|
||||
mock_broadcast.assert_called_once_with(BOT_ID, "test-channel", message)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
@ -399,11 +421,13 @@ def test_send_dm_with_special_characters(mock_post, mock_api_url):
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
message_with_special_chars = "Hello! 🎉 <@123> #general"
|
||||
result = discord.send_dm("user123", message_with_special_chars)
|
||||
result = discord.send_dm(BOT_ID, "user123", message_with_special_chars)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == message_with_special_chars
|
||||
json_payload = call_args[1]["json"]
|
||||
assert json_payload["message"] == message_with_special_chars
|
||||
assert json_payload["bot_id"] == BOT_ID
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
@ -415,11 +439,13 @@ def test_broadcast_message_with_long_message(mock_post, mock_api_url):
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
long_message = "A" * 2000
|
||||
result = discord.broadcast_message("general", long_message)
|
||||
result = discord.broadcast_message(BOT_ID, "general", long_message)
|
||||
|
||||
assert result is True
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["message"] == long_message
|
||||
json_payload = call_args[1]["json"]
|
||||
assert json_payload["message"] == long_message
|
||||
assert json_payload["bot_id"] == BOT_ID
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
@ -430,6 +456,6 @@ def test_is_collector_healthy_missing_status_key(mock_get, mock_api_url):
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = discord.is_collector_healthy()
|
||||
result = discord.is_collector_healthy(BOT_ID)
|
||||
|
||||
assert result is False
|
||||
|
||||
393
tests/memory/workers/tasks/test_backup_tasks.py
Normal file
393
tests/memory/workers/tasks/test_backup_tasks.py
Normal file
@ -0,0 +1,393 @@
|
||||
import io
|
||||
import subprocess
|
||||
import tarfile
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from memory.common import settings
|
||||
from memory.workers.tasks import backup
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_files():
|
||||
"""Create sample files in memory_files structure."""
|
||||
base = settings.FILE_STORAGE_DIR
|
||||
|
||||
dirs_with_files = {
|
||||
"emails": ["email1.txt", "email2.txt"],
|
||||
"notes": ["note1.md", "note2.md"],
|
||||
"photos": ["photo1.jpg"],
|
||||
"comics": ["comic1.png", "comic2.png"],
|
||||
"ebooks": ["book1.epub"],
|
||||
"webpages": ["page1.html"],
|
||||
}
|
||||
|
||||
for dir_name, filenames in dirs_with_files.items():
|
||||
dir_path = base / dir_name
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for filename in filenames:
|
||||
file_path = dir_path / filename
|
||||
content = f"Content of {dir_name}/{filename}\n" * 100
|
||||
file_path.write_text(content)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_s3_client():
|
||||
"""Mock boto3 S3 client."""
|
||||
with patch("boto3.client") as mock_client:
|
||||
s3_mock = MagicMock()
|
||||
mock_client.return_value = s3_mock
|
||||
yield s3_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backup_settings():
|
||||
"""Mock backup settings."""
|
||||
with (
|
||||
patch.object(settings, "S3_BACKUP_ENABLED", True),
|
||||
patch.object(settings, "BACKUP_ENCRYPTION_KEY", "test-password-123"),
|
||||
patch.object(settings, "S3_BACKUP_BUCKET", "test-bucket"),
|
||||
patch.object(settings, "S3_BACKUP_PREFIX", "test-prefix"),
|
||||
patch.object(settings, "S3_BACKUP_REGION", "us-east-1"),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_test_path():
|
||||
"""Helper to construct test paths."""
|
||||
return lambda dir_name: settings.FILE_STORAGE_DIR / dir_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data,key",
|
||||
[
|
||||
(b"This is a test message", "my-secret-key"),
|
||||
(b"\x00\x01\x02\xff" * 10000, "another-key"),
|
||||
(b"x" * 1000000, "large-data-key"),
|
||||
],
|
||||
)
|
||||
def test_encrypt_decrypt_roundtrip(data, key):
|
||||
"""Test encryption and decryption produces original data."""
|
||||
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", key):
|
||||
cipher = backup.get_cipher()
|
||||
encrypted = cipher.encrypt(data)
|
||||
decrypted = cipher.decrypt(encrypted)
|
||||
|
||||
assert decrypted == data
|
||||
assert encrypted != data
|
||||
|
||||
|
||||
def test_encrypt_decrypt_tarball(sample_files):
|
||||
"""Test full tarball creation, encryption, and decryption."""
|
||||
emails_dir = settings.FILE_STORAGE_DIR / "emails"
|
||||
|
||||
# Create tarball
|
||||
tarball_bytes = backup.create_tarball(emails_dir)
|
||||
assert len(tarball_bytes) > 0
|
||||
|
||||
# Encrypt
|
||||
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "tarball-key"):
|
||||
cipher = backup.get_cipher()
|
||||
encrypted = cipher.encrypt(tarball_bytes)
|
||||
|
||||
# Decrypt
|
||||
decrypted = cipher.decrypt(encrypted)
|
||||
|
||||
assert decrypted == tarball_bytes
|
||||
|
||||
# Verify tarball can be extracted
|
||||
tar_buffer = io.BytesIO(decrypted)
|
||||
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
|
||||
members = tar.getmembers()
|
||||
assert len(members) >= 2 # At least 2 email files
|
||||
|
||||
# Extract and verify content
|
||||
for member in members:
|
||||
if member.isfile():
|
||||
extracted = tar.extractfile(member)
|
||||
assert extracted is not None
|
||||
content = extracted.read().decode()
|
||||
assert "Content of emails/" in content
|
||||
|
||||
|
||||
def test_different_keys_produce_different_ciphertext():
|
||||
"""Test that different encryption keys produce different ciphertext."""
|
||||
data = b"Same data encrypted with different keys"
|
||||
|
||||
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key1"):
|
||||
cipher1 = backup.get_cipher()
|
||||
encrypted1 = cipher1.encrypt(data)
|
||||
|
||||
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key2"):
|
||||
cipher2 = backup.get_cipher()
|
||||
encrypted2 = cipher2.encrypt(data)
|
||||
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
|
||||
def test_missing_encryption_key_raises_error():
|
||||
"""Test that missing encryption key raises ValueError."""
|
||||
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", ""):
|
||||
with pytest.raises(ValueError, match="BACKUP_ENCRYPTION_KEY not set"):
|
||||
backup.get_cipher()
|
||||
|
||||
|
||||
def test_create_tarball_with_files(sample_files):
|
||||
"""Test creating tarball from directory with files."""
|
||||
notes_dir = settings.FILE_STORAGE_DIR / "notes"
|
||||
tarball_bytes = backup.create_tarball(notes_dir)
|
||||
|
||||
assert len(tarball_bytes) > 0
|
||||
|
||||
# Verify it's a valid gzipped tarball
|
||||
tar_buffer = io.BytesIO(tarball_bytes)
|
||||
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
|
||||
members = tar.getmembers()
|
||||
filenames = [m.name for m in members if m.isfile()]
|
||||
assert len(filenames) >= 2
|
||||
assert any("note1.md" in f for f in filenames)
|
||||
assert any("note2.md" in f for f in filenames)
|
||||
|
||||
|
||||
def test_create_tarball_nonexistent_directory():
|
||||
"""Test creating tarball from nonexistent directory."""
|
||||
nonexistent = settings.FILE_STORAGE_DIR / "does_not_exist"
|
||||
tarball_bytes = backup.create_tarball(nonexistent)
|
||||
|
||||
assert tarball_bytes == b""
|
||||
|
||||
|
||||
def test_create_tarball_empty_directory():
|
||||
"""Test creating tarball from empty directory."""
|
||||
empty_dir = settings.FILE_STORAGE_DIR / "empty"
|
||||
empty_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tarball_bytes = backup.create_tarball(empty_dir)
|
||||
|
||||
# Should create tarball with just the directory entry
|
||||
assert len(tarball_bytes) > 0
|
||||
tar_buffer = io.BytesIO(tarball_bytes)
|
||||
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
|
||||
members = tar.getmembers()
|
||||
assert len(members) >= 1
|
||||
assert members[0].isdir()
|
||||
|
||||
|
||||
def test_sync_unencrypted_success(sample_files, backup_settings):
|
||||
"""Test successful sync of unencrypted directory."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = Mock(stdout="Synced files", returncode=0)
|
||||
|
||||
comics_path = settings.FILE_STORAGE_DIR / "comics"
|
||||
result = backup.sync_unencrypted_directory(comics_path)
|
||||
|
||||
assert result["synced"] is True
|
||||
assert result["directory"] == comics_path
|
||||
assert "s3_uri" in result
|
||||
assert "test-bucket" in result["s3_uri"]
|
||||
assert "test-prefix/comics" in result["s3_uri"]
|
||||
|
||||
# Verify aws s3 sync was called correctly
|
||||
mock_run.assert_called_once()
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert call_args[0] == "aws"
|
||||
assert call_args[1] == "s3"
|
||||
assert call_args[2] == "sync"
|
||||
assert "--delete" in call_args
|
||||
assert "--region" in call_args
|
||||
|
||||
|
||||
def test_sync_unencrypted_nonexistent_directory(backup_settings):
|
||||
"""Test syncing nonexistent directory."""
|
||||
nonexistent_path = settings.FILE_STORAGE_DIR / "does_not_exist"
|
||||
result = backup.sync_unencrypted_directory(nonexistent_path)
|
||||
|
||||
assert result["synced"] is False
|
||||
assert result["reason"] == "directory_not_found"
|
||||
|
||||
|
||||
def test_sync_unencrypted_aws_cli_failure(sample_files, backup_settings):
|
||||
"""Test handling of AWS CLI failure."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.side_effect = subprocess.CalledProcessError(
|
||||
1, "aws", stderr="AWS CLI error"
|
||||
)
|
||||
|
||||
comics_path = settings.FILE_STORAGE_DIR / "comics"
|
||||
result = backup.sync_unencrypted_directory(comics_path)
|
||||
|
||||
assert result["synced"] is False
|
||||
assert "error" in result
|
||||
|
||||
|
||||
def test_backup_encrypted_success(
|
||||
sample_files, mock_s3_client, backup_settings, get_test_path
|
||||
):
|
||||
"""Test successful encrypted backup."""
|
||||
result = backup.backup_encrypted_directory(get_test_path("emails"))
|
||||
|
||||
assert result["uploaded"] is True
|
||||
assert result["size_bytes"] > 0
|
||||
assert result["s3_key"].endswith("emails.tar.gz.enc")
|
||||
|
||||
call_kwargs = mock_s3_client.put_object.call_args[1]
|
||||
assert call_kwargs["Bucket"] == "test-bucket"
|
||||
assert call_kwargs["ServerSideEncryption"] == "AES256"
|
||||
|
||||
|
||||
def test_backup_encrypted_nonexistent_directory(
|
||||
mock_s3_client, backup_settings, get_test_path
|
||||
):
|
||||
"""Test backing up nonexistent directory."""
|
||||
result = backup.backup_encrypted_directory(get_test_path("does_not_exist"))
|
||||
|
||||
assert result["uploaded"] is False
|
||||
assert result["reason"] == "directory_not_found"
|
||||
mock_s3_client.put_object.assert_not_called()
|
||||
|
||||
|
||||
def test_backup_encrypted_empty_directory(
|
||||
mock_s3_client, backup_settings, get_test_path
|
||||
):
|
||||
"""Test backing up empty directory."""
|
||||
empty_dir = get_test_path("empty_encrypted")
|
||||
empty_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result = backup.backup_encrypted_directory(empty_dir)
|
||||
assert "uploaded" in result
|
||||
|
||||
|
||||
def test_backup_encrypted_s3_failure(
|
||||
sample_files, mock_s3_client, backup_settings, get_test_path
|
||||
):
|
||||
"""Test handling of S3 upload failure."""
|
||||
mock_s3_client.put_object.side_effect = ClientError(
|
||||
{"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, "PutObject"
|
||||
)
|
||||
|
||||
result = backup.backup_encrypted_directory(get_test_path("notes"))
|
||||
assert result["uploaded"] is False
|
||||
assert "error" in result
|
||||
|
||||
|
||||
def test_backup_encrypted_data_integrity(
|
||||
sample_files, mock_s3_client, backup_settings, get_test_path
|
||||
):
|
||||
"""Test that encrypted backup maintains data integrity through full cycle."""
|
||||
result = backup.backup_encrypted_directory(get_test_path("notes"))
|
||||
assert result["uploaded"] is True
|
||||
|
||||
# Decrypt uploaded data
|
||||
cipher = backup.get_cipher()
|
||||
encrypted_data = mock_s3_client.put_object.call_args[1]["Body"]
|
||||
decrypted_tarball = cipher.decrypt(encrypted_data)
|
||||
|
||||
# Verify content
|
||||
tar_buffer = io.BytesIO(decrypted_tarball)
|
||||
with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar:
|
||||
note1_found = False
|
||||
for member in tar.getmembers():
|
||||
if member.name.endswith("note1.md") and member.isfile():
|
||||
content = tar.extractfile(member).read().decode()
|
||||
assert "Content of notes/note1.md" in content
|
||||
note1_found = True
|
||||
assert note1_found, "note1.md not found in tarball"
|
||||
|
||||
|
||||
def test_backup_disabled():
|
||||
"""Test that backup returns early when disabled."""
|
||||
with patch.object(settings, "S3_BACKUP_ENABLED", False):
|
||||
result = backup.backup_all_to_s3()
|
||||
|
||||
assert result["status"] == "disabled"
|
||||
|
||||
|
||||
def test_backup_full_execution(sample_files, mock_s3_client, backup_settings):
|
||||
"""Test full backup execution dispatches tasks for all directories."""
|
||||
with patch.object(backup, "backup_to_s3") as mock_task:
|
||||
mock_task.delay = Mock()
|
||||
|
||||
result = backup.backup_all_to_s3()
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert "message" in result
|
||||
|
||||
# Verify task was queued for each storage directory
|
||||
assert mock_task.delay.call_count == len(settings.storage_dirs)
|
||||
|
||||
|
||||
def test_backup_handles_partial_failures(
|
||||
sample_files, mock_s3_client, backup_settings, get_test_path
|
||||
):
|
||||
"""Test that backup continues even if some directories fail."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.side_effect = subprocess.CalledProcessError(
|
||||
1, "aws", stderr="Sync failed"
|
||||
)
|
||||
result = backup.sync_unencrypted_directory(get_test_path("comics"))
|
||||
|
||||
assert result["synced"] is False
|
||||
assert "error" in result
|
||||
|
||||
|
||||
def test_same_key_different_runs_different_ciphertext():
|
||||
"""Test that Fernet produces different ciphertext each run (due to nonce)."""
|
||||
data = b"Consistent data"
|
||||
|
||||
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "same-key"):
|
||||
cipher = backup.get_cipher()
|
||||
encrypted1 = cipher.encrypt(data)
|
||||
encrypted2 = cipher.encrypt(data)
|
||||
|
||||
# Should be different due to random nonce, but both should decrypt to same value
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
decrypted1 = cipher.decrypt(encrypted1)
|
||||
decrypted2 = cipher.decrypt(encrypted2)
|
||||
assert decrypted1 == decrypted2 == data
|
||||
|
||||
|
||||
def test_key_derivation_consistency():
|
||||
"""Test that same password produces same encryption key."""
|
||||
password = "test-password"
|
||||
|
||||
with patch.object(settings, "BACKUP_ENCRYPTION_KEY", password):
|
||||
cipher1 = backup.get_cipher()
|
||||
cipher2 = backup.get_cipher()
|
||||
|
||||
# Both should be able to decrypt each other's ciphertext
|
||||
data = b"Test data"
|
||||
encrypted = cipher1.encrypt(data)
|
||||
decrypted = cipher2.decrypt(encrypted)
|
||||
assert decrypted == data
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dir_name,is_private",
|
||||
[
|
||||
("emails", True),
|
||||
("notes", True),
|
||||
("photos", True),
|
||||
("comics", False),
|
||||
("ebooks", False),
|
||||
("webpages", False),
|
||||
("lesswrong", False),
|
||||
("chunks", False),
|
||||
],
|
||||
)
|
||||
def test_directory_encryption_classification(dir_name, is_private, backup_settings):
|
||||
"""Test that directories are correctly classified as encrypted or not."""
|
||||
# Create a mock PRIVATE_DIRS list
|
||||
private_dirs = ["emails", "notes", "photos"]
|
||||
|
||||
with patch.object(
|
||||
settings, "PRIVATE_DIRS", [settings.FILE_STORAGE_DIR / d for d in private_dirs]
|
||||
):
|
||||
test_path = settings.FILE_STORAGE_DIR / dir_name
|
||||
is_in_private = test_path in settings.PRIVATE_DIRS
|
||||
|
||||
assert is_in_private == is_private
|
||||
@ -3,6 +3,7 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from memory.common.db.models import (
|
||||
DiscordBotUser,
|
||||
DiscordMessage,
|
||||
DiscordUser,
|
||||
DiscordServer,
|
||||
@ -12,12 +13,25 @@ from memory.workers.tasks import discord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_user(db_session):
|
||||
def discord_bot_user(db_session):
|
||||
bot = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
name="Test Bot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
db_session.add(bot)
|
||||
db_session.commit()
|
||||
return bot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_user(db_session, discord_bot_user):
|
||||
"""Create a Discord user for testing."""
|
||||
user = DiscordUser(
|
||||
id=123456789,
|
||||
username="testuser",
|
||||
ignore_messages=False,
|
||||
system_user_id=discord_bot_user.id,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@ -3,17 +3,23 @@ from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
import uuid
|
||||
|
||||
from memory.common.db.models import ScheduledLLMCall, HumanUser, DiscordUser, DiscordChannel, DiscordServer
|
||||
from memory.common.db.models import (
|
||||
ScheduledLLMCall,
|
||||
DiscordBotUser,
|
||||
DiscordUser,
|
||||
DiscordChannel,
|
||||
DiscordServer,
|
||||
)
|
||||
from memory.workers.tasks import scheduled_calls
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user(db_session):
|
||||
"""Create a sample user for testing."""
|
||||
user = HumanUser.create_with_password(
|
||||
name="testuser",
|
||||
email="test@example.com",
|
||||
password="password",
|
||||
user = DiscordBotUser.create_with_api_key(
|
||||
discord_users=[],
|
||||
name="testbot",
|
||||
email="bot@example.com",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
@ -124,6 +130,7 @@ def test_send_to_discord_user(mock_send_dm, pending_scheduled_call):
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, response)
|
||||
|
||||
mock_send_dm.assert_called_once_with(
|
||||
pending_scheduled_call.user_id,
|
||||
"testuser", # username, not ID
|
||||
"**Topic:** Test Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a test response.",
|
||||
)
|
||||
@ -137,6 +144,7 @@ def test_send_to_discord_channel(mock_broadcast, completed_scheduled_call):
|
||||
scheduled_calls._send_to_discord(completed_scheduled_call, response)
|
||||
|
||||
mock_broadcast.assert_called_once_with(
|
||||
completed_scheduled_call.user_id,
|
||||
"test-channel", # channel name, not ID
|
||||
"**Topic:** Completed Topic\n**Model:** anthropic/claude-3-5-sonnet-20241022\n**Response:** This is a channel response.",
|
||||
)
|
||||
@ -151,7 +159,8 @@ def test_send_to_discord_long_message_truncation(mock_send_dm, pending_scheduled
|
||||
|
||||
# Verify the message was truncated
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
message = args[1]
|
||||
assert args[0] == pending_scheduled_call.user_id
|
||||
message = args[2]
|
||||
assert len(message) <= 1950 # Should be truncated
|
||||
assert message.endswith("... (response truncated)")
|
||||
|
||||
@ -164,7 +173,8 @@ def test_send_to_discord_normal_length_message(mock_send_dm, pending_scheduled_c
|
||||
scheduled_calls._send_to_discord(pending_scheduled_call, normal_response)
|
||||
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
message = args[1]
|
||||
assert args[0] == pending_scheduled_call.user_id
|
||||
message = args[2]
|
||||
assert not message.endswith("... (response truncated)")
|
||||
assert "This is a normal length response." in message
|
||||
|
||||
@ -574,6 +584,7 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
mock_discord_user.username = "testuser"
|
||||
|
||||
mock_call = Mock()
|
||||
mock_call.user_id = 987
|
||||
mock_call.topic = topic
|
||||
mock_call.model = model
|
||||
mock_call.discord_user = mock_discord_user
|
||||
@ -583,7 +594,8 @@ def test_message_formatting(mock_send_dm, topic, model, response, expected_in_me
|
||||
|
||||
# Get the actual message that was sent
|
||||
args, kwargs = mock_send_dm.call_args
|
||||
actual_message = args[1]
|
||||
assert args[0] == mock_call.user_id
|
||||
actual_message = args[2]
|
||||
|
||||
# Verify all expected parts are in the message
|
||||
for expected_part in expected_in_message:
|
||||
|
||||
182
tools/backup_databases.sh
Normal file
182
tools/backup_databases.sh
Normal file
@ -0,0 +1,182 @@
|
||||
#!/bin/bash
|
||||
# Backup Postgres and Qdrant databases to S3
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Install AWS CLI if not present (postgres:15 image doesn't include it)
|
||||
if ! command -v aws >/dev/null 2>&1; then
|
||||
echo "Installing AWS CLI, wget, and jq..."
|
||||
apt-get update -qq && apt-get install -y -qq awscli wget jq >/dev/null 2>&1
|
||||
fi
|
||||
|
||||
# Configuration - read from environment or use defaults
|
||||
BUCKET="${S3_BACKUP_BUCKET:-equistamp-memory-backup}"
|
||||
PREFIX="${S3_BACKUP_PREFIX:-Daniel}/databases"
|
||||
REGION="${S3_BACKUP_REGION:-eu-central-1}"
|
||||
PASSWORD="${BACKUP_ENCRYPTION_KEY:?BACKUP_ENCRYPTION_KEY not set}"
|
||||
MAX_BACKUPS="${MAX_BACKUPS:-30}" # Keep last N backups
|
||||
|
||||
# Service names (docker-compose network)
|
||||
POSTGRES_HOST="${POSTGRES_HOST:-postgres}"
|
||||
POSTGRES_USER="${POSTGRES_USER:-kb}"
|
||||
POSTGRES_DB="${POSTGRES_DB:-kb}"
|
||||
QDRANT_URL="${QDRANT_URL:-http://qdrant:6333}"
|
||||
|
||||
# Timestamp for backups
|
||||
DATE=$(date +%Y%m%d-%H%M%S)
|
||||
DATE_SIMPLE=$(date +%Y%m%d)
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*"
|
||||
}
|
||||
|
||||
error() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] ERROR: $*" >&2
|
||||
}
|
||||
|
||||
# Clean old backups - keep only last N
|
||||
cleanup_old_backups() {
|
||||
local prefix=$1
|
||||
local pattern=$2 # e.g., "postgres-" or "qdrant-"
|
||||
|
||||
log "Checking for old ${pattern} backups to clean up..."
|
||||
|
||||
# List all backups matching pattern, sorted by date (oldest first)
|
||||
local backups
|
||||
backups=$(aws s3 ls "s3://${BUCKET}/${prefix}/" --region "${REGION}" | \
|
||||
grep "${pattern}" | \
|
||||
awk '{print $4}' | \
|
||||
sort)
|
||||
|
||||
local count=$(echo "$backups" | wc -l)
|
||||
|
||||
if [ "$count" -le "$MAX_BACKUPS" ]; then
|
||||
log "Found ${count} ${pattern} backups (max: ${MAX_BACKUPS}), no cleanup needed"
|
||||
return 0
|
||||
fi
|
||||
|
||||
local to_delete=$((count - MAX_BACKUPS))
|
||||
log "Found ${count} ${pattern} backups, deleting ${to_delete} oldest..."
|
||||
|
||||
echo "$backups" | head -n "$to_delete" | while read -r file; do
|
||||
if [ -n "$file" ]; then
|
||||
log "Deleting old backup: ${file}"
|
||||
aws s3 rm "s3://${BUCKET}/${prefix}/${file}" --region "${REGION}"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
# Backup Postgres
|
||||
backup_postgres() {
|
||||
log "Starting Postgres backup..."
|
||||
|
||||
local output_path="s3://${BUCKET}/${PREFIX}/postgres-${DATE_SIMPLE}.sql.gz.enc"
|
||||
|
||||
# Use pg_dump directly with service name (no docker exec needed)
|
||||
export PGPASSWORD=$(cat "${POSTGRES_PASSWORD_FILE}")
|
||||
if pg_dump -h "${POSTGRES_HOST}" -U "${POSTGRES_USER}" "${POSTGRES_DB}" 2>/dev/null | \
|
||||
gzip | \
|
||||
openssl enc -aes-256-cbc -salt -pbkdf2 -pass "pass:${PASSWORD}" | \
|
||||
aws s3 cp - "${output_path}" --region "${REGION}"; then
|
||||
log "Postgres backup completed: ${output_path}"
|
||||
unset PGPASSWORD
|
||||
cleanup_old_backups "${PREFIX}" "postgres-"
|
||||
return 0
|
||||
else
|
||||
error "Postgres backup failed"
|
||||
unset PGPASSWORD
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Backup Qdrant
|
||||
backup_qdrant() {
|
||||
log "Starting Qdrant backup..."
|
||||
|
||||
# Create snapshot via HTTP API (no docker exec needed)
|
||||
local snapshot_response
|
||||
if ! snapshot_response=$(wget -q -O - --post-data='{}' \
|
||||
--header='Content-Type: application/json' \
|
||||
"${QDRANT_URL}/snapshots" 2>/dev/null); then
|
||||
error "Failed to create Qdrant snapshot"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local snapshot_name
|
||||
# Parse snapshot name - wget/busybox may not have jq, so use grep/sed
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
snapshot_name=$(echo "${snapshot_response}" | jq -r '.result.name // empty')
|
||||
else
|
||||
# Fallback: parse JSON without jq (fragile but works for simple case)
|
||||
snapshot_name=$(echo "${snapshot_response}" | grep -o '"name":"[^"]*"' | cut -d'"' -f4)
|
||||
fi
|
||||
|
||||
if [ -z "${snapshot_name}" ]; then
|
||||
error "Could not extract snapshot name from response: ${snapshot_response}"
|
||||
return 1
|
||||
fi
|
||||
|
||||
log "Created Qdrant snapshot: ${snapshot_name}"
|
||||
|
||||
# Download snapshot and upload to S3
|
||||
local output_path="s3://${BUCKET}/${PREFIX}/qdrant-${DATE_SIMPLE}.snapshot.enc"
|
||||
|
||||
if wget -q -O - "${QDRANT_URL}/snapshots/${snapshot_name}" | \
|
||||
openssl enc -aes-256-cbc -salt -pbkdf2 -pass "pass:${PASSWORD}" | \
|
||||
aws s3 cp - "${output_path}" --region "${REGION}"; then
|
||||
log "Qdrant backup completed: ${output_path}"
|
||||
|
||||
# Delete the snapshot from Qdrant
|
||||
if wget -q -O - --method=DELETE \
|
||||
"${QDRANT_URL}/snapshots/${snapshot_name}" >/dev/null 2>&1; then
|
||||
log "Deleted Qdrant snapshot: ${snapshot_name}"
|
||||
else
|
||||
error "Failed to delete Qdrant snapshot: ${snapshot_name}"
|
||||
fi
|
||||
|
||||
cleanup_old_backups "${PREFIX}" "qdrant-"
|
||||
return 0
|
||||
else
|
||||
error "Qdrant backup failed"
|
||||
|
||||
# Try to clean up snapshot
|
||||
wget -q -O - --method=DELETE \
|
||||
"${QDRANT_URL}/snapshots/${snapshot_name}" >/dev/null 2>&1 || true
|
||||
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Main execution
|
||||
main() {
|
||||
log "Database backup started"
|
||||
|
||||
local postgres_result=0
|
||||
local qdrant_result=0
|
||||
|
||||
# Backup Postgres
|
||||
if ! backup_postgres; then
|
||||
postgres_result=1
|
||||
fi
|
||||
|
||||
# Backup Qdrant
|
||||
if ! backup_qdrant; then
|
||||
qdrant_result=1
|
||||
fi
|
||||
|
||||
# Summary
|
||||
if [ $postgres_result -eq 0 ] && [ $qdrant_result -eq 0 ]; then
|
||||
log "All database backups completed successfully"
|
||||
return 0
|
||||
elif [ $postgres_result -ne 0 ] && [ $qdrant_result -ne 0 ]; then
|
||||
error "All database backups failed"
|
||||
return 1
|
||||
else
|
||||
error "Some database backups failed (Postgres: ${postgres_result}, Qdrant: ${qdrant_result})"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main
|
||||
|
||||
@ -51,6 +51,8 @@ from memory.common.celery_app import (
|
||||
UPDATE_METADATA_FOR_SOURCE_ITEMS,
|
||||
SETUP_GIT_NOTES,
|
||||
TRACK_GIT_CHANGES,
|
||||
BACKUP_TO_S3_DIRECTORY,
|
||||
BACKUP_ALL,
|
||||
app,
|
||||
)
|
||||
|
||||
@ -97,6 +99,10 @@ TASK_MAPPINGS = {
|
||||
"setup_git_notes": SETUP_GIT_NOTES,
|
||||
"track_git_changes": TRACK_GIT_CHANGES,
|
||||
},
|
||||
"backup": {
|
||||
"backup_to_s3_directory": BACKUP_TO_S3_DIRECTORY,
|
||||
"backup_all": BACKUP_ALL,
|
||||
},
|
||||
}
|
||||
QUEUE_MAPPINGS = {
|
||||
"email": "email",
|
||||
@ -177,6 +183,28 @@ def execute_task(ctx, category: str, task_name: str, **kwargs):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def backup(ctx):
|
||||
"""Backup-related tasks."""
|
||||
pass
|
||||
|
||||
|
||||
@backup.command("all")
|
||||
@click.pass_context
|
||||
def backup_all(ctx):
|
||||
"""Backup all directories."""
|
||||
execute_task(ctx, "backup", "backup_all")
|
||||
|
||||
|
||||
@backup.command("path")
|
||||
@click.option("--path", required=True, help="Path to backup")
|
||||
@click.pass_context
|
||||
def backup_to_s3_directory(ctx, path):
|
||||
"""Backup a specific path."""
|
||||
execute_task(ctx, "backup", "backup_to_s3_directory", path=path)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def email(ctx):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user