diff --git a/src/memory/api/app.py b/src/memory/api/app.py index d56325d..959f276 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -173,7 +173,9 @@ async def health_check(request: Request): conn.execute(text("SELECT 1")) checks["database"] = "healthy" except Exception as e: - checks["database"] = f"unhealthy: {str(e)[:100]}" + # Log error details but don't expose in response + logger.error(f"Database health check failed: {e}") + checks["database"] = "unhealthy" all_healthy = False # Check Qdrant connection @@ -184,7 +186,9 @@ async def health_check(request: Request): client.get_collections() checks["qdrant"] = "healthy" except Exception as e: - checks["qdrant"] = f"unhealthy: {str(e)[:100]}" + # Log error details but don't expose in response + logger.error(f"Qdrant health check failed: {e}") + checks["qdrant"] = "unhealthy" all_healthy = False checks["status"] = "healthy" if all_healthy else "degraded" diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py index 4ac8cd2..0f57109 100644 --- a/src/memory/api/auth.py +++ b/src/memory/api/auth.py @@ -176,9 +176,8 @@ async def oauth_callback_discord(request: Request): state = request.query_params.get("state") error = request.query_params.get("error") - logger.info( - f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..." - ) + # Log OAuth callback without sensitive data (code/state could be intercepted) + logger.info("Received OAuth callback") message, title, close, status_code = "", "", "", 200 if error: @@ -269,6 +268,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware): headers={"WWW-Authenticate": "Bearer"}, ) - logger.debug(f"Authenticated request from user {user.email} to {path}") + # Log user ID instead of email for privacy + logger.debug(f"Authenticated request from user_id={user.id} to {path}") return await call_next(request) diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index 37ed296..fd97a42 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -44,6 +44,8 @@ UPDATE_METADATA_FOR_SOURCE_ITEMS = ( f"{MAINTENANCE_ROOT}.update_metadata_for_source_items" ) UPDATE_METADATA_FOR_ITEM = f"{MAINTENANCE_ROOT}.update_metadata_for_item" +CLEANUP_EXPIRED_OAUTH_STATES = f"{MAINTENANCE_ROOT}.cleanup_expired_oauth_states" +CLEANUP_EXPIRED_SESSIONS = f"{MAINTENANCE_ROOT}.cleanup_expired_sessions" SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage" SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed" SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds" diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 0e42f72..dce52f0 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -1,9 +1,12 @@ +import logging import os import pathlib from dotenv import load_dotenv load_dotenv() +logger = logging.getLogger(__name__) + def boolean_env(key: str, default: bool = False) -> bool: if key not in os.environ: @@ -96,6 +99,14 @@ storage_dirs = [ for dir in storage_dirs: dir.mkdir(parents=True, exist_ok=True) +# Warn if using default /tmp storage - data will be lost on reboot +if str(FILE_STORAGE_DIR).startswith("/tmp"): + logger.warning( + f"FILE_STORAGE_DIR is set to '{FILE_STORAGE_DIR}' which is a temporary directory. " + "Data stored here may be lost on system reboot. " + "Set FILE_STORAGE_DIR environment variable to a persistent location for production use." + ) + # Maximum attachment size to store directly in the database (10MB) MAX_INLINE_ATTACHMENT_SIZE = int( os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024) diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index c345def..a4ee3c7 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -262,6 +262,11 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, conn.logout() except Exception as e: logger.error(f"Error logging out from {account.imap_server}: {str(e)}") + # If logout fails, explicitly close the socket to prevent resource leak + try: + conn.shutdown() + except Exception: + pass # Socket may already be closed def vectorize_email(email: MailMessage): diff --git a/src/memory/workers/tasks/backup.py b/src/memory/workers/tasks/backup.py index 2dffcb2..5695a93 100644 --- a/src/memory/workers/tasks/backup.py +++ b/src/memory/workers/tasks/backup.py @@ -6,9 +6,11 @@ import io import logging import subprocess import tarfile +from contextlib import contextmanager from pathlib import Path import boto3 +import redis from cryptography.fernet import Fernet from memory.common import settings @@ -16,6 +18,31 @@ from memory.common.celery_app import app, BACKUP_PATH, BACKUP_ALL logger = logging.getLogger(__name__) +# Backup lock timeout (30 minutes max for backup to complete) +BACKUP_LOCK_TIMEOUT = 30 * 60 + + +@contextmanager +def backup_lock(lock_name: str = "backup_all"): + """Acquire a distributed lock for backup operations using Redis. + + Prevents concurrent backup operations which could cause resource + contention and inconsistent state. + """ + redis_client = redis.from_url(settings.REDIS_URL) + lock_key = f"memory:lock:{lock_name}" + + # Try to acquire lock with NX (only if not exists) and expiry + acquired = redis_client.set(lock_key, "1", nx=True, ex=BACKUP_LOCK_TIMEOUT) + if not acquired: + raise RuntimeError(f"Could not acquire backup lock '{lock_name}' - backup already in progress") + + try: + yield + finally: + # Release the lock + redis_client.delete(lock_key) + def get_cipher() -> Fernet: """Create Fernet cipher from password in settings.""" @@ -136,17 +163,25 @@ def backup_to_s3(path: Path | str): @app.task(name=BACKUP_ALL) def backup_all_to_s3(): - """Main backup task that syncs unencrypted dirs and uploads encrypted dirs.""" + """Main backup task that syncs unencrypted dirs and uploads encrypted dirs. + + Uses a distributed lock to prevent concurrent backup operations. + """ if not settings.S3_BACKUP_ENABLED: logger.info("S3 backup is disabled") return {"status": "disabled"} - logger.info("Starting S3 backup...") + try: + with backup_lock(): + logger.info("Starting S3 backup...") - for dir_name in settings.storage_dirs: - backup_to_s3.delay((settings.FILE_STORAGE_DIR / dir_name).as_posix()) + for dir_name in settings.storage_dirs: + backup_to_s3.delay((settings.FILE_STORAGE_DIR / dir_name).as_posix()) - return { - "status": "success", - "message": f"Started backup for {len(settings.storage_dirs)} directories", - } + return { + "status": "success", + "message": f"Started backup for {len(settings.storage_dirs)} directories", + } + except RuntimeError as e: + logger.warning(str(e)) + return {"status": "skipped", "reason": str(e)} diff --git a/src/memory/workers/tasks/discord.py b/src/memory/workers/tasks/discord.py index 65290c4..74db3d9 100644 --- a/src/memory/workers/tasks/discord.py +++ b/src/memory/workers/tasks/discord.py @@ -196,6 +196,29 @@ def process_discord_message(message_id: int) -> dict[str, Any]: "message_id": message_id, } + # Validate required relationships exist before processing + if not discord_message.recipient_user: + logger.warning( + "No recipient_user for message %s; skipping processing", + message_id, + ) + return { + "status": "error", + "error": "No recipient user", + "message_id": message_id, + } + + if not discord_message.from_user: + logger.warning( + "No from_user for message %s; skipping processing", + message_id, + ) + return { + "status": "error", + "error": "No sender user", + "message_id": message_id, + } + mcp_servers = discord_message.get_mcp_servers(session) system_prompt = discord_message.system_prompt or "" system_prompt += comm_channel_prompt( diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py index 7d56f2a..e2fd3dc 100644 --- a/src/memory/workers/tasks/email.py +++ b/src/memory/workers/tasks/email.py @@ -1,6 +1,9 @@ import logging from datetime import datetime from typing import cast + +from sqlalchemy.exc import IntegrityError + from memory.common.db.connection import make_session from memory.common.db.models import EmailAccount, MailMessage from memory.common.celery_app import app, PROCESS_EMAIL, SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS @@ -44,41 +47,46 @@ def process_message( logger.warning(f"Empty email message received for account {account_id}") return {"status": "skipped", "reason": "empty_content"} - with make_session() as db: - account = db.get(EmailAccount, account_id) - if not account: - logger.error(f"Account {account_id} not found") - return {"status": "error", "error": "Account not found"} + try: + with make_session() as db: + account = db.get(EmailAccount, account_id) + if not account: + logger.error(f"Account {account_id} not found") + return {"status": "error", "error": "Account not found"} - parsed_email = parse_email_message(raw_email, message_id) - if check_content_exists( - db, MailMessage, message_id=message_id, sha256=parsed_email["hash"] - ): - return {"status": "already_exists", "message_id": message_id} + parsed_email = parse_email_message(raw_email, message_id) + if check_content_exists( + db, MailMessage, message_id=message_id, sha256=parsed_email["hash"] + ): + return {"status": "already_exists", "message_id": message_id} - mail_message = create_mail_message(db, account.tags, folder, parsed_email) + mail_message = create_mail_message(db, account.tags, folder, parsed_email) - db.flush() - vectorize_email(mail_message) + db.flush() + vectorize_email(mail_message) - db.commit() + db.commit() - logger.info(f"Stored embedding for message {mail_message.message_id}") - logger.info("Chunks:") - for chunk in mail_message.chunks: - logger.info(f" - {chunk.id}") - for attachment in mail_message.attachments: - logger.info(f" - Attachment {attachment.id}") - for chunk in attachment.chunks: - logger.info(f" - {chunk.id}") + logger.info(f"Stored embedding for message {mail_message.message_id}") + logger.info("Chunks:") + for chunk in mail_message.chunks: + logger.info(f" - {chunk.id}") + for attachment in mail_message.attachments: + logger.info(f" - Attachment {attachment.id}") + for chunk in attachment.chunks: + logger.info(f" - {chunk.id}") - return { - "status": "processed", - "mail_message_id": cast(int, mail_message.id), - "message_id": message_id, - "chunks_count": len(mail_message.chunks), - "attachments_count": len(mail_message.attachments), - } + return { + "status": "processed", + "mail_message_id": cast(int, mail_message.id), + "message_id": message_id, + "chunks_count": len(mail_message.chunks), + "attachments_count": len(mail_message.attachments), + } + except IntegrityError: + # Another worker already processed this message (race condition) + logger.info(f"Message {message_id} already exists (concurrent insert)") + return {"status": "already_exists", "message_id": message_id} @app.task(name=SYNC_ACCOUNT) diff --git a/src/memory/workers/tasks/maintenance.py b/src/memory/workers/tasks/maintenance.py index cbf9e3d..897ec8d 100644 --- a/src/memory/workers/tasks/maintenance.py +++ b/src/memory/workers/tasks/maintenance.py @@ -15,6 +15,8 @@ from memory.common.celery_app import ( app, CLEAN_ALL_COLLECTIONS, CLEAN_COLLECTION, + CLEANUP_EXPIRED_OAUTH_STATES, + CLEANUP_EXPIRED_SESSIONS, REINGEST_MISSING_CHUNKS, REINGEST_CHUNK, REINGEST_ITEM, @@ -338,3 +340,66 @@ def update_metadata_for_source_items(item_type: str): update_metadata_for_item.delay(item_id.id, item_type) # type: ignore return {"status": "success", "items": len(item_ids)} + + +@app.task(name=CLEANUP_EXPIRED_OAUTH_STATES) +def cleanup_expired_oauth_states(max_age_hours: int = 1): + """Clean up OAuth states that are older than max_age_hours. + + OAuth states should be short-lived - they're only needed during the + authorization flow which should complete within minutes. + """ + from memory.common.db.models import MCPServer + + logger.info(f"Cleaning up OAuth states older than {max_age_hours} hours") + + cutoff = datetime.now() - timedelta(hours=max_age_hours) + cleaned = 0 + + with make_session() as session: + # Find MCP servers with stale OAuth state + stale_servers = ( + session.query(MCPServer) + .filter( + MCPServer.state.isnot(None), + MCPServer.updated_at < cutoff, + ) + .all() + ) + + for server in stale_servers: + # Clear the temporary OAuth state fields + server.state = None + server.code_verifier = None + cleaned += 1 + + session.commit() + + logger.info(f"Cleaned up {cleaned} expired OAuth states") + return {"cleaned": cleaned} + + +@app.task(name=CLEANUP_EXPIRED_SESSIONS) +def cleanup_expired_sessions(): + """Clean up expired user sessions from the database.""" + from memory.common.db.models import UserSession + from datetime import timezone + + logger.info("Cleaning up expired user sessions") + + now = datetime.now(timezone.utc) + deleted = 0 + + with make_session() as session: + expired_sessions = ( + session.query(UserSession).filter(UserSession.expires_at < now).all() + ) + + for user_session in expired_sessions: + session.delete(user_session) + deleted += 1 + + session.commit() + + logger.info(f"Deleted {deleted} expired user sessions") + return {"deleted": deleted} diff --git a/src/memory/workers/tasks/notes.py b/src/memory/workers/tasks/notes.py index e6e9921..de55753 100644 --- a/src/memory/workers/tasks/notes.py +++ b/src/memory/workers/tasks/notes.py @@ -123,12 +123,23 @@ def sync_note( note.tags = tags # type: ignore note.update_confidences(confidences) - if save_to_file: - with git_tracking( - settings.NOTES_STORAGE_DIR, f"Sync note {filename}: {subject}" - ): - note.save_to_file() - return process_content_item(note, session) + + # Process the content item first (commits transaction) + result = process_content_item(note, session) + + # Git operations MUST be outside the database transaction to avoid + # holding the connection during slow network I/O + if save_to_file: + with git_tracking( + settings.NOTES_STORAGE_DIR, f"Sync note {filename}: {subject}" + ): + # Re-fetch note for file operations (session is closed) + with make_session() as session: + note = session.get(Note, result.get(f"note_id")) + if note: + note.save_to_file() + + return result @app.task(name=SYNC_NOTES) diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py index 83cb386..b7d306d 100644 --- a/src/memory/workers/tasks/scheduled_calls.py +++ b/src/memory/workers/tasks/scheduled_calls.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import cast from memory.common import settings @@ -15,6 +15,10 @@ from memory.workers.tasks.content_processing import safe_task_execution logger = logging.getLogger(__name__) +# Maximum time a task can be in "executing" state before being considered stale +# Should be longer than task_time_limit (1 hour) to allow for legitimate long-running tasks +STALE_EXECUTION_TIMEOUT_HOURS = 2 + def call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None: """Call LLM with tools support for scheduled calls.""" @@ -147,8 +151,36 @@ def run_scheduled_calls(): Uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions when multiple workers query for due calls simultaneously. + + Also recovers stale "executing" tasks that were abandoned due to worker crashes. """ with make_session() as session: + # First, recover stale "executing" tasks that have been stuck too long + # This handles cases where workers crashed mid-execution + stale_cutoff = datetime.now(timezone.utc) - timedelta(hours=STALE_EXECUTION_TIMEOUT_HOURS) + stale_calls = ( + session.query(ScheduledLLMCall) + .filter( + ScheduledLLMCall.status == "executing", + ScheduledLLMCall.executed_at < stale_cutoff, + ) + .with_for_update(skip_locked=True) + .all() + ) + + for stale_call in stale_calls: + logger.warning( + f"Recovering stale scheduled call {stale_call.id} " + f"(stuck in executing since {stale_call.executed_at})" + ) + stale_call.status = "pending" + stale_call.executed_at = None + stale_call.error_message = "Recovered from stale execution state" + + if stale_calls: + session.commit() + logger.info(f"Recovered {len(stale_calls)} stale scheduled calls") + # Use FOR UPDATE SKIP LOCKED to atomically claim pending calls # This prevents multiple workers from processing the same call #