mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Fix 11 high-priority bugs from third deep dive
- Add IMAP connection cleanup on logout failure (email.py) - Handle IntegrityError for concurrent email processing (tasks/email.py) - Recover stale scheduled calls stuck in "executing" state (scheduled_calls.py) - Move git operations outside DB transaction in notes sync (notes.py) - Add null checks for recipient_user/from_user in Discord (discord.py) - Add OAuth state and session cleanup tasks (maintenance.py) - Add distributed lock for backup tasks (backup.py) - Add /tmp storage warning in settings (settings.py) - Fix health check error exposure (app.py) - Remove sensitive data from logs (auth.py) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
8251ad1d6e
commit
f2161e09f3
@ -173,7 +173,9 @@ async def health_check(request: Request):
|
||||
conn.execute(text("SELECT 1"))
|
||||
checks["database"] = "healthy"
|
||||
except Exception as e:
|
||||
checks["database"] = f"unhealthy: {str(e)[:100]}"
|
||||
# Log error details but don't expose in response
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
checks["database"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
# Check Qdrant connection
|
||||
@ -184,7 +186,9 @@ async def health_check(request: Request):
|
||||
client.get_collections()
|
||||
checks["qdrant"] = "healthy"
|
||||
except Exception as e:
|
||||
checks["qdrant"] = f"unhealthy: {str(e)[:100]}"
|
||||
# Log error details but don't expose in response
|
||||
logger.error(f"Qdrant health check failed: {e}")
|
||||
checks["qdrant"] = "unhealthy"
|
||||
all_healthy = False
|
||||
|
||||
checks["status"] = "healthy" if all_healthy else "degraded"
|
||||
|
||||
@ -176,9 +176,8 @@ async def oauth_callback_discord(request: Request):
|
||||
state = request.query_params.get("state")
|
||||
error = request.query_params.get("error")
|
||||
|
||||
logger.info(
|
||||
f"Received OAuth callback: code={code and code[:20]}..., state={state and state[:20]}..."
|
||||
)
|
||||
# Log OAuth callback without sensitive data (code/state could be intercepted)
|
||||
logger.info("Received OAuth callback")
|
||||
|
||||
message, title, close, status_code = "", "", "", 200
|
||||
if error:
|
||||
@ -269,6 +268,7 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
logger.debug(f"Authenticated request from user {user.email} to {path}")
|
||||
# Log user ID instead of email for privacy
|
||||
logger.debug(f"Authenticated request from user_id={user.id} to {path}")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@ -44,6 +44,8 @@ UPDATE_METADATA_FOR_SOURCE_ITEMS = (
|
||||
f"{MAINTENANCE_ROOT}.update_metadata_for_source_items"
|
||||
)
|
||||
UPDATE_METADATA_FOR_ITEM = f"{MAINTENANCE_ROOT}.update_metadata_for_item"
|
||||
CLEANUP_EXPIRED_OAUTH_STATES = f"{MAINTENANCE_ROOT}.cleanup_expired_oauth_states"
|
||||
CLEANUP_EXPIRED_SESSIONS = f"{MAINTENANCE_ROOT}.cleanup_expired_sessions"
|
||||
SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage"
|
||||
SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed"
|
||||
SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def boolean_env(key: str, default: bool = False) -> bool:
|
||||
if key not in os.environ:
|
||||
@ -96,6 +99,14 @@ storage_dirs = [
|
||||
for dir in storage_dirs:
|
||||
dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Warn if using default /tmp storage - data will be lost on reboot
|
||||
if str(FILE_STORAGE_DIR).startswith("/tmp"):
|
||||
logger.warning(
|
||||
f"FILE_STORAGE_DIR is set to '{FILE_STORAGE_DIR}' which is a temporary directory. "
|
||||
"Data stored here may be lost on system reboot. "
|
||||
"Set FILE_STORAGE_DIR environment variable to a persistent location for production use."
|
||||
)
|
||||
|
||||
# Maximum attachment size to store directly in the database (10MB)
|
||||
MAX_INLINE_ATTACHMENT_SIZE = int(
|
||||
os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024)
|
||||
|
||||
@ -262,6 +262,11 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
||||
conn.logout()
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
|
||||
# If logout fails, explicitly close the socket to prevent resource leak
|
||||
try:
|
||||
conn.shutdown()
|
||||
except Exception:
|
||||
pass # Socket may already be closed
|
||||
|
||||
|
||||
def vectorize_email(email: MailMessage):
|
||||
|
||||
@ -6,9 +6,11 @@ import io
|
||||
import logging
|
||||
import subprocess
|
||||
import tarfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import boto3
|
||||
import redis
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from memory.common import settings
|
||||
@ -16,6 +18,31 @@ from memory.common.celery_app import app, BACKUP_PATH, BACKUP_ALL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Backup lock timeout (30 minutes max for backup to complete)
|
||||
BACKUP_LOCK_TIMEOUT = 30 * 60
|
||||
|
||||
|
||||
@contextmanager
|
||||
def backup_lock(lock_name: str = "backup_all"):
|
||||
"""Acquire a distributed lock for backup operations using Redis.
|
||||
|
||||
Prevents concurrent backup operations which could cause resource
|
||||
contention and inconsistent state.
|
||||
"""
|
||||
redis_client = redis.from_url(settings.REDIS_URL)
|
||||
lock_key = f"memory:lock:{lock_name}"
|
||||
|
||||
# Try to acquire lock with NX (only if not exists) and expiry
|
||||
acquired = redis_client.set(lock_key, "1", nx=True, ex=BACKUP_LOCK_TIMEOUT)
|
||||
if not acquired:
|
||||
raise RuntimeError(f"Could not acquire backup lock '{lock_name}' - backup already in progress")
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Release the lock
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
|
||||
def get_cipher() -> Fernet:
|
||||
"""Create Fernet cipher from password in settings."""
|
||||
@ -136,17 +163,25 @@ def backup_to_s3(path: Path | str):
|
||||
|
||||
@app.task(name=BACKUP_ALL)
|
||||
def backup_all_to_s3():
|
||||
"""Main backup task that syncs unencrypted dirs and uploads encrypted dirs."""
|
||||
"""Main backup task that syncs unencrypted dirs and uploads encrypted dirs.
|
||||
|
||||
Uses a distributed lock to prevent concurrent backup operations.
|
||||
"""
|
||||
if not settings.S3_BACKUP_ENABLED:
|
||||
logger.info("S3 backup is disabled")
|
||||
return {"status": "disabled"}
|
||||
|
||||
logger.info("Starting S3 backup...")
|
||||
try:
|
||||
with backup_lock():
|
||||
logger.info("Starting S3 backup...")
|
||||
|
||||
for dir_name in settings.storage_dirs:
|
||||
backup_to_s3.delay((settings.FILE_STORAGE_DIR / dir_name).as_posix())
|
||||
for dir_name in settings.storage_dirs:
|
||||
backup_to_s3.delay((settings.FILE_STORAGE_DIR / dir_name).as_posix())
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Started backup for {len(settings.storage_dirs)} directories",
|
||||
}
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Started backup for {len(settings.storage_dirs)} directories",
|
||||
}
|
||||
except RuntimeError as e:
|
||||
logger.warning(str(e))
|
||||
return {"status": "skipped", "reason": str(e)}
|
||||
|
||||
@ -196,6 +196,29 @@ def process_discord_message(message_id: int) -> dict[str, Any]:
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
# Validate required relationships exist before processing
|
||||
if not discord_message.recipient_user:
|
||||
logger.warning(
|
||||
"No recipient_user for message %s; skipping processing",
|
||||
message_id,
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "No recipient user",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
if not discord_message.from_user:
|
||||
logger.warning(
|
||||
"No from_user for message %s; skipping processing",
|
||||
message_id,
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "No sender user",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
mcp_servers = discord_message.get_mcp_servers(session)
|
||||
system_prompt = discord_message.system_prompt or ""
|
||||
system_prompt += comm_channel_prompt(
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import EmailAccount, MailMessage
|
||||
from memory.common.celery_app import app, PROCESS_EMAIL, SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS
|
||||
@ -44,41 +47,46 @@ def process_message(
|
||||
logger.warning(f"Empty email message received for account {account_id}")
|
||||
return {"status": "skipped", "reason": "empty_content"}
|
||||
|
||||
with make_session() as db:
|
||||
account = db.get(EmailAccount, account_id)
|
||||
if not account:
|
||||
logger.error(f"Account {account_id} not found")
|
||||
return {"status": "error", "error": "Account not found"}
|
||||
try:
|
||||
with make_session() as db:
|
||||
account = db.get(EmailAccount, account_id)
|
||||
if not account:
|
||||
logger.error(f"Account {account_id} not found")
|
||||
return {"status": "error", "error": "Account not found"}
|
||||
|
||||
parsed_email = parse_email_message(raw_email, message_id)
|
||||
if check_content_exists(
|
||||
db, MailMessage, message_id=message_id, sha256=parsed_email["hash"]
|
||||
):
|
||||
return {"status": "already_exists", "message_id": message_id}
|
||||
parsed_email = parse_email_message(raw_email, message_id)
|
||||
if check_content_exists(
|
||||
db, MailMessage, message_id=message_id, sha256=parsed_email["hash"]
|
||||
):
|
||||
return {"status": "already_exists", "message_id": message_id}
|
||||
|
||||
mail_message = create_mail_message(db, account.tags, folder, parsed_email)
|
||||
mail_message = create_mail_message(db, account.tags, folder, parsed_email)
|
||||
|
||||
db.flush()
|
||||
vectorize_email(mail_message)
|
||||
db.flush()
|
||||
vectorize_email(mail_message)
|
||||
|
||||
db.commit()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Stored embedding for message {mail_message.message_id}")
|
||||
logger.info("Chunks:")
|
||||
for chunk in mail_message.chunks:
|
||||
logger.info(f" - {chunk.id}")
|
||||
for attachment in mail_message.attachments:
|
||||
logger.info(f" - Attachment {attachment.id}")
|
||||
for chunk in attachment.chunks:
|
||||
logger.info(f" - {chunk.id}")
|
||||
logger.info(f"Stored embedding for message {mail_message.message_id}")
|
||||
logger.info("Chunks:")
|
||||
for chunk in mail_message.chunks:
|
||||
logger.info(f" - {chunk.id}")
|
||||
for attachment in mail_message.attachments:
|
||||
logger.info(f" - Attachment {attachment.id}")
|
||||
for chunk in attachment.chunks:
|
||||
logger.info(f" - {chunk.id}")
|
||||
|
||||
return {
|
||||
"status": "processed",
|
||||
"mail_message_id": cast(int, mail_message.id),
|
||||
"message_id": message_id,
|
||||
"chunks_count": len(mail_message.chunks),
|
||||
"attachments_count": len(mail_message.attachments),
|
||||
}
|
||||
return {
|
||||
"status": "processed",
|
||||
"mail_message_id": cast(int, mail_message.id),
|
||||
"message_id": message_id,
|
||||
"chunks_count": len(mail_message.chunks),
|
||||
"attachments_count": len(mail_message.attachments),
|
||||
}
|
||||
except IntegrityError:
|
||||
# Another worker already processed this message (race condition)
|
||||
logger.info(f"Message {message_id} already exists (concurrent insert)")
|
||||
return {"status": "already_exists", "message_id": message_id}
|
||||
|
||||
|
||||
@app.task(name=SYNC_ACCOUNT)
|
||||
|
||||
@ -15,6 +15,8 @@ from memory.common.celery_app import (
|
||||
app,
|
||||
CLEAN_ALL_COLLECTIONS,
|
||||
CLEAN_COLLECTION,
|
||||
CLEANUP_EXPIRED_OAUTH_STATES,
|
||||
CLEANUP_EXPIRED_SESSIONS,
|
||||
REINGEST_MISSING_CHUNKS,
|
||||
REINGEST_CHUNK,
|
||||
REINGEST_ITEM,
|
||||
@ -338,3 +340,66 @@ def update_metadata_for_source_items(item_type: str):
|
||||
update_metadata_for_item.delay(item_id.id, item_type) # type: ignore
|
||||
|
||||
return {"status": "success", "items": len(item_ids)}
|
||||
|
||||
|
||||
@app.task(name=CLEANUP_EXPIRED_OAUTH_STATES)
|
||||
def cleanup_expired_oauth_states(max_age_hours: int = 1):
|
||||
"""Clean up OAuth states that are older than max_age_hours.
|
||||
|
||||
OAuth states should be short-lived - they're only needed during the
|
||||
authorization flow which should complete within minutes.
|
||||
"""
|
||||
from memory.common.db.models import MCPServer
|
||||
|
||||
logger.info(f"Cleaning up OAuth states older than {max_age_hours} hours")
|
||||
|
||||
cutoff = datetime.now() - timedelta(hours=max_age_hours)
|
||||
cleaned = 0
|
||||
|
||||
with make_session() as session:
|
||||
# Find MCP servers with stale OAuth state
|
||||
stale_servers = (
|
||||
session.query(MCPServer)
|
||||
.filter(
|
||||
MCPServer.state.isnot(None),
|
||||
MCPServer.updated_at < cutoff,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for server in stale_servers:
|
||||
# Clear the temporary OAuth state fields
|
||||
server.state = None
|
||||
server.code_verifier = None
|
||||
cleaned += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Cleaned up {cleaned} expired OAuth states")
|
||||
return {"cleaned": cleaned}
|
||||
|
||||
|
||||
@app.task(name=CLEANUP_EXPIRED_SESSIONS)
|
||||
def cleanup_expired_sessions():
|
||||
"""Clean up expired user sessions from the database."""
|
||||
from memory.common.db.models import UserSession
|
||||
from datetime import timezone
|
||||
|
||||
logger.info("Cleaning up expired user sessions")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted = 0
|
||||
|
||||
with make_session() as session:
|
||||
expired_sessions = (
|
||||
session.query(UserSession).filter(UserSession.expires_at < now).all()
|
||||
)
|
||||
|
||||
for user_session in expired_sessions:
|
||||
session.delete(user_session)
|
||||
deleted += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Deleted {deleted} expired user sessions")
|
||||
return {"deleted": deleted}
|
||||
|
||||
@ -123,12 +123,23 @@ def sync_note(
|
||||
note.tags = tags # type: ignore
|
||||
|
||||
note.update_confidences(confidences)
|
||||
if save_to_file:
|
||||
with git_tracking(
|
||||
settings.NOTES_STORAGE_DIR, f"Sync note {filename}: {subject}"
|
||||
):
|
||||
note.save_to_file()
|
||||
return process_content_item(note, session)
|
||||
|
||||
# Process the content item first (commits transaction)
|
||||
result = process_content_item(note, session)
|
||||
|
||||
# Git operations MUST be outside the database transaction to avoid
|
||||
# holding the connection during slow network I/O
|
||||
if save_to_file:
|
||||
with git_tracking(
|
||||
settings.NOTES_STORAGE_DIR, f"Sync note {filename}: {subject}"
|
||||
):
|
||||
# Re-fetch note for file operations (session is closed)
|
||||
with make_session() as session:
|
||||
note = session.get(Note, result.get(f"note_id"))
|
||||
if note:
|
||||
note.save_to_file()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@app.task(name=SYNC_NOTES)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
|
||||
from memory.common import settings
|
||||
@ -15,6 +15,10 @@ from memory.workers.tasks.content_processing import safe_task_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum time a task can be in "executing" state before being considered stale
|
||||
# Should be longer than task_time_limit (1 hour) to allow for legitimate long-running tasks
|
||||
STALE_EXECUTION_TIMEOUT_HOURS = 2
|
||||
|
||||
|
||||
def call_llm_for_scheduled(session, scheduled_call: ScheduledLLMCall) -> str | None:
|
||||
"""Call LLM with tools support for scheduled calls."""
|
||||
@ -147,8 +151,36 @@ def run_scheduled_calls():
|
||||
|
||||
Uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions when
|
||||
multiple workers query for due calls simultaneously.
|
||||
|
||||
Also recovers stale "executing" tasks that were abandoned due to worker crashes.
|
||||
"""
|
||||
with make_session() as session:
|
||||
# First, recover stale "executing" tasks that have been stuck too long
|
||||
# This handles cases where workers crashed mid-execution
|
||||
stale_cutoff = datetime.now(timezone.utc) - timedelta(hours=STALE_EXECUTION_TIMEOUT_HOURS)
|
||||
stale_calls = (
|
||||
session.query(ScheduledLLMCall)
|
||||
.filter(
|
||||
ScheduledLLMCall.status == "executing",
|
||||
ScheduledLLMCall.executed_at < stale_cutoff,
|
||||
)
|
||||
.with_for_update(skip_locked=True)
|
||||
.all()
|
||||
)
|
||||
|
||||
for stale_call in stale_calls:
|
||||
logger.warning(
|
||||
f"Recovering stale scheduled call {stale_call.id} "
|
||||
f"(stuck in executing since {stale_call.executed_at})"
|
||||
)
|
||||
stale_call.status = "pending"
|
||||
stale_call.executed_at = None
|
||||
stale_call.error_message = "Recovered from stale execution state"
|
||||
|
||||
if stale_calls:
|
||||
session.commit()
|
||||
logger.info(f"Recovered {len(stale_calls)} stale scheduled calls")
|
||||
|
||||
# Use FOR UPDATE SKIP LOCKED to atomically claim pending calls
|
||||
# This prevents multiple workers from processing the same call
|
||||
#
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user