Fix 11 high-priority bugs from third deep dive

- Add IMAP connection cleanup on logout failure (email.py)
- Handle IntegrityError for concurrent email processing (tasks/email.py)
- Recover stale scheduled calls stuck in "executing" state (scheduled_calls.py)
- Move git operations outside DB transaction in notes sync (notes.py)
- Add null checks for recipient_user/from_user in Discord (discord.py)
- Add OAuth state and session cleanup tasks (maintenance.py)
- Add distributed lock for backup tasks (backup.py)
- Add /tmp storage warning in settings (settings.py)
- Fix health check error exposure (app.py)
- Remove sensitive data from logs (auth.py)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
mruwnik 2025-12-19 22:15:25 +00:00
parent 8251ad1d6e
commit f2161e09f3
11 changed files with 246 additions and 50 deletions

View File

@ -173,7 +173,9 @@ async def health_check(request: Request):
conn.execute(text("SELECT 1")) conn.execute(text("SELECT 1"))
checks["database"] = "healthy" checks["database"] = "healthy"
except Exception as e: except Exception as e:
checks["database"] = f"unhealthy: {str(e)[:100]}" # Log error details but don't expose in response
logger.error(f"Database health check failed: {e}")
checks["database"] = "unhealthy"
all_healthy = False all_healthy = False
# Check Qdrant connection # Check Qdrant connection
@ -184,7 +186,9 @@ async def health_check(request: Request):
client.get_collections() client.get_collections()
checks["qdrant"] = "healthy" checks["qdrant"] = "healthy"
except Exception as e: except Exception as e:
checks["qdrant"] = f"unhealthy: {str(e)[:100]}" # Log error details but don't expose in response
logger.error(f"Qdrant health check failed: {e}")
checks["qdrant"] = "unhealthy"
all_healthy = False all_healthy = False
checks["status"] = "healthy" if all_healthy else "degraded" checks["status"] = "healthy" if all_healthy else "degraded"

View File

@ -176,9 +176,8 @@ async def oauth_callback_discord(request: Request):
state = request.query_params.get("state") state = request.query_params.get("state")
error = request.query_params.get("error") error = request.query_params.get("error")
logger.info( # Log OAuth callback without sensitive data (code/state could be intercepted)
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..." logger.info("Received OAuth callback")
)
message, title, close, status_code = "", "", "", 200 message, title, close, status_code = "", "", "", 200
if error: if error:
@ -269,6 +268,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
logger.debug(f"Authenticated request from user {user.email} to {path}") # Log user ID instead of email for privacy
logger.debug(f"Authenticated request from user_id={user.id} to {path}")
return await call_next(request) return await call_next(request)

View File

@ -44,6 +44,8 @@ UPDATE_METADATA_FOR_SOURCE_ITEMS = (
f"{MAINTENANCE_ROOT}.update_metadata_for_source_items" f"{MAINTENANCE_ROOT}.update_metadata_for_source_items"
) )
UPDATE_METADATA_FOR_ITEM = f"{MAINTENANCE_ROOT}.update_metadata_for_item" UPDATE_METADATA_FOR_ITEM = f"{MAINTENANCE_ROOT}.update_metadata_for_item"
CLEANUP_EXPIRED_OAUTH_STATES = f"{MAINTENANCE_ROOT}.cleanup_expired_oauth_states"
CLEANUP_EXPIRED_SESSIONS = f"{MAINTENANCE_ROOT}.cleanup_expired_sessions"
SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage" SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage"
SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed" SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed"
SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds" SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"

View File

@ -1,9 +1,12 @@
import logging
import os import os
import pathlib import pathlib
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
logger = logging.getLogger(__name__)
def boolean_env(key: str, default: bool = False) -> bool: def boolean_env(key: str, default: bool = False) -> bool:
if key not in os.environ: if key not in os.environ:
@ -96,6 +99,14 @@ storage_dirs = [
for dir in storage_dirs: for dir in storage_dirs:
dir.mkdir(parents=True, exist_ok=True) dir.mkdir(parents=True, exist_ok=True)
# Warn if using default /tmp storage - data will be lost on reboot
if str(FILE_STORAGE_DIR).startswith("/tmp"):
logger.warning(
f"FILE_STORAGE_DIR is set to '{FILE_STORAGE_DIR}' which is a temporary directory. "
"Data stored here may be lost on system reboot. "
"Set FILE_STORAGE_DIR environment variable to a persistent location for production use."
)
# Maximum attachment size to store directly in the database (10MB) # Maximum attachment size to store directly in the database (10MB)
MAX_INLINE_ATTACHMENT_SIZE = int( MAX_INLINE_ATTACHMENT_SIZE = int(
os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024) os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024)

View File

@ -262,6 +262,11 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
conn.logout() conn.logout()
except Exception as e: except Exception as e:
logger.error(f"Error logging out from {account.imap_server}: {str(e)}") logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
# If logout fails, explicitly close the socket to prevent resource leak
try:
conn.shutdown()
except Exception:
pass # Socket may already be closed
def vectorize_email(email: MailMessage): def vectorize_email(email: MailMessage):

View File

@ -6,9 +6,11 @@ import io
import logging import logging
import subprocess import subprocess
import tarfile import tarfile
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
import boto3 import boto3
import redis
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from memory.common import settings from memory.common import settings
@ -16,6 +18,31 @@ from memory.common.celery_app import app, BACKUP_PATH, BACKUP_ALL
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Backup lock timeout (30 minutes max for backup to complete)
BACKUP_LOCK_TIMEOUT = 30 * 60
@contextmanager
def backup_lock(lock_name: str = "backup_all"):
"""Acquire a distributed lock for backup operations using Redis.
Prevents concurrent backup operations which could cause resource
contention and inconsistent state.
"""
redis_client = redis.from_url(settings.REDIS_URL)
lock_key = f"memory:lock:{lock_name}"
# Try to acquire lock with NX (only if not exists) and expiry
acquired = redis_client.set(lock_key, "1", nx=True, ex=BACKUP_LOCK_TIMEOUT)
if not acquired:
raise RuntimeError(f"Could not acquire backup lock '{lock_name}' - backup already in progress")
try:
yield
finally:
# Release the lock
redis_client.delete(lock_key)
def get_cipher() -> Fernet: def get_cipher() -> Fernet:
"""Create Fernet cipher from password in settings.""" """Create Fernet cipher from password in settings."""
@ -136,11 +163,16 @@ def backup_to_s3(path: Path | str):
@app.task(name=BACKUP_ALL) @app.task(name=BACKUP_ALL)
def backup_all_to_s3(): def backup_all_to_s3():
"""Main backup task that syncs unencrypted dirs and uploads encrypted dirs.""" """Main backup task that syncs unencrypted dirs and uploads encrypted dirs.
Uses a distributed lock to prevent concurrent backup operations.
"""
if not settings.S3_BACKUP_ENABLED: if not settings.S3_BACKUP_ENABLED:
logger.info("S3 backup is disabled") logger.info("S3 backup is disabled")
return {"status": "disabled"} return {"status": "disabled"}
try:
with backup_lock():
logger.info("Starting S3 backup...") logger.info("Starting S3 backup...")
for dir_name in settings.storage_dirs: for dir_name in settings.storage_dirs:
@ -150,3 +182,6 @@ def backup_all_to_s3():
"status": "success", "status": "success",
"message": f"Started backup for {len(settings.storage_dirs)} directories", "message": f"Started backup for {len(settings.storage_dirs)} directories",
} }
except RuntimeError as e:
logger.warning(str(e))
return {"status": "skipped", "reason": str(e)}

View File

@ -196,6 +196,29 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
"message_id": message_id, "message_id": message_id,
} }
# Validate required relationships exist before processing
if not discord_message.recipient_user:
logger.warning(
"No recipient_user for message %s; skipping processing",
message_id,
)
return {
"status": "error",
"error": "No recipient user",
"message_id": message_id,
}
if not discord_message.from_user:
logger.warning(
"No from_user for message %s; skipping processing",
message_id,
)
return {
"status": "error",
"error": "No sender user",
"message_id": message_id,
}
mcp_servers = discord_message.get_mcp_servers(session) mcp_servers = discord_message.get_mcp_servers(session)
system_prompt = discord_message.system_prompt or "" system_prompt = discord_message.system_prompt or ""
system_prompt += comm_channel_prompt( system_prompt += comm_channel_prompt(

View File

@ -1,6 +1,9 @@
import logging import logging
from datetime import datetime from datetime import datetime
from typing import cast from typing import cast
from sqlalchemy.exc import IntegrityError
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import EmailAccount, MailMessage from memory.common.db.models import EmailAccount, MailMessage
from memory.common.celery_app import app, PROCESS_EMAIL, SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS from memory.common.celery_app import app, PROCESS_EMAIL, SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS
@ -44,6 +47,7 @@ def process_message(
logger.warning(f"Empty email message received for account {account_id}") logger.warning(f"Empty email message received for account {account_id}")
return {"status": "skipped", "reason": "empty_content"} return {"status": "skipped", "reason": "empty_content"}
try:
with make_session() as db: with make_session() as db:
account = db.get(EmailAccount, account_id) account = db.get(EmailAccount, account_id)
if not account: if not account:
@ -79,6 +83,10 @@ def process_message(
"chunks_count": len(mail_message.chunks), "chunks_count": len(mail_message.chunks),
"attachments_count": len(mail_message.attachments), "attachments_count": len(mail_message.attachments),
} }
except IntegrityError:
# Another worker already processed this message (race condition)
logger.info(f"Message {message_id} already exists (concurrent insert)")
return {"status": "already_exists", "message_id": message_id}
@app.task(name=SYNC_ACCOUNT) @app.task(name=SYNC_ACCOUNT)

View File

@ -15,6 +15,8 @@ from memory.common.celery_app import (
app, app,
CLEAN_ALL_COLLECTIONS, CLEAN_ALL_COLLECTIONS,
CLEAN_COLLECTION, CLEAN_COLLECTION,
CLEANUP_EXPIRED_OAUTH_STATES,
CLEANUP_EXPIRED_SESSIONS,
REINGEST_MISSING_CHUNKS, REINGEST_MISSING_CHUNKS,
REINGEST_CHUNK, REINGEST_CHUNK,
REINGEST_ITEM, REINGEST_ITEM,
@ -338,3 +340,66 @@ def update_metadata_for_source_items(item_type: str):
update_metadata_for_item.delay(item_id.id, item_type) # type: ignore update_metadata_for_item.delay(item_id.id, item_type) # type: ignore
return {"status": "success", "items": len(item_ids)} return {"status": "success", "items": len(item_ids)}
@app.task(name=CLEANUP_EXPIRED_OAUTH_STATES)
def cleanup_expired_oauth_states(max_age_hours: int = 1):
"""Clean up OAuth states that are older than max_age_hours.
OAuth states should be short-lived - they're only needed during the
authorization flow which should complete within minutes.
"""
from memory.common.db.models import MCPServer
logger.info(f"Cleaning up OAuth states older than {max_age_hours} hours")
cutoff = datetime.now() - timedelta(hours=max_age_hours)
cleaned = 0
with make_session() as session:
# Find MCP servers with stale OAuth state
stale_servers = (
session.query(MCPServer)
.filter(
MCPServer.state.isnot(None),
MCPServer.updated_at < cutoff,
)
.all()
)
for server in stale_servers:
# Clear the temporary OAuth state fields
server.state = None
server.code_verifier = None
cleaned += 1
session.commit()
logger.info(f"Cleaned up {cleaned} expired OAuth states")
return {"cleaned": cleaned}
@app.task(name=CLEANUP_EXPIRED_SESSIONS)
def cleanup_expired_sessions():
"""Clean up expired user sessions from the database."""
from memory.common.db.models import UserSession
from datetime import timezone
logger.info("Cleaning up expired user sessions")
now = datetime.now(timezone.utc)
deleted = 0
with make_session() as session:
expired_sessions = (
session.query(UserSession).filter(UserSession.expires_at < now).all()
)
for user_session in expired_sessions:
session.delete(user_session)
deleted += 1
session.commit()
logger.info(f"Deleted {deleted} expired user sessions")
return {"deleted": deleted}

View File

@ -123,12 +123,23 @@ def sync_note(
note.tags = tags # type: ignore note.tags = tags # type: ignore
note.update_confidences(confidences) note.update_confidences(confidences)
# Process the content item first (commits transaction)
result = process_content_item(note, session)
# Git operations MUST be outside the database transaction to avoid
# holding the connection during slow network I/O
if save_to_file: if save_to_file:
with git_tracking( with git_tracking(
settings.NOTES_STORAGE_DIR, f"Sync note {filename}: {subject}" settings.NOTES_STORAGE_DIR, f"Sync note {filename}: {subject}"
): ):
# Re-fetch note for file operations (session is closed)
with make_session() as session:
note = session.get(Note, result.get(f"note_id"))
if note:
note.save_to_file() note.save_to_file()
return process_content_item(note, session)
return result
@app.task(name=SYNC_NOTES) @app.task(name=SYNC_NOTES)

View File

@ -1,5 +1,5 @@
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timedelta, timezone
from typing import cast from typing import cast
from memory.common import settings from memory.common import settings
@ -15,6 +15,10 @@ from memory.workers.tasks.content_processing import safe_task_execution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Maximum time a task can be in "executing" state before being considered stale
# Should be longer than task_time_limit (1 hour) to allow for legitimate long-running tasks
STALE_EXECUTION_TIMEOUT_HOURS = 2
def call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None: def call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
"""Call LLM with tools support for scheduled calls.""" """Call LLM with tools support for scheduled calls."""
@ -147,8 +151,36 @@ def run_scheduled_calls():
Uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions when Uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions when
multiple workers query for due calls simultaneously. multiple workers query for due calls simultaneously.
Also recovers stale "executing" tasks that were abandoned due to worker crashes.
""" """
with make_session() as session: with make_session() as session:
# First, recover stale "executing" tasks that have been stuck too long
# This handles cases where workers crashed mid-execution
stale_cutoff = datetime.now(timezone.utc) - timedelta(hours=STALE_EXECUTION_TIMEOUT_HOURS)
stale_calls = (
session.query(ScheduledLLMCall)
.filter(
ScheduledLLMCall.status == "executing",
ScheduledLLMCall.executed_at < stale_cutoff,
)
.with_for_update(skip_locked=True)
.all()
)
for stale_call in stale_calls:
logger.warning(
f"Recovering stale scheduled call {stale_call.id} "
f"(stuck in executing since {stale_call.executed_at})"
)
stale_call.status = "pending"
stale_call.executed_at = None
stale_call.error_message = "Recovered from stale execution state"
if stale_calls:
session.commit()
logger.info(f"Recovered {len(stale_calls)} stale scheduled calls")
# Use FOR UPDATE SKIP LOCKED to atomically claim pending calls # Use FOR UPDATE SKIP LOCKED to atomically claim pending calls
# This prevents multiple workers from processing the same call # This prevents multiple workers from processing the same call
# #