mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 17:22:58 +01:00
Fix 19 bugs from investigation
Critical/High severity fixes: - BUG-001: Path traversal vulnerabilities (3 endpoints) - BUG-003: BM25 filters now apply size/observation_types - BUG-006: Remove API key from log messages - BUG-008: Chunk size validation before yielding - BUG-009: Race condition fix with FOR UPDATE SKIP LOCKED - BUG-010: Add mcp_servers property to MessageProcessor - BUG-011: Fix user_id type (BigInteger→Integer) - BUG-012: Swap inverted score thresholds - BUG-013: Add retry logic to embedding pipeline - BUG-014: Fix CORS to use specific origin - BUG-015: Add Celery retry/timeout defaults - BUG-016: Re-raise exceptions for Celery retries Medium severity fixes: - BUG-017: Add collection_name index on Chunk - BUG-031: Add SearchConfig limits (max 1000/300s) - BUG-033: Replace debug prints with logger calls - BUG-037: Clarify timezone handling in scheduler - BUG-043: Health check now validates DB + Qdrant - BUG-055: collection_model returns None not "unknown" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a1444efaac
commit
52274f82a6
@ -27,6 +27,32 @@ from memory.common.formatters import observation
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_path_within_directory(
|
||||||
|
base_dir: pathlib.Path, requested_path: str
|
||||||
|
) -> pathlib.Path:
|
||||||
|
"""Validate that a requested path resolves within the base directory.
|
||||||
|
|
||||||
|
Prevents path traversal attacks using ../ or similar techniques.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: The allowed base directory
|
||||||
|
requested_path: The user-provided path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The resolved absolute path if valid
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the path would escape the base directory
|
||||||
|
"""
|
||||||
|
resolved = (base_dir / requested_path.lstrip("/")).resolve()
|
||||||
|
base_resolved = base_dir.resolve()
|
||||||
|
|
||||||
|
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
||||||
|
raise ValueError(f"Path escapes allowed directory: {requested_path}")
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
def filter_observation_source_ids(
|
def filter_observation_source_ids(
|
||||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||||
):
|
):
|
||||||
@ -344,7 +370,11 @@ async def note_files(path: str = "/"):
|
|||||||
|
|
||||||
Returns: List of file paths relative to notes directory
|
Returns: List of file paths relative to notes directory
|
||||||
"""
|
"""
|
||||||
root = settings.NOTES_STORAGE_DIR / path.lstrip("/")
|
try:
|
||||||
|
root = validate_path_within_directory(settings.NOTES_STORAGE_DIR, path)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid path: {e}")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
|
f"/notes/{f.relative_to(settings.NOTES_STORAGE_DIR)}"
|
||||||
for f in root.rglob("*.md")
|
for f in root.rglob("*.md")
|
||||||
@ -359,15 +389,19 @@ def fetch_file(filename: str) -> dict:
|
|||||||
Returns dict with content, mime_type, is_text, file_size.
|
Returns dict with content, mime_type, is_text, file_size.
|
||||||
Text content as string, binary as base64.
|
Text content as string, binary as base64.
|
||||||
"""
|
"""
|
||||||
path = settings.FILE_STORAGE_DIR / filename.strip().lstrip("/")
|
try:
|
||||||
print("fetching file", path)
|
path = validate_path_within_directory(
|
||||||
|
settings.FILE_STORAGE_DIR, filename.strip()
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid path: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"Fetching file: {path}")
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise FileNotFoundError(f"File not found: {filename}")
|
raise FileNotFoundError(f"File not found: {filename}")
|
||||||
|
|
||||||
mime_type = extract.get_mime_type(path)
|
mime_type = extract.get_mime_type(path)
|
||||||
chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True)
|
chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True)
|
||||||
print("mime_type", mime_type)
|
|
||||||
print("chunks", chunks)
|
|
||||||
|
|
||||||
def serialize_chunk(
|
def serialize_chunk(
|
||||||
chunk: extract.DataChunk, data: extract.MulitmodalChunk
|
chunk: extract.DataChunk, data: extract.MulitmodalChunk
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import contextlib
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
import pathlib
|
||||||
|
|
||||||
from fastapi import FastAPI, UploadFile, Request, HTTPException
|
from fastapi import FastAPI, UploadFile, Request, HTTPException
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
@ -33,27 +34,58 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||||
app.add_middleware(AuthenticationMiddleware)
|
app.add_middleware(AuthenticationMiddleware)
|
||||||
|
# Configure CORS with specific origin to prevent CSRF attacks.
|
||||||
|
# allow_credentials=True requires specific origins, not wildcards.
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"], # [settings.SERVER_URL],
|
allow_origins=[settings.SERVER_URL],
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_path_within_directory(base_dir: pathlib.Path, requested_path: str) -> pathlib.Path:
|
||||||
|
"""Validate that a requested path resolves within the base directory.
|
||||||
|
|
||||||
|
Prevents path traversal attacks using ../ or similar techniques.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: The allowed base directory
|
||||||
|
requested_path: The user-provided path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The resolved absolute path if valid
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the path would escape the base directory
|
||||||
|
"""
|
||||||
|
# Resolve to absolute path and ensure it's within base_dir
|
||||||
|
resolved = (base_dir / requested_path).resolve()
|
||||||
|
base_resolved = base_dir.resolve()
|
||||||
|
|
||||||
|
if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ui{full_path:path}")
|
@app.get("/ui{full_path:path}")
|
||||||
async def serve_react_app(full_path: str):
|
async def serve_react_app(full_path: str):
|
||||||
full_path = full_path.lstrip("/")
|
full_path = full_path.lstrip("/")
|
||||||
index_file = settings.STATIC_DIR / full_path
|
try:
|
||||||
|
index_file = validate_path_within_directory(settings.STATIC_DIR, full_path)
|
||||||
if index_file.is_file():
|
if index_file.is_file():
|
||||||
return FileResponse(index_file)
|
return FileResponse(index_file)
|
||||||
|
except HTTPException:
|
||||||
|
pass # Fall through to index.html for SPA routing
|
||||||
return FileResponse(settings.STATIC_DIR / "index.html")
|
return FileResponse(settings.STATIC_DIR / "index.html")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/files/{path:path}")
|
@app.get("/files/{path:path}")
|
||||||
async def serve_file(path: str):
|
async def serve_file(path: str):
|
||||||
file_path = settings.FILE_STORAGE_DIR / path
|
file_path = validate_path_within_directory(settings.FILE_STORAGE_DIR, path)
|
||||||
|
|
||||||
if not file_path.is_file():
|
if not file_path.is_file():
|
||||||
raise HTTPException(status_code=404, detail="File not found")
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
@ -86,10 +118,36 @@ app.include_router(auth_router)
|
|||||||
# Add health check to MCP server instead of main app
|
# Add health check to MCP server instead of main app
|
||||||
@mcp.custom_route("/health", methods=["GET"])
|
@mcp.custom_route("/health", methods=["GET"])
|
||||||
async def health_check(request: Request):
|
async def health_check(request: Request):
|
||||||
"""Simple health check endpoint on MCP server"""
|
"""Health check endpoint that verifies all dependencies are accessible."""
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
return JSONResponse({"status": "healthy", "mcp_oauth": "enabled"})
|
checks = {"mcp_oauth": "enabled"}
|
||||||
|
all_healthy = True
|
||||||
|
|
||||||
|
# Check database connection
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
checks["database"] = "healthy"
|
||||||
|
except Exception as e:
|
||||||
|
checks["database"] = f"unhealthy: {str(e)[:100]}"
|
||||||
|
all_healthy = False
|
||||||
|
|
||||||
|
# Check Qdrant connection
|
||||||
|
try:
|
||||||
|
from memory.common.qdrant import get_qdrant_client
|
||||||
|
|
||||||
|
client = get_qdrant_client()
|
||||||
|
client.get_collections()
|
||||||
|
checks["qdrant"] = "healthy"
|
||||||
|
except Exception as e:
|
||||||
|
checks["qdrant"] = f"unhealthy: {str(e)[:100]}"
|
||||||
|
all_healthy = False
|
||||||
|
|
||||||
|
checks["status"] = "healthy" if all_healthy else "degraded"
|
||||||
|
status_code = 200 if all_healthy else 503
|
||||||
|
return JSONResponse(checks, status_code=status_code)
|
||||||
|
|
||||||
|
|
||||||
# Mount MCP server at root - OAuth endpoints need to be at root level
|
# Mount MCP server at root - OAuth endpoints need to be at root level
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from memory.api.search.types import SearchFilters
|
|||||||
|
|
||||||
from memory.common import extract
|
from memory.common import extract
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Chunk, ConfidenceScore
|
from memory.common.db.models import Chunk, ConfidenceScore, SourceItem
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -29,9 +29,28 @@ async def search_bm25(
|
|||||||
Chunk.content.isnot(None),
|
Chunk.content.isnot(None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Join with SourceItem if we need size filters
|
||||||
|
needs_source_join = any(filters.get(k) for k in ["min_size", "max_size"])
|
||||||
|
if needs_source_join:
|
||||||
|
items_query = items_query.join(
|
||||||
|
SourceItem, SourceItem.id == Chunk.source_id
|
||||||
|
)
|
||||||
|
|
||||||
if source_ids := filters.get("source_ids"):
|
if source_ids := filters.get("source_ids"):
|
||||||
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
||||||
|
|
||||||
|
# Size filters
|
||||||
|
if min_size := filters.get("min_size"):
|
||||||
|
items_query = items_query.filter(SourceItem.size >= min_size)
|
||||||
|
if max_size := filters.get("max_size"):
|
||||||
|
items_query = items_query.filter(SourceItem.size <= max_size)
|
||||||
|
|
||||||
|
# Observation type filter - restricts to specific collection types
|
||||||
|
if observation_types := filters.get("observation_types"):
|
||||||
|
items_query = items_query.filter(
|
||||||
|
Chunk.collection_name.in_(observation_types)
|
||||||
|
)
|
||||||
|
|
||||||
# Add confidence filtering if specified
|
# Add confidence filtering if specified
|
||||||
if min_confidences := filters.get("min_confidences"):
|
if min_confidences := filters.get("min_confidences"):
|
||||||
for confidence_type, min_score in min_confidences.items():
|
for confidence_type, min_score in min_confidences.items():
|
||||||
|
|||||||
@ -182,13 +182,16 @@ async def search_chunks_embeddings(
|
|||||||
filters: SearchFilters = SearchFilters(),
|
filters: SearchFilters = SearchFilters(),
|
||||||
timeout: int = 2,
|
timeout: int = 2,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
|
# Note: Multimodal embeddings typically produce higher similarity scores,
|
||||||
|
# so we use a higher threshold (0.4) to maintain selectivity.
|
||||||
|
# Text embeddings produce lower scores, so we use 0.25.
|
||||||
all_ids = await asyncio.gather(
|
all_ids = await asyncio.gather(
|
||||||
asyncio.wait_for(
|
asyncio.wait_for(
|
||||||
search_chunks(
|
search_chunks(
|
||||||
data,
|
data,
|
||||||
modalities & TEXT_COLLECTIONS,
|
modalities & TEXT_COLLECTIONS,
|
||||||
limit,
|
limit,
|
||||||
0.4,
|
0.25,
|
||||||
filters,
|
filters,
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
@ -199,7 +202,7 @@ async def search_chunks_embeddings(
|
|||||||
data,
|
data,
|
||||||
modalities & MULTIMODAL_COLLECTIONS,
|
modalities & MULTIMODAL_COLLECTIONS,
|
||||||
limit,
|
limit,
|
||||||
0.25,
|
0.4,
|
||||||
filters,
|
filters,
|
||||||
True,
|
True,
|
||||||
),
|
),
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common.db.models.source_item import Chunk
|
from memory.common.db.models.source_item import Chunk
|
||||||
from memory.common import llms, settings, tokens
|
from memory.common import llms, settings, tokens
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SCORE_CHUNK_SYSTEM_PROMPT = """
|
SCORE_CHUNK_SYSTEM_PROMPT = """
|
||||||
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
|
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
|
||||||
@ -32,7 +35,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
|||||||
try:
|
try:
|
||||||
data = chunk.data
|
data = chunk.data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting chunk data: {e}, {type(e)}")
|
logger.error(f"Error getting chunk data: {e}")
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
chunk_text = "\n".join(text for text in data if isinstance(text, str))
|
chunk_text = "\n".join(text for text in data if isinstance(text, str))
|
||||||
@ -47,7 +50,7 @@ async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
|||||||
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
|
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error scoring chunk: {e}, {type(e)}")
|
logger.error(f"Error scoring chunk: {e}")
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
soup = BeautifulSoup(response, "html.parser")
|
soup = BeautifulSoup(response, "html.parser")
|
||||||
|
|||||||
@ -80,3 +80,15 @@ class SearchConfig(BaseModel):
|
|||||||
timeout: int = 20
|
timeout: int = 20
|
||||||
previews: bool = False
|
previews: bool = False
|
||||||
useScores: bool = False
|
useScores: bool = False
|
||||||
|
|
||||||
|
def model_post_init(self, __context) -> None:
|
||||||
|
# Enforce reasonable limits
|
||||||
|
if self.limit < 1:
|
||||||
|
object.__setattr__(self, "limit", 1)
|
||||||
|
elif self.limit > 1000:
|
||||||
|
object.__setattr__(self, "limit", 1000)
|
||||||
|
|
||||||
|
if self.timeout < 1:
|
||||||
|
object.__setattr__(self, "timeout", 1)
|
||||||
|
elif self.timeout > 300:
|
||||||
|
object.__setattr__(self, "timeout", 300)
|
||||||
|
|||||||
@ -88,6 +88,14 @@ app.conf.update(
|
|||||||
task_acks_late=True,
|
task_acks_late=True,
|
||||||
task_reject_on_worker_lost=True,
|
task_reject_on_worker_lost=True,
|
||||||
worker_prefetch_multiplier=1,
|
worker_prefetch_multiplier=1,
|
||||||
|
# Default retry configuration for transient failures
|
||||||
|
task_autoretry_for=(Exception,),
|
||||||
|
task_retry_kwargs={"max_retries": 3},
|
||||||
|
task_retry_backoff=True,
|
||||||
|
task_retry_backoff_max=600, # Max 10 minutes between retries
|
||||||
|
task_retry_jitter=True,
|
||||||
|
task_time_limit=3600, # 1 hour hard limit
|
||||||
|
task_soft_time_limit=3000, # 50 minute soft limit
|
||||||
task_routes={
|
task_routes={
|
||||||
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
f"{EBOOK_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-ebooks"},
|
||||||
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
f"{BLOGS_ROOT}.*": {"queue": f"{settings.CELERY_QUEUE_PREFIX}-blogs"},
|
||||||
|
|||||||
@ -102,29 +102,35 @@ def chunk_text(
|
|||||||
current = ""
|
current = ""
|
||||||
|
|
||||||
for span in yield_spans(text, max_tokens):
|
for span in yield_spans(text, max_tokens):
|
||||||
current = f"{current} {span}".strip()
|
# Check if adding this span would exceed the limit
|
||||||
if tokens.approx_token_count(current) < max_tokens:
|
new_chunk = f"{current} {span}".strip() if current else span
|
||||||
|
if tokens.approx_token_count(new_chunk) <= max_tokens:
|
||||||
|
current = new_chunk
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if overlap <= 0:
|
# Adding span would exceed limit - yield current first (if non-empty)
|
||||||
|
if current:
|
||||||
yield current
|
yield current
|
||||||
current = ""
|
|
||||||
|
# Handle overlap for the next chunk
|
||||||
|
if overlap <= 0 or not current:
|
||||||
|
current = span
|
||||||
continue
|
continue
|
||||||
|
|
||||||
overlap_text = current[-overlap_chars:]
|
# Try to find a clean break point for overlap
|
||||||
|
overlap_text = current[-overlap_chars:] if len(current) > overlap_chars else current
|
||||||
clean_break = max(
|
clean_break = max(
|
||||||
overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ")
|
overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ")
|
||||||
)
|
)
|
||||||
|
|
||||||
if clean_break < 0:
|
if clean_break < 0:
|
||||||
yield current
|
current = span
|
||||||
current = ""
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
break_offset = -overlap_chars + clean_break + 1
|
# Start new chunk with overlap from clean break
|
||||||
chunk = current[break_offset:].strip()
|
break_offset = -len(overlap_text) + clean_break + 1
|
||||||
yield current
|
overlap_portion = current[break_offset:].strip()
|
||||||
current = chunk
|
current = f"{overlap_portion} {span}".strip() if overlap_portion else span
|
||||||
|
|
||||||
if current:
|
if current:
|
||||||
yield current.strip()
|
yield current.strip()
|
||||||
|
|||||||
@ -132,9 +132,14 @@ def get_modality(mime_type: str) -> str:
|
|||||||
def collection_model(
|
def collection_model(
|
||||||
collection: str, text: str, images: list[Image.Image]
|
collection: str, text: str, images: list[Image.Image]
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
|
"""Determine the appropriate embedding model for a collection.
|
||||||
|
|
||||||
|
Returns None if no suitable model can be determined, rather than
|
||||||
|
falling back to an invalid placeholder.
|
||||||
|
"""
|
||||||
config = ALL_COLLECTIONS.get(collection, {})
|
config = ALL_COLLECTIONS.get(collection, {})
|
||||||
if images and config.get("multimodal"):
|
if images and config.get("multimodal"):
|
||||||
return settings.MIXED_EMBEDDING_MODEL
|
return settings.MIXED_EMBEDDING_MODEL
|
||||||
if text and config.get("text"):
|
if text and config.get("text"):
|
||||||
return settings.TEXT_EMBEDDING_MODEL
|
return settings.TEXT_EMBEDDING_MODEL
|
||||||
return "unknown"
|
return None
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from sqlalchemy import (
|
|||||||
Text,
|
Text,
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship, object_session
|
||||||
|
|
||||||
from memory.common.db.models.base import Base
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
@ -27,6 +27,25 @@ class MessageProcessor:
|
|||||||
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
allowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
disallowed_tools = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mcp_servers(self) -> list:
|
||||||
|
"""Get MCP servers assigned to this entity via MCPServerAssignment."""
|
||||||
|
from memory.common.db.models.mcp import MCPServer, MCPServerAssignment
|
||||||
|
|
||||||
|
session = object_session(self)
|
||||||
|
if not session:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return (
|
||||||
|
session.query(MCPServer)
|
||||||
|
.join(MCPServerAssignment)
|
||||||
|
.filter(
|
||||||
|
MCPServerAssignment.entity_type == self.entity_type,
|
||||||
|
MCPServerAssignment.entity_id == self.id,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
system_prompt = Column(
|
system_prompt = Column(
|
||||||
Text,
|
Text,
|
||||||
nullable=True,
|
nullable=True,
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from sqlalchemy import (
|
|||||||
DateTime,
|
DateTime,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
BigInteger,
|
BigInteger,
|
||||||
|
Integer,
|
||||||
JSON,
|
JSON,
|
||||||
Text,
|
Text,
|
||||||
)
|
)
|
||||||
@ -20,7 +21,7 @@ class ScheduledLLMCall(Base):
|
|||||||
__tablename__ = "scheduled_llm_calls"
|
__tablename__ = "scheduled_llm_calls"
|
||||||
|
|
||||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
user_id = Column(BigInteger, ForeignKey("users.id"), nullable=False)
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
topic = Column(Text, nullable=True)
|
topic = Column(Text, nullable=True)
|
||||||
|
|
||||||
# Scheduling info
|
# Scheduling info
|
||||||
|
|||||||
@ -165,6 +165,7 @@ class Chunk(Base):
|
|||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
|
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
|
||||||
Index("chunk_source_idx", "source_id"),
|
Index("chunk_source_idx", "source_id"),
|
||||||
|
Index("chunk_collection_idx", "collection_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Literal, cast
|
from typing import Literal, cast
|
||||||
|
|
||||||
import voyageai
|
import voyageai
|
||||||
@ -15,6 +16,12 @@ from memory.common.db.models import Chunk, SourceItem
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingError(Exception):
|
||||||
|
"""Raised when embedding generation fails after retries."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def as_string(
|
def as_string(
|
||||||
chunk: extract.MulitmodalChunk | list[extract.MulitmodalChunk],
|
chunk: extract.MulitmodalChunk | list[extract.MulitmodalChunk],
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -29,9 +36,33 @@ def embed_chunks(
|
|||||||
chunks: list[list[extract.MulitmodalChunk]],
|
chunks: list[list[extract.MulitmodalChunk]],
|
||||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 1.0,
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
logger.debug(f"Embedding chunks: {model} - {str(chunks)} {len(chunks)}")
|
"""Embed chunks with retry logic for transient failures.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of chunk lists to embed
|
||||||
|
model: Embedding model to use
|
||||||
|
input_type: Whether embedding documents or queries
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
retry_delay: Base delay between retries (exponential backoff)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmbeddingError: If embedding fails after all retries
|
||||||
|
"""
|
||||||
|
if not chunks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.debug(f"Embedding {len(chunks)} chunks with model {model}")
|
||||||
vo = voyageai.Client() # type: ignore
|
vo = voyageai.Client() # type: ignore
|
||||||
|
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||||
return vo.multimodal_embed(
|
return vo.multimodal_embed(
|
||||||
chunks,
|
chunks,
|
||||||
@ -40,10 +71,25 @@ def embed_chunks(
|
|||||||
).embeddings
|
).embeddings
|
||||||
|
|
||||||
texts = [as_string(c) for c in chunks]
|
texts = [as_string(c) for c in chunks]
|
||||||
logger.debug(f"Embedding texts: {texts}")
|
|
||||||
return cast(
|
return cast(
|
||||||
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
|
list[Vector],
|
||||||
|
vo.embed(texts, model=model, input_type=input_type).embeddings,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
delay = retry_delay * (2**attempt)
|
||||||
|
logger.warning(
|
||||||
|
f"Embedding attempt {attempt + 1}/{max_retries} failed: {e}. "
|
||||||
|
f"Retrying in {delay:.1f}s..."
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
logger.error(f"Embedding failed after {max_retries} attempts: {e}")
|
||||||
|
|
||||||
|
raise EmbeddingError(
|
||||||
|
f"Failed to generate embeddings after {max_retries} attempts"
|
||||||
|
) from last_error
|
||||||
|
|
||||||
|
|
||||||
def break_chunk(
|
def break_chunk(
|
||||||
|
|||||||
@ -284,6 +284,8 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
# Include server info if present
|
# Include server info if present
|
||||||
if current_tool_use.get("server_name"):
|
if current_tool_use.get("server_name"):
|
||||||
tool_data["server_name"] = current_tool_use["server_name"]
|
tool_data["server_name"] = current_tool_use["server_name"]
|
||||||
|
if current_tool_use.get("is_server_call"):
|
||||||
|
tool_data["is_server_call"] = current_tool_use["is_server_call"]
|
||||||
|
|
||||||
# Emit different event type for MCP server tools
|
# Emit different event type for MCP server tools
|
||||||
if current_tool_use.get("is_server_call"):
|
if current_tool_use.get("is_server_call"):
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class Collector:
|
|||||||
bot_name: str
|
bot_name: str
|
||||||
|
|
||||||
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
def __init__(self, collector: MessageCollector, bot: DiscordBotUser):
|
||||||
logger.error(f"Initialized collector for {bot.name} woth {bot.api_key}")
|
logger.info(f"Initialized collector for {bot.name}")
|
||||||
self.collector = collector
|
self.collector = collector
|
||||||
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
self.collector_task = asyncio.create_task(collector.start(str(bot.api_key)))
|
||||||
self.bot_id = cast(int, bot.id)
|
self.bot_id = cast(int, bot.id)
|
||||||
@ -80,8 +80,8 @@ async def lifespan(app: FastAPI):
|
|||||||
bots = session.query(DiscordBotUser).all()
|
bots = session.query(DiscordBotUser).all()
|
||||||
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
app.bots = {bot.id: make_collector(bot) for bot in bots}
|
||||||
|
|
||||||
logger.error(
|
logger.info(
|
||||||
f"Discord collectors started for {len(app.bots)} bots: {app.bots.keys()}"
|
f"Discord collectors started for {len(app.bots)} bots: {list(app.bots.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|||||||
@ -354,7 +354,6 @@ class MessageCollector(commands.Bot):
|
|||||||
"users_updated": users_updated,
|
"users_updated": users_updated,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"✅ Metadata refresh complete: {result}")
|
|
||||||
logger.info(f"Metadata refresh complete: {result}")
|
logger.info(f"Metadata refresh complete: {result}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -259,22 +259,26 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
|||||||
"""
|
"""
|
||||||
Decorator for safe task execution with comprehensive error handling.
|
Decorator for safe task execution with comprehensive error handling.
|
||||||
|
|
||||||
Wraps task functions to catch and log exceptions, ensuring tasks
|
Wraps task functions to log exceptions and notify on failures while
|
||||||
always return a result dictionary even when they fail.
|
still allowing Celery to handle retries. Exceptions are re-raised after
|
||||||
|
logging to allow Celery's retry mechanism to work.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: Task function to wrap
|
func: Task function to wrap
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Wrapped function that handles exceptions gracefully
|
Wrapped function that logs exceptions and re-raises for retry
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@app.task(bind=True)
|
||||||
@safe_task_execution
|
@safe_task_execution
|
||||||
def my_task(arg1, arg2):
|
def my_task(self, arg1, arg2):
|
||||||
# Task implementation
|
# Task implementation
|
||||||
return {"status": "success"}
|
return {"status": "success"}
|
||||||
"""
|
"""
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs) -> dict:
|
def wrapper(*args, **kwargs) -> dict:
|
||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
@ -283,14 +287,25 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
|||||||
traceback_str = traceback.format_exc()
|
traceback_str = traceback.format_exc()
|
||||||
logger.error(traceback_str)
|
logger.error(traceback_str)
|
||||||
|
|
||||||
|
# Check if this is a bound task and if retries are exhausted
|
||||||
|
task_self = args[0] if args and hasattr(args[0], "request") else None
|
||||||
|
is_final_retry = (
|
||||||
|
task_self
|
||||||
|
and hasattr(task_self, "request")
|
||||||
|
and task_self.request.retries >= task_self.max_retries
|
||||||
|
)
|
||||||
|
|
||||||
|
# Notify on final failure only
|
||||||
|
if is_final_retry or task_self is None:
|
||||||
notify_task_failure(
|
notify_task_failure(
|
||||||
task_name=func.__name__,
|
task_name=func.__name__,
|
||||||
error_message=str(e),
|
error_message=str(e),
|
||||||
task_args=args,
|
task_args=args[1:] if task_self else args,
|
||||||
task_kwargs=kwargs,
|
task_kwargs=kwargs,
|
||||||
traceback_str=traceback_str,
|
traceback_str=traceback_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"status": "error", "error": str(e), "traceback": traceback_str}
|
# Re-raise to allow Celery retries
|
||||||
|
raise
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@ -189,7 +189,7 @@ def sync_book(
|
|||||||
# Create book and sections with relationships
|
# Create book and sections with relationships
|
||||||
book, all_sections = create_book_and_sections(ebook, session, tags)
|
book, all_sections = create_book_and_sections(ebook, session, tags)
|
||||||
for section in all_sections:
|
for section in all_sections:
|
||||||
print(section.section_title, section.book)
|
logger.debug(f"Created section: {section.section_title}")
|
||||||
|
|
||||||
if title:
|
if title:
|
||||||
book.title = title # type: ignore
|
book.title = title # type: ignore
|
||||||
|
|||||||
@ -76,12 +76,12 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
|||||||
logger.error(f"Scheduled call {scheduled_call_id} not found")
|
logger.error(f"Scheduled call {scheduled_call_id} not found")
|
||||||
return {"error": "Scheduled call not found"}
|
return {"error": "Scheduled call not found"}
|
||||||
|
|
||||||
# Check if the call is still pending
|
# Check if the call is ready to execute (pending or queued)
|
||||||
if not scheduled_call.is_pending():
|
if scheduled_call.status not in ("pending", "queued"):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Scheduled call {scheduled_call_id} is not pending (status: {scheduled_call.status})"
|
f"Scheduled call {scheduled_call_id} is not ready (status: {scheduled_call.status})"
|
||||||
)
|
)
|
||||||
return {"error": f"Call is not pending (status: {scheduled_call.status})"}
|
return {"error": f"Call is not ready (status: {scheduled_call.status})"}
|
||||||
|
|
||||||
# Update status to executing
|
# Update status to executing
|
||||||
scheduled_call.status = "executing"
|
scheduled_call.status = "executing"
|
||||||
@ -143,21 +143,40 @@ def execute_scheduled_call(self, scheduled_call_id: str):
|
|||||||
@app.task(name=RUN_SCHEDULED_CALLS)
|
@app.task(name=RUN_SCHEDULED_CALLS)
|
||||||
@safe_task_execution
|
@safe_task_execution
|
||||||
def run_scheduled_calls():
|
def run_scheduled_calls():
|
||||||
"""Run scheduled calls that are due."""
|
"""Run scheduled calls that are due.
|
||||||
|
|
||||||
|
Uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions when
|
||||||
|
multiple workers query for due calls simultaneously.
|
||||||
|
"""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
|
# Use FOR UPDATE SKIP LOCKED to atomically claim pending calls
|
||||||
|
# This prevents multiple workers from processing the same call
|
||||||
|
#
|
||||||
|
# Note: scheduled_time is stored as naive datetime (assumed UTC).
|
||||||
|
# We compare against current UTC time, also as naive datetime.
|
||||||
|
now_utc = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
calls = (
|
calls = (
|
||||||
session.query(ScheduledLLMCall)
|
session.query(ScheduledLLMCall)
|
||||||
.filter(
|
.filter(
|
||||||
ScheduledLLMCall.status.in_(["pending"]),
|
ScheduledLLMCall.status.in_(["pending"]),
|
||||||
ScheduledLLMCall.scheduled_time
|
ScheduledLLMCall.scheduled_time < now_utc,
|
||||||
< datetime.now(timezone.utc).replace(tzinfo=None),
|
|
||||||
)
|
)
|
||||||
|
.with_for_update(skip_locked=True)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mark calls as queued before dispatching to prevent re-processing
|
||||||
|
call_ids = []
|
||||||
for call in calls:
|
for call in calls:
|
||||||
execute_scheduled_call.delay(call.id)
|
call.status = "queued"
|
||||||
|
call_ids.append(call.id)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Now dispatch tasks for queued calls
|
||||||
|
for call_id in call_ids:
|
||||||
|
execute_scheduled_call.delay(call_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"calls": [call.id for call in calls],
|
"calls": call_ids,
|
||||||
"count": len(calls),
|
"count": len(call_ids),
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user