diff --git a/docker-compose.yaml b/docker-compose.yaml index e3aa586..5ef41cb 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -173,7 +173,7 @@ services: <<: *worker-base environment: <<: *worker-env - QUEUES: "email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler" + QUEUES: "backup,email,ebooks,discord,comic,blogs,forums,maintenance,notes,scheduler" ingest-hub: <<: *worker-base diff --git a/docker/workers/Dockerfile b/docker/workers/Dockerfile index c616fe6..b3a0e5c 100644 --- a/docker/workers/Dockerfile +++ b/docker/workers/Dockerfile @@ -16,7 +16,7 @@ RUN apt-get update && apt-get install -y \ COPY requirements ./requirements/ COPY setup.py ./ RUN mkdir src -RUN pip install -e ".[common]" +RUN pip install -e ".[workers]" # Install Python dependencies COPY src/ ./src/ @@ -44,7 +44,7 @@ RUN git config --global user.email "${GIT_USER_EMAIL}" && \ git config --global user.name "${GIT_USER_NAME}" # Default queues to process -ENV QUEUES="ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance" +ENV QUEUES="backup,ebooks,email,discord,comic,blogs,forums,photo_embed,maintenance" ENV PYTHONPATH="/app" ENTRYPOINT ["./entry.sh"] \ No newline at end of file diff --git a/requirements/requirements-common.txt b/requirements/requirements-common.txt index 3683f32..f94da8b 100644 --- a/requirements/requirements-common.txt +++ b/requirements/requirements-common.txt @@ -9,4 +9,4 @@ anthropic==0.69.0 openai==2.3.0 # Pin the httpx version, as newer versions break the anthropic client httpx==0.27.0 -celery[redis,sqs]==5.3.6 +celery[redis,sqs]==5.3.6 \ No newline at end of file diff --git a/requirements/requirements-workers.txt b/requirements/requirements-workers.txt new file mode 100644 index 0000000..7a77524 --- /dev/null +++ b/requirements/requirements-workers.txt @@ -0,0 +1,3 @@ +cryptography==43.0.0 +boto3 +awscli==1.42.64 \ No newline at end of file diff --git a/setup.py b/setup.py index 7e60b99..fed655e 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ parsers_requires = read_requirements("requirements-parsers.txt") api_requires = read_requirements("requirements-api.txt") dev_requires = read_requirements("requirements-dev.txt") ingesters_requires = read_requirements("requirements-ingesters.txt") +workers_requires = read_requirements("requirements-workers.txt") setup( name="memory", @@ -30,10 +31,12 @@ setup( "common": common_requires + parsers_requires, "dev": dev_requires, "ingesters": common_requires + parsers_requires + ingesters_requires, + "workers": common_requires + parsers_requires + workers_requires, "all": api_requires + common_requires + dev_requires + parsers_requires - + ingesters_requires, + + ingesters_requires + + workers_requires, }, ) diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index 03cb3b8..11f33fb 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -13,6 +13,7 @@ NOTES_ROOT = "memory.workers.tasks.notes" OBSERVATIONS_ROOT = "memory.workers.tasks.observations" SCHEDULED_CALLS_ROOT = "memory.workers.tasks.scheduled_calls" DISCORD_ROOT = "memory.workers.tasks.discord" +BACKUP_ROOT = "memory.workers.tasks.backup" ADD_DISCORD_MESSAGE = f"{DISCORD_ROOT}.add_discord_message" EDIT_DISCORD_MESSAGE = f"{DISCORD_ROOT}.edit_discord_message" PROCESS_DISCORD_MESSAGE = f"{DISCORD_ROOT}.process_discord_message" @@ -53,6 +54,10 @@ SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive" EXECUTE_SCHEDULED_CALL = f"{SCHEDULED_CALLS_ROOT}.execute_scheduled_call" RUN_SCHEDULED_CALLS = f"{SCHEDULED_CALLS_ROOT}.run_scheduled_calls" +# Backup tasks +BACKUP_TO_S3_DIRECTORY = f"{BACKUP_ROOT}.backup_to_s3_directory" +BACKUP_ALL = f"{BACKUP_ROOT}.backup_all" + def get_broker_url() -> str: protocol = settings.CELERY_BROKER_TYPE @@ -99,6 +104,7 @@ app.conf.update( f"{SCHEDULED_CALLS_ROOT}.*": { "queue": f"{settings.CELERY_QUEUE_PREFIX}-scheduler" }, + f"{BACKUP_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-backup"}, }, ) diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 5617f5b..1053b26 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -73,9 +73,14 @@ WEBPAGE_STORAGE_DIR = pathlib.Path( NOTES_STORAGE_DIR = pathlib.Path( os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes") ) +PRIVATE_DIRS = [ + EMAIL_STORAGE_DIR, + NOTES_STORAGE_DIR, + PHOTO_STORAGE_DIR, + CHUNK_STORAGE_DIR, +] storage_dirs = [ - FILE_STORAGE_DIR, EBOOK_STORAGE_DIR, EMAIL_STORAGE_DIR, CHUNK_STORAGE_DIR, @@ -197,3 +202,14 @@ DISCORD_COLLECT_BOTS = boolean_env("DISCORD_COLLECT_BOTS", True) DISCORD_COLLECTOR_PORT = int(os.getenv("DISCORD_COLLECTOR_PORT", 8003)) DISCORD_COLLECTOR_SERVER_URL = os.getenv("DISCORD_COLLECTOR_SERVER_URL", "0.0.0.0") DISCORD_CONTEXT_WINDOW = int(os.getenv("DISCORD_CONTEXT_WINDOW", 10)) + + +# S3 Backup settings +S3_BACKUP_BUCKET = os.getenv("S3_BACKUP_BUCKET", "equistamp-memory-backup") +S3_BACKUP_PREFIX = os.getenv("S3_BACKUP_PREFIX", "Daniel") +S3_BACKUP_REGION = os.getenv("S3_BACKUP_REGION", "eu-central-1") +BACKUP_ENCRYPTION_KEY = os.getenv("BACKUP_ENCRYPTION_KEY", "") +S3_BACKUP_ENABLED = boolean_env("S3_BACKUP_ENABLED", bool(BACKUP_ENCRYPTION_KEY)) +S3_BACKUP_INTERVAL = int( + os.getenv("S3_BACKUP_INTERVAL", 60 * 60 * 24) +) # Daily by default diff --git a/src/memory/workers/ingest.py b/src/memory/workers/ingest.py index 2b577df..e79131e 100644 --- a/src/memory/workers/ingest.py +++ b/src/memory/workers/ingest.py @@ -10,6 +10,7 @@ from memory.common.celery_app import ( TRACK_GIT_CHANGES, SYNC_LESSWRONG, RUN_SCHEDULED_CALLS, + BACKUP_TO_S3, ) logger = logging.getLogger(__name__) @@ -48,4 +49,8 @@ app.conf.beat_schedule = { "task": RUN_SCHEDULED_CALLS, "schedule": settings.SCHEDULED_CALL_RUN_INTERVAL, }, + "backup-to-s3": { + "task": BACKUP_TO_S3, + "schedule": settings.S3_BACKUP_INTERVAL, + }, } diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py index 80cc6e9..41c142d 100644 --- a/src/memory/workers/tasks/__init__.py +++ b/src/memory/workers/tasks/__init__.py @@ -3,11 +3,12 @@ Import sub-modules so Celery can register their @app.task decorators. """ from memory.workers.tasks import ( - email, - comic, + backup, blogs, + comic, discord, ebook, + email, forums, maintenance, notes, @@ -15,8 +16,8 @@ from memory.workers.tasks import ( scheduled_calls, ) # noqa - __all__ = [ + "backup", "email", "comic", "blogs", diff --git a/src/memory/workers/tasks/backup.py b/src/memory/workers/tasks/backup.py new file mode 100644 index 0000000..7edda60 --- /dev/null +++ b/src/memory/workers/tasks/backup.py @@ -0,0 +1,152 @@ +"""S3 backup tasks for memory files.""" + +import base64 +import hashlib +import io +import logging +import subprocess +import tarfile +from pathlib import Path + +import boto3 +from cryptography.fernet import Fernet + +from memory.common import settings +from memory.common.celery_app import app, BACKUP_TO_S3_DIRECTORY, BACKUP_ALL + +logger = logging.getLogger(__name__) + + +def get_cipher() -> Fernet: + """Create Fernet cipher from password in settings.""" + if not settings.BACKUP_ENCRYPTION_KEY: + raise ValueError("BACKUP_ENCRYPTION_KEY not set in environment") + + # Derive key from password using SHA256 + key_bytes = hashlib.sha256(settings.BACKUP_ENCRYPTION_KEY.encode()).digest() + key = base64.urlsafe_b64encode(key_bytes) + return Fernet(key) + + +def create_tarball(directory: Path) -> bytes: + """Create a gzipped tarball of a directory in memory.""" + if not directory.exists(): + logger.warning(f"Directory does not exist: {directory}") + return b"" + + tar_buffer = io.BytesIO() + with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar: + tar.add(directory, arcname=directory.name) + + tar_buffer.seek(0) + return tar_buffer.read() + + +def sync_unencrypted_directory(path: Path) -> dict: + """Sync an unencrypted directory to S3 using aws s3 sync.""" + if not path.exists(): + logger.warning(f"Directory does not exist: {path}") + return {"synced": False, "reason": "directory_not_found"} + + s3_uri = f"s3://{settings.S3_BACKUP_BUCKET}/{settings.S3_BACKUP_PREFIX}/{path.name}" + + cmd = [ + "aws", + "s3", + "sync", + str(path), + s3_uri, + "--delete", + "--region", + settings.S3_BACKUP_REGION, + ] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, + ) + logger.info(f"Synced {path} to {s3_uri}") + logger.debug(f"Output: {result.stdout}") + return {"synced": True, "directory": path, "s3_uri": s3_uri} + except subprocess.CalledProcessError as e: + logger.error(f"Failed to sync {path}: {e.stderr}") + return {"synced": False, "directory": path, "error": str(e)} + + +def backup_encrypted_directory(path: Path) -> dict: + """Create encrypted tarball of directory and upload to S3.""" + if not path.exists(): + logger.warning(f"Directory does not exist: {path}") + return {"uploaded": False, "reason": "directory_not_found"} + + # Create tarball + logger.info(f"Creating tarball of {path}...") + tarball_bytes = create_tarball(path) + + if not tarball_bytes: + logger.warning(f"Empty tarball for {path}, skipping") + return {"uploaded": False, "reason": "empty_directory"} + + # Encrypt + logger.info(f"Encrypting {path} ({len(tarball_bytes)} bytes)...") + cipher = get_cipher() + encrypted_bytes = cipher.encrypt(tarball_bytes) + + # Upload to S3 + s3_client = boto3.client("s3", region_name=settings.S3_BACKUP_REGION) + s3_key = f"{settings.S3_BACKUP_PREFIX}/{path.name}.tar.gz.enc" + + try: + logger.info( + f"Uploading encrypted {path} to s3://{settings.S3_BACKUP_BUCKET}/{s3_key}" + ) + s3_client.put_object( + Bucket=settings.S3_BACKUP_BUCKET, + Key=s3_key, + Body=encrypted_bytes, + ServerSideEncryption="AES256", + ) + return { + "uploaded": True, + "directory": path, + "size_bytes": len(encrypted_bytes), + "s3_key": s3_key, + } + except Exception as e: + logger.error(f"Failed to upload {path}: {e}") + return {"uploaded": False, "directory": path, "error": str(e)} + + +@app.task(name=BACKUP_TO_S3_DIRECTORY) +def backup_to_s3(path: Path | str): + """Backup a specific directory to S3.""" + path = Path(path) + + if not path.exists(): + logger.warning(f"Directory does not exist: {path}") + return {"uploaded": False, "reason": "directory_not_found"} + + if path in settings.PRIVATE_DIRS: + return backup_encrypted_directory(path) + return sync_unencrypted_directory(path) + + +@app.task(name=BACKUP_ALL) +def backup_all_to_s3(): + """Main backup task that syncs unencrypted dirs and uploads encrypted dirs.""" + if not settings.S3_BACKUP_ENABLED: + logger.info("S3 backup is disabled") + return {"status": "disabled"} + + logger.info("Starting S3 backup...") + + for dir_name in settings.storage_dirs: + backup_to_s3.delay((settings.FILE_STORAGE_DIR / dir_name).as_posix()) + + return { + "status": "success", + "message": f"Started backup for {len(settings.storage_dirs)} directories", + } diff --git a/tests/memory/workers/tasks/test_backup_tasks.py b/tests/memory/workers/tasks/test_backup_tasks.py new file mode 100644 index 0000000..98ccffd --- /dev/null +++ b/tests/memory/workers/tasks/test_backup_tasks.py @@ -0,0 +1,393 @@ +import io +import subprocess +import tarfile +from unittest.mock import Mock, patch, MagicMock + +import pytest +from botocore.exceptions import ClientError + +from memory.common import settings +from memory.workers.tasks import backup + + +@pytest.fixture +def sample_files(): + """Create sample files in memory_files structure.""" + base = settings.FILE_STORAGE_DIR + + dirs_with_files = { + "emails": ["email1.txt", "email2.txt"], + "notes": ["note1.md", "note2.md"], + "photos": ["photo1.jpg"], + "comics": ["comic1.png", "comic2.png"], + "ebooks": ["book1.epub"], + "webpages": ["page1.html"], + } + + for dir_name, filenames in dirs_with_files.items(): + dir_path = base / dir_name + dir_path.mkdir(parents=True, exist_ok=True) + + for filename in filenames: + file_path = dir_path / filename + content = f"Content of {dir_name}/{filename}\n" * 100 + file_path.write_text(content) + + +@pytest.fixture +def mock_s3_client(): + """Mock boto3 S3 client.""" + with patch("boto3.client") as mock_client: + s3_mock = MagicMock() + mock_client.return_value = s3_mock + yield s3_mock + + +@pytest.fixture +def backup_settings(): + """Mock backup settings.""" + with ( + patch.object(settings, "S3_BACKUP_ENABLED", True), + patch.object(settings, "BACKUP_ENCRYPTION_KEY", "test-password-123"), + patch.object(settings, "S3_BACKUP_BUCKET", "test-bucket"), + patch.object(settings, "S3_BACKUP_PREFIX", "test-prefix"), + patch.object(settings, "S3_BACKUP_REGION", "us-east-1"), + ): + yield + + +@pytest.fixture +def get_test_path(): + """Helper to construct test paths.""" + return lambda dir_name: settings.FILE_STORAGE_DIR / dir_name + + +@pytest.mark.parametrize( + "data,key", + [ + (b"This is a test message", "my-secret-key"), + (b"\x00\x01\x02\xff" * 10000, "another-key"), + (b"x" * 1000000, "large-data-key"), + ], +) +def test_encrypt_decrypt_roundtrip(data, key): + """Test encryption and decryption produces original data.""" + with patch.object(settings, "BACKUP_ENCRYPTION_KEY", key): + cipher = backup.get_cipher() + encrypted = cipher.encrypt(data) + decrypted = cipher.decrypt(encrypted) + + assert decrypted == data + assert encrypted != data + + +def test_encrypt_decrypt_tarball(sample_files): + """Test full tarball creation, encryption, and decryption.""" + emails_dir = settings.FILE_STORAGE_DIR / "emails" + + # Create tarball + tarball_bytes = backup.create_tarball(emails_dir) + assert len(tarball_bytes) > 0 + + # Encrypt + with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "tarball-key"): + cipher = backup.get_cipher() + encrypted = cipher.encrypt(tarball_bytes) + + # Decrypt + decrypted = cipher.decrypt(encrypted) + + assert decrypted == tarball_bytes + + # Verify tarball can be extracted + tar_buffer = io.BytesIO(decrypted) + with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar: + members = tar.getmembers() + assert len(members) >= 2 # At least 2 email files + + # Extract and verify content + for member in members: + if member.isfile(): + extracted = tar.extractfile(member) + assert extracted is not None + content = extracted.read().decode() + assert "Content of emails/" in content + + +def test_different_keys_produce_different_ciphertext(): + """Test that different encryption keys produce different ciphertext.""" + data = b"Same data encrypted with different keys" + + with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key1"): + cipher1 = backup.get_cipher() + encrypted1 = cipher1.encrypt(data) + + with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "key2"): + cipher2 = backup.get_cipher() + encrypted2 = cipher2.encrypt(data) + + assert encrypted1 != encrypted2 + + +def test_missing_encryption_key_raises_error(): + """Test that missing encryption key raises ValueError.""" + with patch.object(settings, "BACKUP_ENCRYPTION_KEY", ""): + with pytest.raises(ValueError, match="BACKUP_ENCRYPTION_KEY not set"): + backup.get_cipher() + + +def test_create_tarball_with_files(sample_files): + """Test creating tarball from directory with files.""" + notes_dir = settings.FILE_STORAGE_DIR / "notes" + tarball_bytes = backup.create_tarball(notes_dir) + + assert len(tarball_bytes) > 0 + + # Verify it's a valid gzipped tarball + tar_buffer = io.BytesIO(tarball_bytes) + with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar: + members = tar.getmembers() + filenames = [m.name for m in members if m.isfile()] + assert len(filenames) >= 2 + assert any("note1.md" in f for f in filenames) + assert any("note2.md" in f for f in filenames) + + +def test_create_tarball_nonexistent_directory(): + """Test creating tarball from nonexistent directory.""" + nonexistent = settings.FILE_STORAGE_DIR / "does_not_exist" + tarball_bytes = backup.create_tarball(nonexistent) + + assert tarball_bytes == b"" + + +def test_create_tarball_empty_directory(): + """Test creating tarball from empty directory.""" + empty_dir = settings.FILE_STORAGE_DIR / "empty" + empty_dir.mkdir(parents=True, exist_ok=True) + + tarball_bytes = backup.create_tarball(empty_dir) + + # Should create tarball with just the directory entry + assert len(tarball_bytes) > 0 + tar_buffer = io.BytesIO(tarball_bytes) + with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar: + members = tar.getmembers() + assert len(members) >= 1 + assert members[0].isdir() + + +def test_sync_unencrypted_success(sample_files, backup_settings): + """Test successful sync of unencrypted directory.""" + with patch("subprocess.run") as mock_run: + mock_run.return_value = Mock(stdout="Synced files", returncode=0) + + comics_path = settings.FILE_STORAGE_DIR / "comics" + result = backup.sync_unencrypted_directory(comics_path) + + assert result["synced"] is True + assert result["directory"] == comics_path + assert "s3_uri" in result + assert "test-bucket" in result["s3_uri"] + assert "test-prefix/comics" in result["s3_uri"] + + # Verify aws s3 sync was called correctly + mock_run.assert_called_once() + call_args = mock_run.call_args[0][0] + assert call_args[0] == "aws" + assert call_args[1] == "s3" + assert call_args[2] == "sync" + assert "--delete" in call_args + assert "--region" in call_args + + +def test_sync_unencrypted_nonexistent_directory(backup_settings): + """Test syncing nonexistent directory.""" + nonexistent_path = settings.FILE_STORAGE_DIR / "does_not_exist" + result = backup.sync_unencrypted_directory(nonexistent_path) + + assert result["synced"] is False + assert result["reason"] == "directory_not_found" + + +def test_sync_unencrypted_aws_cli_failure(sample_files, backup_settings): + """Test handling of AWS CLI failure.""" + with patch("subprocess.run") as mock_run: + mock_run.side_effect = subprocess.CalledProcessError( + 1, "aws", stderr="AWS CLI error" + ) + + comics_path = settings.FILE_STORAGE_DIR / "comics" + result = backup.sync_unencrypted_directory(comics_path) + + assert result["synced"] is False + assert "error" in result + + +def test_backup_encrypted_success( + sample_files, mock_s3_client, backup_settings, get_test_path +): + """Test successful encrypted backup.""" + result = backup.backup_encrypted_directory(get_test_path("emails")) + + assert result["uploaded"] is True + assert result["size_bytes"] > 0 + assert result["s3_key"].endswith("emails.tar.gz.enc") + + call_kwargs = mock_s3_client.put_object.call_args[1] + assert call_kwargs["Bucket"] == "test-bucket" + assert call_kwargs["ServerSideEncryption"] == "AES256" + + +def test_backup_encrypted_nonexistent_directory( + mock_s3_client, backup_settings, get_test_path +): + """Test backing up nonexistent directory.""" + result = backup.backup_encrypted_directory(get_test_path("does_not_exist")) + + assert result["uploaded"] is False + assert result["reason"] == "directory_not_found" + mock_s3_client.put_object.assert_not_called() + + +def test_backup_encrypted_empty_directory( + mock_s3_client, backup_settings, get_test_path +): + """Test backing up empty directory.""" + empty_dir = get_test_path("empty_encrypted") + empty_dir.mkdir(parents=True, exist_ok=True) + + result = backup.backup_encrypted_directory(empty_dir) + assert "uploaded" in result + + +def test_backup_encrypted_s3_failure( + sample_files, mock_s3_client, backup_settings, get_test_path +): + """Test handling of S3 upload failure.""" + mock_s3_client.put_object.side_effect = ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, "PutObject" + ) + + result = backup.backup_encrypted_directory(get_test_path("notes")) + assert result["uploaded"] is False + assert "error" in result + + +def test_backup_encrypted_data_integrity( + sample_files, mock_s3_client, backup_settings, get_test_path +): + """Test that encrypted backup maintains data integrity through full cycle.""" + result = backup.backup_encrypted_directory(get_test_path("notes")) + assert result["uploaded"] is True + + # Decrypt uploaded data + cipher = backup.get_cipher() + encrypted_data = mock_s3_client.put_object.call_args[1]["Body"] + decrypted_tarball = cipher.decrypt(encrypted_data) + + # Verify content + tar_buffer = io.BytesIO(decrypted_tarball) + with tarfile.open(fileobj=tar_buffer, mode="r:gz") as tar: + note1_found = False + for member in tar.getmembers(): + if member.name.endswith("note1.md") and member.isfile(): + content = tar.extractfile(member).read().decode() + assert "Content of notes/note1.md" in content + note1_found = True + assert note1_found, "note1.md not found in tarball" + + +def test_backup_disabled(): + """Test that backup returns early when disabled.""" + with patch.object(settings, "S3_BACKUP_ENABLED", False): + result = backup.backup_all_to_s3() + + assert result["status"] == "disabled" + + +def test_backup_full_execution(sample_files, mock_s3_client, backup_settings): + """Test full backup execution dispatches tasks for all directories.""" + with patch.object(backup, "backup_to_s3") as mock_task: + mock_task.delay = Mock() + + result = backup.backup_all_to_s3() + + assert result["status"] == "success" + assert "message" in result + + # Verify task was queued for each storage directory + assert mock_task.delay.call_count == len(settings.storage_dirs) + + +def test_backup_handles_partial_failures( + sample_files, mock_s3_client, backup_settings, get_test_path +): + """Test that backup continues even if some directories fail.""" + with patch("subprocess.run") as mock_run: + mock_run.side_effect = subprocess.CalledProcessError( + 1, "aws", stderr="Sync failed" + ) + result = backup.sync_unencrypted_directory(get_test_path("comics")) + + assert result["synced"] is False + assert "error" in result + + +def test_same_key_different_runs_different_ciphertext(): + """Test that Fernet produces different ciphertext each run (due to nonce).""" + data = b"Consistent data" + + with patch.object(settings, "BACKUP_ENCRYPTION_KEY", "same-key"): + cipher = backup.get_cipher() + encrypted1 = cipher.encrypt(data) + encrypted2 = cipher.encrypt(data) + + # Should be different due to random nonce, but both should decrypt to same value + assert encrypted1 != encrypted2 + + decrypted1 = cipher.decrypt(encrypted1) + decrypted2 = cipher.decrypt(encrypted2) + assert decrypted1 == decrypted2 == data + + +def test_key_derivation_consistency(): + """Test that same password produces same encryption key.""" + password = "test-password" + + with patch.object(settings, "BACKUP_ENCRYPTION_KEY", password): + cipher1 = backup.get_cipher() + cipher2 = backup.get_cipher() + + # Both should be able to decrypt each other's ciphertext + data = b"Test data" + encrypted = cipher1.encrypt(data) + decrypted = cipher2.decrypt(encrypted) + assert decrypted == data + + +@pytest.mark.parametrize( + "dir_name,is_private", + [ + ("emails", True), + ("notes", True), + ("photos", True), + ("comics", False), + ("ebooks", False), + ("webpages", False), + ("lesswrong", False), + ("chunks", False), + ], +) +def test_directory_encryption_classification(dir_name, is_private, backup_settings): + """Test that directories are correctly classified as encrypted or not.""" + # Create a mock PRIVATE_DIRS list + private_dirs = ["emails", "notes", "photos"] + + with patch.object( + settings, "PRIVATE_DIRS", [settings.FILE_STORAGE_DIR / d for d in private_dirs] + ): + test_path = settings.FILE_STORAGE_DIR / dir_name + is_in_private = test_path in settings.PRIVATE_DIRS + + assert is_in_private == is_private diff --git a/tools/run_celery_task.py b/tools/run_celery_task.py index b5d1f35..9ca9fa0 100644 --- a/tools/run_celery_task.py +++ b/tools/run_celery_task.py @@ -51,6 +51,8 @@ from memory.common.celery_app import ( UPDATE_METADATA_FOR_SOURCE_ITEMS, SETUP_GIT_NOTES, TRACK_GIT_CHANGES, + BACKUP_TO_S3_DIRECTORY, + BACKUP_ALL, app, ) @@ -97,6 +99,10 @@ TASK_MAPPINGS = { "setup_git_notes": SETUP_GIT_NOTES, "track_git_changes": TRACK_GIT_CHANGES, }, + "backup": { + "backup_to_s3_directory": BACKUP_TO_S3_DIRECTORY, + "backup_all": BACKUP_ALL, + }, } QUEUE_MAPPINGS = { "email": "email", @@ -177,6 +183,28 @@ def execute_task(ctx, category: str, task_name: str, **kwargs): sys.exit(1) +@cli.group() +@click.pass_context +def backup(ctx): + """Backup-related tasks.""" + pass + + +@backup.command("all") +@click.pass_context +def backup_all(ctx): + """Backup all directories.""" + execute_task(ctx, "backup", "backup_all") + + +@backup.command("path") +@click.option("--path", required=True, help="Path to backup") +@click.pass_context +def backup_to_s3_directory(ctx, path): + """Backup a specific path.""" + execute_task(ctx, "backup", "backup_to_s3_directory", path=path) + + @cli.group() @click.pass_context def email(ctx):