mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 17:22:58 +01:00
Fix 19 bugs from investigation
Critical/High severity fixes: - BUG-001: Path traversal vulnerabilities (3 endpoints) - BUG-003: BM25 filters now apply size/observation_types - BUG-006: Remove API key from log messages - BUG-008: Chunk size validation before yielding - BUG-009: Race condition fix with FOR UPDATE SKIP LOCKED - BUG-010: Add mcp_servers property to MessageProcessor - BUG-011: Fix user_id type (BigInteger→Integer) - BUG-012: Swap inverted score thresholds - BUG-013: Add retry logic to embedding pipeline - BUG-014: Fix CORS to use specific origin - BUG-015: Add Celery retry/timeout defaults - BUG-016: Re-raise exceptions for Celery retries Medium severity fixes: - BUG-017: Add collection_name index on Chunk - BUG-031: Add SearchConfig limits (max 1000/300s) - BUG-033: Replace debug prints with logger calls - BUG-037: Clarify timezone handling in scheduler - BUG-043: Health check now validates DB + Qdrant - BUG-055: collection_model returns None not "unknown" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a1444efaac
commit
52274f82a6
@ -27,6 +27,32 @@ from memory.common.formatters import observation
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_path_within_directory(
|
||||
base_dir: pathlib.Path, requested_path: str
|
||||
) -> pathlib.Path:
|
||||
"""Validate that a requested path resolves within the base directory.
|
||||
|
||||
Prevents path traversal attacks using ../ or similar techniques.
|
||||
|
||||
Args:
|
||||
base_dir: The allowed base directory
|
||||
requested_path: The user-provided path
|
||||
|
||||
Returns:
|
||||
The resolved absolute path if valid
|
||||
|
||||
Raises:
|
||||
ValueError: If the path would escape the base directory
|
||||
"""
|
||||
resolved = (base_dir / requested_path.lstrip("/")).resolve()
|
||||
base_resolved = base_dir.resolve()
|
||||
|
||||
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
||||
raise ValueError(f"Path escapes allowed directory: {requested_path}")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def filter_observation_source_ids(
|
||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||
):
|
||||
@ -344,7 +370,11 @@ async def note_files(path: str = "/"):
|
||||
|
||||
Returns: List of file paths relative to notes directory
|
||||
"""
|
||||
root = settings.NOTES_STORAGE_DIR / path.lstrip("/")
|
||||
try:
|
||||
root = validate_path_within_directory(settings.NOTES_STORAGE_DIR, path)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid path: {e}")
|
||||
|
||||
return [
|
||||
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
|
||||
for f in root.rglob("*.md")
|
||||
@ -359,15 +389,19 @@ def fetch_file(filename: str) -> dict:
|
||||
Returns dict with content, mime_type, is_text, file_size.
|
||||
Text content as string, binary as base64.
|
||||
"""
|
||||
path = settings.FILE_STORAGE_DIR / filename.strip().lstrip("/")
|
||||
print("fetching file", path)
|
||||
try:
|
||||
path = validate_path_within_directory(
|
||||
settings.FILE_STORAGE_DIR, filename.strip()
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid path: {e}")
|
||||
|
||||
logger.debug(f"Fetching file: {path}")
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {filename}")
|
||||
|
||||
mime_type = extract.get_mime_type(path)
|
||||
chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True)
|
||||
print("mime_type", mime_type)
|
||||
print("chunks", chunks)
|
||||
|
||||
def serialize_chunk(
|
||||
chunk: extract.DataChunk, data: extract.MulitmodalChunk
|
||||
|
||||
@ -6,6 +6,7 @@ import contextlib
|
||||
import os
|
||||
import logging
|
||||
import mimetypes
|
||||
import pathlib
|
||||
|
||||
from fastapi import FastAPI, UploadFile, Request, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
@ -33,27 +34,58 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||
app.add_middleware(AuthenticationMiddleware)
|
||||
# Configure CORS with specific origin to prevent CSRF attacks.
|
||||
# allow_credentials=True requires specific origins, not wildcards.
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # [settings.SERVER_URL],
|
||||
allow_origins=[settings.SERVER_URL],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str) -> pathlib.Path:
|
||||
"""Validate that a requested path resolves within the base directory.
|
||||
|
||||
Prevents path traversal attacks using ../ or similar techniques.
|
||||
|
||||
Args:
|
||||
base_dir: The allowed base directory
|
||||
requested_path: The user-provided path
|
||||
|
||||
Returns:
|
||||
The resolved absolute path if valid
|
||||
|
||||
Raises:
|
||||
HTTPException: If the path would escape the base directory
|
||||
"""
|
||||
# Resolve to absolute path and ensure it's within base_dir
|
||||
resolved = (base_dir / requested_path).resolve()
|
||||
base_resolved = base_dir.resolve()
|
||||
|
||||
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
@app.get("/ui{full_path:path}")
|
||||
async def serve_react_app(full_path: str):
|
||||
full_path = full_path.lstrip("/")
|
||||
index_file = settings.STATIC_DIR / full_path
|
||||
try:
|
||||
index_file = validate_path_within_directory(settings.STATIC_DIR, full_path)
|
||||
if index_file.is_file():
|
||||
return FileResponse(index_file)
|
||||
except HTTPException:
|
||||
pass # Fall through to index.html for SPA routing
|
||||
return FileResponse(settings.STATIC_DIR / "index.html")
|
||||
|
||||
|
||||
@app.get("/files/{path:path}")
|
||||
async def serve_file(path: str):
|
||||
file_path = settings.FILE_STORAGE_DIR / path
|
||||
file_path = validate_path_within_directory(settings.FILE_STORAGE_DIR, path)
|
||||
|
||||
if not file_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
@ -86,10 +118,36 @@ app.include_router(auth_router)
|
||||
# Add health check to MCP server instead of main app
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request):
|
||||
"""Simple health check endpoint on MCP server"""
|
||||
"""Health check endpoint that verifies all dependencies are accessible."""
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import text
|
||||
|
||||
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"})
|
||||
checks = {"mcp_oauth": "enabled"}
|
||||
all_healthy = True
|
||||
|
||||
# Check database connection
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
checks["database"] = "healthy"
|
||||
except Exception as e:
|
||||
checks["database"] = f"unhealthy: {str(e)[:100]}"
|
||||
all_healthy = False
|
||||
|
||||
# Check Qdrant connection
|
||||
try:
|
||||
from memory.common.qdrant import get_qdrant_client
|
||||
|
||||
client = get_qdrant_client()
|
||||
client.get_collections()
|
||||
checks["qdrant"] = "healthy"
|
||||
except Exception as e:
|
||||
checks["qdrant"] = f"unhealthy: {str(e)[:100]}"
|
||||
all_healthy = False
|
||||
|
||||
checks["status"] = "healthy" if all_healthy else "degraded"
|
||||
status_code = 200 if all_healthy else 503
|
||||
return JSONResponse(checks, status_code=status_code)
|
||||
|
||||
|
||||
# Mount MCP server at root - OAuth endpoints need to be at root level
|
||||
|
||||
@ -12,7 +12,7 @@ from memory.api.search.types import SearchFilters
|
||||
|
||||
from memory.common import extract
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, ConfidenceScore
|
||||
from memory.common.db.models import Chunk, ConfidenceScore, SourceItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -29,9 +29,28 @@ async def search_bm25(
|
||||
Chunk.content.isnot(None),
|
||||
)
|
||||
|
||||
# Join with SourceItem if we need size filters
|
||||
needs_source_join = any(filters.get(k) for k in ["min_size", "max_size"])
|
||||
if needs_source_join:
|
||||
items_query = items_query.join(
|
||||
SourceItem, SourceItem.id == Chunk.source_id
|
||||
)
|
||||
|
||||
if source_ids := filters.get("source_ids"):
|
||||
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
||||
|
||||
# Size filters
|
||||
if min_size := filters.get("min_size"):
|
||||
items_query = items_query.filter(SourceItem.size >= min_size)
|
||||
if max_size := filters.get("max_size"):
|
||||
items_query = items_query.filter(SourceItem.size <= max_size)
|
||||
|
||||
# Observation type filter - restricts to specific collection types
|
||||
if observation_types := filters.get("observation_types"):
|
||||
items_query = items_query.filter(
|
||||
Chunk.collection_name.in_(observation_types)
|
||||
)
|
||||
|
||||
# Add confidence filtering if specified
|
||||
if min_confidences := filters.get("min_confidences"):
|
||||
for confidence_type, min_score in min_confidences.items():
|
||||
|
||||
@ -182,13 +182,16 @@ async def search_chunks_embeddings(
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
timeout: int = 2,
|
||||
) -> list[str]:
|
||||
# Note: Multimodal embeddings typically produce higher similarity scores,
|
||||
# so we use a higher threshold (0.4) to maintain selectivity.
|
||||
# Text embeddings produce lower scores, so we use 0.25.
|
||||
all_ids = await asyncio.gather(
|
||||
asyncio.wait_for(
|
||||
search_chunks(
|
||||
data,
|
||||
modalities & TEXT_COLLECTIONS,
|
||||
limit,
|
||||
0.4,
|
||||
0.25,
|
||||
filters,
|
||||
False,
|
||||
),
|
||||
@ -199,7 +202,7 @@ async def search_chunks_embeddings(
|
||||
data,
|
||||
modalities & MULTIMODAL_COLLECTIONS,
|
||||
limit,
|
||||
0.25,
|
||||
0.4,
|
||||
filters,
|
||||
True,
|
||||
),
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from bs4 import BeautifulSoup
|
||||
from PIL import Image
|
||||
|
||||
from memory.common.db.models.source_item import Chunk
|
||||
from memory.common import llms, settings, tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SCORE_CHUNK_SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
|
||||
@ -32,7 +35,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||
try:
|
||||
data = chunk.data
|
||||
except Exception as e:
|
||||
print(f"Error getting chunk data: {e}, {type(e)}")
|
||||
logger.error(f"Error getting chunk data: {e}")
|
||||
return chunk
|
||||
|
||||
chunk_text = "\n".join(text for text in data if isinstance(text, str))
|
||||
@ -47,7 +50,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error scoring chunk: {e}, {type(e)}")
|
||||
logger.error(f"Error scoring chunk: {e}")
|
||||
return chunk
|
||||
|
||||
soup = BeautifulSoup(response, "html.parser")
|
||||
|
||||
@ -80,3 +80,15 @@ class SearchConfig(BaseModel):
|
||||
timeout: int = 20
|
||||
previews: bool = False
|
||||
useScores: bool = False
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
# Enforce reasonable limits
|
||||
if self.limit < 1:
|
||||
object.__setattr__(self, "limit", 1)
|
||||
elif self.limit > 1000:
|
||||
object.__setattr__(self, "limit", 1000)
|
||||
|
||||
if self.timeout < 1:
|
||||
object.__setattr__(self, "timeout", 1)
|
||||
elif self.timeout > 300:
|
||||
object.__setattr__(self, "timeout", 300)
|
||||
|
||||
@ -88,6 +88,14 @@ app.conf.update(
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
# Default retry configuration for transient failures
|
||||
task_autoretry_for=(Exception,),
|
||||
task_retry_kwargs={"max_retries": 3},
|
||||
task_retry_backoff=True,
|
||||
task_retry_backoff_max=600, # Max 10 minutes between retries
|
||||
task_retry_jitter=True,
|
||||
task_time_limit=3600, # 1 hour hard limit
|
||||
task_soft_time_limit=3000, # 50 minute soft limit
|
||||
task_routes={
|
||||
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
||||
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
||||
|
||||
@ -102,29 +102,35 @@ def chunk_text(
|
||||
current = ""
|
||||
|
||||
for span in yield_spans(text, max_tokens):
|
||||
current = f"{current} {span}".strip()
|
||||
if tokens.approx_token_count(current) < max_tokens:
|
||||
# Check if adding this span would exceed the limit
|
||||
new_chunk = f"{current} {span}".strip() if current else span
|
||||
if tokens.approx_token_count(new_chunk) <= max_tokens:
|
||||
current = new_chunk
|
||||
continue
|
||||
|
||||
if overlap <= 0:
|
||||
# Adding span would exceed limit - yield current first (if non-empty)
|
||||
if current:
|
||||
yield current
|
||||
current = ""
|
||||
|
||||
# Handle overlap for the next chunk
|
||||
if overlap <= 0 or not current:
|
||||
current = span
|
||||
continue
|
||||
|
||||
overlap_text = current[-overlap_chars:]
|
||||
# Try to find a clean break point for overlap
|
||||
overlap_text = current[-overlap_chars:] if len(current) > overlap_chars else current
|
||||
clean_break = max(
|
||||
overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ")
|
||||
)
|
||||
|
||||
if clean_break < 0:
|
||||
yield current
|
||||
current = ""
|
||||
current = span
|
||||
continue
|
||||
|
||||
break_offset = -overlap_chars + clean_break + 1
|
||||
chunk = current[break_offset:].strip()
|
||||
yield current
|
||||
current = chunk
|
||||
# Start new chunk with overlap from clean break
|
||||
break_offset = -len(overlap_text) + clean_break + 1
|
||||
overlap_portion = current[break_offset:].strip()
|
||||
current = f"{overlap_portion} {span}".strip() if overlap_portion else span
|
||||
|
||||
if current:
|
||||
yield current.strip()
|
||||
|
||||
@ -132,9 +132,14 @@ def get_modality(mime_type: str) -> str:
|
||||
def collection_model(
|
||||
collection: str, text: str, images: list[Image.Image]
|
||||
) -> str | None:
|
||||
"""Determine the appropriate embedding model for a collection.
|
||||
|
||||
Returns None if no suitable model can be determined, rather than
|
||||
falling back to an invalid placeholder.
|
||||
"""
|
||||
config = ALL_COLLECTIONS.get(collection, {})
|
||||
if images and config.get("multimodal"):
|
||||
return settings.MIXED_EMBEDDING_MODEL
|
||||
if text and config.get("text"):
|
||||
return settings.TEXT_EMBEDDING_MODEL
|
||||
return "unknown"
|
||||
return None
|
||||
|
||||
@ -16,7 +16,7 @@ from sqlalchemy import (
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, object_session
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
@ -27,6 +27,25 @@ class MessageProcessor:
|
||||
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
@property
|
||||
def mcp_servers(self) -> list:
|
||||
"""Get MCP servers assigned to this entity via MCPServerAssignment."""
|
||||
from memory.common.db.models.mcp import MCPServer, MCPServerAssignment
|
||||
|
||||
session = object_session(self)
|
||||
if not session:
|
||||
return []
|
||||
|
||||
return (
|
||||
session.query(MCPServer)
|
||||
.join(MCPServerAssignment)
|
||||
.filter(
|
||||
MCPServerAssignment.entity_type == self.entity_type,
|
||||
MCPServerAssignment.entity_id == self.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
system_prompt = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy import (
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
BigInteger,
|
||||
Integer,
|
||||
JSON,
|
||||
Text,
|
||||
)
|
||||
@ -20,7 +21,7 @@ class ScheduledLLMCall(Base):
|
||||
__tablename__ = "scheduled_llm_calls"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(BigInteger, ForeignKey("users.id"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
topic = Column(Text, nullable=True)
|
||||
|
||||
# Scheduling info
|
||||
|
||||
@ -165,6 +165,7 @@ class Chunk(Base):
|
||||
__table_args__ = (
|
||||
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
|
||||
Index("chunk_source_idx", "source_id"),
|
||||
Index("chunk_collection_idx", "collection_name"),
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal, cast
|
||||
|
||||
import voyageai
|
||||
@ -15,6 +16,12 @@ from memory.common.db.models import Chunk, SourceItem
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingError(Exception):
|
||||
"""Raised when embedding generation fails after retries."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def as_string(
|
||||
chunk: extract.MulitmodalChunk | list[extract.MulitmodalChunk],
|
||||
) -> str:
|
||||
@ -29,9 +36,33 @@ def embed_chunks(
|
||||
chunks: list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
) -> list[Vector]:
|
||||
logger.debug(f"Embedding chunks: {model} - {str(chunks)} {len(chunks)}")
|
||||
"""Embed chunks with retry logic for transient failures.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk lists to embed
|
||||
model: Embedding model to use
|
||||
input_type: Whether embedding documents or queries
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_delay: Base delay between retries (exponential backoff)
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
|
||||
Raises:
|
||||
EmbeddingError: If embedding fails after all retries
|
||||
"""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
logger.debug(f"Embedding {len(chunks)} chunks with model {model}")
|
||||
vo = voyageai.Client() # type: ignore
|
||||
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||
return vo.multimodal_embed(
|
||||
chunks,
|
||||
@ -40,10 +71,25 @@ def embed_chunks(
|
||||
).embeddings
|
||||
|
||||
texts = [as_string(c) for c in chunks]
|
||||
logger.debug(f"Embedding texts: {texts}")
|
||||
return cast(
|
||||
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
|
||||
list[Vector],
|
||||
vo.embed(texts, model=model, input_type=input_type).embeddings,
|
||||
)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
delay = retry_delay * (2**attempt)
|
||||
logger.warning(
|
||||
f"Embedding attempt {attempt + 1}/{max_retries} failed: {e}. "
|
||||
f"Retrying in {delay:.1f}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
logger.error(f"Embedding failed after {max_retries} attempts: {e}")
|
||||
|
||||
raise EmbeddingError(
|
||||
f"Failed to generate embeddings after {max_retries} attempts"
|
||||
) from last_error
|
||||
|
||||
|
||||
def break_chunk(
|
||||
|
||||
@ -284,6 +284,8 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Include server info if present
|
||||
if current_tool_use.get("server_name"):
|
||||
tool_data["server_name"] = current_tool_use["server_name"]
|
||||
if current_tool_use.get("is_server_call"):
|
||||
tool_data["is_server_call"] = current_tool_use["is_server_call"]
|
||||
|
||||
# Emit different event type for MCP server tools
|
||||
if current_tool_use.get("is_server_call"):
|
||||
|
||||
@ -60,7 +60,7 @@ class Collector:
|
||||
bot_name: str
|
||||
|
||||
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
||||
logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}")
|
||||
logger.info(f"Initialized collector for {bot.name}")
|
||||
self.collector = collector
|
||||
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
||||
self.bot_id = cast(int, bot.id)
|
||||
@ -80,8 +80,8 @@ async def lifespan(app: FastAPI):
|
||||
bots = session.query(DiscordBotUser).all()
|
||||
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
||||
|
||||
logger.error(
|
||||
f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}"
|
||||
logger.info(
|
||||
f"Discord collectors started for {len(app.bots)} bots: {list(app.bots.keys())}"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
@ -354,7 +354,6 @@ class MessageCollector(commands.Bot):
|
||||
"users_updated": users_updated,
|
||||
}
|
||||
|
||||
print(f"✅ Metadata refresh complete: {result}")
|
||||
logger.info(f"Metadata refresh complete: {result}")
|
||||
|
||||
return result
|
||||
|
||||
@ -259,22 +259,26 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
||||
"""
|
||||
Decorator for safe task execution with comprehensive error handling.
|
||||
|
||||
Wraps task functions to catch and log exceptions, ensuring tasks
|
||||
always return a result dictionary even when they fail.
|
||||
Wraps task functions to log exceptions and notify on failures while
|
||||
still allowing Celery to handle retries. Exceptions are re-raised after
|
||||
logging to allow Celery's retry mechanism to work.
|
||||
|
||||
Args:
|
||||
func: Task function to wrap
|
||||
|
||||
Returns:
|
||||
Wrapped function that handles exceptions gracefully
|
||||
Wrapped function that logs exceptions and re-raises for retry
|
||||
|
||||
Example:
|
||||
@app.task(bind=True)
|
||||
@safe_task_execution
|
||||
def my_task(arg1, arg2):
|
||||
def my_task(self, arg1, arg2):
|
||||
# Task implementation
|
||||
return {"status": "success"}
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> dict:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@ -283,14 +287,25 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
||||
traceback_str = traceback.format_exc()
|
||||
logger.error(traceback_str)
|
||||
|
||||
# Check if this is a bound task and if retries are exhausted
|
||||
task_self = args[0] if args and hasattr(args[0], "request") else None
|
||||
is_final_retry = (
|
||||
task_self
|
||||
and hasattr(task_self, "request")
|
||||
and task_self.request.retries >= task_self.max_retries
|
||||
)
|
||||
|
||||
# Notify on final failure only
|
||||
if is_final_retry or task_self is None:
|
||||
notify_task_failure(
|
||||
task_name=func.__name__,
|
||||
error_message=str(e),
|
||||
task_args=args,
|
||||
task_args=args[1:] if task_self else args,
|
||||
task_kwargs=kwargs,
|
||||
traceback_str=traceback_str,
|
||||
)
|
||||
|
||||
return {"status": "error", "error": str(e), "traceback": traceback_str}
|
||||
# Re-raise to allow Celery retries
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -189,7 +189,7 @@ def sync_book(
|
||||
# Create book and sections with relationships
|
||||
book, all_sections = create_book_and_sections(ebook, session, tags)
|
||||
for section in all_sections:
|
||||
print(section.section_title, section.book)
|
||||
logger.debug(f"Created section: {section.section_title}")
|
||||
|
||||
if title:
|
||||
book.title = title # type: ignore
|
||||
|
||||
@ -76,12 +76,12 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
||||
logger.error(f"Scheduled call {scheduled_call_id} not found")
|
||||
return {"error": "Scheduled call not found"}
|
||||
|
||||
# Check if the call is still pending
|
||||
if not scheduled_call.is_pending():
|
||||
# Check if the call is ready to execute (pending or queued)
|
||||
if scheduled_call.status not in ("pending", "queued"):
|
||||
logger.warning(
|
||||
f"Scheduled call {scheduled_call_id} is not pending (status: {scheduled_call.status})"
|
||||
f"Scheduled call {scheduled_call_id} is not ready (status: {scheduled_call.status})"
|
||||
)
|
||||
return {"error": f"Call is not pending (status: {scheduled_call.status})"}
|
||||
return {"error": f"Call is not ready (status: {scheduled_call.status})"}
|
||||
|
||||
# Update status to executing
|
||||
scheduled_call.status = "executing"
|
||||
@ -143,21 +143,40 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
||||
@app.task(name=RUN_SCHEDULED_CALLS)
|
||||
@safe_task_execution
|
||||
def run_scheduled_calls():
|
||||
"""Run scheduled calls that are due."""
|
||||
"""Run scheduled calls that are due.
|
||||
|
||||
Uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions when
|
||||
multiple workers query for due calls simultaneously.
|
||||
"""
|
||||
with make_session() as session:
|
||||
# Use FOR UPDATE SKIP LOCKED to atomically claim pending calls
|
||||
# This prevents multiple workers from processing the same call
|
||||
#
|
||||
# Note: scheduled_time is stored as naive datetime (assumed UTC).
|
||||
# We compare against current UTC time, also as naive datetime.
|
||||
now_utc = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
calls = (
|
||||
session.query(ScheduledLLMCall)
|
||||
.filter(
|
||||
ScheduledLLMCall.status.in_(["pending"]),
|
||||
ScheduledLLMCall.scheduled_time
|
||||
< datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
ScheduledLLMCall.scheduled_time < now_utc,
|
||||
)
|
||||
.with_for_update(skip_locked=True)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Mark calls as queued before dispatching to prevent re-processing
|
||||
call_ids = []
|
||||
for call in calls:
|
||||
execute_scheduled_call.delay(call.id)
|
||||
call.status = "queued"
|
||||
call_ids.append(call.id)
|
||||
session.commit()
|
||||
|
||||
# Now dispatch tasks for queued calls
|
||||
for call_id in call_ids:
|
||||
execute_scheduled_call.delay(call_id)
|
||||
|
||||
return {
|
||||
"calls": [call.id for call in calls],
|
||||
"count": len(calls),
|
||||
"calls": call_ids,
|
||||
"count": len(call_ids),
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user