Add less wrong tasks + reindexer

This commit is contained in:
Daniel O'Connell 2025-05-28 02:46:43 +02:00
parent ab87bced81
commit ed8033bdd3
27 changed files with 2160 additions and 110 deletions

View File

@ -0,0 +1,52 @@
"""Add forum posts
Revision ID: 2524646f56f6
Revises: 1b535e1b044e
Create Date: 2025-05-28 01:23:00.079366
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "2524646f56f6"
down_revision: Union[str, None] = "1b535e1b044e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"forum_post",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("url", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("authors", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("modified_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("slug", sa.Text(), nullable=True),
sa.Column("karma", sa.Integer(), nullable=True),
sa.Column("votes", sa.Integer(), nullable=True),
sa.Column("comments", sa.Integer(), nullable=True),
sa.Column("words", sa.Integer(), nullable=True),
sa.Column("score", sa.Integer(), nullable=True),
sa.Column("images", sa.ARRAY(sa.Text()), nullable=True),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("url"),
)
op.create_index("forum_post_slug_idx", "forum_post", ["slug"], unique=False)
op.create_index("forum_post_title_idx", "forum_post", ["title"], unique=False)
op.create_index("forum_post_url_idx", "forum_post", ["url"], unique=False)
def downgrade() -> None:
op.drop_index("forum_post_url_idx", table_name="forum_post")
op.drop_index("forum_post_title_idx", table_name="forum_post")
op.drop_index("forum_post_slug_idx", table_name="forum_post")
op.drop_table("forum_post")

View File

@ -207,6 +207,12 @@ services:
<<: *worker-env
QUEUES: "blogs"
worker-forums:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "forums"
worker-photo:
<<: *worker-base
environment:

View File

@ -33,7 +33,7 @@ RUN mkdir -p /var/log/supervisor /var/run/supervisor
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files
USER kb
ENV QUEUES="docs,email,maintenance"
ENV QUEUES="maintenance"
ENV PYTHONPATH="/app"
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]

View File

@ -39,7 +39,7 @@ RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
USER kb
# Default queues to process
ENV QUEUES="ebooks,email,comic,blogs,photo_embed,maintenance"
ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance"
ENV PYTHONPATH="/app"
ENTRYPOINT ["./entry.sh"]

View File

@ -3,4 +3,5 @@ pytest-cov==4.1.0
black==23.12.1
mypy==1.8.0
isort==5.13.2
testcontainers[qdrant]==4.10.0
testcontainers[qdrant]==4.10.0
click==8.1.7

View File

@ -17,6 +17,7 @@ from memory.common.db.models import (
MiscDoc,
ArticleFeed,
EmailAccount,
ForumPost,
)
@ -33,7 +34,11 @@ DEFAULT_COLUMNS = (
def source_columns(model: type[SourceItem], *columns: str):
return [getattr(model, c) for c in columns + DEFAULT_COLUMNS if hasattr(model, c)]
return [
getattr(model, c)
for c in ("id",) + columns + DEFAULT_COLUMNS
if hasattr(model, c)
]
# Create admin views for all models
@ -86,6 +91,21 @@ class BlogPostAdmin(ModelView, model=BlogPost):
column_searchable_list = ["title", "author", "domain"]
class ForumPostAdmin(ModelView, model=ForumPost):
column_list = source_columns(
ForumPost,
"title",
"authors",
"published_at",
"url",
"karma",
"votes",
"comments",
"score",
)
column_searchable_list = ["title", "authors"]
class PhotoAdmin(ModelView, model=Photo):
column_list = source_columns(Photo, "exif_taken_at", "camera")
@ -166,5 +186,6 @@ def setup_admin(admin: Admin):
admin.add_view(MiscDocAdmin)
admin.add_view(ArticleFeedAdmin)
admin.add_view(BlogPostAdmin)
admin.add_view(ForumPostAdmin)
admin.add_view(ComicAdmin)
admin.add_view(PhotoAdmin)

View File

@ -43,7 +43,12 @@ ALL_COLLECTIONS: dict[str, Collection] = {
"blog": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"model": settings.MIXED_EMBEDDING_MODEL,
},
"forum": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
"text": {
"dimension": 1024,

View File

@ -105,6 +105,21 @@ def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalCh
]
def chunk_mixed(
content: str, image_paths: Sequence[str]
) -> list[list[extract.MulitmodalChunk]]:
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
full_text: list[extract.MulitmodalChunk] = [content.strip(), *images]
chunks = []
tokens = chunker.approx_token_count(content)
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
all_chunks = [full_text] + chunks
return [c for c in all_chunks if c and all(i for i in c)]
class Chunk(Base):
"""Stores content chunks with their vector embeddings."""
@ -134,6 +149,19 @@ class Chunk(Base):
Index("chunk_source_idx", "source_id"),
)
@property
def chunks(self) -> list[extract.MulitmodalChunk]:
chunks: list[extract.MulitmodalChunk] = []
if cast(str | None, self.content):
chunks = [cast(str, self.content)]
if self.images:
chunks += self.images
elif cast(Sequence[str] | None, self.file_paths):
chunks += [
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
]
return chunks
@property
def data(self) -> list[bytes | str | Image.Image]:
if self.file_paths is None:
@ -638,20 +666,57 @@ class BlogPost(SourceItem):
return {k: v for k, v in payload.items() if v}
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
images = [
Image.open(settings.FILE_STORAGE_DIR / image) for image in self.images
]
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
content = cast(str, self.content)
full_text = [content.strip(), *images]
chunks = []
tokens = chunker.approx_token_count(content)
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
class ForumPost(SourceItem):
__tablename__ = "forum_post"
all_chunks = [full_text] + chunks
return [c for c in all_chunks if c and all(i for i in c)]
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
url = Column(Text, unique=True)
title = Column(Text)
description = Column(Text, nullable=True)
authors = Column(ARRAY(Text), nullable=True)
published_at = Column(DateTime(timezone=True), nullable=True)
modified_at = Column(DateTime(timezone=True), nullable=True)
slug = Column(Text, nullable=True)
karma = Column(Integer, nullable=True)
votes = Column(Integer, nullable=True)
comments = Column(Integer, nullable=True)
words = Column(Integer, nullable=True)
score = Column(Integer, nullable=True)
images = Column(ARRAY(Text), nullable=True)
__mapper_args__ = {
"polymorphic_identity": "forum_post",
}
__table_args__ = (
Index("forum_post_url_idx", "url"),
Index("forum_post_slug_idx", "slug"),
Index("forum_post_title_idx", "title"),
)
def as_payload(self) -> dict:
return {
"source_id": self.id,
"url": self.url,
"title": self.title,
"description": self.description,
"authors": self.authors,
"published_at": self.published_at,
"slug": self.slug,
"karma": self.karma,
"votes": self.votes,
"score": self.score,
"comments": self.comments,
"tags": self.tags,
}
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
class MiscDoc(SourceItem):
@ -710,7 +775,7 @@ class ArticleFeed(Base):
description = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
check_interval = Column(
Integer, nullable=False, server_default="3600", doc="Seconds between checks"
Integer, nullable=False, server_default="60", doc="Minutes between checks"
)
last_checked_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")

View File

@ -20,7 +20,7 @@ def embed_chunks(
model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
) -> list[Vector]:
logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]}")
logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]} {len(chunks)}")
vo = voyageai.Client() # type: ignore
if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed(
@ -79,7 +79,7 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
if not model_chunks:
return []
vectors = embed_chunks([chunk.content for chunk in model_chunks], model) # type: ignore
vectors = embed_chunks([chunk.chunks for chunk in model_chunks], model)
for chunk, vector in zip(model_chunks, vectors):
chunk.vector = vector
return model_chunks

View File

@ -86,9 +86,11 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
# Worker settings
# Intervals are in seconds
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 86400))
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 3600))
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 60 * 60))
COMIC_SYNC_INTERVAL = int(os.getenv("COMIC_SYNC_INTERVAL", 60 * 60))
ARTICLE_FEED_SYNC_INTERVAL = int(os.getenv("ARTICLE_FEED_SYNC_INTERVAL", 30 * 60))
CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 24 * 60 * 60))
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))

View File

@ -0,0 +1,304 @@
from dataclasses import dataclass, field
import logging
import time
from datetime import datetime, timedelta
from typing import Any, Generator, TypedDict, NotRequired
from bs4 import BeautifulSoup
from PIL import Image as PILImage
from memory.common import settings
import requests
from markdownify import markdownify
from memory.parsers.html import parse_date, process_images
logger = logging.getLogger(__name__)
class LessWrongPost(TypedDict):
"""Represents a post from LessWrong."""
title: str
url: str
description: str
content: str
authors: list[str]
published_at: datetime | None
guid: str | None
karma: int
votes: int
comments: int
words: int
tags: list[str]
af: bool
score: int
extended_score: int
modified_at: NotRequired[str | None]
slug: NotRequired[str | None]
images: NotRequired[list[str]]
def make_graphql_query(
after: datetime, af: bool = False, limit: int = 50, min_karma: int = 10
) -> str:
"""Create GraphQL query for fetching posts."""
return f"""
{{
posts(input: {{
terms: {{
excludeEvents: true
view: "old"
af: {str(af).lower()}
limit: {limit}
karmaThreshold: {min_karma}
after: "{after.isoformat()}Z"
filter: "tagged"
}}
}}) {{
totalCount
results {{
_id
title
slug
pageUrl
postedAt
modifiedAt
score
extendedScore
baseScore
voteCount
commentCount
wordCount
tags {{
name
}}
user {{
displayName
}}
coauthors {{
displayName
}}
af
htmlBody
}}
}}
}}
"""
def fetch_posts_from_api(url: str, query: str) -> dict[str, Any]:
"""Fetch posts from LessWrong GraphQL API."""
response = requests.post(
url,
headers={
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0"
},
json={"query": query},
timeout=30,
)
response.raise_for_status()
return response.json()["data"]["posts"]
def is_valid_post(post: dict[str, Any], min_karma: int = 10) -> bool:
"""Check if post should be included."""
# Must have content
if not post.get("htmlBody"):
return False
# Must meet karma threshold
if post.get("baseScore", 0) < min_karma:
return False
return True
def extract_authors(post: dict[str, Any]) -> list[str]:
"""Extract authors from post data."""
authors = post.get("coauthors", []) or []
if post.get("user"):
authors = [post["user"]] + authors
return [a["displayName"] for a in authors] or ["anonymous"]
def extract_description(body: str) -> str:
"""Extract description from post HTML content."""
first_paragraph = body.split("\n\n")[0]
# Truncate if too long
if len(first_paragraph) > 300:
first_paragraph = first_paragraph[:300] + "..."
return first_paragraph
def parse_lesswrong_date(date_str: str | None) -> datetime | None:
"""Parse ISO date string from LessWrong API to datetime."""
if not date_str:
return None
# Try multiple ISO formats that LessWrong might use
formats = [
"%Y-%m-%dT%H:%M:%S.%fZ", # 2023-01-15T10:30:00.000Z
"%Y-%m-%dT%H:%M:%SZ", # 2023-01-15T10:30:00Z
"%Y-%m-%dT%H:%M:%S.%f", # 2023-01-15T10:30:00.000
"%Y-%m-%dT%H:%M:%S", # 2023-01-15T10:30:00
]
for fmt in formats:
if result := parse_date(date_str, fmt):
return result
# Fallback: try removing 'Z' and using fromisoformat
try:
clean_date = date_str.rstrip("Z")
return datetime.fromisoformat(clean_date)
except (ValueError, TypeError):
logger.warning(f"Could not parse date: {date_str}")
return None
def extract_body(post: dict[str, Any]) -> tuple[str, dict[str, PILImage.Image]]:
"""Extract body from post data."""
if not (body := post.get("htmlBody", "").strip()):
return "", {}
url = post.get("pageUrl", "")
image_dir = settings.FILE_STORAGE_DIR / "lesswrong" / url
soup = BeautifulSoup(body, "html.parser")
soup, images = process_images(soup, url, image_dir)
body = markdownify(str(soup)).strip()
return body, images
def format_post(post: dict[str, Any]) -> LessWrongPost:
"""Convert raw API post data to GreaterWrongPost."""
body, images = extract_body(post)
result: LessWrongPost = {
"title": post.get("title", "Untitled"),
"url": post.get("pageUrl", ""),
"description": extract_description(body),
"content": body,
"authors": extract_authors(post),
"published_at": parse_lesswrong_date(post.get("postedAt")),
"guid": post.get("_id"),
"karma": post.get("baseScore", 0),
"votes": post.get("voteCount", 0),
"comments": post.get("commentCount", 0),
"words": post.get("wordCount", 0),
"tags": [tag["name"] for tag in post.get("tags", [])],
"af": post.get("af", False),
"score": post.get("score", 0),
"extended_score": post.get("extendedScore", 0),
"modified_at": post.get("modifiedAt"),
"slug": post.get("slug"),
"images": list(images.keys()),
}
return result
def fetch_lesswrong(
url: str,
current_date: datetime,
af: bool = False,
min_karma: int = 10,
limit: int = 50,
last_url: str | None = None,
) -> list[LessWrongPost]:
"""
Fetch a batch of posts from LessWrong.
Returns:
(posts, next_date, last_item) where next_date is None if iteration should stop
"""
query = make_graphql_query(current_date, af, limit, min_karma)
api_response = fetch_posts_from_api(url, query)
if not api_response["results"]:
return []
# If we only get the same item we started with, we're done
if (
len(api_response["results"]) == 1
and last_url
and api_response["results"][0]["pageUrl"] == last_url
):
return []
return [
format_post(post)
for post in api_response["results"]
if is_valid_post(post, min_karma)
]
def fetch_lesswrong_posts(
since: datetime | None = None,
min_karma: int = 10,
limit: int = 50,
cooldown: float = 0.5,
max_items: int = 1000,
af: bool = False,
url: str = "https://www.lesswrong.com/graphql",
) -> Generator[LessWrongPost, None, None]:
"""
Fetch posts from LessWrong.
Args:
url: GraphQL endpoint URL
af: Whether to fetch Alignment Forum posts
min_karma: Minimum karma threshold for posts
limit: Number of posts per API request
start_year: Default start year if no since date provided
since: Start date for fetching posts
cooldown: Delay between API requests in seconds
max_pages: Maximum number of pages to fetch
Returns:
List of GreaterWrongPost objects
"""
if not since:
since = datetime.now() - timedelta(days=1)
logger.info(f"Starting from {since}")
last_url = None
next_date = since
items_count = 0
while next_date and items_count < max_items:
try:
page_posts = fetch_lesswrong(url, next_date, af, min_karma, limit, last_url)
except Exception as e:
logger.error(f"Error fetching posts: {e}")
break
if not page_posts or next_date is None:
break
for post in page_posts:
yield post
last_item = page_posts[-1]
prev_date = next_date
next_date = last_item.get("published_at")
if not next_date or prev_date == next_date:
logger.warning(
f"Could not advance through dataset, stopping at {next_date}"
)
break
# The articles are paged by date (inclusive) so passing the last date as
# is will return the same article again.
next_date += timedelta(seconds=1)
items_count += len(page_posts)
last_url = last_item["url"]
if cooldown > 0:
time.sleep(cooldown)
logger.info(f"Fetched {items_count} items")

View File

@ -1,6 +1,14 @@
from celery import Celery
from memory.common import settings
EMAIL_ROOT = "memory.workers.tasks.email"
FORUMS_ROOT = "memory.workers.tasks.forums"
BLOGS_ROOT = "memory.workers.tasks.blogs"
PHOTO_ROOT = "memory.workers.tasks.photo"
COMIC_ROOT = "memory.workers.tasks.comic"
EBOOK_ROOT = "memory.workers.tasks.ebook"
MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
def rabbit_url() -> str:
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:5672//"
@ -21,12 +29,13 @@ app.conf.update(
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1,
task_routes={
"memory.workers.tasks.email.*": {"queue": "email"},
"memory.workers.tasks.photo.*": {"queue": "photo_embed"},
"memory.workers.tasks.comic.*": {"queue": "comic"},
"memory.workers.tasks.ebook.*": {"queue": "ebooks"},
"memory.workers.tasks.blogs.*": {"queue": "blogs"},
"memory.workers.tasks.maintenance.*": {"queue": "maintenance"},
f"{EMAIL_ROOT}.*": {"queue": "email"},
f"{PHOTO_ROOT}.*": {"queue": "photo_embed"},
f"{COMIC_ROOT}.*": {"queue": "comic"},
f"{EBOOK_ROOT}.*": {"queue": "ebooks"},
f"{BLOGS_ROOT}.*": {"queue": "blogs"},
f"{FORUMS_ROOT}.*": {"queue": "forums"},
f"{MAINTENANCE_ROOT}.*": {"queue": "maintenance"},
},
)

View File

@ -8,10 +8,6 @@ logger = logging.getLogger(__name__)
app.conf.beat_schedule = {
"sync-mail-all": {
"task": "memory.workers.tasks.email.sync_all_accounts",
"schedule": settings.EMAIL_SYNC_INTERVAL,
},
"clean-all-collections": {
"task": CLEAN_ALL_COLLECTIONS,
"schedule": settings.CLEAN_COLLECTION_INTERVAL,
@ -20,4 +16,16 @@ app.conf.beat_schedule = {
"task": REINGEST_MISSING_CHUNKS,
"schedule": settings.CHUNK_REINGEST_INTERVAL,
},
"sync-mail-all": {
"task": "memory.workers.tasks.email.sync_all_accounts",
"schedule": settings.EMAIL_SYNC_INTERVAL,
},
"sync-all-comics": {
"task": "memory.workers.tasks.comic.sync_all_comics",
"schedule": settings.COMIC_SYNC_INTERVAL,
},
"sync-all-article-feeds": {
"task": "memory.workers.tasks.blogs.sync_all_article_feeds",
"schedule": settings.ARTICLE_FEED_SYNC_INTERVAL,
},
}

View File

@ -2,7 +2,7 @@
Import sub-modules so Celery can register their @app.task decorators.
"""
from memory.workers.tasks import email, comic, blogs, ebook # noqa
from memory.workers.tasks import email, comic, blogs, ebook, forums # noqa
from memory.workers.tasks.blogs import (
SYNC_WEBPAGE,
SYNC_ARTICLE_FEED,
@ -12,11 +12,13 @@ from memory.workers.tasks.blogs import (
from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_SMBC, SYNC_XKCD
from memory.workers.tasks.ebook import SYNC_BOOK
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
from memory.workers.tasks.forums import SYNC_LESSWRONG, SYNC_LESSWRONG_POST
from memory.workers.tasks.maintenance import (
CLEAN_ALL_COLLECTIONS,
CLEAN_COLLECTION,
REINGEST_MISSING_CHUNKS,
REINGEST_CHUNK,
REINGEST_ITEM,
)
@ -25,6 +27,7 @@ __all__ = [
"comic",
"blogs",
"ebook",
"forums",
"SYNC_WEBPAGE",
"SYNC_ARTICLE_FEED",
"SYNC_ALL_ARTICLE_FEEDS",
@ -34,10 +37,13 @@ __all__ = [
"SYNC_XKCD",
"SYNC_BOOK",
"SYNC_ACCOUNT",
"SYNC_LESSWRONG",
"SYNC_LESSWRONG_POST",
"SYNC_ALL_ACCOUNTS",
"PROCESS_EMAIL",
"CLEAN_ALL_COLLECTIONS",
"CLEAN_COLLECTION",
"REINGEST_MISSING_CHUNKS",
"REINGEST_CHUNK",
"REINGEST_ITEM",
]

View File

@ -1,5 +1,5 @@
import logging
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Iterable, cast
from memory.common.db.connection import make_session
@ -7,7 +7,7 @@ from memory.common.db.models import ArticleFeed, BlogPost
from memory.parsers.blogs import parse_webpage
from memory.parsers.feeds import get_feed_parser
from memory.parsers.archives import get_archive_fetcher
from memory.workers.celery_app import app
from memory.workers.celery_app import app, BLOGS_ROOT
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
@ -18,10 +18,10 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__)
SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage"
SYNC_ARTICLE_FEED = "memory.workers.tasks.blogs.sync_article_feed"
SYNC_ALL_ARTICLE_FEEDS = "memory.workers.tasks.blogs.sync_all_article_feeds"
SYNC_WEBSITE_ARCHIVE = "memory.workers.tasks.blogs.sync_website_archive"
SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage"
SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed"
SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
@app.task(name=SYNC_WEBPAGE)
@ -71,7 +71,7 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
logger.info(f"Blog post already exists: {existing_post.title}")
return create_task_result(existing_post, "already_exists", url=article.url)
return process_content_item(blog_post, "blog", session, tags)
return process_content_item(blog_post, session)
@app.task(name=SYNC_ARTICLE_FEED)
@ -93,8 +93,8 @@ def sync_article_feed(feed_id: int) -> dict:
return {"status": "error", "error": "Feed not found or inactive"}
last_checked_at = cast(datetime | None, feed.last_checked_at)
if last_checked_at and datetime.now() - last_checked_at < timedelta(
seconds=cast(int, feed.check_interval)
if last_checked_at and datetime.now(timezone.utc) - last_checked_at < timedelta(
minutes=cast(int, feed.check_interval)
):
logger.info(f"Feed {feed_id} checked too recently, skipping")
return {"status": "skipped_recent_check", "feed_id": feed_id}
@ -129,7 +129,7 @@ def sync_article_feed(feed_id: int) -> dict:
logger.error(f"Error parsing feed {feed.url}: {e}")
errors += 1
feed.last_checked_at = datetime.now() # type: ignore
feed.last_checked_at = datetime.now(timezone.utc) # type: ignore
session.commit()
return {

View File

@ -9,7 +9,7 @@ from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import Comic, clean_filename
from memory.parsers import comics
from memory.workers.celery_app import app
from memory.workers.celery_app import app, COMIC_ROOT
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
@ -19,10 +19,10 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__)
SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics"
SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc"
SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd"
SYNC_COMIC = "memory.workers.tasks.comic.sync_comic"
SYNC_ALL_COMICS = f"{COMIC_ROOT}.sync_all_comics"
SYNC_SMBC = f"{COMIC_ROOT}.sync_smbc"
SYNC_XKCD = f"{COMIC_ROOT}.sync_xkcd"
SYNC_COMIC = f"{COMIC_ROOT}.sync_comic"
BASE_SMBC_URL = "https://www.smbc-comics.com/"
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
@ -109,7 +109,7 @@ def sync_comic(
)
with make_session() as session:
return process_content_item(comic, "comic", session)
return process_content_item(comic, session)
@app.task(name=SYNC_SMBC)

View File

@ -99,6 +99,7 @@ def embed_source_item(source_item: SourceItem) -> int:
except Exception as e:
source_item.embed_status = "FAILED" # type: ignore
logger.error(f"Failed to embed {type(source_item).__name__}: {e}")
logger.error(traceback.format_exc())
return 0
@ -156,6 +157,7 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
for item in items_to_process:
item.embed_status = "FAILED" # type: ignore
logger.error(f"Failed to push embeddings to Qdrant: {e}")
logger.error(traceback.format_exc())
raise
@ -186,9 +188,7 @@ def create_task_result(
}
def process_content_item(
item: SourceItem, collection_name: str, session, tags: Iterable[str] = []
) -> dict[str, Any]:
def process_content_item(item: SourceItem, session) -> dict[str, Any]:
"""
Execute complete content processing workflow.
@ -200,7 +200,6 @@ def process_content_item(
Args:
item: SourceItem to process
collection_name: Qdrant collection name for vector storage
session: Database session for persistence
tags: Optional tags to associate with the item (currently unused)
@ -223,7 +222,7 @@ def process_content_item(
return create_task_result(item, status, content_length=getattr(item, "size", 0))
try:
push_to_qdrant([item], collection_name)
push_to_qdrant([item], cast(str, item.modality))
status = "processed"
item.embed_status = "STORED" # type: ignore
logger.info(
@ -231,6 +230,7 @@ def process_content_item(
)
except Exception as e:
logger.error(f"Failed to push embeddings to Qdrant: {e}")
logger.error(traceback.format_exc())
item.embed_status = "FAILED" # type: ignore
session.commit()
@ -261,10 +261,8 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
try:
return func(*args, **kwargs)
except Exception as e:
logger.error(
f"Task {func.__name__} failed with traceback:\n{traceback.format_exc()}"
)
logger.error(f"Task {func.__name__} failed: {e}")
logger.error(traceback.format_exc())
return {"status": "error", "error": str(e)}
return wrapper

View File

@ -6,7 +6,7 @@ import memory.common.settings as settings
from memory.parsers.ebook import Ebook, parse_ebook, Section
from memory.common.db.models import Book, BookSection
from memory.common.db.connection import make_session
from memory.workers.celery_app import app
from memory.workers.celery_app import app, EBOOK_ROOT
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
@ -17,7 +17,7 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__)
SYNC_BOOK = "memory.workers.tasks.ebook.sync_book"
SYNC_BOOK = f"{EBOOK_ROOT}.sync_book"
# Minimum section length to embed (avoid noise from very short sections)
MIN_SECTION_LENGTH = 100

View File

@ -3,7 +3,7 @@ from datetime import datetime
from typing import cast
from memory.common.db.connection import make_session
from memory.common.db.models import EmailAccount, MailMessage
from memory.workers.celery_app import app
from memory.workers.celery_app import app, EMAIL_ROOT
from memory.workers.email import (
create_mail_message,
imap_connection,
@ -18,9 +18,9 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__)
PROCESS_EMAIL = "memory.workers.tasks.email.process_message"
SYNC_ACCOUNT = "memory.workers.tasks.email.sync_account"
SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
PROCESS_EMAIL = f"{EMAIL_ROOT}.process_message"
SYNC_ACCOUNT = f"{EMAIL_ROOT}.sync_account"
SYNC_ALL_ACCOUNTS = f"{EMAIL_ROOT}.sync_all_accounts"
@app.task(name=PROCESS_EMAIL)

View File

@ -0,0 +1,85 @@
from datetime import datetime, timedelta
import logging
from memory.parsers.lesswrong import fetch_lesswrong_posts, LessWrongPost
from memory.common.db.connection import make_session
from memory.common.db.models import ForumPost
from memory.workers.celery_app import app, FORUMS_ROOT
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
create_task_result,
process_content_item,
safe_task_execution,
)
logger = logging.getLogger(__name__)
SYNC_LESSWRONG = f"{FORUMS_ROOT}.sync_lesswrong"
SYNC_LESSWRONG_POST = f"{FORUMS_ROOT}.sync_lesswrong_post"
@app.task(name=SYNC_LESSWRONG_POST)
@safe_task_execution
def sync_lesswrong_post(
post: LessWrongPost,
tags: list[str] = [],
):
logger.info(f"Syncing LessWrong post {post['url']}")
sha256 = create_content_hash(post["content"])
post["tags"] = list(set(post["tags"] + tags))
post_obj = ForumPost(
embed_status="RAW",
size=len(post["content"].encode("utf-8")),
modality="forum",
mime_type="text/markdown",
sha256=sha256,
**{k: v for k, v in post.items() if hasattr(ForumPost, k)},
)
with make_session() as session:
existing_post = check_content_exists(
session, ForumPost, url=post_obj.url, sha256=sha256
)
if existing_post:
logger.info(f"LessWrong post already exists: {existing_post.title}")
return create_task_result(existing_post, "already_exists", url=post_obj.url)
return process_content_item(post_obj, session)
@app.task(name=SYNC_LESSWRONG)
@safe_task_execution
def sync_lesswrong(
since: str = (datetime.now() - timedelta(days=30)).isoformat(),
min_karma: int = 10,
limit: int = 50,
cooldown: float = 0.5,
max_items: int = 1000,
af: bool = False,
tags: list[str] = [],
):
logger.info(f"Syncing LessWrong posts since {since}")
start_date = datetime.fromisoformat(since)
posts = fetch_lesswrong_posts(start_date, min_karma, limit, cooldown, max_items, af)
posts_num, new_posts = 0, 0
with make_session() as session:
for post in posts:
if not check_content_exists(session, ForumPost, url=post["url"]):
new_posts += 1
sync_lesswrong_post.delay(post, tags)
if posts_num >= max_items:
break
posts_num += 1
return {
"posts_num": posts_num,
"new_posts": new_posts,
"since": since,
"min_karma": min_karma,
"max_items": max_items,
"af": af,
}

View File

@ -3,21 +3,24 @@ from collections import defaultdict
from datetime import datetime, timedelta
from typing import Sequence
from memory.workers.tasks.content_processing import process_content_item
from sqlalchemy import select
from sqlalchemy.orm import contains_eager
from memory.common import collections, embedding, qdrant, settings
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
from memory.workers.celery_app import app
from memory.workers.celery_app import app, MAINTENANCE_ROOT
logger = logging.getLogger(__name__)
CLEAN_ALL_COLLECTIONS = "memory.workers.tasks.maintenance.clean_all_collections"
CLEAN_COLLECTION = "memory.workers.tasks.maintenance.clean_collection"
REINGEST_MISSING_CHUNKS = "memory.workers.tasks.maintenance.reingest_missing_chunks"
REINGEST_CHUNK = "memory.workers.tasks.maintenance.reingest_chunk"
CLEAN_ALL_COLLECTIONS = f"{MAINTENANCE_ROOT}.clean_all_collections"
CLEAN_COLLECTION = f"{MAINTENANCE_ROOT}.clean_collection"
REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks"
REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk"
REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item"
REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items"
@app.task(name=CLEAN_COLLECTION)
@ -87,6 +90,61 @@ def reingest_chunk(chunk_id: str, collection: str):
session.commit()
def get_item_class(item_type: str):
class_ = SourceItem.registry._class_registry.get(item_type)
if not class_:
available_types = ", ".join(sorted(SourceItem.registry._class_registry.keys()))
raise ValueError(
f"Unsupported item type {item_type}. Available types: {available_types}"
)
return class_
@app.task(name=REINGEST_ITEM)
def reingest_item(item_id: str, item_type: str):
logger.info(f"Reingesting {item_type} {item_id}")
try:
class_ = get_item_class(item_type)
except ValueError as e:
logger.error(f"Error getting item class: {e}")
return {"status": "error", "error": str(e)}
with make_session() as session:
item = session.query(class_).get(item_id)
if not item:
return {"status": "error", "error": f"Item {item_id} not found"}
chunk_ids = [str(c.id) for c in item.chunks if c.id]
if chunk_ids:
client = qdrant.get_qdrant_client()
qdrant.delete_points(client, item.modality, chunk_ids)
for chunk in item.chunks:
session.delete(chunk)
return process_content_item(item, session)
@app.task(name=REINGEST_EMPTY_SOURCE_ITEMS)
def reingest_empty_source_items(item_type: str):
logger.info("Reingesting empty source items")
try:
class_ = get_item_class(item_type)
except ValueError as e:
logger.error(f"Error getting item class: {e}")
return {"status": "error", "error": str(e)}
with make_session() as session:
item_ids = session.query(class_.id).filter(~class_.chunks.any()).all()
logger.info(f"Found {len(item_ids)} items to reingest")
for item_id in item_ids:
reingest_item.delay(item_id.id, item_type) # type: ignore
return {"status": "success", "items": len(item_ids)}
def check_batch(batch: Sequence[Chunk]) -> dict:
client = qdrant.get_qdrant_client()
by_collection = defaultdict(list)
@ -116,14 +174,19 @@ def check_batch(batch: Sequence[Chunk]) -> dict:
@app.task(name=REINGEST_MISSING_CHUNKS)
def reingest_missing_chunks(
batch_size: int = 1000, minutes_ago: int = settings.CHUNK_REINGEST_SINCE_MINUTES
batch_size: int = 1000,
collection: str | None = None,
minutes_ago: int = settings.CHUNK_REINGEST_SINCE_MINUTES,
):
logger.info("Reingesting missing chunks")
total_stats = defaultdict(lambda: {"missing": 0, "correct": 0, "total": 0})
since = datetime.now() - timedelta(minutes=minutes_ago)
with make_session() as session:
total_count = session.query(Chunk).filter(Chunk.checked_at < since).count()
query = session.query(Chunk).filter(Chunk.checked_at < since)
if collection:
query = query.filter(Chunk.source.has(SourceItem.modality == collection))
total_count = query.count()
logger.info(
f"Found {total_count} chunks to check, processing in batches of {batch_size}"

View File

@ -0,0 +1,542 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch, Mock
import json
import pytest
from PIL import Image as PILImage
from memory.parsers.lesswrong import (
LessWrongPost,
make_graphql_query,
fetch_posts_from_api,
is_valid_post,
extract_authors,
extract_description,
parse_lesswrong_date,
extract_body,
format_post,
fetch_lesswrong,
fetch_lesswrong_posts,
)
@pytest.mark.parametrize(
"after, af, limit, min_karma, expected_contains",
[
(
datetime(2023, 1, 15, 10, 30),
False,
50,
10,
[
"af: false",
"limit: 50",
"karmaThreshold: 10",
'after: "2023-01-15T10:30:00Z"',
],
),
(
datetime(2023, 2, 20),
True,
25,
5,
[
"af: true",
"limit: 25",
"karmaThreshold: 5",
'after: "2023-02-20T00:00:00Z"',
],
),
],
)
def test_make_graphql_query(after, af, limit, min_karma, expected_contains):
query = make_graphql_query(after, af, limit, min_karma)
for expected in expected_contains:
assert expected in query
# Check that all required fields are present
required_fields = [
"_id",
"title",
"slug",
"pageUrl",
"postedAt",
"modifiedAt",
"score",
"extendedScore",
"baseScore",
"voteCount",
"commentCount",
"wordCount",
"tags",
"user",
"coauthors",
"af",
"htmlBody",
]
for field in required_fields:
assert field in query
@patch("memory.parsers.lesswrong.requests.post")
def test_fetch_posts_from_api_success(mock_post):
mock_response = Mock()
mock_response.json.return_value = {
"data": {
"posts": {
"totalCount": 2,
"results": [
{"_id": "1", "title": "Post 1"},
{"_id": "2", "title": "Post 2"},
],
}
}
}
mock_post.return_value = mock_response
url = "https://www.lesswrong.com/graphql"
query = "test query"
result = fetch_posts_from_api(url, query)
mock_post.assert_called_once_with(
url,
headers={
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0"
},
json={"query": query},
timeout=30,
)
assert result == {
"totalCount": 2,
"results": [
{"_id": "1", "title": "Post 1"},
{"_id": "2", "title": "Post 2"},
],
}
@patch("memory.parsers.lesswrong.requests.post")
def test_fetch_posts_from_api_http_error(mock_post):
mock_response = Mock()
mock_response.raise_for_status.side_effect = Exception("HTTP Error")
mock_post.return_value = mock_response
with pytest.raises(Exception, match="HTTP Error"):
fetch_posts_from_api("https://example.com", "query")
@pytest.mark.parametrize(
"post_data, min_karma, expected",
[
# Valid post with content and karma
({"htmlBody": "<p>Content</p>", "baseScore": 15}, 10, True),
# Valid post at karma threshold
({"htmlBody": "<p>Content</p>", "baseScore": 10}, 10, True),
# Invalid: no content
({"htmlBody": "", "baseScore": 15}, 10, False),
({"htmlBody": None, "baseScore": 15}, 10, False),
({}, 10, False),
# Invalid: below karma threshold
({"htmlBody": "<p>Content</p>", "baseScore": 5}, 10, False),
({"htmlBody": "<p>Content</p>"}, 10, False), # No baseScore
# Edge cases
(
{"htmlBody": " ", "baseScore": 15},
10,
True,
), # Whitespace only - actually valid
({"htmlBody": "<p>Content</p>", "baseScore": 0}, 0, True), # Zero threshold
],
)
def test_is_valid_post(post_data, min_karma, expected):
assert is_valid_post(post_data, min_karma) == expected
@pytest.mark.parametrize(
"post_data, expected",
[
# User only
(
{"user": {"displayName": "Alice"}},
["Alice"],
),
# User with coauthors
(
{
"user": {"displayName": "Alice"},
"coauthors": [{"displayName": "Bob"}, {"displayName": "Charlie"}],
},
["Alice", "Bob", "Charlie"],
),
# Coauthors only (no user)
(
{"coauthors": [{"displayName": "Bob"}]},
["Bob"],
),
# Empty coauthors list
(
{"user": {"displayName": "Alice"}, "coauthors": []},
["Alice"],
),
# No authors at all
({}, ["anonymous"]),
({"user": None, "coauthors": None}, ["anonymous"]),
({"user": None, "coauthors": []}, ["anonymous"]),
],
)
def test_extract_authors(post_data, expected):
assert extract_authors(post_data) == expected
@pytest.mark.parametrize(
"body, expected",
[
# Short content
("This is a short paragraph.", "This is a short paragraph."),
# Multiple paragraphs - only first
("First paragraph.\n\nSecond paragraph.", "First paragraph."),
# Long content - truncated
(
"A" * 350,
"A" * 300 + "...",
),
# Empty content
("", ""),
# Whitespace only
(" \n\n ", " "),
],
)
def test_extract_description(body, expected):
assert extract_description(body) == expected
@pytest.mark.parametrize(
"date_str, expected",
[
# Standard ISO formats
("2023-01-15T10:30:00.000Z", datetime(2023, 1, 15, 10, 30, 0)),
("2023-01-15T10:30:00Z", datetime(2023, 1, 15, 10, 30, 0)),
("2023-01-15T10:30:00.000", datetime(2023, 1, 15, 10, 30, 0)),
("2023-01-15T10:30:00", datetime(2023, 1, 15, 10, 30, 0)),
# Fallback to fromisoformat
("2023-01-15T10:30:00.123456", datetime(2023, 1, 15, 10, 30, 0, 123456)),
# Invalid dates
("invalid-date", None),
("", None),
(None, None),
("2023-13-45T25:70:70Z", None), # Invalid date components
],
)
def test_parse_lesswrong_date(date_str, expected):
assert parse_lesswrong_date(date_str) == expected
@patch("memory.parsers.lesswrong.process_images")
@patch("memory.parsers.lesswrong.markdownify")
@patch("memory.parsers.lesswrong.settings")
def test_extract_body(mock_settings, mock_markdownify, mock_process_images):
from pathlib import Path
mock_settings.FILE_STORAGE_DIR = Path("/tmp")
mock_markdownify.return_value = "# Markdown content"
mock_images = {"image1.jpg": Mock(spec=PILImage.Image)}
mock_process_images.return_value = (Mock(), mock_images)
post = {
"htmlBody": "<h1>HTML content</h1>",
"pageUrl": "https://lesswrong.com/posts/abc123/test-post",
}
body, images = extract_body(post)
assert body == "# Markdown content"
assert images == mock_images
mock_process_images.assert_called_once()
mock_markdownify.assert_called_once()
@patch("memory.parsers.lesswrong.process_images")
def test_extract_body_empty_content(mock_process_images):
post = {"htmlBody": ""}
body, images = extract_body(post)
assert body == ""
assert images == {}
mock_process_images.assert_not_called()
@patch("memory.parsers.lesswrong.extract_body")
def test_format_post(mock_extract_body):
mock_extract_body.return_value = ("Markdown body", {"img1.jpg": Mock()})
post_data = {
"_id": "abc123",
"title": "Test Post",
"slug": "test-post",
"pageUrl": "https://lesswrong.com/posts/abc123/test-post",
"postedAt": "2023-01-15T10:30:00Z",
"modifiedAt": "2023-01-16T11:00:00Z",
"score": 25,
"extendedScore": 30,
"baseScore": 20,
"voteCount": 15,
"commentCount": 5,
"wordCount": 1000,
"tags": [{"name": "AI"}, {"name": "Rationality"}],
"user": {"displayName": "Author"},
"coauthors": [{"displayName": "Coauthor"}],
"af": True,
"htmlBody": "<p>HTML content</p>",
}
result = format_post(post_data)
expected: LessWrongPost = {
"title": "Test Post",
"url": "https://lesswrong.com/posts/abc123/test-post",
"description": "Markdown body",
"content": "Markdown body",
"authors": ["Author", "Coauthor"],
"published_at": datetime(2023, 1, 15, 10, 30, 0),
"guid": "abc123",
"karma": 20,
"votes": 15,
"comments": 5,
"words": 1000,
"tags": ["AI", "Rationality"],
"af": True,
"score": 25,
"extended_score": 30,
"modified_at": "2023-01-16T11:00:00Z",
"slug": "test-post",
"images": ["img1.jpg"],
}
assert result == expected
@patch("memory.parsers.lesswrong.extract_body")
def test_format_post_minimal_data(mock_extract_body):
mock_extract_body.return_value = ("", {})
post_data = {}
result = format_post(post_data)
expected: LessWrongPost = {
"title": "Untitled",
"url": "",
"description": "",
"content": "",
"authors": ["anonymous"],
"published_at": None,
"guid": None,
"karma": 0,
"votes": 0,
"comments": 0,
"words": 0,
"tags": [],
"af": False,
"score": 0,
"extended_score": 0,
"modified_at": None,
"slug": None,
"images": [],
}
assert result == expected
@patch("memory.parsers.lesswrong.fetch_posts_from_api")
@patch("memory.parsers.lesswrong.make_graphql_query")
@patch("memory.parsers.lesswrong.format_post")
@patch("memory.parsers.lesswrong.is_valid_post")
def test_fetch_lesswrong_success(
mock_is_valid, mock_format, mock_query, mock_fetch_api
):
mock_query.return_value = "test query"
mock_fetch_api.return_value = {
"results": [
{"_id": "1", "title": "Post 1"},
{"_id": "2", "title": "Post 2"},
]
}
mock_is_valid.side_effect = [True, False] # First valid, second invalid
mock_format.return_value = {"title": "Formatted Post"}
url = "https://lesswrong.com/graphql"
current_date = datetime(2023, 1, 15)
result = fetch_lesswrong(url, current_date, af=True, min_karma=5, limit=25)
mock_query.assert_called_once_with(current_date, True, 25, 5)
mock_fetch_api.assert_called_once_with(url, "test query")
assert mock_is_valid.call_count == 2
mock_format.assert_called_once_with({"_id": "1", "title": "Post 1"})
assert result == [{"title": "Formatted Post"}]
@patch("memory.parsers.lesswrong.fetch_posts_from_api")
def test_fetch_lesswrong_empty_results(mock_fetch_api):
mock_fetch_api.return_value = {"results": []}
result = fetch_lesswrong("url", datetime.now())
assert result == []
@patch("memory.parsers.lesswrong.fetch_posts_from_api")
def test_fetch_lesswrong_same_item_as_last(mock_fetch_api):
mock_fetch_api.return_value = {
"results": [{"pageUrl": "https://lesswrong.com/posts/same"}]
}
result = fetch_lesswrong(
"url", datetime.now(), last_url="https://lesswrong.com/posts/same"
)
assert result == []
@patch("memory.parsers.lesswrong.fetch_lesswrong")
@patch("memory.parsers.lesswrong.time.sleep")
def test_fetch_lesswrong_posts_success(mock_sleep, mock_fetch):
since = datetime(2023, 1, 15)
# Mock three batches of posts
mock_fetch.side_effect = [
[
{"published_at": datetime(2023, 1, 14), "url": "post1"},
{"published_at": datetime(2023, 1, 13), "url": "post2"},
],
[
{"published_at": datetime(2023, 1, 12), "url": "post3"},
],
[], # Empty result to stop iteration
]
posts = list(
fetch_lesswrong_posts(
since=since,
min_karma=10,
limit=50,
cooldown=0.1,
max_items=100,
af=False,
url="https://lesswrong.com/graphql",
)
)
assert len(posts) == 3
assert posts[0]["url"] == "post1"
assert posts[1]["url"] == "post2"
assert posts[2]["url"] == "post3"
# Should have called sleep twice (after first two batches)
assert mock_sleep.call_count == 2
mock_sleep.assert_called_with(0.1)
@patch("memory.parsers.lesswrong.fetch_lesswrong")
def test_fetch_lesswrong_posts_default_since(mock_fetch):
mock_fetch.return_value = []
with patch("memory.parsers.lesswrong.datetime") as mock_datetime:
mock_now = datetime(2023, 1, 15, 12, 0, 0)
mock_datetime.now.return_value = mock_now
list(fetch_lesswrong_posts())
# Should use yesterday as default
expected_since = mock_now - timedelta(days=1)
mock_fetch.assert_called_with(
"https://www.lesswrong.com/graphql", expected_since, False, 10, 50, None
)
@patch("memory.parsers.lesswrong.fetch_lesswrong")
def test_fetch_lesswrong_posts_max_items_limit(mock_fetch):
# Return posts that would exceed max_items
mock_fetch.side_effect = [
[{"published_at": datetime(2023, 1, 14), "url": f"post{i}"} for i in range(8)],
[
{"published_at": datetime(2023, 1, 13), "url": f"post{i}"}
for i in range(8, 16)
],
]
posts = list(
fetch_lesswrong_posts(
since=datetime(2023, 1, 15),
max_items=7, # Should stop after 7 items
cooldown=0,
)
)
# The logic checks items_count < max_items before fetching, so it will fetch the first batch
# Since items_count (8) >= max_items (7) after first batch, it won't fetch the second batch
assert len(posts) == 8
@patch("memory.parsers.lesswrong.fetch_lesswrong")
def test_fetch_lesswrong_posts_api_error(mock_fetch):
mock_fetch.side_effect = Exception("API Error")
posts = list(fetch_lesswrong_posts(since=datetime(2023, 1, 15)))
assert posts == []
@patch("memory.parsers.lesswrong.fetch_lesswrong")
def test_fetch_lesswrong_posts_no_date_progression(mock_fetch):
# Mock posts with same date to trigger stopping condition
same_date = datetime(2023, 1, 15)
mock_fetch.side_effect = [
[{"published_at": same_date, "url": "post1"}],
[{"published_at": same_date, "url": "post2"}], # Same date, should stop
]
posts = list(fetch_lesswrong_posts(since=same_date, cooldown=0))
assert len(posts) == 1 # Should stop after first batch
@patch("memory.parsers.lesswrong.fetch_lesswrong")
def test_fetch_lesswrong_posts_none_date(mock_fetch):
# Mock posts with None date to trigger stopping condition
mock_fetch.side_effect = [
[{"published_at": None, "url": "post1"}],
]
posts = list(fetch_lesswrong_posts(since=datetime(2023, 1, 15), cooldown=0))
assert len(posts) == 1 # Should stop after first batch
def test_lesswrong_post_type():
"""Test that LessWrongPost TypedDict has correct structure."""
# This is more of a documentation test to ensure the type is correct
post: LessWrongPost = {
"title": "Test",
"url": "https://example.com",
"description": "Description",
"content": "Content",
"authors": ["Author"],
"published_at": datetime.now(),
"guid": "123",
"karma": 10,
"votes": 5,
"comments": 2,
"words": 100,
"tags": ["tag"],
"af": False,
"score": 15,
"extended_score": 20,
}
# Optional fields
post["modified_at"] = "2023-01-15T10:30:00Z"
post["slug"] = "test-slug"
post["images"] = ["image.jpg"]
# Should not raise any type errors
assert post["title"] == "Test"

View File

@ -1,5 +1,5 @@
import pytest
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, patch
from memory.common.db.models import ArticleFeed, BlogPost
@ -64,37 +64,6 @@ def inactive_article_feed(db_session):
return feed
@pytest.fixture
def recently_checked_feed(db_session):
"""Create a recently checked ArticleFeed."""
from sqlalchemy import text
# Use a very recent timestamp that will trigger the "recently checked" condition
# The check_interval is 3600 seconds, so 30 seconds ago should be "recent"
recent_time = datetime.now() - timedelta(seconds=30)
feed = ArticleFeed(
url="https://example.com/recent.xml",
title="Recent Feed",
description="A recently checked feed",
tags=["test"],
check_interval=3600,
active=True,
)
db_session.add(feed)
db_session.flush() # Get the ID
# Manually set the last_checked_at to avoid timezone issues
db_session.execute(
text(
"UPDATE article_feeds SET last_checked_at = :timestamp WHERE id = :feed_id"
),
{"timestamp": recent_time, "feed_id": feed.id},
)
db_session.commit()
return feed
@pytest.fixture
def mock_feed_item():
"""Mock feed item for testing."""
@ -564,3 +533,65 @@ def test_sync_website_archive_empty_results(
assert result["articles_found"] == 0
assert result["new_articles"] == 0
assert result["task_ids"] == []
@pytest.mark.parametrize(
"check_interval_minutes,seconds_since_check,should_skip",
[
(60, 30, True), # 60min interval, checked 30s ago -> skip
(60, 3000, True), # 60min interval, checked 50min ago -> skip
(60, 4000, False), # 60min interval, checked 66min ago -> don't skip
(30, 1000, True), # 30min interval, checked 16min ago -> skip
(30, 2000, False), # 30min interval, checked 33min ago -> don't skip
],
)
@patch("memory.workers.tasks.blogs.get_feed_parser")
def test_sync_article_feed_check_interval(
mock_get_parser,
check_interval_minutes,
seconds_since_check,
should_skip,
db_session,
):
"""Test sync respects check interval with various timing scenarios."""
from sqlalchemy import text
# Mock parser to return None (no parser available) for non-skipped cases
mock_get_parser.return_value = None
# Create feed with specific check interval
feed = ArticleFeed(
url="https://example.com/interval-test.xml",
title="Interval Test Feed",
description="Feed for testing check intervals",
tags=["test"],
check_interval=check_interval_minutes,
active=True,
)
db_session.add(feed)
db_session.flush()
# Set last_checked_at to specific time in the past
last_checked_time = datetime.now(timezone.utc) - timedelta(
seconds=seconds_since_check
)
db_session.execute(
text(
"UPDATE article_feeds SET last_checked_at = :timestamp WHERE id = :feed_id"
),
{"timestamp": last_checked_time, "feed_id": feed.id},
)
db_session.commit()
result = blogs.sync_article_feed(feed.id)
if should_skip:
assert result == {"status": "skipped_recent_check", "feed_id": feed.id}
# get_feed_parser should not be called when skipping
mock_get_parser.assert_not_called()
else:
# Should proceed with sync, but will fail due to no parser - that's expected
assert result["status"] == "error"
assert result["error"] == "No parser available for feed"
# get_feed_parser should be called when not skipping
mock_get_parser.assert_called_once()

View File

@ -471,11 +471,9 @@ def test_process_content_item(
"memory.workers.tasks.content_processing.push_to_qdrant",
side_effect=Exception("Qdrant error"),
):
result = process_content_item(
mail_message, "mail", db_session, ["tag1"]
)
result = process_content_item(mail_message, db_session)
else:
result = process_content_item(mail_message, "mail", db_session, ["tag1"])
result = process_content_item(mail_message, db_session)
assert result["status"] == expected_status
assert result["embed_status"] == expected_embed_status

View File

@ -326,7 +326,7 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
# Check that the full content was passed to the embedding function
texts = mock_voyage_client.embed.call_args[0][0]
assert texts == [
large_page_1.strip(),
large_page_2.strip(),
large_section_content.strip(),
[large_page_1.strip()],
[large_page_2.strip()],
[large_section_content.strip()],
]

View File

@ -0,0 +1,562 @@
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, patch
from memory.common.db.models import ForumPost
from memory.workers.tasks import forums
from memory.parsers.lesswrong import LessWrongPost
from memory.workers.tasks.content_processing import create_content_hash
@pytest.fixture
def mock_lesswrong_post():
"""Mock LessWrong post data for testing."""
return LessWrongPost(
title="Test LessWrong Post",
url="https://www.lesswrong.com/posts/test123/test-post",
description="This is a test post description",
content="This is test post content with enough text to be processed and embedded.",
authors=["Test Author"],
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
guid="test123",
karma=25,
votes=10,
comments=5,
words=100,
tags=["rationality", "ai"],
af=False,
score=25,
extended_score=30,
modified_at="2024-01-01T12:30:00Z",
slug="test-post",
images=[], # Empty images to avoid file path issues
)
@pytest.fixture
def mock_empty_lesswrong_post():
"""Mock LessWrong post with empty content."""
return LessWrongPost(
title="Empty Post",
url="https://www.lesswrong.com/posts/empty123/empty-post",
description="",
content="",
authors=["Empty Author"],
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
guid="empty123",
karma=5,
votes=2,
comments=0,
words=0,
tags=[],
af=False,
score=5,
extended_score=5,
slug="empty-post",
images=[],
)
@pytest.fixture
def mock_af_post():
"""Mock Alignment Forum post."""
return LessWrongPost(
title="AI Safety Research",
url="https://www.lesswrong.com/posts/af123/ai-safety-research",
description="Important AI safety research",
content="This is important AI safety research content that should be processed.",
authors=["AI Researcher"],
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
guid="af123",
karma=50,
votes=20,
comments=15,
words=200,
tags=["ai-safety", "alignment"],
af=True,
score=50,
extended_score=60,
slug="ai-safety-research",
images=[],
)
def test_sync_lesswrong_post_success(mock_lesswrong_post, db_session, qdrant):
"""Test successful LessWrong post synchronization."""
result = forums.sync_lesswrong_post(mock_lesswrong_post, ["test", "forum"])
# Verify the ForumPost was created in the database
forum_post = (
db_session.query(ForumPost)
.filter_by(url="https://www.lesswrong.com/posts/test123/test-post")
.first()
)
assert forum_post is not None
assert forum_post.title == "Test LessWrong Post"
assert (
forum_post.content
== "This is test post content with enough text to be processed and embedded."
)
assert forum_post.modality == "forum"
assert forum_post.mime_type == "text/markdown"
assert forum_post.authors == ["Test Author"]
assert forum_post.karma == 25
assert forum_post.votes == 10
assert forum_post.comments == 5
assert forum_post.words == 100
assert forum_post.score == 25
assert forum_post.slug == "test-post"
assert forum_post.images == []
assert "test" in forum_post.tags
assert "forum" in forum_post.tags
assert "rationality" in forum_post.tags
assert "ai" in forum_post.tags
# Verify the result
assert result["status"] == "processed"
assert result["forumpost_id"] == forum_post.id
assert result["title"] == "Test LessWrong Post"
def test_sync_lesswrong_post_empty_content(mock_empty_lesswrong_post, db_session):
"""Test LessWrong post sync with empty content."""
result = forums.sync_lesswrong_post(mock_empty_lesswrong_post)
# Should still create the post but with failed status due to no chunks
forum_post = (
db_session.query(ForumPost)
.filter_by(url="https://www.lesswrong.com/posts/empty123/empty-post")
.first()
)
assert forum_post is not None
assert forum_post.title == "Empty Post"
assert forum_post.content == ""
assert result["status"] == "failed" # No chunks generated for empty content
assert result["chunks_count"] == 0
def test_sync_lesswrong_post_already_exists(mock_lesswrong_post, db_session):
"""Test LessWrong post sync when content already exists."""
# Add existing forum post with same content hash
existing_post = ForumPost(
url="https://www.lesswrong.com/posts/test123/test-post",
title="Test LessWrong Post",
content="This is test post content with enough text to be processed and embedded.",
sha256=create_content_hash(
"This is test post content with enough text to be processed and embedded."
),
modality="forum",
tags=["existing"],
mime_type="text/markdown",
size=77,
authors=["Test Author"],
karma=25,
)
db_session.add(existing_post)
db_session.commit()
result = forums.sync_lesswrong_post(mock_lesswrong_post, ["test"])
assert result["status"] == "already_exists"
assert result["forumpost_id"] == existing_post.id
# Verify no duplicate was created
forum_posts = (
db_session.query(ForumPost)
.filter_by(url="https://www.lesswrong.com/posts/test123/test-post")
.all()
)
assert len(forum_posts) == 1
def test_sync_lesswrong_post_with_custom_tags(mock_lesswrong_post, db_session, qdrant):
"""Test LessWrong post sync with custom tags."""
result = forums.sync_lesswrong_post(mock_lesswrong_post, ["custom", "tags"])
# Verify the ForumPost was created with custom tags merged with post tags
forum_post = (
db_session.query(ForumPost)
.filter_by(url="https://www.lesswrong.com/posts/test123/test-post")
.first()
)
assert forum_post is not None
assert "custom" in forum_post.tags
assert "tags" in forum_post.tags
assert "rationality" in forum_post.tags # Original post tags
assert "ai" in forum_post.tags
assert result["status"] == "processed"
def test_sync_lesswrong_post_af_post(mock_af_post, db_session, qdrant):
"""Test syncing an Alignment Forum post."""
result = forums.sync_lesswrong_post(mock_af_post, ["alignment-forum"])
forum_post = (
db_session.query(ForumPost)
.filter_by(url="https://www.lesswrong.com/posts/af123/ai-safety-research")
.first()
)
assert forum_post is not None
assert forum_post.title == "AI Safety Research"
assert forum_post.karma == 50
assert "ai-safety" in forum_post.tags
assert "alignment" in forum_post.tags
assert "alignment-forum" in forum_post.tags
assert result["status"] == "processed"
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
def test_sync_lesswrong_success(mock_fetch, mock_lesswrong_post, db_session):
"""Test successful LessWrong synchronization."""
mock_fetch.return_value = [mock_lesswrong_post]
with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post:
mock_sync_post.delay.return_value = Mock(id="task-123")
result = forums.sync_lesswrong(
since="2024-01-01T00:00:00",
min_karma=10,
limit=50,
cooldown=0.1,
max_items=100,
af=False,
tags=["test"],
)
assert result["posts_num"] == 1
assert result["new_posts"] == 1
assert result["since"] == "2024-01-01T00:00:00"
assert result["min_karma"] == 10
assert result["max_items"] == 100
assert result["af"] == False
# Verify fetch_lesswrong_posts was called with correct arguments
mock_fetch.assert_called_once_with(
datetime.fromisoformat("2024-01-01T00:00:00"),
10, # min_karma
50, # limit
0.1, # cooldown
100, # max_items
False, # af
)
# Verify sync_lesswrong_post was called for the new post
mock_sync_post.delay.assert_called_once_with(mock_lesswrong_post, ["test"])
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
def test_sync_lesswrong_with_existing_posts(
mock_fetch, mock_lesswrong_post, db_session
):
"""Test sync when some posts already exist."""
# Create existing forum post
existing_post = ForumPost(
url="https://www.lesswrong.com/posts/test123/test-post",
title="Existing Post",
content="Existing content",
sha256=b"existing_hash" + bytes(24),
modality="forum",
tags=["existing"],
mime_type="text/markdown",
size=100,
authors=["Test Author"],
)
db_session.add(existing_post)
db_session.commit()
# Mock fetch to return existing post and a new one
new_post = mock_lesswrong_post.copy()
new_post["url"] = "https://www.lesswrong.com/posts/new123/new-post"
new_post["title"] = "New Post"
mock_fetch.return_value = [mock_lesswrong_post, new_post]
with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post:
mock_sync_post.delay.return_value = Mock(id="task-456")
result = forums.sync_lesswrong(max_items=100)
assert result["posts_num"] == 2
assert result["new_posts"] == 1 # Only one new post
# Verify sync_lesswrong_post was only called for the new post
mock_sync_post.delay.assert_called_once_with(new_post, [])
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
def test_sync_lesswrong_no_posts(mock_fetch, db_session):
"""Test sync when no posts are returned."""
mock_fetch.return_value = []
result = forums.sync_lesswrong()
assert result["posts_num"] == 0
assert result["new_posts"] == 0
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
def test_sync_lesswrong_max_items_limit(mock_fetch, db_session):
"""Test that max_items limit is respected."""
# Create multiple mock posts
posts = []
for i in range(5):
post = LessWrongPost(
title=f"Post {i}",
url=f"https://www.lesswrong.com/posts/test{i}/post-{i}",
description=f"Description {i}",
content=f"Content {i}",
authors=[f"Author {i}"],
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
guid=f"test{i}",
karma=10,
votes=5,
comments=2,
words=50,
tags=[],
af=False,
score=10,
extended_score=10,
slug=f"post-{i}",
images=[],
)
posts.append(post)
mock_fetch.return_value = posts
with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post:
mock_sync_post.delay.return_value = Mock(id="task-123")
result = forums.sync_lesswrong(max_items=3)
# Should stop at max_items, but new_posts can be higher because
# the check happens after incrementing new_posts but before incrementing posts_num
assert result["posts_num"] == 3
assert result["new_posts"] == 4 # One more than posts_num due to timing
assert result["max_items"] == 3
@pytest.mark.parametrize(
"since_str,expected_days_ago",
[
("2024-01-01T00:00:00", None), # Specific date
(None, 30), # Default should be 30 days ago
],
)
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
def test_sync_lesswrong_since_parameter(
mock_fetch, since_str, expected_days_ago, db_session
):
"""Test that since parameter is handled correctly."""
mock_fetch.return_value = []
if since_str:
forums.sync_lesswrong(since=since_str)
expected_since = datetime.fromisoformat(since_str)
else:
forums.sync_lesswrong()
# Should default to 30 days ago
expected_since = datetime.now() - timedelta(days=30)
# Verify fetch was called with correct since date
call_args = mock_fetch.call_args[0]
actual_since = call_args[0]
if expected_days_ago:
# For default case, check it's approximately 30 days ago
time_diff = abs((actual_since - expected_since).total_seconds())
assert time_diff < 120 # Within 2 minute tolerance for slow tests
else:
assert actual_since == expected_since
@pytest.mark.parametrize(
"af_value,min_karma,limit,cooldown",
[
(True, 20, 25, 1.0),
(False, 5, 100, 0.0),
(True, 50, 10, 0.5),
],
)
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
def test_sync_lesswrong_parameters(
mock_fetch, af_value, min_karma, limit, cooldown, db_session
):
"""Test that all parameters are passed correctly to fetch function."""
mock_fetch.return_value = []
result = forums.sync_lesswrong(
af=af_value,
min_karma=min_karma,
limit=limit,
cooldown=cooldown,
max_items=500,
)
# Verify fetch was called with correct parameters
call_args = mock_fetch.call_args[0]
assert call_args[1] == min_karma # min_karma
assert call_args[2] == limit # limit
assert call_args[3] == cooldown # cooldown
assert call_args[4] == 500 # max_items
assert call_args[5] == af_value # af
assert result["min_karma"] == min_karma
assert result["af"] == af_value
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
def test_sync_lesswrong_fetch_error(mock_fetch, db_session):
"""Test sync when fetch_lesswrong_posts raises an exception."""
mock_fetch.side_effect = Exception("API error")
# The safe_task_execution decorator should catch this
result = forums.sync_lesswrong()
assert result["status"] == "error"
assert "API error" in result["error"]
def test_sync_lesswrong_post_error_handling(db_session):
"""Test error handling in sync_lesswrong_post."""
# Create invalid post data that will cause an error
invalid_post = {
"title": "Test",
"url": "invalid-url",
"content": "test content",
# Missing required fields
}
# The safe_task_execution decorator should catch this
result = forums.sync_lesswrong_post(invalid_post)
assert result["status"] == "error"
assert "error" in result
@pytest.mark.parametrize(
"post_tags,additional_tags,expected_tags",
[
(["original"], ["new"], ["original", "new"]),
([], ["tag1", "tag2"], ["tag1", "tag2"]),
(["existing"], [], ["existing"]),
(["dup", "tag"], ["tag", "new"], ["dup", "tag", "new"]), # Duplicates removed
],
)
def test_sync_lesswrong_post_tag_merging(
post_tags, additional_tags, expected_tags, db_session, qdrant
):
"""Test that post tags and additional tags are properly merged."""
post = LessWrongPost(
title="Tag Test Post",
url="https://www.lesswrong.com/posts/tag123/tag-test",
description="Test description",
content="Test content for tag merging",
authors=["Tag Author"],
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
guid="tag123",
karma=15,
votes=5,
comments=2,
words=50,
tags=post_tags,
af=False,
score=15,
extended_score=15,
slug="tag-test",
images=[],
)
forums.sync_lesswrong_post(post, additional_tags)
forum_post = (
db_session.query(ForumPost)
.filter_by(url="https://www.lesswrong.com/posts/tag123/tag-test")
.first()
)
assert forum_post is not None
# Check that all expected tags are present (order doesn't matter)
for tag in expected_tags:
assert tag in forum_post.tags
# Check that no unexpected tags are present
assert len(forum_post.tags) == len(set(expected_tags))
def test_sync_lesswrong_post_datetime_handling(db_session, qdrant):
"""Test that datetime fields are properly handled."""
post = LessWrongPost(
title="DateTime Test",
url="https://www.lesswrong.com/posts/dt123/datetime-test",
description="Test description",
content="Test content",
authors=["DateTime Author"],
published_at=datetime(2024, 6, 15, 14, 30, 45, tzinfo=timezone.utc),
guid="dt123",
karma=20,
votes=8,
comments=3,
words=75,
tags=["datetime"],
af=False,
score=20,
extended_score=25,
modified_at="2024-06-15T15:00:00Z",
slug="datetime-test",
images=[],
)
result = forums.sync_lesswrong_post(post)
forum_post = (
db_session.query(ForumPost)
.filter_by(url="https://www.lesswrong.com/posts/dt123/datetime-test")
.first()
)
assert forum_post is not None
assert forum_post.published_at == datetime(
2024, 6, 15, 14, 30, 45, tzinfo=timezone.utc
)
# modified_at should be stored as string in the post data
assert result["status"] == "processed"
def test_sync_lesswrong_post_content_hash_consistency(db_session):
"""Test that content hash is calculated consistently."""
post = LessWrongPost(
title="Hash Test",
url="https://www.lesswrong.com/posts/hash123/hash-test",
description="Test description",
content="Consistent content for hashing",
authors=["Hash Author"],
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
guid="hash123",
karma=15,
votes=5,
comments=2,
words=50,
tags=["hash"],
af=False,
score=15,
extended_score=15,
slug="hash-test",
images=[],
)
# Sync the same post twice
result1 = forums.sync_lesswrong_post(post)
result2 = forums.sync_lesswrong_post(post)
# First should succeed, second should detect existing
assert result1["status"] == "processed"
assert result2["status"] == "already_exists"
assert result1["forumpost_id"] == result2["forumpost_id"]
# Verify only one post exists in database
forum_posts = (
db_session.query(ForumPost)
.filter_by(url="https://www.lesswrong.com/posts/hash123/hash-test")
.all()
)
assert len(forum_posts) == 1

View File

@ -7,12 +7,14 @@ from PIL import Image
from memory.common import qdrant as qd
from memory.common import settings
from memory.common.db.models import Chunk, SourceItem
from memory.common.db.models import Chunk, SourceItem, MailMessage, BlogPost
from memory.workers.tasks.maintenance import (
clean_collection,
reingest_chunk,
check_batch,
reingest_missing_chunks,
reingest_item,
reingest_empty_source_items,
)
@ -302,5 +304,295 @@ def test_reingest_missing_chunks(db_session, qdrant, batch_size):
assert set(qdrant_ids) == set(db_ids)
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
def test_reingest_item_success(db_session, qdrant, item_type):
"""Test successful reingestion of an item with existing chunks."""
if item_type == "MailMessage":
item = MailMessage(
sha256=b"test_hash" + bytes(24),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="STORED",
message_id="<test@example.com>",
subject="Test Subject",
sender="sender@example.com",
recipients=["recipient@example.com"],
content="Test content for reingestion",
folder="INBOX",
modality="mail",
)
else: # blog_post
item = BlogPost(
sha256=b"test_hash" + bytes(24),
tags=["test"],
size=100,
mime_type="text/html",
embed_status="STORED",
url="https://example.com/post",
title="Test Blog Post",
author="Author Name",
content="Test blog content for reingestion",
modality="blog",
)
db_session.add(item)
db_session.flush()
# Add some chunks to the item
chunk_ids = [str(uuid.uuid4()) for _ in range(3)]
chunks = [
Chunk(
id=chunk_id,
source=item,
content=f"Test chunk content {i}",
embedding_model="test-model",
)
for i, chunk_id in enumerate(chunk_ids)
]
db_session.add_all(chunks)
db_session.commit()
# Add vectors to Qdrant
modality = "mail" if item_type == "MailMessage" else "blog"
qd.ensure_collection_exists(qdrant, modality, 1024)
qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids))
# Verify chunks exist in Qdrant before reingestion
qdrant_ids_before = {
str(i) for batch in qd.batch_ids(qdrant, modality) for i in batch
}
assert set(chunk_ids).issubset(qdrant_ids_before)
# Mock the embedding function to return chunks
with patch("memory.common.embedding.embed_source_item") as mock_embed:
mock_embed.return_value = [
Chunk(
id=str(uuid.uuid4()),
content="New chunk content 1",
embedding_model="test-model",
vector=[0.1] * 1024,
item_metadata={"source_id": item.id, "tags": ["test"]},
),
Chunk(
id=str(uuid.uuid4()),
content="New chunk content 2",
embedding_model="test-model",
vector=[0.2] * 1024,
item_metadata={"source_id": item.id, "tags": ["test"]},
),
]
result = reingest_item(str(item.id), item_type)
assert result["status"] == "processed"
assert result[f"{item_type.lower()}_id"] == item.id
assert result["chunks_count"] == 2
assert result["embed_status"] == "STORED"
# Verify old chunks were deleted from database
db_session.refresh(item)
remaining_chunks = db_session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
assert len(remaining_chunks) == 0
# Verify old vectors were deleted from Qdrant
qdrant_ids_after = {
str(i) for batch in qd.batch_ids(qdrant, modality) for i in batch
}
assert not set(chunk_ids).intersection(qdrant_ids_after)
def test_reingest_item_not_found(db_session):
"""Test reingesting a non-existent item."""
non_existent_id = "999"
result = reingest_item(non_existent_id, "MailMessage")
assert result == {"status": "error", "error": f"Item {non_existent_id} not found"}
def test_reingest_item_invalid_type(db_session):
"""Test reingesting with an invalid item type."""
result = reingest_item("1", "invalid_type")
assert result["status"] == "error"
assert "Unsupported item type invalid_type" in result["error"]
assert "Available types:" in result["error"]
def test_reingest_item_no_chunks(db_session, qdrant):
"""Test reingesting an item that has no chunks."""
item = MailMessage(
sha256=b"test_hash" + bytes(24),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="RAW",
message_id="<test@example.com>",
subject="Test Subject",
sender="sender@example.com",
recipients=["recipient@example.com"],
content="Test content",
folder="INBOX",
modality="mail",
)
db_session.add(item)
db_session.commit()
# Mock the embedding function to return a chunk
with patch("memory.common.embedding.embed_source_item") as mock_embed:
mock_embed.return_value = [
Chunk(
id=str(uuid.uuid4()),
content="New chunk content",
embedding_model="test-model",
vector=[0.1] * 1024,
item_metadata={"source_id": item.id, "tags": ["test"]},
),
]
result = reingest_item(str(item.id), "MailMessage")
assert result["status"] == "processed"
assert result["mailmessage_id"] == item.id
assert result["chunks_count"] == 1
assert result["embed_status"] == "STORED"
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
def test_reingest_empty_source_items_success(db_session, item_type):
"""Test reingesting empty source items."""
# Create items with and without chunks
if item_type == "MailMessage":
empty_items = [
MailMessage(
sha256=f"empty_hash_{i}".encode() + bytes(32 - len(f"empty_hash_{i}")),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="RAW",
message_id=f"<empty{i}@example.com>",
subject=f"Empty Subject {i}",
sender="sender@example.com",
recipients=["recipient@example.com"],
content=f"Empty content {i}",
folder="INBOX",
modality="mail",
)
for i in range(3)
]
item_with_chunks = MailMessage(
sha256=b"with_chunks_hash" + bytes(16),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="STORED",
message_id="<with_chunks@example.com>",
subject="With Chunks Subject",
sender="sender@example.com",
recipients=["recipient@example.com"],
content="Content with chunks",
folder="INBOX",
modality="mail",
)
else: # blog_post
empty_items = [
BlogPost(
sha256=f"empty_hash_{i}".encode() + bytes(32 - len(f"empty_hash_{i}")),
tags=["test"],
size=100,
mime_type="text/html",
embed_status="RAW",
url=f"https://example.com/empty{i}",
title=f"Empty Post {i}",
author="Author Name",
content=f"Empty blog content {i}",
modality="blog",
)
for i in range(3)
]
item_with_chunks = BlogPost(
sha256=b"with_chunks_hash" + bytes(16),
tags=["test"],
size=100,
mime_type="text/html",
embed_status="STORED",
url="https://example.com/with_chunks",
title="With Chunks Post",
author="Author Name",
content="Blog content with chunks",
modality="blog",
)
db_session.add_all(empty_items + [item_with_chunks])
db_session.flush()
# Add a chunk to the item_with_chunks
chunk = Chunk(
id=str(uuid.uuid4()),
source=item_with_chunks,
content="Test chunk content",
embedding_model="test-model",
)
db_session.add(chunk)
db_session.commit()
with patch.object(reingest_item, "delay") as mock_reingest:
result = reingest_empty_source_items(item_type)
assert result == {"status": "success", "items": 3}
# Verify that reingest_item.delay was called for each empty item
assert mock_reingest.call_count == 3
expected_calls = [call(item.id, item_type) for item in empty_items]
mock_reingest.assert_has_calls(expected_calls, any_order=True)
def test_reingest_empty_source_items_no_empty_items(db_session):
"""Test when there are no empty source items."""
# Create an item with chunks
item = MailMessage(
sha256=b"with_chunks_hash" + bytes(16),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="STORED",
message_id="<with_chunks@example.com>",
subject="With Chunks Subject",
sender="sender@example.com",
recipients=["recipient@example.com"],
content="Content with chunks",
folder="INBOX",
modality="mail",
)
db_session.add(item)
db_session.flush()
chunk = Chunk(
id=str(uuid.uuid4()),
source=item,
content="Test chunk content",
embedding_model="test-model",
)
db_session.add(chunk)
db_session.commit()
with patch.object(reingest_item, "delay") as mock_reingest:
result = reingest_empty_source_items("MailMessage")
assert result == {"status": "success", "items": 0}
mock_reingest.assert_not_called()
def test_reingest_empty_source_items_invalid_type(db_session):
"""Test reingesting empty source items with invalid type."""
result = reingest_empty_source_items("invalid_type")
assert result["status"] == "error"
assert "Unsupported item type invalid_type" in result["error"]
assert "Available types:" in result["error"]
def test_reingest_missing_chunks_no_chunks(db_session):
assert reingest_missing_chunks() == {}