diff --git a/db/migrations/versions/20250528_012300_add_forum_posts.py b/db/migrations/versions/20250528_012300_add_forum_posts.py new file mode 100644 index 0000000..8eddfc3 --- /dev/null +++ b/db/migrations/versions/20250528_012300_add_forum_posts.py @@ -0,0 +1,52 @@ +"""Add forum posts + +Revision ID: 2524646f56f6 +Revises: 1b535e1b044e +Create Date: 2025-05-28 01:23:00.079366 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "2524646f56f6" +down_revision: Union[str, None] = "1b535e1b044e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "forum_post", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("url", sa.Text(), nullable=True), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("authors", sa.ARRAY(sa.Text()), nullable=True), + sa.Column("published_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("modified_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("slug", sa.Text(), nullable=True), + sa.Column("karma", sa.Integer(), nullable=True), + sa.Column("votes", sa.Integer(), nullable=True), + sa.Column("comments", sa.Integer(), nullable=True), + sa.Column("words", sa.Integer(), nullable=True), + sa.Column("score", sa.Integer(), nullable=True), + sa.Column("images", sa.ARRAY(sa.Text()), nullable=True), + sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("url"), + ) + op.create_index("forum_post_slug_idx", "forum_post", ["slug"], unique=False) + op.create_index("forum_post_title_idx", "forum_post", ["title"], unique=False) + op.create_index("forum_post_url_idx", "forum_post", ["url"], unique=False) + + +def downgrade() -> None: + op.drop_index("forum_post_url_idx", table_name="forum_post") + op.drop_index("forum_post_title_idx", table_name="forum_post") + op.drop_index("forum_post_slug_idx", table_name="forum_post") + op.drop_table("forum_post") diff --git a/docker-compose.yaml b/docker-compose.yaml index 63f9406..c5094ad 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -207,6 +207,12 @@ services: <<: *worker-env QUEUES: "blogs" + worker-forums: + <<: *worker-base + environment: + <<: *worker-env + QUEUES: "forums" + worker-photo: <<: *worker-base environment: diff --git a/docker/ingest_hub/Dockerfile b/docker/ingest_hub/Dockerfile index 2c88d17..699c8c6 100644 --- a/docker/ingest_hub/Dockerfile +++ b/docker/ingest_hub/Dockerfile @@ -33,7 +33,7 @@ RUN mkdir -p /var/log/supervisor /var/run/supervisor RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files USER kb -ENV QUEUES="docs,email,maintenance" +ENV QUEUES="maintenance" ENV PYTHONPATH="/app" ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"] \ No newline at end of file diff --git a/docker/workers/Dockerfile b/docker/workers/Dockerfile index b0f4e6d..9d87448 100644 --- a/docker/workers/Dockerfile +++ b/docker/workers/Dockerfile @@ -39,7 +39,7 @@ RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \ USER kb # Default queues to process -ENV QUEUES="ebooks,email,comic,blogs,photo_embed,maintenance" +ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance" ENV PYTHONPATH="/app" ENTRYPOINT ["./entry.sh"] \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 69a11ff..445b8ba 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,4 +3,5 @@ pytest-cov==4.1.0 black==23.12.1 mypy==1.8.0 isort==5.13.2 -testcontainers[qdrant]==4.10.0 \ No newline at end of file +testcontainers[qdrant]==4.10.0 +click==8.1.7 \ No newline at end of file diff --git a/src/memory/api/admin.py b/src/memory/api/admin.py index 5066323..fa4139e 100644 --- a/src/memory/api/admin.py +++ b/src/memory/api/admin.py @@ -17,6 +17,7 @@ from memory.common.db.models import ( MiscDoc, ArticleFeed, EmailAccount, + ForumPost, ) @@ -33,7 +34,11 @@ DEFAULT_COLUMNS = ( def source_columns(model: type[SourceItem], *columns: str): - return [getattr(model, c) for c in columns + DEFAULT_COLUMNS if hasattr(model, c)] + return [ + getattr(model, c) + for c in ("id",) + columns + DEFAULT_COLUMNS + if hasattr(model, c) + ] # Create admin views for all models @@ -86,6 +91,21 @@ class BlogPostAdmin(ModelView, model=BlogPost): column_searchable_list = ["title", "author", "domain"] +class ForumPostAdmin(ModelView, model=ForumPost): + column_list = source_columns( + ForumPost, + "title", + "authors", + "published_at", + "url", + "karma", + "votes", + "comments", + "score", + ) + column_searchable_list = ["title", "authors"] + + class PhotoAdmin(ModelView, model=Photo): column_list = source_columns(Photo, "exif_taken_at", "camera") @@ -166,5 +186,6 @@ def setup_admin(admin: Admin): admin.add_view(MiscDocAdmin) admin.add_view(ArticleFeedAdmin) admin.add_view(BlogPostAdmin) + admin.add_view(ForumPostAdmin) admin.add_view(ComicAdmin) admin.add_view(PhotoAdmin) diff --git a/src/memory/common/collections.py b/src/memory/common/collections.py index 8a2c074..ae3e78b 100644 --- a/src/memory/common/collections.py +++ b/src/memory/common/collections.py @@ -43,7 +43,12 @@ ALL_COLLECTIONS: dict[str, Collection] = { "blog": { "dimension": 1024, "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, + "model": settings.MIXED_EMBEDDING_MODEL, + }, + "forum": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, }, "text": { "dimension": 1024, diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index 9c7a0db..62596cd 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -105,6 +105,21 @@ def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalCh ] +def chunk_mixed( + content: str, image_paths: Sequence[str] +) -> list[list[extract.MulitmodalChunk]]: + images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths] + full_text: list[extract.MulitmodalChunk] = [content.strip(), *images] + + chunks = [] + tokens = chunker.approx_token_count(content) + if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2: + chunks = [add_pics(c, images) for c in chunker.chunk_text(content)] + + all_chunks = [full_text] + chunks + return [c for c in all_chunks if c and all(i for i in c)] + + class Chunk(Base): """Stores content chunks with their vector embeddings.""" @@ -134,6 +149,19 @@ class Chunk(Base): Index("chunk_source_idx", "source_id"), ) + @property + def chunks(self) -> list[extract.MulitmodalChunk]: + chunks: list[extract.MulitmodalChunk] = [] + if cast(str | None, self.content): + chunks = [cast(str, self.content)] + if self.images: + chunks += self.images + elif cast(Sequence[str] | None, self.file_paths): + chunks += [ + Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths + ] + return chunks + @property def data(self) -> list[bytes | str | Image.Image]: if self.file_paths is None: @@ -638,20 +666,57 @@ class BlogPost(SourceItem): return {k: v for k, v in payload.items() if v} def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]: - images = [ - Image.open(settings.FILE_STORAGE_DIR / image) for image in self.images - ] + return chunk_mixed(cast(str, self.content), cast(list[str], self.images)) - content = cast(str, self.content) - full_text = [content.strip(), *images] - chunks = [] - tokens = chunker.approx_token_count(content) - if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2: - chunks = [add_pics(c, images) for c in chunker.chunk_text(content)] +class ForumPost(SourceItem): + __tablename__ = "forum_post" - all_chunks = [full_text] + chunks - return [c for c in all_chunks if c and all(i for i in c)] + id = Column( + BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True + ) + url = Column(Text, unique=True) + title = Column(Text) + description = Column(Text, nullable=True) + authors = Column(ARRAY(Text), nullable=True) + published_at = Column(DateTime(timezone=True), nullable=True) + modified_at = Column(DateTime(timezone=True), nullable=True) + slug = Column(Text, nullable=True) + karma = Column(Integer, nullable=True) + votes = Column(Integer, nullable=True) + comments = Column(Integer, nullable=True) + words = Column(Integer, nullable=True) + score = Column(Integer, nullable=True) + images = Column(ARRAY(Text), nullable=True) + + __mapper_args__ = { + "polymorphic_identity": "forum_post", + } + + __table_args__ = ( + Index("forum_post_url_idx", "url"), + Index("forum_post_slug_idx", "slug"), + Index("forum_post_title_idx", "title"), + ) + + def as_payload(self) -> dict: + return { + "source_id": self.id, + "url": self.url, + "title": self.title, + "description": self.description, + "authors": self.authors, + "published_at": self.published_at, + "slug": self.slug, + "karma": self.karma, + "votes": self.votes, + "score": self.score, + "comments": self.comments, + "tags": self.tags, + } + + def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]: + return chunk_mixed(cast(str, self.content), cast(list[str], self.images)) class MiscDoc(SourceItem): @@ -710,7 +775,7 @@ class ArticleFeed(Base): description = Column(Text) tags = Column(ARRAY(Text), nullable=False, server_default="{}") check_interval = Column( - Integer, nullable=False, server_default="3600", doc="Seconds between checks" + Integer, nullable=False, server_default="60", doc="Minutes between checks" ) last_checked_at = Column(DateTime(timezone=True)) active = Column(Boolean, nullable=False, server_default="true") diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index 38d3741..caddbf8 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -20,7 +20,7 @@ def embed_chunks( model: str = settings.TEXT_EMBEDDING_MODEL, input_type: Literal["document", "query"] = "document", ) -> list[Vector]: - logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]}") + logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]} {len(chunks)}") vo = voyageai.Client() # type: ignore if model == settings.MIXED_EMBEDDING_MODEL: return vo.multimodal_embed( @@ -79,7 +79,7 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]: if not model_chunks: return [] - vectors = embed_chunks([chunk.content for chunk in model_chunks], model) # type: ignore + vectors = embed_chunks([chunk.chunks for chunk in model_chunks], model) for chunk, vector in zip(model_chunks, vectors): chunk.vector = vector return model_chunks diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 51f145f..05d6d3a 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -86,9 +86,11 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60")) # Worker settings # Intervals are in seconds -EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600)) -CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 86400)) -CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 3600)) +EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 60 * 60)) +COMIC_SYNC_INTERVAL = int(os.getenv("COMIC_SYNC_INTERVAL", 60 * 60)) +ARTICLE_FEED_SYNC_INTERVAL = int(os.getenv("ARTICLE_FEED_SYNC_INTERVAL", 30 * 60)) +CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 24 * 60 * 60)) +CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60)) CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24)) diff --git a/src/memory/parsers/lesswrong.py b/src/memory/parsers/lesswrong.py new file mode 100644 index 0000000..125c5c0 --- /dev/null +++ b/src/memory/parsers/lesswrong.py @@ -0,0 +1,304 @@ +from dataclasses import dataclass, field +import logging +import time +from datetime import datetime, timedelta +from typing import Any, Generator, TypedDict, NotRequired + +from bs4 import BeautifulSoup +from PIL import Image as PILImage +from memory.common import settings +import requests +from markdownify import markdownify + +from memory.parsers.html import parse_date, process_images + +logger = logging.getLogger(__name__) + + +class LessWrongPost(TypedDict): + """Represents a post from LessWrong.""" + + title: str + url: str + description: str + content: str + authors: list[str] + published_at: datetime | None + guid: str | None + karma: int + votes: int + comments: int + words: int + tags: list[str] + af: bool + score: int + extended_score: int + modified_at: NotRequired[str | None] + slug: NotRequired[str | None] + images: NotRequired[list[str]] + + +def make_graphql_query( + after: datetime, af: bool = False, limit: int = 50, min_karma: int = 10 +) -> str: + """Create GraphQL query for fetching posts.""" + return f""" + {{ + posts(input: {{ + terms: {{ + excludeEvents: true + view: "old" + af: {str(af).lower()} + limit: {limit} + karmaThreshold: {min_karma} + after: "{after.isoformat()}Z" + filter: "tagged" + }} + }}) {{ + totalCount + results {{ + _id + title + slug + pageUrl + postedAt + modifiedAt + score + extendedScore + baseScore + voteCount + commentCount + wordCount + tags {{ + name + }} + user {{ + displayName + }} + coauthors {{ + displayName + }} + af + htmlBody + }} + }} + }} + """ + + +def fetch_posts_from_api(url: str, query: str) -> dict[str, Any]: + """Fetch posts from LessWrong GraphQL API.""" + response = requests.post( + url, + headers={ + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0" + }, + json={"query": query}, + timeout=30, + ) + response.raise_for_status() + return response.json()["data"]["posts"] + + +def is_valid_post(post: dict[str, Any], min_karma: int = 10) -> bool: + """Check if post should be included.""" + # Must have content + if not post.get("htmlBody"): + return False + + # Must meet karma threshold + if post.get("baseScore", 0) < min_karma: + return False + + return True + + +def extract_authors(post: dict[str, Any]) -> list[str]: + """Extract authors from post data.""" + authors = post.get("coauthors", []) or [] + if post.get("user"): + authors = [post["user"]] + authors + return [a["displayName"] for a in authors] or ["anonymous"] + + +def extract_description(body: str) -> str: + """Extract description from post HTML content.""" + first_paragraph = body.split("\n\n")[0] + + # Truncate if too long + if len(first_paragraph) > 300: + first_paragraph = first_paragraph[:300] + "..." + + return first_paragraph + + +def parse_lesswrong_date(date_str: str | None) -> datetime | None: + """Parse ISO date string from LessWrong API to datetime.""" + if not date_str: + return None + + # Try multiple ISO formats that LessWrong might use + formats = [ + "%Y-%m-%dT%H:%M:%S.%fZ", # 2023-01-15T10:30:00.000Z + "%Y-%m-%dT%H:%M:%SZ", # 2023-01-15T10:30:00Z + "%Y-%m-%dT%H:%M:%S.%f", # 2023-01-15T10:30:00.000 + "%Y-%m-%dT%H:%M:%S", # 2023-01-15T10:30:00 + ] + + for fmt in formats: + if result := parse_date(date_str, fmt): + return result + + # Fallback: try removing 'Z' and using fromisoformat + try: + clean_date = date_str.rstrip("Z") + return datetime.fromisoformat(clean_date) + except (ValueError, TypeError): + logger.warning(f"Could not parse date: {date_str}") + return None + + +def extract_body(post: dict[str, Any]) -> tuple[str, dict[str, PILImage.Image]]: + """Extract body from post data.""" + if not (body := post.get("htmlBody", "").strip()): + return "", {} + + url = post.get("pageUrl", "") + image_dir = settings.FILE_STORAGE_DIR / "lesswrong" / url + + soup = BeautifulSoup(body, "html.parser") + soup, images = process_images(soup, url, image_dir) + body = markdownify(str(soup)).strip() + return body, images + + +def format_post(post: dict[str, Any]) -> LessWrongPost: + """Convert raw API post data to GreaterWrongPost.""" + body, images = extract_body(post) + + result: LessWrongPost = { + "title": post.get("title", "Untitled"), + "url": post.get("pageUrl", ""), + "description": extract_description(body), + "content": body, + "authors": extract_authors(post), + "published_at": parse_lesswrong_date(post.get("postedAt")), + "guid": post.get("_id"), + "karma": post.get("baseScore", 0), + "votes": post.get("voteCount", 0), + "comments": post.get("commentCount", 0), + "words": post.get("wordCount", 0), + "tags": [tag["name"] for tag in post.get("tags", [])], + "af": post.get("af", False), + "score": post.get("score", 0), + "extended_score": post.get("extendedScore", 0), + "modified_at": post.get("modifiedAt"), + "slug": post.get("slug"), + "images": list(images.keys()), + } + + return result + + +def fetch_lesswrong( + url: str, + current_date: datetime, + af: bool = False, + min_karma: int = 10, + limit: int = 50, + last_url: str | None = None, +) -> list[LessWrongPost]: + """ + Fetch a batch of posts from LessWrong. + + Returns: + (posts, next_date, last_item) where next_date is None if iteration should stop + """ + query = make_graphql_query(current_date, af, limit, min_karma) + api_response = fetch_posts_from_api(url, query) + + if not api_response["results"]: + return [] + + # If we only get the same item we started with, we're done + if ( + len(api_response["results"]) == 1 + and last_url + and api_response["results"][0]["pageUrl"] == last_url + ): + return [] + + return [ + format_post(post) + for post in api_response["results"] + if is_valid_post(post, min_karma) + ] + + +def fetch_lesswrong_posts( + since: datetime | None = None, + min_karma: int = 10, + limit: int = 50, + cooldown: float = 0.5, + max_items: int = 1000, + af: bool = False, + url: str = "https://www.lesswrong.com/graphql", +) -> Generator[LessWrongPost, None, None]: + """ + Fetch posts from LessWrong. + + Args: + url: GraphQL endpoint URL + af: Whether to fetch Alignment Forum posts + min_karma: Minimum karma threshold for posts + limit: Number of posts per API request + start_year: Default start year if no since date provided + since: Start date for fetching posts + cooldown: Delay between API requests in seconds + max_pages: Maximum number of pages to fetch + + Returns: + List of GreaterWrongPost objects + """ + if not since: + since = datetime.now() - timedelta(days=1) + + logger.info(f"Starting from {since}") + + last_url = None + next_date = since + items_count = 0 + + while next_date and items_count < max_items: + try: + page_posts = fetch_lesswrong(url, next_date, af, min_karma, limit, last_url) + except Exception as e: + logger.error(f"Error fetching posts: {e}") + break + + if not page_posts or next_date is None: + break + + for post in page_posts: + yield post + + last_item = page_posts[-1] + prev_date = next_date + next_date = last_item.get("published_at") + + if not next_date or prev_date == next_date: + logger.warning( + f"Could not advance through dataset, stopping at {next_date}" + ) + break + + # The articles are paged by date (inclusive) so passing the last date as + # is will return the same article again. + next_date += timedelta(seconds=1) + items_count += len(page_posts) + last_url = last_item["url"] + + if cooldown > 0: + time.sleep(cooldown) + + logger.info(f"Fetched {items_count} items") diff --git a/src/memory/workers/celery_app.py b/src/memory/workers/celery_app.py index 82f61b2..394da0e 100644 --- a/src/memory/workers/celery_app.py +++ b/src/memory/workers/celery_app.py @@ -1,6 +1,14 @@ from celery import Celery from memory.common import settings +EMAIL_ROOT = "memory.workers.tasks.email" +FORUMS_ROOT = "memory.workers.tasks.forums" +BLOGS_ROOT = "memory.workers.tasks.blogs" +PHOTO_ROOT = "memory.workers.tasks.photo" +COMIC_ROOT = "memory.workers.tasks.comic" +EBOOK_ROOT = "memory.workers.tasks.ebook" +MAINTENANCE_ROOT = "memory.workers.tasks.maintenance" + def rabbit_url() -> str: return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:5672//" @@ -21,12 +29,13 @@ app.conf.update( task_reject_on_worker_lost=True, worker_prefetch_multiplier=1, task_routes={ - "memory.workers.tasks.email.*": {"queue": "email"}, - "memory.workers.tasks.photo.*": {"queue": "photo_embed"}, - "memory.workers.tasks.comic.*": {"queue": "comic"}, - "memory.workers.tasks.ebook.*": {"queue": "ebooks"}, - "memory.workers.tasks.blogs.*": {"queue": "blogs"}, - "memory.workers.tasks.maintenance.*": {"queue": "maintenance"}, + f"{EMAIL_ROOT}.*": {"queue": "email"}, + f"{PHOTO_ROOT}.*": {"queue": "photo_embed"}, + f"{COMIC_ROOT}.*": {"queue": "comic"}, + f"{EBOOK_ROOT}.*": {"queue": "ebooks"}, + f"{BLOGS_ROOT}.*": {"queue": "blogs"}, + f"{FORUMS_ROOT}.*": {"queue": "forums"}, + f"{MAINTENANCE_ROOT}.*": {"queue": "maintenance"}, }, ) diff --git a/src/memory/workers/ingest.py b/src/memory/workers/ingest.py index 98dcd9f..47a4354 100644 --- a/src/memory/workers/ingest.py +++ b/src/memory/workers/ingest.py @@ -8,10 +8,6 @@ logger = logging.getLogger(__name__) app.conf.beat_schedule = { - "sync-mail-all": { - "task": "memory.workers.tasks.email.sync_all_accounts", - "schedule": settings.EMAIL_SYNC_INTERVAL, - }, "clean-all-collections": { "task": CLEAN_ALL_COLLECTIONS, "schedule": settings.CLEAN_COLLECTION_INTERVAL, @@ -20,4 +16,16 @@ app.conf.beat_schedule = { "task": REINGEST_MISSING_CHUNKS, "schedule": settings.CHUNK_REINGEST_INTERVAL, }, + "sync-mail-all": { + "task": "memory.workers.tasks.email.sync_all_accounts", + "schedule": settings.EMAIL_SYNC_INTERVAL, + }, + "sync-all-comics": { + "task": "memory.workers.tasks.comic.sync_all_comics", + "schedule": settings.COMIC_SYNC_INTERVAL, + }, + "sync-all-article-feeds": { + "task": "memory.workers.tasks.blogs.sync_all_article_feeds", + "schedule": settings.ARTICLE_FEED_SYNC_INTERVAL, + }, } diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py index d0036db..e87ffa9 100644 --- a/src/memory/workers/tasks/__init__.py +++ b/src/memory/workers/tasks/__init__.py @@ -2,7 +2,7 @@ Import sub-modules so Celery can register their @app.task decorators. """ -from memory.workers.tasks import email, comic, blogs, ebook # noqa +from memory.workers.tasks import email, comic, blogs, ebook, forums # noqa from memory.workers.tasks.blogs import ( SYNC_WEBPAGE, SYNC_ARTICLE_FEED, @@ -12,11 +12,13 @@ from memory.workers.tasks.blogs import ( from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_SMBC, SYNC_XKCD from memory.workers.tasks.ebook import SYNC_BOOK from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL +from memory.workers.tasks.forums import SYNC_LESSWRONG, SYNC_LESSWRONG_POST from memory.workers.tasks.maintenance import ( CLEAN_ALL_COLLECTIONS, CLEAN_COLLECTION, REINGEST_MISSING_CHUNKS, REINGEST_CHUNK, + REINGEST_ITEM, ) @@ -25,6 +27,7 @@ __all__ = [ "comic", "blogs", "ebook", + "forums", "SYNC_WEBPAGE", "SYNC_ARTICLE_FEED", "SYNC_ALL_ARTICLE_FEEDS", @@ -34,10 +37,13 @@ __all__ = [ "SYNC_XKCD", "SYNC_BOOK", "SYNC_ACCOUNT", + "SYNC_LESSWRONG", + "SYNC_LESSWRONG_POST", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL", "CLEAN_ALL_COLLECTIONS", "CLEAN_COLLECTION", "REINGEST_MISSING_CHUNKS", "REINGEST_CHUNK", + "REINGEST_ITEM", ] diff --git a/src/memory/workers/tasks/blogs.py b/src/memory/workers/tasks/blogs.py index 7b4f9cd..db484af 100644 --- a/src/memory/workers/tasks/blogs.py +++ b/src/memory/workers/tasks/blogs.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Iterable, cast from memory.common.db.connection import make_session @@ -7,7 +7,7 @@ from memory.common.db.models import ArticleFeed, BlogPost from memory.parsers.blogs import parse_webpage from memory.parsers.feeds import get_feed_parser from memory.parsers.archives import get_archive_fetcher -from memory.workers.celery_app import app +from memory.workers.celery_app import app, BLOGS_ROOT from memory.workers.tasks.content_processing import ( check_content_exists, create_content_hash, @@ -18,10 +18,10 @@ from memory.workers.tasks.content_processing import ( logger = logging.getLogger(__name__) -SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage" -SYNC_ARTICLE_FEED = "memory.workers.tasks.blogs.sync_article_feed" -SYNC_ALL_ARTICLE_FEEDS = "memory.workers.tasks.blogs.sync_all_article_feeds" -SYNC_WEBSITE_ARCHIVE = "memory.workers.tasks.blogs.sync_website_archive" +SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage" +SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed" +SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds" +SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive" @app.task(name=SYNC_WEBPAGE) @@ -71,7 +71,7 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict: logger.info(f"Blog post already exists: {existing_post.title}") return create_task_result(existing_post, "already_exists", url=article.url) - return process_content_item(blog_post, "blog", session, tags) + return process_content_item(blog_post, session) @app.task(name=SYNC_ARTICLE_FEED) @@ -93,8 +93,8 @@ def sync_article_feed(feed_id: int) -> dict: return {"status": "error", "error": "Feed not found or inactive"} last_checked_at = cast(datetime | None, feed.last_checked_at) - if last_checked_at and datetime.now() - last_checked_at < timedelta( - seconds=cast(int, feed.check_interval) + if last_checked_at and datetime.now(timezone.utc) - last_checked_at < timedelta( + minutes=cast(int, feed.check_interval) ): logger.info(f"Feed {feed_id} checked too recently, skipping") return {"status": "skipped_recent_check", "feed_id": feed_id} @@ -129,7 +129,7 @@ def sync_article_feed(feed_id: int) -> dict: logger.error(f"Error parsing feed {feed.url}: {e}") errors += 1 - feed.last_checked_at = datetime.now() # type: ignore + feed.last_checked_at = datetime.now(timezone.utc) # type: ignore session.commit() return { diff --git a/src/memory/workers/tasks/comic.py b/src/memory/workers/tasks/comic.py index 580b868..4888d88 100644 --- a/src/memory/workers/tasks/comic.py +++ b/src/memory/workers/tasks/comic.py @@ -9,7 +9,7 @@ from memory.common import settings from memory.common.db.connection import make_session from memory.common.db.models import Comic, clean_filename from memory.parsers import comics -from memory.workers.celery_app import app +from memory.workers.celery_app import app, COMIC_ROOT from memory.workers.tasks.content_processing import ( check_content_exists, create_content_hash, @@ -19,10 +19,10 @@ from memory.workers.tasks.content_processing import ( logger = logging.getLogger(__name__) -SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics" -SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc" -SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd" -SYNC_COMIC = "memory.workers.tasks.comic.sync_comic" +SYNC_ALL_COMICS = f"{COMIC_ROOT}.sync_all_comics" +SYNC_SMBC = f"{COMIC_ROOT}.sync_smbc" +SYNC_XKCD = f"{COMIC_ROOT}.sync_xkcd" +SYNC_COMIC = f"{COMIC_ROOT}.sync_comic" BASE_SMBC_URL = "https://www.smbc-comics.com/" SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss" @@ -109,7 +109,7 @@ def sync_comic( ) with make_session() as session: - return process_content_item(comic, "comic", session) + return process_content_item(comic, session) @app.task(name=SYNC_SMBC) diff --git a/src/memory/workers/tasks/content_processing.py b/src/memory/workers/tasks/content_processing.py index 8fd5ea4..51dd4cc 100644 --- a/src/memory/workers/tasks/content_processing.py +++ b/src/memory/workers/tasks/content_processing.py @@ -99,6 +99,7 @@ def embed_source_item(source_item: SourceItem) -> int: except Exception as e: source_item.embed_status = "FAILED" # type: ignore logger.error(f"Failed to embed {type(source_item).__name__}: {e}") + logger.error(traceback.format_exc()) return 0 @@ -156,6 +157,7 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str): for item in items_to_process: item.embed_status = "FAILED" # type: ignore logger.error(f"Failed to push embeddings to Qdrant: {e}") + logger.error(traceback.format_exc()) raise @@ -186,9 +188,7 @@ def create_task_result( } -def process_content_item( - item: SourceItem, collection_name: str, session, tags: Iterable[str] = [] -) -> dict[str, Any]: +def process_content_item(item: SourceItem, session) -> dict[str, Any]: """ Execute complete content processing workflow. @@ -200,7 +200,6 @@ def process_content_item( Args: item: SourceItem to process - collection_name: Qdrant collection name for vector storage session: Database session for persistence tags: Optional tags to associate with the item (currently unused) @@ -223,7 +222,7 @@ def process_content_item( return create_task_result(item, status, content_length=getattr(item, "size", 0)) try: - push_to_qdrant([item], collection_name) + push_to_qdrant([item], cast(str, item.modality)) status = "processed" item.embed_status = "STORED" # type: ignore logger.info( @@ -231,6 +230,7 @@ def process_content_item( ) except Exception as e: logger.error(f"Failed to push embeddings to Qdrant: {e}") + logger.error(traceback.format_exc()) item.embed_status = "FAILED" # type: ignore session.commit() @@ -261,10 +261,8 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]: try: return func(*args, **kwargs) except Exception as e: - logger.error( - f"Task {func.__name__} failed with traceback:\n{traceback.format_exc()}" - ) logger.error(f"Task {func.__name__} failed: {e}") + logger.error(traceback.format_exc()) return {"status": "error", "error": str(e)} return wrapper diff --git a/src/memory/workers/tasks/ebook.py b/src/memory/workers/tasks/ebook.py index 1835fbb..1eedc88 100644 --- a/src/memory/workers/tasks/ebook.py +++ b/src/memory/workers/tasks/ebook.py @@ -6,7 +6,7 @@ import memory.common.settings as settings from memory.parsers.ebook import Ebook, parse_ebook, Section from memory.common.db.models import Book, BookSection from memory.common.db.connection import make_session -from memory.workers.celery_app import app +from memory.workers.celery_app import app, EBOOK_ROOT from memory.workers.tasks.content_processing import ( check_content_exists, create_content_hash, @@ -17,7 +17,7 @@ from memory.workers.tasks.content_processing import ( logger = logging.getLogger(__name__) -SYNC_BOOK = "memory.workers.tasks.ebook.sync_book" +SYNC_BOOK = f"{EBOOK_ROOT}.sync_book" # Minimum section length to embed (avoid noise from very short sections) MIN_SECTION_LENGTH = 100 diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py index 531e2c7..78e5b7b 100644 --- a/src/memory/workers/tasks/email.py +++ b/src/memory/workers/tasks/email.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import cast from memory.common.db.connection import make_session from memory.common.db.models import EmailAccount, MailMessage -from memory.workers.celery_app import app +from memory.workers.celery_app import app, EMAIL_ROOT from memory.workers.email import ( create_mail_message, imap_connection, @@ -18,9 +18,9 @@ from memory.workers.tasks.content_processing import ( logger = logging.getLogger(__name__) -PROCESS_EMAIL = "memory.workers.tasks.email.process_message" -SYNC_ACCOUNT = "memory.workers.tasks.email.sync_account" -SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts" +PROCESS_EMAIL = f"{EMAIL_ROOT}.process_message" +SYNC_ACCOUNT = f"{EMAIL_ROOT}.sync_account" +SYNC_ALL_ACCOUNTS = f"{EMAIL_ROOT}.sync_all_accounts" @app.task(name=PROCESS_EMAIL) diff --git a/src/memory/workers/tasks/forums.py b/src/memory/workers/tasks/forums.py new file mode 100644 index 0000000..f937765 --- /dev/null +++ b/src/memory/workers/tasks/forums.py @@ -0,0 +1,85 @@ +from datetime import datetime, timedelta +import logging + +from memory.parsers.lesswrong import fetch_lesswrong_posts, LessWrongPost +from memory.common.db.connection import make_session +from memory.common.db.models import ForumPost +from memory.workers.celery_app import app, FORUMS_ROOT +from memory.workers.tasks.content_processing import ( + check_content_exists, + create_content_hash, + create_task_result, + process_content_item, + safe_task_execution, +) + +logger = logging.getLogger(__name__) + +SYNC_LESSWRONG = f"{FORUMS_ROOT}.sync_lesswrong" +SYNC_LESSWRONG_POST = f"{FORUMS_ROOT}.sync_lesswrong_post" + + +@app.task(name=SYNC_LESSWRONG_POST) +@safe_task_execution +def sync_lesswrong_post( + post: LessWrongPost, + tags: list[str] = [], +): + logger.info(f"Syncing LessWrong post {post['url']}") + sha256 = create_content_hash(post["content"]) + + post["tags"] = list(set(post["tags"] + tags)) + post_obj = ForumPost( + embed_status="RAW", + size=len(post["content"].encode("utf-8")), + modality="forum", + mime_type="text/markdown", + sha256=sha256, + **{k: v for k, v in post.items() if hasattr(ForumPost, k)}, + ) + + with make_session() as session: + existing_post = check_content_exists( + session, ForumPost, url=post_obj.url, sha256=sha256 + ) + if existing_post: + logger.info(f"LessWrong post already exists: {existing_post.title}") + return create_task_result(existing_post, "already_exists", url=post_obj.url) + + return process_content_item(post_obj, session) + + +@app.task(name=SYNC_LESSWRONG) +@safe_task_execution +def sync_lesswrong( + since: str = (datetime.now() - timedelta(days=30)).isoformat(), + min_karma: int = 10, + limit: int = 50, + cooldown: float = 0.5, + max_items: int = 1000, + af: bool = False, + tags: list[str] = [], +): + logger.info(f"Syncing LessWrong posts since {since}") + start_date = datetime.fromisoformat(since) + posts = fetch_lesswrong_posts(start_date, min_karma, limit, cooldown, max_items, af) + + posts_num, new_posts = 0, 0 + with make_session() as session: + for post in posts: + if not check_content_exists(session, ForumPost, url=post["url"]): + new_posts += 1 + sync_lesswrong_post.delay(post, tags) + + if posts_num >= max_items: + break + posts_num += 1 + + return { + "posts_num": posts_num, + "new_posts": new_posts, + "since": since, + "min_karma": min_karma, + "max_items": max_items, + "af": af, + } diff --git a/src/memory/workers/tasks/maintenance.py b/src/memory/workers/tasks/maintenance.py index b309ead..88bd057 100644 --- a/src/memory/workers/tasks/maintenance.py +++ b/src/memory/workers/tasks/maintenance.py @@ -3,21 +3,24 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Sequence +from memory.workers.tasks.content_processing import process_content_item from sqlalchemy import select from sqlalchemy.orm import contains_eager from memory.common import collections, embedding, qdrant, settings from memory.common.db.connection import make_session from memory.common.db.models import Chunk, SourceItem -from memory.workers.celery_app import app +from memory.workers.celery_app import app, MAINTENANCE_ROOT logger = logging.getLogger(__name__) -CLEAN_ALL_COLLECTIONS = "memory.workers.tasks.maintenance.clean_all_collections" -CLEAN_COLLECTION = "memory.workers.tasks.maintenance.clean_collection" -REINGEST_MISSING_CHUNKS = "memory.workers.tasks.maintenance.reingest_missing_chunks" -REINGEST_CHUNK = "memory.workers.tasks.maintenance.reingest_chunk" +CLEAN_ALL_COLLECTIONS = f"{MAINTENANCE_ROOT}.clean_all_collections" +CLEAN_COLLECTION = f"{MAINTENANCE_ROOT}.clean_collection" +REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks" +REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk" +REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item" +REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items" @app.task(name=CLEAN_COLLECTION) @@ -87,6 +90,61 @@ def reingest_chunk(chunk_id: str, collection: str): session.commit() +def get_item_class(item_type: str): + class_ = SourceItem.registry._class_registry.get(item_type) + if not class_: + available_types = ", ".join(sorted(SourceItem.registry._class_registry.keys())) + raise ValueError( + f"Unsupported item type {item_type}. Available types: {available_types}" + ) + return class_ + + +@app.task(name=REINGEST_ITEM) +def reingest_item(item_id: str, item_type: str): + logger.info(f"Reingesting {item_type} {item_id}") + try: + class_ = get_item_class(item_type) + except ValueError as e: + logger.error(f"Error getting item class: {e}") + return {"status": "error", "error": str(e)} + + with make_session() as session: + item = session.query(class_).get(item_id) + if not item: + return {"status": "error", "error": f"Item {item_id} not found"} + + chunk_ids = [str(c.id) for c in item.chunks if c.id] + if chunk_ids: + client = qdrant.get_qdrant_client() + qdrant.delete_points(client, item.modality, chunk_ids) + + for chunk in item.chunks: + session.delete(chunk) + + return process_content_item(item, session) + + +@app.task(name=REINGEST_EMPTY_SOURCE_ITEMS) +def reingest_empty_source_items(item_type: str): + logger.info("Reingesting empty source items") + try: + class_ = get_item_class(item_type) + except ValueError as e: + logger.error(f"Error getting item class: {e}") + return {"status": "error", "error": str(e)} + + with make_session() as session: + item_ids = session.query(class_.id).filter(~class_.chunks.any()).all() + + logger.info(f"Found {len(item_ids)} items to reingest") + + for item_id in item_ids: + reingest_item.delay(item_id.id, item_type) # type: ignore + + return {"status": "success", "items": len(item_ids)} + + def check_batch(batch: Sequence[Chunk]) -> dict: client = qdrant.get_qdrant_client() by_collection = defaultdict(list) @@ -116,14 +174,19 @@ def check_batch(batch: Sequence[Chunk]) -> dict: @app.task(name=REINGEST_MISSING_CHUNKS) def reingest_missing_chunks( - batch_size: int = 1000, minutes_ago: int = settings.CHUNK_REINGEST_SINCE_MINUTES + batch_size: int = 1000, + collection: str | None = None, + minutes_ago: int = settings.CHUNK_REINGEST_SINCE_MINUTES, ): logger.info("Reingesting missing chunks") total_stats = defaultdict(lambda: {"missing": 0, "correct": 0, "total": 0}) since = datetime.now() - timedelta(minutes=minutes_ago) with make_session() as session: - total_count = session.query(Chunk).filter(Chunk.checked_at < since).count() + query = session.query(Chunk).filter(Chunk.checked_at < since) + if collection: + query = query.filter(Chunk.source.has(SourceItem.modality == collection)) + total_count = query.count() logger.info( f"Found {total_count} chunks to check, processing in batches of {batch_size}" diff --git a/tests/memory/parsers/test_lesswrong.py b/tests/memory/parsers/test_lesswrong.py new file mode 100644 index 0000000..a4e60f6 --- /dev/null +++ b/tests/memory/parsers/test_lesswrong.py @@ -0,0 +1,542 @@ +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch, Mock +import json +import pytest +from PIL import Image as PILImage + +from memory.parsers.lesswrong import ( + LessWrongPost, + make_graphql_query, + fetch_posts_from_api, + is_valid_post, + extract_authors, + extract_description, + parse_lesswrong_date, + extract_body, + format_post, + fetch_lesswrong, + fetch_lesswrong_posts, +) + + +@pytest.mark.parametrize( + "after, af, limit, min_karma, expected_contains", + [ + ( + datetime(2023, 1, 15, 10, 30), + False, + 50, + 10, + [ + "af: false", + "limit: 50", + "karmaThreshold: 10", + 'after: "2023-01-15T10:30:00Z"', + ], + ), + ( + datetime(2023, 2, 20), + True, + 25, + 5, + [ + "af: true", + "limit: 25", + "karmaThreshold: 5", + 'after: "2023-02-20T00:00:00Z"', + ], + ), + ], +) +def test_make_graphql_query(after, af, limit, min_karma, expected_contains): + query = make_graphql_query(after, af, limit, min_karma) + + for expected in expected_contains: + assert expected in query + + # Check that all required fields are present + required_fields = [ + "_id", + "title", + "slug", + "pageUrl", + "postedAt", + "modifiedAt", + "score", + "extendedScore", + "baseScore", + "voteCount", + "commentCount", + "wordCount", + "tags", + "user", + "coauthors", + "af", + "htmlBody", + ] + for field in required_fields: + assert field in query + + +@patch("memory.parsers.lesswrong.requests.post") +def test_fetch_posts_from_api_success(mock_post): + mock_response = Mock() + mock_response.json.return_value = { + "data": { + "posts": { + "totalCount": 2, + "results": [ + {"_id": "1", "title": "Post 1"}, + {"_id": "2", "title": "Post 2"}, + ], + } + } + } + mock_post.return_value = mock_response + + url = "https://www.lesswrong.com/graphql" + query = "test query" + + result = fetch_posts_from_api(url, query) + + mock_post.assert_called_once_with( + url, + headers={ + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0" + }, + json={"query": query}, + timeout=30, + ) + + assert result == { + "totalCount": 2, + "results": [ + {"_id": "1", "title": "Post 1"}, + {"_id": "2", "title": "Post 2"}, + ], + } + + +@patch("memory.parsers.lesswrong.requests.post") +def test_fetch_posts_from_api_http_error(mock_post): + mock_response = Mock() + mock_response.raise_for_status.side_effect = Exception("HTTP Error") + mock_post.return_value = mock_response + + with pytest.raises(Exception, match="HTTP Error"): + fetch_posts_from_api("https://example.com", "query") + + +@pytest.mark.parametrize( + "post_data, min_karma, expected", + [ + # Valid post with content and karma + ({"htmlBody": "
Content
", "baseScore": 15}, 10, True), + # Valid post at karma threshold + ({"htmlBody": "Content
", "baseScore": 10}, 10, True), + # Invalid: no content + ({"htmlBody": "", "baseScore": 15}, 10, False), + ({"htmlBody": None, "baseScore": 15}, 10, False), + ({}, 10, False), + # Invalid: below karma threshold + ({"htmlBody": "Content
", "baseScore": 5}, 10, False), + ({"htmlBody": "Content
"}, 10, False), # No baseScore + # Edge cases + ( + {"htmlBody": " ", "baseScore": 15}, + 10, + True, + ), # Whitespace only - actually valid + ({"htmlBody": "Content
", "baseScore": 0}, 0, True), # Zero threshold + ], +) +def test_is_valid_post(post_data, min_karma, expected): + assert is_valid_post(post_data, min_karma) == expected + + +@pytest.mark.parametrize( + "post_data, expected", + [ + # User only + ( + {"user": {"displayName": "Alice"}}, + ["Alice"], + ), + # User with coauthors + ( + { + "user": {"displayName": "Alice"}, + "coauthors": [{"displayName": "Bob"}, {"displayName": "Charlie"}], + }, + ["Alice", "Bob", "Charlie"], + ), + # Coauthors only (no user) + ( + {"coauthors": [{"displayName": "Bob"}]}, + ["Bob"], + ), + # Empty coauthors list + ( + {"user": {"displayName": "Alice"}, "coauthors": []}, + ["Alice"], + ), + # No authors at all + ({}, ["anonymous"]), + ({"user": None, "coauthors": None}, ["anonymous"]), + ({"user": None, "coauthors": []}, ["anonymous"]), + ], +) +def test_extract_authors(post_data, expected): + assert extract_authors(post_data) == expected + + +@pytest.mark.parametrize( + "body, expected", + [ + # Short content + ("This is a short paragraph.", "This is a short paragraph."), + # Multiple paragraphs - only first + ("First paragraph.\n\nSecond paragraph.", "First paragraph."), + # Long content - truncated + ( + "A" * 350, + "A" * 300 + "...", + ), + # Empty content + ("", ""), + # Whitespace only + (" \n\n ", " "), + ], +) +def test_extract_description(body, expected): + assert extract_description(body) == expected + + +@pytest.mark.parametrize( + "date_str, expected", + [ + # Standard ISO formats + ("2023-01-15T10:30:00.000Z", datetime(2023, 1, 15, 10, 30, 0)), + ("2023-01-15T10:30:00Z", datetime(2023, 1, 15, 10, 30, 0)), + ("2023-01-15T10:30:00.000", datetime(2023, 1, 15, 10, 30, 0)), + ("2023-01-15T10:30:00", datetime(2023, 1, 15, 10, 30, 0)), + # Fallback to fromisoformat + ("2023-01-15T10:30:00.123456", datetime(2023, 1, 15, 10, 30, 0, 123456)), + # Invalid dates + ("invalid-date", None), + ("", None), + (None, None), + ("2023-13-45T25:70:70Z", None), # Invalid date components + ], +) +def test_parse_lesswrong_date(date_str, expected): + assert parse_lesswrong_date(date_str) == expected + + +@patch("memory.parsers.lesswrong.process_images") +@patch("memory.parsers.lesswrong.markdownify") +@patch("memory.parsers.lesswrong.settings") +def test_extract_body(mock_settings, mock_markdownify, mock_process_images): + from pathlib import Path + + mock_settings.FILE_STORAGE_DIR = Path("/tmp") + mock_markdownify.return_value = "# Markdown content" + mock_images = {"image1.jpg": Mock(spec=PILImage.Image)} + mock_process_images.return_value = (Mock(), mock_images) + + post = { + "htmlBody": "HTML content
", + } + + result = format_post(post_data) + + expected: LessWrongPost = { + "title": "Test Post", + "url": "https://lesswrong.com/posts/abc123/test-post", + "description": "Markdown body", + "content": "Markdown body", + "authors": ["Author", "Coauthor"], + "published_at": datetime(2023, 1, 15, 10, 30, 0), + "guid": "abc123", + "karma": 20, + "votes": 15, + "comments": 5, + "words": 1000, + "tags": ["AI", "Rationality"], + "af": True, + "score": 25, + "extended_score": 30, + "modified_at": "2023-01-16T11:00:00Z", + "slug": "test-post", + "images": ["img1.jpg"], + } + + assert result == expected + + +@patch("memory.parsers.lesswrong.extract_body") +def test_format_post_minimal_data(mock_extract_body): + mock_extract_body.return_value = ("", {}) + + post_data = {} + + result = format_post(post_data) + + expected: LessWrongPost = { + "title": "Untitled", + "url": "", + "description": "", + "content": "", + "authors": ["anonymous"], + "published_at": None, + "guid": None, + "karma": 0, + "votes": 0, + "comments": 0, + "words": 0, + "tags": [], + "af": False, + "score": 0, + "extended_score": 0, + "modified_at": None, + "slug": None, + "images": [], + } + + assert result == expected + + +@patch("memory.parsers.lesswrong.fetch_posts_from_api") +@patch("memory.parsers.lesswrong.make_graphql_query") +@patch("memory.parsers.lesswrong.format_post") +@patch("memory.parsers.lesswrong.is_valid_post") +def test_fetch_lesswrong_success( + mock_is_valid, mock_format, mock_query, mock_fetch_api +): + mock_query.return_value = "test query" + mock_fetch_api.return_value = { + "results": [ + {"_id": "1", "title": "Post 1"}, + {"_id": "2", "title": "Post 2"}, + ] + } + mock_is_valid.side_effect = [True, False] # First valid, second invalid + mock_format.return_value = {"title": "Formatted Post"} + + url = "https://lesswrong.com/graphql" + current_date = datetime(2023, 1, 15) + + result = fetch_lesswrong(url, current_date, af=True, min_karma=5, limit=25) + + mock_query.assert_called_once_with(current_date, True, 25, 5) + mock_fetch_api.assert_called_once_with(url, "test query") + assert mock_is_valid.call_count == 2 + mock_format.assert_called_once_with({"_id": "1", "title": "Post 1"}) + assert result == [{"title": "Formatted Post"}] + + +@patch("memory.parsers.lesswrong.fetch_posts_from_api") +def test_fetch_lesswrong_empty_results(mock_fetch_api): + mock_fetch_api.return_value = {"results": []} + + result = fetch_lesswrong("url", datetime.now()) + assert result == [] + + +@patch("memory.parsers.lesswrong.fetch_posts_from_api") +def test_fetch_lesswrong_same_item_as_last(mock_fetch_api): + mock_fetch_api.return_value = { + "results": [{"pageUrl": "https://lesswrong.com/posts/same"}] + } + + result = fetch_lesswrong( + "url", datetime.now(), last_url="https://lesswrong.com/posts/same" + ) + assert result == [] + + +@patch("memory.parsers.lesswrong.fetch_lesswrong") +@patch("memory.parsers.lesswrong.time.sleep") +def test_fetch_lesswrong_posts_success(mock_sleep, mock_fetch): + since = datetime(2023, 1, 15) + + # Mock three batches of posts + mock_fetch.side_effect = [ + [ + {"published_at": datetime(2023, 1, 14), "url": "post1"}, + {"published_at": datetime(2023, 1, 13), "url": "post2"}, + ], + [ + {"published_at": datetime(2023, 1, 12), "url": "post3"}, + ], + [], # Empty result to stop iteration + ] + + posts = list( + fetch_lesswrong_posts( + since=since, + min_karma=10, + limit=50, + cooldown=0.1, + max_items=100, + af=False, + url="https://lesswrong.com/graphql", + ) + ) + + assert len(posts) == 3 + assert posts[0]["url"] == "post1" + assert posts[1]["url"] == "post2" + assert posts[2]["url"] == "post3" + + # Should have called sleep twice (after first two batches) + assert mock_sleep.call_count == 2 + mock_sleep.assert_called_with(0.1) + + +@patch("memory.parsers.lesswrong.fetch_lesswrong") +def test_fetch_lesswrong_posts_default_since(mock_fetch): + mock_fetch.return_value = [] + + with patch("memory.parsers.lesswrong.datetime") as mock_datetime: + mock_now = datetime(2023, 1, 15, 12, 0, 0) + mock_datetime.now.return_value = mock_now + + list(fetch_lesswrong_posts()) + + # Should use yesterday as default + expected_since = mock_now - timedelta(days=1) + mock_fetch.assert_called_with( + "https://www.lesswrong.com/graphql", expected_since, False, 10, 50, None + ) + + +@patch("memory.parsers.lesswrong.fetch_lesswrong") +def test_fetch_lesswrong_posts_max_items_limit(mock_fetch): + # Return posts that would exceed max_items + mock_fetch.side_effect = [ + [{"published_at": datetime(2023, 1, 14), "url": f"post{i}"} for i in range(8)], + [ + {"published_at": datetime(2023, 1, 13), "url": f"post{i}"} + for i in range(8, 16) + ], + ] + + posts = list( + fetch_lesswrong_posts( + since=datetime(2023, 1, 15), + max_items=7, # Should stop after 7 items + cooldown=0, + ) + ) + + # The logic checks items_count < max_items before fetching, so it will fetch the first batch + # Since items_count (8) >= max_items (7) after first batch, it won't fetch the second batch + assert len(posts) == 8 + + +@patch("memory.parsers.lesswrong.fetch_lesswrong") +def test_fetch_lesswrong_posts_api_error(mock_fetch): + mock_fetch.side_effect = Exception("API Error") + + posts = list(fetch_lesswrong_posts(since=datetime(2023, 1, 15))) + assert posts == [] + + +@patch("memory.parsers.lesswrong.fetch_lesswrong") +def test_fetch_lesswrong_posts_no_date_progression(mock_fetch): + # Mock posts with same date to trigger stopping condition + same_date = datetime(2023, 1, 15) + mock_fetch.side_effect = [ + [{"published_at": same_date, "url": "post1"}], + [{"published_at": same_date, "url": "post2"}], # Same date, should stop + ] + + posts = list(fetch_lesswrong_posts(since=same_date, cooldown=0)) + + assert len(posts) == 1 # Should stop after first batch + + +@patch("memory.parsers.lesswrong.fetch_lesswrong") +def test_fetch_lesswrong_posts_none_date(mock_fetch): + # Mock posts with None date to trigger stopping condition + mock_fetch.side_effect = [ + [{"published_at": None, "url": "post1"}], + ] + + posts = list(fetch_lesswrong_posts(since=datetime(2023, 1, 15), cooldown=0)) + + assert len(posts) == 1 # Should stop after first batch + + +def test_lesswrong_post_type(): + """Test that LessWrongPost TypedDict has correct structure.""" + # This is more of a documentation test to ensure the type is correct + post: LessWrongPost = { + "title": "Test", + "url": "https://example.com", + "description": "Description", + "content": "Content", + "authors": ["Author"], + "published_at": datetime.now(), + "guid": "123", + "karma": 10, + "votes": 5, + "comments": 2, + "words": 100, + "tags": ["tag"], + "af": False, + "score": 15, + "extended_score": 20, + } + + # Optional fields + post["modified_at"] = "2023-01-15T10:30:00Z" + post["slug"] = "test-slug" + post["images"] = ["image.jpg"] + + # Should not raise any type errors + assert post["title"] == "Test" diff --git a/tests/memory/workers/tasks/test_blogs_tasks.py b/tests/memory/workers/tasks/test_blogs_tasks.py index 14f1e78..63ed90d 100644 --- a/tests/memory/workers/tasks/test_blogs_tasks.py +++ b/tests/memory/workers/tasks/test_blogs_tasks.py @@ -1,5 +1,5 @@ import pytest -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch from memory.common.db.models import ArticleFeed, BlogPost @@ -64,37 +64,6 @@ def inactive_article_feed(db_session): return feed -@pytest.fixture -def recently_checked_feed(db_session): - """Create a recently checked ArticleFeed.""" - from sqlalchemy import text - - # Use a very recent timestamp that will trigger the "recently checked" condition - # The check_interval is 3600 seconds, so 30 seconds ago should be "recent" - recent_time = datetime.now() - timedelta(seconds=30) - - feed = ArticleFeed( - url="https://example.com/recent.xml", - title="Recent Feed", - description="A recently checked feed", - tags=["test"], - check_interval=3600, - active=True, - ) - db_session.add(feed) - db_session.flush() # Get the ID - - # Manually set the last_checked_at to avoid timezone issues - db_session.execute( - text( - "UPDATE article_feeds SET last_checked_at = :timestamp WHERE id = :feed_id" - ), - {"timestamp": recent_time, "feed_id": feed.id}, - ) - db_session.commit() - return feed - - @pytest.fixture def mock_feed_item(): """Mock feed item for testing.""" @@ -564,3 +533,65 @@ def test_sync_website_archive_empty_results( assert result["articles_found"] == 0 assert result["new_articles"] == 0 assert result["task_ids"] == [] + + +@pytest.mark.parametrize( + "check_interval_minutes,seconds_since_check,should_skip", + [ + (60, 30, True), # 60min interval, checked 30s ago -> skip + (60, 3000, True), # 60min interval, checked 50min ago -> skip + (60, 4000, False), # 60min interval, checked 66min ago -> don't skip + (30, 1000, True), # 30min interval, checked 16min ago -> skip + (30, 2000, False), # 30min interval, checked 33min ago -> don't skip + ], +) +@patch("memory.workers.tasks.blogs.get_feed_parser") +def test_sync_article_feed_check_interval( + mock_get_parser, + check_interval_minutes, + seconds_since_check, + should_skip, + db_session, +): + """Test sync respects check interval with various timing scenarios.""" + from sqlalchemy import text + + # Mock parser to return None (no parser available) for non-skipped cases + mock_get_parser.return_value = None + + # Create feed with specific check interval + feed = ArticleFeed( + url="https://example.com/interval-test.xml", + title="Interval Test Feed", + description="Feed for testing check intervals", + tags=["test"], + check_interval=check_interval_minutes, + active=True, + ) + db_session.add(feed) + db_session.flush() + + # Set last_checked_at to specific time in the past + last_checked_time = datetime.now(timezone.utc) - timedelta( + seconds=seconds_since_check + ) + db_session.execute( + text( + "UPDATE article_feeds SET last_checked_at = :timestamp WHERE id = :feed_id" + ), + {"timestamp": last_checked_time, "feed_id": feed.id}, + ) + db_session.commit() + + result = blogs.sync_article_feed(feed.id) + + if should_skip: + assert result == {"status": "skipped_recent_check", "feed_id": feed.id} + # get_feed_parser should not be called when skipping + mock_get_parser.assert_not_called() + else: + # Should proceed with sync, but will fail due to no parser - that's expected + assert result["status"] == "error" + assert result["error"] == "No parser available for feed" + # get_feed_parser should be called when not skipping + mock_get_parser.assert_called_once() diff --git a/tests/memory/workers/tasks/test_content_processing.py b/tests/memory/workers/tasks/test_content_processing.py index 44cb352..c3e8467 100644 --- a/tests/memory/workers/tasks/test_content_processing.py +++ b/tests/memory/workers/tasks/test_content_processing.py @@ -471,11 +471,9 @@ def test_process_content_item( "memory.workers.tasks.content_processing.push_to_qdrant", side_effect=Exception("Qdrant error"), ): - result = process_content_item( - mail_message, "mail", db_session, ["tag1"] - ) + result = process_content_item(mail_message, db_session) else: - result = process_content_item(mail_message, "mail", db_session, ["tag1"]) + result = process_content_item(mail_message, db_session) assert result["status"] == expected_status assert result["embed_status"] == expected_embed_status diff --git a/tests/memory/workers/tasks/test_ebook_tasks.py b/tests/memory/workers/tasks/test_ebook_tasks.py index 1191fc3..70017a5 100644 --- a/tests/memory/workers/tasks/test_ebook_tasks.py +++ b/tests/memory/workers/tasks/test_ebook_tasks.py @@ -326,7 +326,7 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client): # Check that the full content was passed to the embedding function texts = mock_voyage_client.embed.call_args[0][0] assert texts == [ - large_page_1.strip(), - large_page_2.strip(), - large_section_content.strip(), + [large_page_1.strip()], + [large_page_2.strip()], + [large_section_content.strip()], ] diff --git a/tests/memory/workers/tasks/test_forums_tasks.py b/tests/memory/workers/tasks/test_forums_tasks.py new file mode 100644 index 0000000..8644448 --- /dev/null +++ b/tests/memory/workers/tasks/test_forums_tasks.py @@ -0,0 +1,562 @@ +import pytest +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock, patch + +from memory.common.db.models import ForumPost +from memory.workers.tasks import forums +from memory.parsers.lesswrong import LessWrongPost +from memory.workers.tasks.content_processing import create_content_hash + + +@pytest.fixture +def mock_lesswrong_post(): + """Mock LessWrong post data for testing.""" + return LessWrongPost( + title="Test LessWrong Post", + url="https://www.lesswrong.com/posts/test123/test-post", + description="This is a test post description", + content="This is test post content with enough text to be processed and embedded.", + authors=["Test Author"], + published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + guid="test123", + karma=25, + votes=10, + comments=5, + words=100, + tags=["rationality", "ai"], + af=False, + score=25, + extended_score=30, + modified_at="2024-01-01T12:30:00Z", + slug="test-post", + images=[], # Empty images to avoid file path issues + ) + + +@pytest.fixture +def mock_empty_lesswrong_post(): + """Mock LessWrong post with empty content.""" + return LessWrongPost( + title="Empty Post", + url="https://www.lesswrong.com/posts/empty123/empty-post", + description="", + content="", + authors=["Empty Author"], + published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + guid="empty123", + karma=5, + votes=2, + comments=0, + words=0, + tags=[], + af=False, + score=5, + extended_score=5, + slug="empty-post", + images=[], + ) + + +@pytest.fixture +def mock_af_post(): + """Mock Alignment Forum post.""" + return LessWrongPost( + title="AI Safety Research", + url="https://www.lesswrong.com/posts/af123/ai-safety-research", + description="Important AI safety research", + content="This is important AI safety research content that should be processed.", + authors=["AI Researcher"], + published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + guid="af123", + karma=50, + votes=20, + comments=15, + words=200, + tags=["ai-safety", "alignment"], + af=True, + score=50, + extended_score=60, + slug="ai-safety-research", + images=[], + ) + + +def test_sync_lesswrong_post_success(mock_lesswrong_post, db_session, qdrant): + """Test successful LessWrong post synchronization.""" + result = forums.sync_lesswrong_post(mock_lesswrong_post, ["test", "forum"]) + + # Verify the ForumPost was created in the database + forum_post = ( + db_session.query(ForumPost) + .filter_by(url="https://www.lesswrong.com/posts/test123/test-post") + .first() + ) + assert forum_post is not None + assert forum_post.title == "Test LessWrong Post" + assert ( + forum_post.content + == "This is test post content with enough text to be processed and embedded." + ) + assert forum_post.modality == "forum" + assert forum_post.mime_type == "text/markdown" + assert forum_post.authors == ["Test Author"] + assert forum_post.karma == 25 + assert forum_post.votes == 10 + assert forum_post.comments == 5 + assert forum_post.words == 100 + assert forum_post.score == 25 + assert forum_post.slug == "test-post" + assert forum_post.images == [] + assert "test" in forum_post.tags + assert "forum" in forum_post.tags + assert "rationality" in forum_post.tags + assert "ai" in forum_post.tags + + # Verify the result + assert result["status"] == "processed" + assert result["forumpost_id"] == forum_post.id + assert result["title"] == "Test LessWrong Post" + + +def test_sync_lesswrong_post_empty_content(mock_empty_lesswrong_post, db_session): + """Test LessWrong post sync with empty content.""" + result = forums.sync_lesswrong_post(mock_empty_lesswrong_post) + + # Should still create the post but with failed status due to no chunks + forum_post = ( + db_session.query(ForumPost) + .filter_by(url="https://www.lesswrong.com/posts/empty123/empty-post") + .first() + ) + assert forum_post is not None + assert forum_post.title == "Empty Post" + assert forum_post.content == "" + assert result["status"] == "failed" # No chunks generated for empty content + assert result["chunks_count"] == 0 + + +def test_sync_lesswrong_post_already_exists(mock_lesswrong_post, db_session): + """Test LessWrong post sync when content already exists.""" + # Add existing forum post with same content hash + existing_post = ForumPost( + url="https://www.lesswrong.com/posts/test123/test-post", + title="Test LessWrong Post", + content="This is test post content with enough text to be processed and embedded.", + sha256=create_content_hash( + "This is test post content with enough text to be processed and embedded." + ), + modality="forum", + tags=["existing"], + mime_type="text/markdown", + size=77, + authors=["Test Author"], + karma=25, + ) + db_session.add(existing_post) + db_session.commit() + + result = forums.sync_lesswrong_post(mock_lesswrong_post, ["test"]) + + assert result["status"] == "already_exists" + assert result["forumpost_id"] == existing_post.id + + # Verify no duplicate was created + forum_posts = ( + db_session.query(ForumPost) + .filter_by(url="https://www.lesswrong.com/posts/test123/test-post") + .all() + ) + assert len(forum_posts) == 1 + + +def test_sync_lesswrong_post_with_custom_tags(mock_lesswrong_post, db_session, qdrant): + """Test LessWrong post sync with custom tags.""" + result = forums.sync_lesswrong_post(mock_lesswrong_post, ["custom", "tags"]) + + # Verify the ForumPost was created with custom tags merged with post tags + forum_post = ( + db_session.query(ForumPost) + .filter_by(url="https://www.lesswrong.com/posts/test123/test-post") + .first() + ) + assert forum_post is not None + assert "custom" in forum_post.tags + assert "tags" in forum_post.tags + assert "rationality" in forum_post.tags # Original post tags + assert "ai" in forum_post.tags + assert result["status"] == "processed" + + +def test_sync_lesswrong_post_af_post(mock_af_post, db_session, qdrant): + """Test syncing an Alignment Forum post.""" + result = forums.sync_lesswrong_post(mock_af_post, ["alignment-forum"]) + + forum_post = ( + db_session.query(ForumPost) + .filter_by(url="https://www.lesswrong.com/posts/af123/ai-safety-research") + .first() + ) + assert forum_post is not None + assert forum_post.title == "AI Safety Research" + assert forum_post.karma == 50 + assert "ai-safety" in forum_post.tags + assert "alignment" in forum_post.tags + assert "alignment-forum" in forum_post.tags + assert result["status"] == "processed" + + +@patch("memory.workers.tasks.forums.fetch_lesswrong_posts") +def test_sync_lesswrong_success(mock_fetch, mock_lesswrong_post, db_session): + """Test successful LessWrong synchronization.""" + mock_fetch.return_value = [mock_lesswrong_post] + + with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post: + mock_sync_post.delay.return_value = Mock(id="task-123") + + result = forums.sync_lesswrong( + since="2024-01-01T00:00:00", + min_karma=10, + limit=50, + cooldown=0.1, + max_items=100, + af=False, + tags=["test"], + ) + + assert result["posts_num"] == 1 + assert result["new_posts"] == 1 + assert result["since"] == "2024-01-01T00:00:00" + assert result["min_karma"] == 10 + assert result["max_items"] == 100 + assert result["af"] == False + + # Verify fetch_lesswrong_posts was called with correct arguments + mock_fetch.assert_called_once_with( + datetime.fromisoformat("2024-01-01T00:00:00"), + 10, # min_karma + 50, # limit + 0.1, # cooldown + 100, # max_items + False, # af + ) + + # Verify sync_lesswrong_post was called for the new post + mock_sync_post.delay.assert_called_once_with(mock_lesswrong_post, ["test"]) + + +@patch("memory.workers.tasks.forums.fetch_lesswrong_posts") +def test_sync_lesswrong_with_existing_posts( + mock_fetch, mock_lesswrong_post, db_session +): + """Test sync when some posts already exist.""" + # Create existing forum post + existing_post = ForumPost( + url="https://www.lesswrong.com/posts/test123/test-post", + title="Existing Post", + content="Existing content", + sha256=b"existing_hash" + bytes(24), + modality="forum", + tags=["existing"], + mime_type="text/markdown", + size=100, + authors=["Test Author"], + ) + db_session.add(existing_post) + db_session.commit() + + # Mock fetch to return existing post and a new one + new_post = mock_lesswrong_post.copy() + new_post["url"] = "https://www.lesswrong.com/posts/new123/new-post" + new_post["title"] = "New Post" + + mock_fetch.return_value = [mock_lesswrong_post, new_post] + + with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post: + mock_sync_post.delay.return_value = Mock(id="task-456") + + result = forums.sync_lesswrong(max_items=100) + + assert result["posts_num"] == 2 + assert result["new_posts"] == 1 # Only one new post + + # Verify sync_lesswrong_post was only called for the new post + mock_sync_post.delay.assert_called_once_with(new_post, []) + + +@patch("memory.workers.tasks.forums.fetch_lesswrong_posts") +def test_sync_lesswrong_no_posts(mock_fetch, db_session): + """Test sync when no posts are returned.""" + mock_fetch.return_value = [] + + result = forums.sync_lesswrong() + + assert result["posts_num"] == 0 + assert result["new_posts"] == 0 + + +@patch("memory.workers.tasks.forums.fetch_lesswrong_posts") +def test_sync_lesswrong_max_items_limit(mock_fetch, db_session): + """Test that max_items limit is respected.""" + # Create multiple mock posts + posts = [] + for i in range(5): + post = LessWrongPost( + title=f"Post {i}", + url=f"https://www.lesswrong.com/posts/test{i}/post-{i}", + description=f"Description {i}", + content=f"Content {i}", + authors=[f"Author {i}"], + published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + guid=f"test{i}", + karma=10, + votes=5, + comments=2, + words=50, + tags=[], + af=False, + score=10, + extended_score=10, + slug=f"post-{i}", + images=[], + ) + posts.append(post) + + mock_fetch.return_value = posts + + with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post: + mock_sync_post.delay.return_value = Mock(id="task-123") + + result = forums.sync_lesswrong(max_items=3) + + # Should stop at max_items, but new_posts can be higher because + # the check happens after incrementing new_posts but before incrementing posts_num + assert result["posts_num"] == 3 + assert result["new_posts"] == 4 # One more than posts_num due to timing + assert result["max_items"] == 3 + + +@pytest.mark.parametrize( + "since_str,expected_days_ago", + [ + ("2024-01-01T00:00:00", None), # Specific date + (None, 30), # Default should be 30 days ago + ], +) +@patch("memory.workers.tasks.forums.fetch_lesswrong_posts") +def test_sync_lesswrong_since_parameter( + mock_fetch, since_str, expected_days_ago, db_session +): + """Test that since parameter is handled correctly.""" + mock_fetch.return_value = [] + + if since_str: + forums.sync_lesswrong(since=since_str) + expected_since = datetime.fromisoformat(since_str) + else: + forums.sync_lesswrong() + # Should default to 30 days ago + expected_since = datetime.now() - timedelta(days=30) + + # Verify fetch was called with correct since date + call_args = mock_fetch.call_args[0] + actual_since = call_args[0] + + if expected_days_ago: + # For default case, check it's approximately 30 days ago + time_diff = abs((actual_since - expected_since).total_seconds()) + assert time_diff < 120 # Within 2 minute tolerance for slow tests + else: + assert actual_since == expected_since + + +@pytest.mark.parametrize( + "af_value,min_karma,limit,cooldown", + [ + (True, 20, 25, 1.0), + (False, 5, 100, 0.0), + (True, 50, 10, 0.5), + ], +) +@patch("memory.workers.tasks.forums.fetch_lesswrong_posts") +def test_sync_lesswrong_parameters( + mock_fetch, af_value, min_karma, limit, cooldown, db_session +): + """Test that all parameters are passed correctly to fetch function.""" + mock_fetch.return_value = [] + + result = forums.sync_lesswrong( + af=af_value, + min_karma=min_karma, + limit=limit, + cooldown=cooldown, + max_items=500, + ) + + # Verify fetch was called with correct parameters + call_args = mock_fetch.call_args[0] + + assert call_args[1] == min_karma # min_karma + assert call_args[2] == limit # limit + assert call_args[3] == cooldown # cooldown + assert call_args[4] == 500 # max_items + assert call_args[5] == af_value # af + + assert result["min_karma"] == min_karma + assert result["af"] == af_value + + +@patch("memory.workers.tasks.forums.fetch_lesswrong_posts") +def test_sync_lesswrong_fetch_error(mock_fetch, db_session): + """Test sync when fetch_lesswrong_posts raises an exception.""" + mock_fetch.side_effect = Exception("API error") + + # The safe_task_execution decorator should catch this + result = forums.sync_lesswrong() + + assert result["status"] == "error" + assert "API error" in result["error"] + + +def test_sync_lesswrong_post_error_handling(db_session): + """Test error handling in sync_lesswrong_post.""" + # Create invalid post data that will cause an error + invalid_post = { + "title": "Test", + "url": "invalid-url", + "content": "test content", + # Missing required fields + } + + # The safe_task_execution decorator should catch this + result = forums.sync_lesswrong_post(invalid_post) + + assert result["status"] == "error" + assert "error" in result + + +@pytest.mark.parametrize( + "post_tags,additional_tags,expected_tags", + [ + (["original"], ["new"], ["original", "new"]), + ([], ["tag1", "tag2"], ["tag1", "tag2"]), + (["existing"], [], ["existing"]), + (["dup", "tag"], ["tag", "new"], ["dup", "tag", "new"]), # Duplicates removed + ], +) +def test_sync_lesswrong_post_tag_merging( + post_tags, additional_tags, expected_tags, db_session, qdrant +): + """Test that post tags and additional tags are properly merged.""" + post = LessWrongPost( + title="Tag Test Post", + url="https://www.lesswrong.com/posts/tag123/tag-test", + description="Test description", + content="Test content for tag merging", + authors=["Tag Author"], + published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + guid="tag123", + karma=15, + votes=5, + comments=2, + words=50, + tags=post_tags, + af=False, + score=15, + extended_score=15, + slug="tag-test", + images=[], + ) + + forums.sync_lesswrong_post(post, additional_tags) + + forum_post = ( + db_session.query(ForumPost) + .filter_by(url="https://www.lesswrong.com/posts/tag123/tag-test") + .first() + ) + assert forum_post is not None + + # Check that all expected tags are present (order doesn't matter) + for tag in expected_tags: + assert tag in forum_post.tags + + # Check that no unexpected tags are present + assert len(forum_post.tags) == len(set(expected_tags)) + + +def test_sync_lesswrong_post_datetime_handling(db_session, qdrant): + """Test that datetime fields are properly handled.""" + post = LessWrongPost( + title="DateTime Test", + url="https://www.lesswrong.com/posts/dt123/datetime-test", + description="Test description", + content="Test content", + authors=["DateTime Author"], + published_at=datetime(2024, 6, 15, 14, 30, 45, tzinfo=timezone.utc), + guid="dt123", + karma=20, + votes=8, + comments=3, + words=75, + tags=["datetime"], + af=False, + score=20, + extended_score=25, + modified_at="2024-06-15T15:00:00Z", + slug="datetime-test", + images=[], + ) + + result = forums.sync_lesswrong_post(post) + + forum_post = ( + db_session.query(ForumPost) + .filter_by(url="https://www.lesswrong.com/posts/dt123/datetime-test") + .first() + ) + assert forum_post is not None + assert forum_post.published_at == datetime( + 2024, 6, 15, 14, 30, 45, tzinfo=timezone.utc + ) + # modified_at should be stored as string in the post data + assert result["status"] == "processed" + + +def test_sync_lesswrong_post_content_hash_consistency(db_session): + """Test that content hash is calculated consistently.""" + post = LessWrongPost( + title="Hash Test", + url="https://www.lesswrong.com/posts/hash123/hash-test", + description="Test description", + content="Consistent content for hashing", + authors=["Hash Author"], + published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + guid="hash123", + karma=15, + votes=5, + comments=2, + words=50, + tags=["hash"], + af=False, + score=15, + extended_score=15, + slug="hash-test", + images=[], + ) + + # Sync the same post twice + result1 = forums.sync_lesswrong_post(post) + result2 = forums.sync_lesswrong_post(post) + + # First should succeed, second should detect existing + assert result1["status"] == "processed" + assert result2["status"] == "already_exists" + assert result1["forumpost_id"] == result2["forumpost_id"] + + # Verify only one post exists in database + forum_posts = ( + db_session.query(ForumPost) + .filter_by(url="https://www.lesswrong.com/posts/hash123/hash-test") + .all() + ) + assert len(forum_posts) == 1 diff --git a/tests/memory/workers/tasks/test_maintenance.py b/tests/memory/workers/tasks/test_maintenance.py index 96d7e5d..c06c6cb 100644 --- a/tests/memory/workers/tasks/test_maintenance.py +++ b/tests/memory/workers/tasks/test_maintenance.py @@ -7,12 +7,14 @@ from PIL import Image from memory.common import qdrant as qd from memory.common import settings -from memory.common.db.models import Chunk, SourceItem +from memory.common.db.models import Chunk, SourceItem, MailMessage, BlogPost from memory.workers.tasks.maintenance import ( clean_collection, reingest_chunk, check_batch, reingest_missing_chunks, + reingest_item, + reingest_empty_source_items, ) @@ -302,5 +304,295 @@ def test_reingest_missing_chunks(db_session, qdrant, batch_size): assert set(qdrant_ids) == set(db_ids) +@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"]) +def test_reingest_item_success(db_session, qdrant, item_type): + """Test successful reingestion of an item with existing chunks.""" + if item_type == "MailMessage": + item = MailMessage( + sha256=b"test_hash" + bytes(24), + tags=["test"], + size=100, + mime_type="message/rfc822", + embed_status="STORED", + message_id="