mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Fix 8 security and code quality issues from deep dive
Security fixes: - Issue #1: Improved path traversal validation using pathlib.relative_to() - Issue #4: Added timing attack prevention for user enumeration - Issue #5: Added constant-time API key comparison using secrets.compare_digest() Performance fixes: - Issue #20: Cache database engine and session factory for proper connection pooling Code quality fixes: - Issue #28: Fixed string literal without effect (now proper comment) - Issue #29: Removed duplicate db_session.add() call - Issue #30: Fixed incorrect docstring parameter name - Issue #31: Added parentheses for clear operator precedence in set operations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
92dad5b9fd
commit
b9d6ff8745
@ -139,7 +139,8 @@ async def search_knowledge_base(
|
||||
|
||||
if not modalities:
|
||||
modalities = set(ALL_COLLECTIONS.keys())
|
||||
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
||||
# Filter to valid collections, excluding observation collections
|
||||
modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS
|
||||
|
||||
search_filters = SearchFilters(**filters)
|
||||
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
|
||||
|
||||
@ -76,7 +76,7 @@ app.add_middleware(
|
||||
def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str) -> pathlib.Path:
|
||||
"""Validate that a requested path resolves within the base directory.
|
||||
|
||||
Prevents path traversal attacks using ../ or similar techniques.
|
||||
Prevents path traversal attacks using ../, symlinks, or similar techniques.
|
||||
|
||||
Args:
|
||||
base_dir: The allowed base directory
|
||||
@ -88,11 +88,25 @@ def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str)
|
||||
Raises:
|
||||
HTTPException: If the path would escape the base directory
|
||||
"""
|
||||
# Resolve to absolute path and ensure it's within base_dir
|
||||
resolved = (base_dir / requested_path).resolve()
|
||||
base_resolved = base_dir.resolve()
|
||||
# Resolve base directory to absolute path
|
||||
base_resolved = base_dir.resolve(strict=True)
|
||||
|
||||
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
||||
# Build the target path and resolve it
|
||||
# Use strict=False first to check the path before it exists
|
||||
target = base_dir / requested_path
|
||||
|
||||
# Resolve the path (follows symlinks)
|
||||
try:
|
||||
resolved = target.resolve(strict=True)
|
||||
except (OSError, ValueError):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Use pathlib's is_relative_to for proper path containment check
|
||||
# This is safer than string comparison as it handles edge cases
|
||||
try:
|
||||
resolved.relative_to(base_resolved)
|
||||
except ValueError:
|
||||
# Path is not relative to base - access denied
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return resolved
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
|
||||
@ -118,16 +119,38 @@ def create_user(email: str, password: str, name: str, db: DBSession) -> HumanUse
|
||||
|
||||
|
||||
def authenticate_user(email: str, password: str, db: DBSession) -> HumanUser | None:
|
||||
"""Authenticate a human user by email and password"""
|
||||
"""Authenticate a human user by email and password.
|
||||
|
||||
Uses constant-time comparison to prevent timing-based user enumeration.
|
||||
"""
|
||||
user = db.query(HumanUser).filter(HumanUser.email == email).first()
|
||||
if user and user.is_valid_password(password):
|
||||
return user
|
||||
|
||||
# Always perform password check to prevent timing attacks
|
||||
# Even if user doesn't exist, we do a dummy check
|
||||
if user:
|
||||
if user.is_valid_password(password):
|
||||
return user
|
||||
else:
|
||||
# Dummy password check to prevent timing-based user enumeration
|
||||
# This ensures the function takes similar time whether user exists or not
|
||||
from memory.common.db.models.users import verify_password
|
||||
verify_password(password, "$2b$12$dummy.hash.for.timing.attack.prevention")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def authenticate_bot(api_key: str, db: DBSession) -> BotUser | None:
|
||||
"""Authenticate a bot by API key"""
|
||||
return db.query(BotUser).filter(BotUser.api_key == api_key).first()
|
||||
"""Authenticate a bot by API key.
|
||||
|
||||
Uses constant-time comparison to prevent timing attacks.
|
||||
"""
|
||||
# Get all bot users and compare with constant-time function
|
||||
# This prevents timing attacks on API key discovery
|
||||
bots = db.query(BotUser).all()
|
||||
for bot in bots:
|
||||
if bot.api_key and secrets.compare_digest(bot.api_key, api_key):
|
||||
return bot
|
||||
return None
|
||||
|
||||
|
||||
@router.api_route("/logout", methods=["GET", "POST"])
|
||||
|
||||
@ -9,22 +9,43 @@ from sqlalchemy.orm import sessionmaker, scoped_session, Session
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
# Cached engine and session factory for connection pooling
|
||||
_engine = None
|
||||
_session_factory = None
|
||||
_scoped_session = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""Create SQLAlchemy engine from environment variables"""
|
||||
return create_engine(settings.DB_URL)
|
||||
"""Get or create SQLAlchemy engine with connection pooling.
|
||||
|
||||
The engine is cached to ensure connection pooling works correctly.
|
||||
Creating a new engine for each request would bypass the pool.
|
||||
"""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = create_engine(
|
||||
settings.DB_URL,
|
||||
pool_pre_ping=True, # Verify connections before use
|
||||
pool_recycle=3600, # Recycle connections after 1 hour
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory():
|
||||
"""Create a session factory for SQLAlchemy sessions"""
|
||||
engine = get_engine()
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
return session_factory
|
||||
"""Get or create a cached session factory for SQLAlchemy sessions."""
|
||||
global _session_factory
|
||||
if _session_factory is None:
|
||||
engine = get_engine()
|
||||
_session_factory = sessionmaker(bind=engine)
|
||||
return _session_factory
|
||||
|
||||
|
||||
def get_scoped_session():
|
||||
"""Create a thread-local scoped session factory"""
|
||||
return scoped_session(get_session_factory())
|
||||
"""Get or create a thread-local scoped session factory."""
|
||||
global _scoped_session
|
||||
if _scoped_session is None:
|
||||
_scoped_session = scoped_session(get_session_factory())
|
||||
return _scoped_session
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
@ -31,14 +31,14 @@ def process_attachment(
|
||||
|
||||
Args:
|
||||
attachment: Attachment dictionary with metadata and content
|
||||
message_id: Email message ID to use in file path generation
|
||||
message: MailMessage instance to use for file path generation
|
||||
|
||||
Returns:
|
||||
Processed attachment dictionary with appropriate metadata
|
||||
"""
|
||||
content, file_path = None, None
|
||||
if not (real_content := attachment.get("content")):
|
||||
"No content, so just save the metadata"
|
||||
pass # No content, so just save the metadata
|
||||
elif attachment["size"] <= settings.MAX_INLINE_ATTACHMENT_SIZE and attachment[
|
||||
"content_type"
|
||||
].startswith("text/"):
|
||||
@ -130,7 +130,6 @@ def create_mail_message(
|
||||
db_session.add_all(attachments)
|
||||
mail_message.attachments = attachments
|
||||
|
||||
db_session.add(mail_message)
|
||||
return mail_message
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user