Fix 19 bugs from investigation

Critical/High severity fixes:
- BUG-001: Path traversal vulnerabilities (3 endpoints)
- BUG-003: BM25 filters now apply size/observation_types
- BUG-006: Remove API key from log messages
- BUG-008: Chunk size validation before yielding
- BUG-009: Race condition fix with FOR UPDATE SKIP LOCKED
- BUG-010: Add mcp_servers property to MessageProcessor
- BUG-011: Fix user_id type (BigInteger→Integer)
- BUG-012: Swap inverted score thresholds
- BUG-013: Add retry logic to embedding pipeline
- BUG-014: Fix CORS to use specific origin
- BUG-015: Add Celery retry/timeout defaults
- BUG-016: Re-raise exceptions for Celery retries

Medium severity fixes:
- BUG-017: Add collection_name index on Chunk
- BUG-031: Add SearchConfig limits (max 1000/300s)
- BUG-033: Replace debug prints with logger calls
- BUG-037: Clarify timezone handling in scheduler
- BUG-043: Health check now validates DB + Qdrant
- BUG-055: collection_model returns None not "unknown"

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Daniel O'Connell 2025-12-19 18:59:14 +01:00
parent a1444efaac
commit 52274f82a6
19 changed files with 320 additions and 70 deletions

View File

@ -27,6 +27,32 @@ from memory.common.formatters import observation
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def validate_path_within_directory(
base_dir: pathlib.Path, requested_path: str
) -> pathlib.Path:
"""Validate that a requested path resolves within the base directory.
Prevents path traversal attacks using ../ or similar techniques.
Args:
base_dir: The allowed base directory
requested_path: The user-provided path
Returns:
The resolved absolute path if valid
Raises:
ValueError: If the path would escape the base directory
"""
resolved = (base_dir / requested_path.lstrip("/")).resolve()
base_resolved = base_dir.resolve()
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
raise ValueError(f"Path escapes allowed directory: {requested_path}")
return resolved
def filter_observation_source_ids( def filter_observation_source_ids(
tags: list[str] | None = None, observation_types: list[str] | None = None tags: list[str] | None = None, observation_types: list[str] | None = None
): ):
@ -344,7 +370,11 @@ async def note_files(path: str = "/"):
Returns: List of file paths relative to notes directory Returns: List of file paths relative to notes directory
""" """
root = settings.NOTES_STORAGE_DIR / path.lstrip("/") try:
root = validate_path_within_directory(settings.NOTES_STORAGE_DIR, path)
except ValueError as e:
raise ValueError(f"Invalid path: {e}")
return [ return [
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}" f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
for f in root.rglob("*.md") for f in root.rglob("*.md")
@ -359,15 +389,19 @@ def fetch_file(filename: str) -> dict:
Returns dict with content, mime_type, is_text, file_size. Returns dict with content, mime_type, is_text, file_size.
Text content as string, binary as base64. Text content as string, binary as base64.
""" """
path = settings.FILE_STORAGE_DIR / filename.strip().lstrip("/") try:
print("fetching file", path) path = validate_path_within_directory(
settings.FILE_STORAGE_DIR, filename.strip()
)
except ValueError as e:
raise ValueError(f"Invalid path: {e}")
logger.debug(f"Fetching file: {path}")
if not path.exists(): if not path.exists():
raise FileNotFoundError(f"File not found: {filename}") raise FileNotFoundError(f"File not found: {filename}")
mime_type = extract.get_mime_type(path) mime_type = extract.get_mime_type(path)
chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True) chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True)
print("mime_type", mime_type)
print("chunks", chunks)
def serialize_chunk( def serialize_chunk(
chunk: extract.DataChunk, data: extract.MulitmodalChunk chunk: extract.DataChunk, data: extract.MulitmodalChunk

View File

@ -6,6 +6,7 @@ import contextlib
import os import os
import logging import logging
import mimetypes import mimetypes
import pathlib
from fastapi import FastAPI, UploadFile, Request, HTTPException from fastapi import FastAPI, UploadFile, Request, HTTPException
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
@ -33,27 +34,58 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Knowledge Base API", lifespan=lifespan) app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
app.add_middleware(AuthenticationMiddleware) app.add_middleware(AuthenticationMiddleware)
# Configure CORS with specific origin to prevent CSRF attacks.
# allow_credentials=True requires specific origins, not wildcards.
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], # [settings.SERVER_URL], allow_origins=[settings.SERVER_URL],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str) -> pathlib.Path:
"""Validate that a requested path resolves within the base directory.
Prevents path traversal attacks using ../ or similar techniques.
Args:
base_dir: The allowed base directory
requested_path: The user-provided path
Returns:
The resolved absolute path if valid
Raises:
HTTPException: If the path would escape the base directory
"""
# Resolve to absolute path and ensure it's within base_dir
resolved = (base_dir / requested_path).resolve()
base_resolved = base_dir.resolve()
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
raise HTTPException(status_code=403, detail="Access denied")
return resolved
@app.get("/ui{full_path:path}") @app.get("/ui{full_path:path}")
async def serve_react_app(full_path: str): async def serve_react_app(full_path: str):
full_path = full_path.lstrip("/") full_path = full_path.lstrip("/")
index_file = settings.STATIC_DIR / full_path try:
if index_file.is_file(): index_file = validate_path_within_directory(settings.STATIC_DIR, full_path)
return FileResponse(index_file) if index_file.is_file():
return FileResponse(index_file)
except HTTPException:
pass # Fall through to index.html for SPA routing
return FileResponse(settings.STATIC_DIR / "index.html") return FileResponse(settings.STATIC_DIR / "index.html")
@app.get("/files/{path:path}") @app.get("/files/{path:path}")
async def serve_file(path: str): async def serve_file(path: str):
file_path = settings.FILE_STORAGE_DIR / path file_path = validate_path_within_directory(settings.FILE_STORAGE_DIR, path)
if not file_path.is_file(): if not file_path.is_file():
raise HTTPException(status_code=404, detail="File not found") raise HTTPException(status_code=404, detail="File not found")
@ -86,10 +118,36 @@ app.include_router(auth_router)
# Add health check to MCP server instead of main app # Add health check to MCP server instead of main app
@mcp.custom_route("/health", methods=["GET"]) @mcp.custom_route("/health", methods=["GET"])
async def health_check(request: Request): async def health_check(request: Request):
"""Simple health check endpoint on MCP server""" """Health check endpoint that verifies all dependencies are accessible."""
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import text
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"}) checks = {"mcp_oauth": "enabled"}
all_healthy = True
# Check database connection
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
checks["database"] = "healthy"
except Exception as e:
checks["database"] = f"unhealthy: {str(e)[:100]}"
all_healthy = False
# Check Qdrant connection
try:
from memory.common.qdrant import get_qdrant_client
client = get_qdrant_client()
client.get_collections()
checks["qdrant"] = "healthy"
except Exception as e:
checks["qdrant"] = f"unhealthy: {str(e)[:100]}"
all_healthy = False
checks["status"] = "healthy" if all_healthy else "degraded"
status_code = 200 if all_healthy else 503
return JSONResponse(checks, status_code=status_code)
# Mount MCP server at root - OAuth endpoints need to be at root level # Mount MCP server at root - OAuth endpoints need to be at root level

View File

@ -12,7 +12,7 @@ from memory.api.search.types import SearchFilters
from memory.common import extract from memory.common import extract
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, ConfidenceScore from memory.common.db.models import Chunk, ConfidenceScore, SourceItem
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,9 +29,28 @@ async def search_bm25(
Chunk.content.isnot(None), Chunk.content.isnot(None),
) )
# Join with SourceItem if we need size filters
needs_source_join = any(filters.get(k) for k in ["min_size", "max_size"])
if needs_source_join:
items_query = items_query.join(
SourceItem, SourceItem.id == Chunk.source_id
)
if source_ids := filters.get("source_ids"): if source_ids := filters.get("source_ids"):
items_query = items_query.filter(Chunk.source_id.in_(source_ids)) items_query = items_query.filter(Chunk.source_id.in_(source_ids))
# Size filters
if min_size := filters.get("min_size"):
items_query = items_query.filter(SourceItem.size >= min_size)
if max_size := filters.get("max_size"):
items_query = items_query.filter(SourceItem.size <= max_size)
# Observation type filter - restricts to specific collection types
if observation_types := filters.get("observation_types"):
items_query = items_query.filter(
Chunk.collection_name.in_(observation_types)
)
# Add confidence filtering if specified # Add confidence filtering if specified
if min_confidences := filters.get("min_confidences"): if min_confidences := filters.get("min_confidences"):
for confidence_type, min_score in min_confidences.items(): for confidence_type, min_score in min_confidences.items():

View File

@ -182,13 +182,16 @@ async def search_chunks_embeddings(
filters: SearchFilters = SearchFilters(), filters: SearchFilters = SearchFilters(),
timeout: int = 2, timeout: int = 2,
) -> list[str]: ) -> list[str]:
# Note: Multimodal embeddings typically produce higher similarity scores,
# so we use a higher threshold (0.4) to maintain selectivity.
# Text embeddings produce lower scores, so we use 0.25.
all_ids = await asyncio.gather( all_ids = await asyncio.gather(
asyncio.wait_for( asyncio.wait_for(
search_chunks( search_chunks(
data, data,
modalities & TEXT_COLLECTIONS, modalities & TEXT_COLLECTIONS,
limit, limit,
0.4, 0.25,
filters, filters,
False, False,
), ),
@ -199,7 +202,7 @@ async def search_chunks_embeddings(
data, data,
modalities & MULTIMODAL_COLLECTIONS, modalities & MULTIMODAL_COLLECTIONS,
limit, limit,
0.25, 0.4,
filters, filters,
True, True,
), ),

View File

@ -1,10 +1,13 @@
import asyncio import asyncio
import logging
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from PIL import Image from PIL import Image
from memory.common.db.models.source_item import Chunk from memory.common.db.models.source_item import Chunk
from memory.common import llms, settings, tokens from memory.common import llms, settings, tokens
logger = logging.getLogger(__name__)
SCORE_CHUNK_SYSTEM_PROMPT = """ SCORE_CHUNK_SYSTEM_PROMPT = """
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query. You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
@ -32,7 +35,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
try: try:
data = chunk.data data = chunk.data
except Exception as e: except Exception as e:
print(f"Error getting chunk data: {e}, {type(e)}") logger.error(f"Error getting chunk data: {e}")
return chunk return chunk
chunk_text = "\n".join(text for text in data if isinstance(text, str)) chunk_text = "\n".join(text for text in data if isinstance(text, str))
@ -47,7 +50,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT, system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
) )
except Exception as e: except Exception as e:
print(f"Error scoring chunk: {e}, {type(e)}") logger.error(f"Error scoring chunk: {e}")
return chunk return chunk
soup = BeautifulSoup(response, "html.parser") soup = BeautifulSoup(response, "html.parser")

View File

@ -80,3 +80,15 @@ class SearchConfig(BaseModel):
timeout: int = 20 timeout: int = 20
previews: bool = False previews: bool = False
useScores: bool = False useScores: bool = False
def model_post_init(self, __context) -> None:
# Enforce reasonable limits
if self.limit < 1:
object.__setattr__(self, "limit", 1)
elif self.limit > 1000:
object.__setattr__(self, "limit", 1000)
if self.timeout < 1:
object.__setattr__(self, "timeout", 1)
elif self.timeout > 300:
object.__setattr__(self, "timeout", 300)

View File

@ -88,6 +88,14 @@ app.conf.update(
task_acks_late=True, task_acks_late=True,
task_reject_on_worker_lost=True, task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1, worker_prefetch_multiplier=1,
# Default retry configuration for transient failures
task_autoretry_for=(Exception,),
task_retry_kwargs={"max_retries": 3},
task_retry_backoff=True,
task_retry_backoff_max=600, # Max 10 minutes between retries
task_retry_jitter=True,
task_time_limit=3600, # 1 hour hard limit
task_soft_time_limit=3000, # 50 minute soft limit
task_routes={ task_routes={
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"}, f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"}, f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},

View File

@ -102,29 +102,35 @@ def chunk_text(
current = "" current = ""
for span in yield_spans(text, max_tokens): for span in yield_spans(text, max_tokens):
current = f"{current} {span}".strip() # Check if adding this span would exceed the limit
if tokens.approx_token_count(current) < max_tokens: new_chunk = f"{current} {span}".strip() if current else span
if tokens.approx_token_count(new_chunk) <= max_tokens:
current = new_chunk
continue continue
if overlap <= 0: # Adding span would exceed limit - yield current first (if non-empty)
if current:
yield current yield current
current = ""
# Handle overlap for the next chunk
if overlap <= 0 or not current:
current = span
continue continue
overlap_text = current[-overlap_chars:] # Try to find a clean break point for overlap
overlap_text = current[-overlap_chars:] if len(current) > overlap_chars else current
clean_break = max( clean_break = max(
overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ") overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ")
) )
if clean_break < 0: if clean_break < 0:
yield current current = span
current = ""
continue continue
break_offset = -overlap_chars + clean_break + 1 # Start new chunk with overlap from clean break
chunk = current[break_offset:].strip() break_offset = -len(overlap_text) + clean_break + 1
yield current overlap_portion = current[break_offset:].strip()
current = chunk current = f"{overlap_portion} {span}".strip() if overlap_portion else span
if current: if current:
yield current.strip() yield current.strip()

View File

@ -132,9 +132,14 @@ def get_modality(mime_type: str) -> str:
def collection_model( def collection_model(
collection: str, text: str, images: list[Image.Image] collection: str, text: str, images: list[Image.Image]
) -> str | None: ) -> str | None:
"""Determine the appropriate embedding model for a collection.
Returns None if no suitable model can be determined, rather than
falling back to an invalid placeholder.
"""
config = ALL_COLLECTIONS.get(collection, {}) config = ALL_COLLECTIONS.get(collection, {})
if images and config.get("multimodal"): if images and config.get("multimodal"):
return settings.MIXED_EMBEDDING_MODEL return settings.MIXED_EMBEDDING_MODEL
if text and config.get("text"): if text and config.get("text"):
return settings.TEXT_EMBEDDING_MODEL return settings.TEXT_EMBEDDING_MODEL
return "unknown" return None

View File

@ -16,7 +16,7 @@ from sqlalchemy import (
Text, Text,
func, func,
) )
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship, object_session
from memory.common.db.models.base import Base from memory.common.db.models.base import Base
@ -27,6 +27,25 @@ class MessageProcessor:
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
@property
def mcp_servers(self) -> list:
"""Get MCP servers assigned to this entity via MCPServerAssignment."""
from memory.common.db.models.mcp import MCPServer, MCPServerAssignment
session = object_session(self)
if not session:
return []
return (
session.query(MCPServer)
.join(MCPServerAssignment)
.filter(
MCPServerAssignment.entity_type == self.entity_type,
MCPServerAssignment.entity_id == self.id,
)
.all()
)
system_prompt = Column( system_prompt = Column(
Text, Text,
nullable=True, nullable=True,

View File

@ -7,6 +7,7 @@ from sqlalchemy import (
DateTime, DateTime,
ForeignKey, ForeignKey,
BigInteger, BigInteger,
Integer,
JSON, JSON,
Text, Text,
) )
@ -20,7 +21,7 @@ class ScheduledLLMCall(Base):
__tablename__ = "scheduled_llm_calls" __tablename__ = "scheduled_llm_calls"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(BigInteger, ForeignKey("users.id"), nullable=False) user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
topic = Column(Text, nullable=True) topic = Column(Text, nullable=True)
# Scheduling info # Scheduling info

View File

@ -165,6 +165,7 @@ class Chunk(Base):
__table_args__ = ( __table_args__ = (
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"), CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
Index("chunk_source_idx", "source_id"), Index("chunk_source_idx", "source_id"),
Index("chunk_collection_idx", "collection_name"),
) )
@property @property

View File

@ -1,4 +1,5 @@
import logging import logging
import time
from typing import Literal, cast from typing import Literal, cast
import voyageai import voyageai
@ -15,6 +16,12 @@ from memory.common.db.models import Chunk, SourceItem
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmbeddingError(Exception):
"""Raised when embedding generation fails after retries."""
pass
def as_string( def as_string(
chunk: extract.MulitmodalChunk | list[extract.MulitmodalChunk], chunk: extract.MulitmodalChunk | list[extract.MulitmodalChunk],
) -> str: ) -> str:
@ -29,21 +36,60 @@ def embed_chunks(
chunks: list[list[extract.MulitmodalChunk]], chunks: list[list[extract.MulitmodalChunk]],
model: str = settings.TEXT_EMBEDDING_MODEL, model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document", input_type: Literal["document", "query"] = "document",
max_retries: int = 3,
retry_delay: float = 1.0,
) -> list[Vector]: ) -> list[Vector]:
logger.debug(f"Embedding chunks: {model} - {str(chunks)} {len(chunks)}") """Embed chunks with retry logic for transient failures.
vo = voyageai.Client() # type: ignore
if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed(
chunks,
model=model,
input_type=input_type,
).embeddings
texts = [as_string(c) for c in chunks] Args:
logger.debug(f"Embedding texts: {texts}") chunks: List of chunk lists to embed
return cast( model: Embedding model to use
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings input_type: Whether embedding documents or queries
) max_retries: Maximum number of retry attempts
retry_delay: Base delay between retries (exponential backoff)
Returns:
List of embedding vectors
Raises:
EmbeddingError: If embedding fails after all retries
"""
if not chunks:
return []
logger.debug(f"Embedding {len(chunks)} chunks with model {model}")
vo = voyageai.Client() # type: ignore
last_error = None
for attempt in range(max_retries):
try:
if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed(
chunks,
model=model,
input_type=input_type,
).embeddings
texts = [as_string(c) for c in chunks]
return cast(
list[Vector],
vo.embed(texts, model=model, input_type=input_type).embeddings,
)
except Exception as e:
last_error = e
if attempt < max_retries - 1:
delay = retry_delay * (2**attempt)
logger.warning(
f"Embedding attempt {attempt + 1}/{max_retries} failed: {e}. "
f"Retrying in {delay:.1f}s..."
)
time.sleep(delay)
else:
logger.error(f"Embedding failed after {max_retries} attempts: {e}")
raise EmbeddingError(
f"Failed to generate embeddings after {max_retries} attempts"
) from last_error
def break_chunk( def break_chunk(

View File

@ -284,6 +284,8 @@ class AnthropicProvider(BaseLLMProvider):
# Include server info if present # Include server info if present
if current_tool_use.get("server_name"): if current_tool_use.get("server_name"):
tool_data["server_name"] = current_tool_use["server_name"] tool_data["server_name"] = current_tool_use["server_name"]
if current_tool_use.get("is_server_call"):
tool_data["is_server_call"] = current_tool_use["is_server_call"]
# Emit different event type for MCP server tools # Emit different event type for MCP server tools
if current_tool_use.get("is_server_call"): if current_tool_use.get("is_server_call"):

View File

@ -60,7 +60,7 @@ class Collector:
bot_name: str bot_name: str
def __init__(self, collector: MessageCollector, bot: DiscordBotUser): def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}") logger.info(f"Initialized collector for {bot.name}")
self.collector = collector self.collector = collector
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key))) self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
self.bot_id = cast(int, bot.id) self.bot_id = cast(int, bot.id)
@ -80,8 +80,8 @@ async def lifespan(app: FastAPI):
bots = session.query(DiscordBotUser).all() bots = session.query(DiscordBotUser).all()
app.bots = {bot.id: make_collector(bot) for bot in bots} app.bots = {bot.id: make_collector(bot) for bot in bots}
logger.error( logger.info(
f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}" f"Discord collectors started for {len(app.bots)} bots: {list(app.bots.keys())}"
) )
yield yield

View File

@ -354,7 +354,6 @@ class MessageCollector(commands.Bot):
"users_updated": users_updated, "users_updated": users_updated,
} }
print(f"✅ Metadata refresh complete: {result}")
logger.info(f"Metadata refresh complete: {result}") logger.info(f"Metadata refresh complete: {result}")
return result return result

View File

@ -259,22 +259,26 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
""" """
Decorator for safe task execution with comprehensive error handling. Decorator for safe task execution with comprehensive error handling.
Wraps task functions to catch and log exceptions, ensuring tasks Wraps task functions to log exceptions and notify on failures while
always return a result dictionary even when they fail. still allowing Celery to handle retries. Exceptions are re-raised after
logging to allow Celery's retry mechanism to work.
Args: Args:
func: Task function to wrap func: Task function to wrap
Returns: Returns:
Wrapped function that handles exceptions gracefully Wrapped function that logs exceptions and re-raises for retry
Example: Example:
@app.task(bind=True)
@safe_task_execution @safe_task_execution
def my_task(arg1, arg2): def my_task(self, arg1, arg2):
# Task implementation # Task implementation
return {"status": "success"} return {"status": "success"}
""" """
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs) -> dict: def wrapper(*args, **kwargs) -> dict:
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
@ -283,14 +287,25 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
traceback_str = traceback.format_exc() traceback_str = traceback.format_exc()
logger.error(traceback_str) logger.error(traceback_str)
notify_task_failure( # Check if this is a bound task and if retries are exhausted
task_name=func.__name__, task_self = args[0] if args and hasattr(args[0], "request") else None
error_message=str(e), is_final_retry = (
task_args=args, task_self
task_kwargs=kwargs, and hasattr(task_self, "request")
traceback_str=traceback_str, and task_self.request.retries >= task_self.max_retries
) )
return {"status": "error", "error": str(e), "traceback": traceback_str} # Notify on final failure only
if is_final_retry or task_self is None:
notify_task_failure(
task_name=func.__name__,
error_message=str(e),
task_args=args[1:] if task_self else args,
task_kwargs=kwargs,
traceback_str=traceback_str,
)
# Re-raise to allow Celery retries
raise
return wrapper return wrapper

View File

@ -189,7 +189,7 @@ def sync_book(
# Create book and sections with relationships # Create book and sections with relationships
book, all_sections = create_book_and_sections(ebook, session, tags) book, all_sections = create_book_and_sections(ebook, session, tags)
for section in all_sections: for section in all_sections:
print(section.section_title, section.book) logger.debug(f"Created section: {section.section_title}")
if title: if title:
book.title = title # type: ignore book.title = title # type: ignore

View File

@ -76,12 +76,12 @@ def execute_scheduled_call(self, scheduled_call_id: str):
logger.error(f"Scheduled call {scheduled_call_id} not found") logger.error(f"Scheduled call {scheduled_call_id} not found")
return {"error": "Scheduled call not found"} return {"error": "Scheduled call not found"}
# Check if the call is still pending # Check if the call is ready to execute (pending or queued)
if not scheduled_call.is_pending(): if scheduled_call.status not in ("pending", "queued"):
logger.warning( logger.warning(
f"Scheduled call {scheduled_call_id} is not pending (status: {scheduled_call.status})" f"Scheduled call {scheduled_call_id} is not ready (status: {scheduled_call.status})"
) )
return {"error": f"Call is not pending (status: {scheduled_call.status})"} return {"error": f"Call is not ready (status: {scheduled_call.status})"}
# Update status to executing # Update status to executing
scheduled_call.status = "executing" scheduled_call.status = "executing"
@ -143,21 +143,40 @@ def execute_scheduled_call(self, scheduled_call_id: str):
@app.task(name=RUN_SCHEDULED_CALLS) @app.task(name=RUN_SCHEDULED_CALLS)
@safe_task_execution @safe_task_execution
def run_scheduled_calls(): def run_scheduled_calls():
"""Run scheduled calls that are due.""" """Run scheduled calls that are due.
Uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions when
multiple workers query for due calls simultaneously.
"""
with make_session() as session: with make_session() as session:
# Use FOR UPDATE SKIP LOCKED to atomically claim pending calls
# This prevents multiple workers from processing the same call
#
# Note: scheduled_time is stored as naive datetime (assumed UTC).
# We compare against current UTC time, also as naive datetime.
now_utc = datetime.now(timezone.utc).replace(tzinfo=None)
calls = ( calls = (
session.query(ScheduledLLMCall) session.query(ScheduledLLMCall)
.filter( .filter(
ScheduledLLMCall.status.in_(["pending"]), ScheduledLLMCall.status.in_(["pending"]),
ScheduledLLMCall.scheduled_time ScheduledLLMCall.scheduled_time < now_utc,
< datetime.now(timezone.utc).replace(tzinfo=None),
) )
.with_for_update(skip_locked=True)
.all() .all()
) )
# Mark calls as queued before dispatching to prevent re-processing
call_ids = []
for call in calls: for call in calls:
execute_scheduled_call.delay(call.id) call.status = "queued"
call_ids.append(call.id)
session.commit()
# Now dispatch tasks for queued calls
for call_id in call_ids:
execute_scheduled_call.delay(call_id)
return { return {
"calls": [call.id for call in calls], "calls": call_ids,
"count": len(calls), "count": len(call_ids),
} }