diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/memory.py index 24cb1a4..a925e1a 100644 --- a/src/memory/api/MCP/memory.py +++ b/src/memory/api/MCP/memory.py @@ -27,6 +27,32 @@ from memory.common.formatters import observation logger = logging.getLogger(__name__) +def validate_path_within_directory( + base_dir: pathlib.Path, requested_path: str +) -> pathlib.Path: + """Validate that a requested path resolves within the base directory. + + Prevents path traversal attacks using ../ or similar techniques. + + Args: + base_dir: The allowed base directory + requested_path: The user-provided path + + Returns: + The resolved absolute path if valid + + Raises: + ValueError: If the path would escape the base directory + """ + resolved = (base_dir / requested_path.lstrip("/")).resolve() + base_resolved = base_dir.resolve() + + if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved: + raise ValueError(f"Path escapes allowed directory: {requested_path}") + + return resolved + + def filter_observation_source_ids( tags: list[str] | None = None, observation_types: list[str] | None = None ): @@ -344,7 +370,11 @@ async def note_files(path: str = "/"): Returns: List of file paths relative to notes directory """ - root = settings.NOTES_STORAGE_DIR / path.lstrip("/") + try: + root = validate_path_within_directory(settings.NOTES_STORAGE_DIR, path) + except ValueError as e: + raise ValueError(f"Invalid path: {e}") + return [ f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}" for f in root.rglob("*.md") @@ -359,15 +389,19 @@ def fetch_file(filename: str) -> dict: Returns dict with content, mime_type, is_text, file_size. Text content as string, binary as base64. """ - path = settings.FILE_STORAGE_DIR / filename.strip().lstrip("/") - print("fetching file", path) + try: + path = validate_path_within_directory( + settings.FILE_STORAGE_DIR, filename.strip() + ) + except ValueError as e: + raise ValueError(f"Invalid path: {e}") + + logger.debug(f"Fetching file: {path}") if not path.exists(): raise FileNotFoundError(f"File not found: {filename}") mime_type = extract.get_mime_type(path) chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True) - print("mime_type", mime_type) - print("chunks", chunks) def serialize_chunk( chunk: extract.DataChunk, data: extract.MulitmodalChunk diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 3b7d939..e32df7b 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -6,6 +6,7 @@ import contextlib import os import logging import mimetypes +import pathlib from fastapi import FastAPI, UploadFile, Request, HTTPException from fastapi.responses import FileResponse @@ -33,27 +34,58 @@ async def lifespan(app: FastAPI): app = FastAPI(title="Knowledge Base API", lifespan=lifespan) app.add_middleware(AuthenticationMiddleware) +# Configure CORS with specific origin to prevent CSRF attacks. +# allow_credentials=True requires specific origins, not wildcards. app.add_middleware( CORSMiddleware, - allow_origins=["*"], # [settings.SERVER_URL], + allow_origins=[settings.SERVER_URL], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) +def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str) -> pathlib.Path: + """Validate that a requested path resolves within the base directory. + + Prevents path traversal attacks using ../ or similar techniques. + + Args: + base_dir: The allowed base directory + requested_path: The user-provided path + + Returns: + The resolved absolute path if valid + + Raises: + HTTPException: If the path would escape the base directory + """ + # Resolve to absolute path and ensure it's within base_dir + resolved = (base_dir / requested_path).resolve() + base_resolved = base_dir.resolve() + + if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved: + raise HTTPException(status_code=403, detail="Access denied") + + return resolved + + @app.get("/ui{full_path:path}") async def serve_react_app(full_path: str): full_path = full_path.lstrip("/") - index_file = settings.STATIC_DIR / full_path - if index_file.is_file(): - return FileResponse(index_file) + try: + index_file = validate_path_within_directory(settings.STATIC_DIR, full_path) + if index_file.is_file(): + return FileResponse(index_file) + except HTTPException: + pass # Fall through to index.html for SPA routing return FileResponse(settings.STATIC_DIR / "index.html") @app.get("/files/{path:path}") async def serve_file(path: str): - file_path = settings.FILE_STORAGE_DIR / path + file_path = validate_path_within_directory(settings.FILE_STORAGE_DIR, path) + if not file_path.is_file(): raise HTTPException(status_code=404, detail="File not found") @@ -86,10 +118,36 @@ app.include_router(auth_router) # Add health check to MCP server instead of main app @mcp.custom_route("/health", methods=["GET"]) async def health_check(request: Request): - """Simple health check endpoint on MCP server""" + """Health check endpoint that verifies all dependencies are accessible.""" from fastapi.responses import JSONResponse + from sqlalchemy import text - return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"}) + checks = {"mcp_oauth": "enabled"} + all_healthy = True + + # Check database connection + try: + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + checks["database"] = "healthy" + except Exception as e: + checks["database"] = f"unhealthy: {str(e)[:100]}" + all_healthy = False + + # Check Qdrant connection + try: + from memory.common.qdrant import get_qdrant_client + + client = get_qdrant_client() + client.get_collections() + checks["qdrant"] = "healthy" + except Exception as e: + checks["qdrant"] = f"unhealthy: {str(e)[:100]}" + all_healthy = False + + checks["status"] = "healthy" if all_healthy else "degraded" + status_code = 200 if all_healthy else 503 + return JSONResponse(checks, status_code=status_code) # Mount MCP server at root - OAuth endpoints need to be at root level diff --git a/src/memory/api/search/bm25.py b/src/memory/api/search/bm25.py index e20780f..7c0d73a 100644 --- a/src/memory/api/search/bm25.py +++ b/src/memory/api/search/bm25.py @@ -12,7 +12,7 @@ from memory.api.search.types import SearchFilters from memory.common import extract from memory.common.db.connection import make_session -from memory.common.db.models import Chunk, ConfidenceScore +from memory.common.db.models import Chunk, ConfidenceScore, SourceItem logger = logging.getLogger(__name__) @@ -29,9 +29,28 @@ async def search_bm25( Chunk.content.isnot(None), ) + # Join with SourceItem if we need size filters + needs_source_join = any(filters.get(k) for k in ["min_size", "max_size"]) + if needs_source_join: + items_query = items_query.join( + SourceItem, SourceItem.id == Chunk.source_id + ) + if source_ids := filters.get("source_ids"): items_query = items_query.filter(Chunk.source_id.in_(source_ids)) + # Size filters + if min_size := filters.get("min_size"): + items_query = items_query.filter(SourceItem.size >= min_size) + if max_size := filters.get("max_size"): + items_query = items_query.filter(SourceItem.size <= max_size) + + # Observation type filter - restricts to specific collection types + if observation_types := filters.get("observation_types"): + items_query = items_query.filter( + Chunk.collection_name.in_(observation_types) + ) + # Add confidence filtering if specified if min_confidences := filters.get("min_confidences"): for confidence_type, min_score in min_confidences.items(): diff --git a/src/memory/api/search/embeddings.py b/src/memory/api/search/embeddings.py index dd5721b..19a9539 100644 --- a/src/memory/api/search/embeddings.py +++ b/src/memory/api/search/embeddings.py @@ -182,13 +182,16 @@ async def search_chunks_embeddings( filters: SearchFilters = SearchFilters(), timeout: int = 2, ) -> list[str]: + # Note: Multimodal embeddings typically produce higher similarity scores, + # so we use a higher threshold (0.4) to maintain selectivity. + # Text embeddings produce lower scores, so we use 0.25. all_ids = await asyncio.gather( asyncio.wait_for( search_chunks( data, modalities & TEXT_COLLECTIONS, limit, - 0.4, + 0.25, filters, False, ), @@ -199,7 +202,7 @@ async def search_chunks_embeddings( data, modalities & MULTIMODAL_COLLECTIONS, limit, - 0.25, + 0.4, filters, True, ), diff --git a/src/memory/api/search/scorer.py b/src/memory/api/search/scorer.py index 5049d95..d26b392 100644 --- a/src/memory/api/search/scorer.py +++ b/src/memory/api/search/scorer.py @@ -1,10 +1,13 @@ import asyncio +import logging from bs4 import BeautifulSoup from PIL import Image from memory.common.db.models.source_item import Chunk from memory.common import llms, settings, tokens +logger = logging.getLogger(__name__) + SCORE_CHUNK_SYSTEM_PROMPT = """ You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query. @@ -32,7 +35,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk: try: data = chunk.data except Exception as e: - print(f"Error getting chunk data: {e}, {type(e)}") + logger.error(f"Error getting chunk data: {e}") return chunk chunk_text = "\n".join(text for text in data if isinstance(text, str)) @@ -47,7 +50,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk: system_prompt=SCORE_CHUNK_SYSTEM_PROMPT, ) except Exception as e: - print(f"Error scoring chunk: {e}, {type(e)}") + logger.error(f"Error scoring chunk: {e}") return chunk soup = BeautifulSoup(response, "html.parser") diff --git a/src/memory/api/search/types.py b/src/memory/api/search/types.py index 5f9d4ee..dfc5a76 100644 --- a/src/memory/api/search/types.py +++ b/src/memory/api/search/types.py @@ -80,3 +80,15 @@ class SearchConfig(BaseModel): timeout: int = 20 previews: bool = False useScores: bool = False + + def model_post_init(self, __context) -> None: + # Enforce reasonable limits + if self.limit < 1: + object.__setattr__(self, "limit", 1) + elif self.limit > 1000: + object.__setattr__(self, "limit", 1000) + + if self.timeout < 1: + object.__setattr__(self, "timeout", 1) + elif self.timeout > 300: + object.__setattr__(self, "timeout", 300) diff --git a/src/memory/common/celery_app.py b/src/memory/common/celery_app.py index 55f90d3..37ed296 100644 --- a/src/memory/common/celery_app.py +++ b/src/memory/common/celery_app.py @@ -88,6 +88,14 @@ app.conf.update( task_acks_late=True, task_reject_on_worker_lost=True, worker_prefetch_multiplier=1, + # Default retry configuration for transient failures + task_autoretry_for=(Exception,), + task_retry_kwargs={"max_retries": 3}, + task_retry_backoff=True, + task_retry_backoff_max=600, # Max 10 minutes between retries + task_retry_jitter=True, + task_time_limit=3600, # 1 hour hard limit + task_soft_time_limit=3000, # 50 minute soft limit task_routes={ f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"}, f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"}, diff --git a/src/memory/common/chunker.py b/src/memory/common/chunker.py index a6bdfcb..bcefdb2 100644 --- a/src/memory/common/chunker.py +++ b/src/memory/common/chunker.py @@ -102,29 +102,35 @@ def chunk_text( current = "" for span in yield_spans(text, max_tokens): - current = f"{current} {span}".strip() - if tokens.approx_token_count(current) < max_tokens: + # Check if adding this span would exceed the limit + new_chunk = f"{current} {span}".strip() if current else span + if tokens.approx_token_count(new_chunk) <= max_tokens: + current = new_chunk continue - if overlap <= 0: + # Adding span would exceed limit - yield current first (if non-empty) + if current: yield current - current = "" + + # Handle overlap for the next chunk + if overlap <= 0 or not current: + current = span continue - overlap_text = current[-overlap_chars:] + # Try to find a clean break point for overlap + overlap_text = current[-overlap_chars:] if len(current) > overlap_chars else current clean_break = max( overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ") ) if clean_break < 0: - yield current - current = "" + current = span continue - break_offset = -overlap_chars + clean_break + 1 - chunk = current[break_offset:].strip() - yield current - current = chunk + # Start new chunk with overlap from clean break + break_offset = -len(overlap_text) + clean_break + 1 + overlap_portion = current[break_offset:].strip() + current = f"{overlap_portion} {span}".strip() if overlap_portion else span if current: yield current.strip() diff --git a/src/memory/common/collections.py b/src/memory/common/collections.py index 98fc2cc..cc7a71a 100644 --- a/src/memory/common/collections.py +++ b/src/memory/common/collections.py @@ -132,9 +132,14 @@ def get_modality(mime_type: str) -> str: def collection_model( collection: str, text: str, images: list[Image.Image] ) -> str | None: + """Determine the appropriate embedding model for a collection. + + Returns None if no suitable model can be determined, rather than + falling back to an invalid placeholder. + """ config = ALL_COLLECTIONS.get(collection, {}) if images and config.get("multimodal"): return settings.MIXED_EMBEDDING_MODEL if text and config.get("text"): return settings.TEXT_EMBEDDING_MODEL - return "unknown" + return None diff --git a/src/memory/common/db/models/discord.py b/src/memory/common/db/models/discord.py index 752ce45..0b5c568 100644 --- a/src/memory/common/db/models/discord.py +++ b/src/memory/common/db/models/discord.py @@ -16,7 +16,7 @@ from sqlalchemy import ( Text, func, ) -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, object_session from memory.common.db.models.base import Base @@ -27,6 +27,25 @@ class MessageProcessor: allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}") + @property + def mcp_servers(self) -> list: + """Get MCP servers assigned to this entity via MCPServerAssignment.""" + from memory.common.db.models.mcp import MCPServer, MCPServerAssignment + + session = object_session(self) + if not session: + return [] + + return ( + session.query(MCPServer) + .join(MCPServerAssignment) + .filter( + MCPServerAssignment.entity_type == self.entity_type, + MCPServerAssignment.entity_id == self.id, + ) + .all() + ) + system_prompt = Column( Text, nullable=True, diff --git a/src/memory/common/db/models/scheduled_calls.py b/src/memory/common/db/models/scheduled_calls.py index e690f88..0ca2303 100644 --- a/src/memory/common/db/models/scheduled_calls.py +++ b/src/memory/common/db/models/scheduled_calls.py @@ -7,6 +7,7 @@ from sqlalchemy import ( DateTime, ForeignKey, BigInteger, + Integer, JSON, Text, ) @@ -20,7 +21,7 @@ class ScheduledLLMCall(Base): __tablename__ = "scheduled_llm_calls" id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) - user_id = Column(BigInteger, ForeignKey("users.id"), nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) topic = Column(Text, nullable=True) # Scheduling info diff --git a/src/memory/common/db/models/source_item.py b/src/memory/common/db/models/source_item.py index a2f93c2..be6a239 100644 --- a/src/memory/common/db/models/source_item.py +++ b/src/memory/common/db/models/source_item.py @@ -165,6 +165,7 @@ class Chunk(Base): __table_args__ = ( CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"), Index("chunk_source_idx", "source_id"), + Index("chunk_collection_idx", "collection_name"), ) @property diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index 3fc1f20..2f99e0e 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -1,4 +1,5 @@ import logging +import time from typing import Literal, cast import voyageai @@ -15,6 +16,12 @@ from memory.common.db.models import Chunk, SourceItem logger = logging.getLogger(__name__) +class EmbeddingError(Exception): + """Raised when embedding generation fails after retries.""" + + pass + + def as_string( chunk: extract.MulitmodalChunk | list[extract.MulitmodalChunk], ) -> str: @@ -29,21 +36,60 @@ def embed_chunks( chunks: list[list[extract.MulitmodalChunk]], model: str = settings.TEXT_EMBEDDING_MODEL, input_type: Literal["document", "query"] = "document", + max_retries: int = 3, + retry_delay: float = 1.0, ) -> list[Vector]: - logger.debug(f"Embedding chunks: {model} - {str(chunks)} {len(chunks)}") - vo = voyageai.Client() # type: ignore - if model == settings.MIXED_EMBEDDING_MODEL: - return vo.multimodal_embed( - chunks, - model=model, - input_type=input_type, - ).embeddings + """Embed chunks with retry logic for transient failures. - texts = [as_string(c) for c in chunks] - logger.debug(f"Embedding texts: {texts}") - return cast( - list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings - ) + Args: + chunks: List of chunk lists to embed + model: Embedding model to use + input_type: Whether embedding documents or queries + max_retries: Maximum number of retry attempts + retry_delay: Base delay between retries (exponential backoff) + + Returns: + List of embedding vectors + + Raises: + EmbeddingError: If embedding fails after all retries + """ + if not chunks: + return [] + + logger.debug(f"Embedding {len(chunks)} chunks with model {model}") + vo = voyageai.Client() # type: ignore + + last_error = None + for attempt in range(max_retries): + try: + if model == settings.MIXED_EMBEDDING_MODEL: + return vo.multimodal_embed( + chunks, + model=model, + input_type=input_type, + ).embeddings + + texts = [as_string(c) for c in chunks] + return cast( + list[Vector], + vo.embed(texts, model=model, input_type=input_type).embeddings, + ) + except Exception as e: + last_error = e + if attempt < max_retries - 1: + delay = retry_delay * (2**attempt) + logger.warning( + f"Embedding attempt {attempt + 1}/{max_retries} failed: {e}. " + f"Retrying in {delay:.1f}s..." + ) + time.sleep(delay) + else: + logger.error(f"Embedding failed after {max_retries} attempts: {e}") + + raise EmbeddingError( + f"Failed to generate embeddings after {max_retries} attempts" + ) from last_error def break_chunk( diff --git a/src/memory/common/llms/anthropic_provider.py b/src/memory/common/llms/anthropic_provider.py index 172baf1..4290539 100644 --- a/src/memory/common/llms/anthropic_provider.py +++ b/src/memory/common/llms/anthropic_provider.py @@ -284,6 +284,8 @@ class AnthropicProvider(BaseLLMProvider): # Include server info if present if current_tool_use.get("server_name"): tool_data["server_name"] = current_tool_use["server_name"] + if current_tool_use.get("is_server_call"): + tool_data["is_server_call"] = current_tool_use["is_server_call"] # Emit different event type for MCP server tools if current_tool_use.get("is_server_call"): diff --git a/src/memory/discord/api.py b/src/memory/discord/api.py index 532018f..05d679f 100644 --- a/src/memory/discord/api.py +++ b/src/memory/discord/api.py @@ -60,7 +60,7 @@ class Collector: bot_name: str def __init__(self, collector: MessageCollector, bot: DiscordBotUser): - logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}") + logger.info(f"Initialized collector for {bot.name}") self.collector = collector self.collector_task = asyncio.create_task(collector.start(str(bot.api_key))) self.bot_id = cast(int, bot.id) @@ -80,8 +80,8 @@ async def lifespan(app: FastAPI): bots = session.query(DiscordBotUser).all() app.bots = {bot.id: make_collector(bot) for bot in bots} - logger.error( - f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}" + logger.info( + f"Discord collectors started for {len(app.bots)} bots: {list(app.bots.keys())}" ) yield diff --git a/src/memory/discord/collector.py b/src/memory/discord/collector.py index dc69cc9..07eb9cd 100644 --- a/src/memory/discord/collector.py +++ b/src/memory/discord/collector.py @@ -354,7 +354,6 @@ class MessageCollector(commands.Bot): "users_updated": users_updated, } - print(f"✅ Metadata refresh complete: {result}") logger.info(f"Metadata refresh complete: {result}") return result diff --git a/src/memory/workers/tasks/content_processing.py b/src/memory/workers/tasks/content_processing.py index d1584e6..53246b3 100644 --- a/src/memory/workers/tasks/content_processing.py +++ b/src/memory/workers/tasks/content_processing.py @@ -259,22 +259,26 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]: """ Decorator for safe task execution with comprehensive error handling. - Wraps task functions to catch and log exceptions, ensuring tasks - always return a result dictionary even when they fail. + Wraps task functions to log exceptions and notify on failures while + still allowing Celery to handle retries. Exceptions are re-raised after + logging to allow Celery's retry mechanism to work. Args: func: Task function to wrap Returns: - Wrapped function that handles exceptions gracefully + Wrapped function that logs exceptions and re-raises for retry Example: + @app.task(bind=True) @safe_task_execution - def my_task(arg1, arg2): + def my_task(self, arg1, arg2): # Task implementation return {"status": "success"} """ + from functools import wraps + @wraps(func) def wrapper(*args, **kwargs) -> dict: try: return func(*args, **kwargs) @@ -283,14 +287,25 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]: traceback_str = traceback.format_exc() logger.error(traceback_str) - notify_task_failure( - task_name=func.__name__, - error_message=str(e), - task_args=args, - task_kwargs=kwargs, - traceback_str=traceback_str, + # Check if this is a bound task and if retries are exhausted + task_self = args[0] if args and hasattr(args[0], "request") else None + is_final_retry = ( + task_self + and hasattr(task_self, "request") + and task_self.request.retries >= task_self.max_retries ) - return {"status": "error", "error": str(e), "traceback": traceback_str} + # Notify on final failure only + if is_final_retry or task_self is None: + notify_task_failure( + task_name=func.__name__, + error_message=str(e), + task_args=args[1:] if task_self else args, + task_kwargs=kwargs, + traceback_str=traceback_str, + ) + + # Re-raise to allow Celery retries + raise return wrapper diff --git a/src/memory/workers/tasks/ebook.py b/src/memory/workers/tasks/ebook.py index aedaeb2..8fc4cc9 100644 --- a/src/memory/workers/tasks/ebook.py +++ b/src/memory/workers/tasks/ebook.py @@ -189,7 +189,7 @@ def sync_book( # Create book and sections with relationships book, all_sections = create_book_and_sections(ebook, session, tags) for section in all_sections: - print(section.section_title, section.book) + logger.debug(f"Created section: {section.section_title}") if title: book.title = title # type: ignore diff --git a/src/memory/workers/tasks/scheduled_calls.py b/src/memory/workers/tasks/scheduled_calls.py index 8d83867..721a165 100644 --- a/src/memory/workers/tasks/scheduled_calls.py +++ b/src/memory/workers/tasks/scheduled_calls.py @@ -76,12 +76,12 @@ def execute_scheduled_call(self, scheduled_call_id: str): logger.error(f"Scheduled call {scheduled_call_id} not found") return {"error": "Scheduled call not found"} - # Check if the call is still pending - if not scheduled_call.is_pending(): + # Check if the call is ready to execute (pending or queued) + if scheduled_call.status not in ("pending", "queued"): logger.warning( - f"Scheduled call {scheduled_call_id} is not pending (status: {scheduled_call.status})" + f"Scheduled call {scheduled_call_id} is not ready (status: {scheduled_call.status})" ) - return {"error": f"Call is not pending (status: {scheduled_call.status})"} + return {"error": f"Call is not ready (status: {scheduled_call.status})"} # Update status to executing scheduled_call.status = "executing" @@ -143,21 +143,40 @@ def execute_scheduled_call(self, scheduled_call_id: str): @app.task(name=RUN_SCHEDULED_CALLS) @safe_task_execution def run_scheduled_calls(): - """Run scheduled calls that are due.""" + """Run scheduled calls that are due. + + Uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions when + multiple workers query for due calls simultaneously. + """ with make_session() as session: + # Use FOR UPDATE SKIP LOCKED to atomically claim pending calls + # This prevents multiple workers from processing the same call + # + # Note: scheduled_time is stored as naive datetime (assumed UTC). + # We compare against current UTC time, also as naive datetime. + now_utc = datetime.now(timezone.utc).replace(tzinfo=None) calls = ( session.query(ScheduledLLMCall) .filter( ScheduledLLMCall.status.in_(["pending"]), - ScheduledLLMCall.scheduled_time - < datetime.now(timezone.utc).replace(tzinfo=None), + ScheduledLLMCall.scheduled_time < now_utc, ) + .with_for_update(skip_locked=True) .all() ) + + # Mark calls as queued before dispatching to prevent re-processing + call_ids = [] for call in calls: - execute_scheduled_call.delay(call.id) + call.status = "queued" + call_ids.append(call.id) + session.commit() + + # Now dispatch tasks for queued calls + for call_id in call_ids: + execute_scheduled_call.delay(call_id) return { - "calls": [call.id for call in calls], - "count": len(calls), + "calls": call_ids, + "count": len(call_ids), }