Fix 8 security and code quality issues from deep dive

Security fixes:
- Issue #1: Improved path traversal validation using pathlib.relative_to()
- Issue #4: Added timing attack prevention for user enumeration
- Issue #5: Added constant-time API key comparison using secrets.compare_digest()

Performance fixes:
- Issue #20: Cache database engine and session factory for proper connection pooling

Code quality fixes:
- Issue #28: Fixed string literal without effect (now proper comment)
- Issue #29: Removed duplicate db_session.add() call
- Issue #30: Fixed incorrect docstring parameter name
- Issue #31: Added parentheses for clear operator precedence in set operations

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
mruwnik 2025-12-19 21:55:59 +00:00
parent 92dad5b9fd
commit b9d6ff8745
5 changed files with 80 additions and 22 deletions

View File

@ -139,7 +139,8 @@ async def search_knowledge_base(
if not modalities:
modalities = set(ALL_COLLECTIONS.keys())
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
# Filter to valid collections, excluding observation collections
modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS
search_filters = SearchFilters(**filters)
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)

View File

@ -76,7 +76,7 @@ app.add_middleware(
def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str) -> pathlib.Path:
"""Validate that a requested path resolves within the base directory.
Prevents path traversal attacks using ../ or similar techniques.
Prevents path traversal attacks using ../, symlinks, or similar techniques.
Args:
base_dir: The allowed base directory
@ -88,11 +88,25 @@ def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str)
Raises:
HTTPException: If the path would escape the base directory
"""
# Resolve to absolute path and ensure it's within base_dir
resolved = (base_dir / requested_path).resolve()
base_resolved = base_dir.resolve()
# Resolve base directory to absolute path
base_resolved = base_dir.resolve(strict=True)
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
# Build the target path and resolve it
# Use strict=False first to check the path before it exists
target = base_dir / requested_path
# Resolve the path (follows symlinks)
try:
resolved = target.resolve(strict=True)
except (OSError, ValueError):
raise HTTPException(status_code=404, detail="File not found")
# Use pathlib's is_relative_to for proper path containment check
# This is safer than string comparison as it handles edge cases
try:
resolved.relative_to(base_resolved)
except ValueError:
# Path is not relative to base - access denied
raise HTTPException(status_code=403, detail="Access denied")
return resolved

View File

@ -1,4 +1,5 @@
import logging
import secrets
from datetime import datetime, timedelta, timezone
from typing import cast
@ -118,16 +119,38 @@ def create_user(email: str, password: str, name: str, db: DBSession) -> HumanUse
def authenticate_user(email: str, password: str, db: DBSession) -> HumanUser | None:
"""Authenticate a human user by email and password"""
"""Authenticate a human user by email and password.
Uses constant-time comparison to prevent timing-based user enumeration.
"""
user = db.query(HumanUser).filter(HumanUser.email == email).first()
if user and user.is_valid_password(password):
return user
# Always perform password check to prevent timing attacks
# Even if user doesn't exist, we do a dummy check
if user:
if user.is_valid_password(password):
return user
else:
# Dummy password check to prevent timing-based user enumeration
# This ensures the function takes similar time whether user exists or not
from memory.common.db.models.users import verify_password
verify_password(password, "$2b$12$dummy.hash.for.timing.attack.prevention")
return None
def authenticate_bot(api_key: str, db: DBSession) -> BotUser | None:
"""Authenticate a bot by API key"""
return db.query(BotUser).filter(BotUser.api_key == api_key).first()
"""Authenticate a bot by API key.
Uses constant-time comparison to prevent timing attacks.
"""
# Get all bot users and compare with constant-time function
# This prevents timing attacks on API key discovery
bots = db.query(BotUser).all()
for bot in bots:
if bot.api_key and secrets.compare_digest(bot.api_key, api_key):
return bot
return None
@router.api_route("/logout", methods=["GET", "POST"])

View File

@ -9,22 +9,43 @@ from sqlalchemy.orm import sessionmaker, scoped_session, Session
from memory.common import settings
# Cached engine and session factory for connection pooling
_engine = None
_session_factory = None
_scoped_session = None
def get_engine():
"""Create SQLAlchemy engine from environment variables"""
return create_engine(settings.DB_URL)
"""Get or create SQLAlchemy engine with connection pooling.
The engine is cached to ensure connection pooling works correctly.
Creating a new engine for each request would bypass the pool.
"""
global _engine
if _engine is None:
_engine = create_engine(
settings.DB_URL,
pool_pre_ping=True, # Verify connections before use
pool_recycle=3600, # Recycle connections after 1 hour
)
return _engine
def get_session_factory():
"""Create a session factory for SQLAlchemy sessions"""
engine = get_engine()
session_factory = sessionmaker(bind=engine)
return session_factory
"""Get or create a cached session factory for SQLAlchemy sessions."""
global _session_factory
if _session_factory is None:
engine = get_engine()
_session_factory = sessionmaker(bind=engine)
return _session_factory
def get_scoped_session():
"""Create a thread-local scoped session factory"""
return scoped_session(get_session_factory())
"""Get or create a thread-local scoped session factory."""
global _scoped_session
if _scoped_session is None:
_scoped_session = scoped_session(get_session_factory())
return _scoped_session
def get_session() -> Generator[Session, None, None]:

View File

@ -31,14 +31,14 @@ def process_attachment(
Args:
attachment: Attachment dictionary with metadata and content
message_id: Email message ID to use in file path generation
message: MailMessage instance to use for file path generation
Returns:
Processed attachment dictionary with appropriate metadata
"""
content, file_path = None, None
if not (real_content := attachment.get("content")):
"No content, so just save the metadata"
pass # No content, so just save the metadata
elif attachment["size"] <= settings.MAX_INLINE_ATTACHMENT_SIZE and attachment[
"content_type"
].startswith("text/"):
@ -130,7 +130,6 @@ def create_mail_message(
db_session.add_all(attachments)
mail_message.attachments = attachments
db_session.add(mail_message)
return mail_message