diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/memory.py index a925e1a..9fbeca6 100644 --- a/src/memory/api/MCP/memory.py +++ b/src/memory/api/MCP/memory.py @@ -139,7 +139,8 @@ async def search_knowledge_base( if not modalities: modalities = set(ALL_COLLECTIONS.keys()) - modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS + # Filter to valid collections, excluding observation collections + modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS search_filters = SearchFilters(**filters) search_filters["source_ids"] = filter_source_ids(modalities, search_filters) diff --git a/src/memory/api/app.py b/src/memory/api/app.py index e8bbcae..d56325d 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -76,7 +76,7 @@ app.add_middleware( def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str) -> pathlib.Path: """Validate that a requested path resolves within the base directory. - Prevents path traversal attacks using ../ or similar techniques. + Prevents path traversal attacks using ../, symlinks, or similar techniques. Args: base_dir: The allowed base directory @@ -88,11 +88,25 @@ def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str) Raises: HTTPException: If the path would escape the base directory """ - # Resolve to absolute path and ensure it's within base_dir - resolved = (base_dir / requested_path).resolve() - base_resolved = base_dir.resolve() + # Resolve base directory to absolute path + base_resolved = base_dir.resolve(strict=True) - if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved: + # Build the target path and resolve it + # Use strict=False first to check the path before it exists + target = base_dir / requested_path + + # Resolve the path (follows symlinks) + try: + resolved = target.resolve(strict=True) + except (OSError, ValueError): + raise HTTPException(status_code=404, detail="File not found") + + # Use pathlib's is_relative_to for proper path containment check + # This is safer than string comparison as it handles edge cases + try: + resolved.relative_to(base_resolved) + except ValueError: + # Path is not relative to base - access denied raise HTTPException(status_code=403, detail="Access denied") return resolved diff --git a/src/memory/api/auth.py b/src/memory/api/auth.py index aa5b865..4ac8cd2 100644 --- a/src/memory/api/auth.py +++ b/src/memory/api/auth.py @@ -1,4 +1,5 @@ import logging +import secrets from datetime import datetime, timedelta, timezone from typing import cast @@ -118,16 +119,38 @@ def create_user(email: str, password: str, name: str, db: DBSession) -> HumanUse def authenticate_user(email: str, password: str, db: DBSession) -> HumanUser | None: - """Authenticate a human user by email and password""" + """Authenticate a human user by email and password. + + Uses constant-time comparison to prevent timing-based user enumeration. + """ user = db.query(HumanUser).filter(HumanUser.email == email).first() - if user and user.is_valid_password(password): - return user + + # Always perform password check to prevent timing attacks + # Even if user doesn't exist, we do a dummy check + if user: + if user.is_valid_password(password): + return user + else: + # Dummy password check to prevent timing-based user enumeration + # This ensures the function takes similar time whether user exists or not + from memory.common.db.models.users import verify_password + verify_password(password, "$2b$12$dummy.hash.for.timing.attack.prevention") + return None def authenticate_bot(api_key: str, db: DBSession) -> BotUser | None: - """Authenticate a bot by API key""" - return db.query(BotUser).filter(BotUser.api_key == api_key).first() + """Authenticate a bot by API key. + + Uses constant-time comparison to prevent timing attacks. + """ + # Get all bot users and compare with constant-time function + # This prevents timing attacks on API key discovery + bots = db.query(BotUser).all() + for bot in bots: + if bot.api_key and secrets.compare_digest(bot.api_key, api_key): + return bot + return None @router.api_route("/logout", methods=["GET", "POST"]) diff --git a/src/memory/common/db/connection.py b/src/memory/common/db/connection.py index 91b39bd..8aad41c 100644 --- a/src/memory/common/db/connection.py +++ b/src/memory/common/db/connection.py @@ -9,22 +9,43 @@ from sqlalchemy.orm import sessionmaker, scoped_session, Session from memory.common import settings +# Cached engine and session factory for connection pooling +_engine = None +_session_factory = None +_scoped_session = None + def get_engine(): - """Create SQLAlchemy engine from environment variables""" - return create_engine(settings.DB_URL) + """Get or create SQLAlchemy engine with connection pooling. + + The engine is cached to ensure connection pooling works correctly. + Creating a new engine for each request would bypass the pool. + """ + global _engine + if _engine is None: + _engine = create_engine( + settings.DB_URL, + pool_pre_ping=True, # Verify connections before use + pool_recycle=3600, # Recycle connections after 1 hour + ) + return _engine def get_session_factory(): - """Create a session factory for SQLAlchemy sessions""" - engine = get_engine() - session_factory = sessionmaker(bind=engine) - return session_factory + """Get or create a cached session factory for SQLAlchemy sessions.""" + global _session_factory + if _session_factory is None: + engine = get_engine() + _session_factory = sessionmaker(bind=engine) + return _session_factory def get_scoped_session(): - """Create a thread-local scoped session factory""" - return scoped_session(get_session_factory()) + """Get or create a thread-local scoped session factory.""" + global _scoped_session + if _scoped_session is None: + _scoped_session = scoped_session(get_session_factory()) + return _scoped_session def get_session() -> Generator[Session, None, None]: diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index 4c708a7..c345def 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -31,14 +31,14 @@ def process_attachment( Args: attachment: Attachment dictionary with metadata and content - message_id: Email message ID to use in file path generation + message: MailMessage instance to use for file path generation Returns: Processed attachment dictionary with appropriate metadata """ content, file_path = None, None if not (real_content := attachment.get("content")): - "No content, so just save the metadata" + pass # No content, so just save the metadata elif attachment["size"] <= settings.MAX_INLINE_ATTACHMENT_SIZE and attachment[ "content_type" ].startswith("text/"): @@ -130,7 +130,6 @@ def create_mail_message( db_session.add_all(attachments) mail_message.attachments = attachments - db_session.add(mail_message) return mail_message