mirror of
https://github.com/mruwnik/memory.git
synced 2025-08-01 15:36:55 +02:00
Add less wrong tasks + reindexer
This commit is contained in:
parent
ab87bced81
commit
ed8033bdd3
52
db/migrations/versions/20250528_012300_add_forum_posts.py
Normal file
52
db/migrations/versions/20250528_012300_add_forum_posts.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
"""Add forum posts
|
||||||
|
|
||||||
|
Revision ID: 2524646f56f6
|
||||||
|
Revises: 1b535e1b044e
|
||||||
|
Create Date: 2025-05-28 01:23:00.079366
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "2524646f56f6"
|
||||||
|
down_revision: Union[str, None] = "1b535e1b044e"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"forum_post",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("url", sa.Text(), nullable=True),
|
||||||
|
sa.Column("title", sa.Text(), nullable=True),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column("authors", sa.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("modified_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("slug", sa.Text(), nullable=True),
|
||||||
|
sa.Column("karma", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("votes", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("comments", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("words", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("score", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("images", sa.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("url"),
|
||||||
|
)
|
||||||
|
op.create_index("forum_post_slug_idx", "forum_post", ["slug"], unique=False)
|
||||||
|
op.create_index("forum_post_title_idx", "forum_post", ["title"], unique=False)
|
||||||
|
op.create_index("forum_post_url_idx", "forum_post", ["url"], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("forum_post_url_idx", table_name="forum_post")
|
||||||
|
op.drop_index("forum_post_title_idx", table_name="forum_post")
|
||||||
|
op.drop_index("forum_post_slug_idx", table_name="forum_post")
|
||||||
|
op.drop_table("forum_post")
|
@ -207,6 +207,12 @@ services:
|
|||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "blogs"
|
QUEUES: "blogs"
|
||||||
|
|
||||||
|
worker-forums:
|
||||||
|
<<: *worker-base
|
||||||
|
environment:
|
||||||
|
<<: *worker-env
|
||||||
|
QUEUES: "forums"
|
||||||
|
|
||||||
worker-photo:
|
worker-photo:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
|
@ -33,7 +33,7 @@ RUN mkdir -p /var/log/supervisor /var/run/supervisor
|
|||||||
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files
|
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files
|
||||||
USER kb
|
USER kb
|
||||||
|
|
||||||
ENV QUEUES="docs,email,maintenance"
|
ENV QUEUES="maintenance"
|
||||||
ENV PYTHONPATH="/app"
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]
|
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]
|
@ -39,7 +39,7 @@ RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
|||||||
USER kb
|
USER kb
|
||||||
|
|
||||||
# Default queues to process
|
# Default queues to process
|
||||||
ENV QUEUES="ebooks,email,comic,blogs,photo_embed,maintenance"
|
ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance"
|
||||||
ENV PYTHONPATH="/app"
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
ENTRYPOINT ["./entry.sh"]
|
ENTRYPOINT ["./entry.sh"]
|
@ -4,3 +4,4 @@ black==23.12.1
|
|||||||
mypy==1.8.0
|
mypy==1.8.0
|
||||||
isort==5.13.2
|
isort==5.13.2
|
||||||
testcontainers[qdrant]==4.10.0
|
testcontainers[qdrant]==4.10.0
|
||||||
|
click==8.1.7
|
@ -17,6 +17,7 @@ from memory.common.db.models import (
|
|||||||
MiscDoc,
|
MiscDoc,
|
||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
|
ForumPost,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +34,11 @@ DEFAULT_COLUMNS = (
|
|||||||
|
|
||||||
|
|
||||||
def source_columns(model: type[SourceItem], *columns: str):
|
def source_columns(model: type[SourceItem], *columns: str):
|
||||||
return [getattr(model, c) for c in columns + DEFAULT_COLUMNS if hasattr(model, c)]
|
return [
|
||||||
|
getattr(model, c)
|
||||||
|
for c in ("id",) + columns + DEFAULT_COLUMNS
|
||||||
|
if hasattr(model, c)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# Create admin views for all models
|
# Create admin views for all models
|
||||||
@ -86,6 +91,21 @@ class BlogPostAdmin(ModelView, model=BlogPost):
|
|||||||
column_searchable_list = ["title", "author", "domain"]
|
column_searchable_list = ["title", "author", "domain"]
|
||||||
|
|
||||||
|
|
||||||
|
class ForumPostAdmin(ModelView, model=ForumPost):
|
||||||
|
column_list = source_columns(
|
||||||
|
ForumPost,
|
||||||
|
"title",
|
||||||
|
"authors",
|
||||||
|
"published_at",
|
||||||
|
"url",
|
||||||
|
"karma",
|
||||||
|
"votes",
|
||||||
|
"comments",
|
||||||
|
"score",
|
||||||
|
)
|
||||||
|
column_searchable_list = ["title", "authors"]
|
||||||
|
|
||||||
|
|
||||||
class PhotoAdmin(ModelView, model=Photo):
|
class PhotoAdmin(ModelView, model=Photo):
|
||||||
column_list = source_columns(Photo, "exif_taken_at", "camera")
|
column_list = source_columns(Photo, "exif_taken_at", "camera")
|
||||||
|
|
||||||
@ -166,5 +186,6 @@ def setup_admin(admin: Admin):
|
|||||||
admin.add_view(MiscDocAdmin)
|
admin.add_view(MiscDocAdmin)
|
||||||
admin.add_view(ArticleFeedAdmin)
|
admin.add_view(ArticleFeedAdmin)
|
||||||
admin.add_view(BlogPostAdmin)
|
admin.add_view(BlogPostAdmin)
|
||||||
|
admin.add_view(ForumPostAdmin)
|
||||||
admin.add_view(ComicAdmin)
|
admin.add_view(ComicAdmin)
|
||||||
admin.add_view(PhotoAdmin)
|
admin.add_view(PhotoAdmin)
|
||||||
|
@ -43,7 +43,12 @@ ALL_COLLECTIONS: dict[str, Collection] = {
|
|||||||
"blog": {
|
"blog": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"forum": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||||
},
|
},
|
||||||
"text": {
|
"text": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
|
@ -105,6 +105,21 @@ def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalCh
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_mixed(
|
||||||
|
content: str, image_paths: Sequence[str]
|
||||||
|
) -> list[list[extract.MulitmodalChunk]]:
|
||||||
|
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
||||||
|
full_text: list[extract.MulitmodalChunk] = [content.strip(), *images]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
tokens = chunker.approx_token_count(content)
|
||||||
|
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||||
|
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
|
||||||
|
|
||||||
|
all_chunks = [full_text] + chunks
|
||||||
|
return [c for c in all_chunks if c and all(i for i in c)]
|
||||||
|
|
||||||
|
|
||||||
class Chunk(Base):
|
class Chunk(Base):
|
||||||
"""Stores content chunks with their vector embeddings."""
|
"""Stores content chunks with their vector embeddings."""
|
||||||
|
|
||||||
@ -134,6 +149,19 @@ class Chunk(Base):
|
|||||||
Index("chunk_source_idx", "source_id"),
|
Index("chunk_source_idx", "source_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunks(self) -> list[extract.MulitmodalChunk]:
|
||||||
|
chunks: list[extract.MulitmodalChunk] = []
|
||||||
|
if cast(str | None, self.content):
|
||||||
|
chunks = [cast(str, self.content)]
|
||||||
|
if self.images:
|
||||||
|
chunks += self.images
|
||||||
|
elif cast(Sequence[str] | None, self.file_paths):
|
||||||
|
chunks += [
|
||||||
|
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
|
||||||
|
]
|
||||||
|
return chunks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> list[bytes | str | Image.Image]:
|
def data(self) -> list[bytes | str | Image.Image]:
|
||||||
if self.file_paths is None:
|
if self.file_paths is None:
|
||||||
@ -638,20 +666,57 @@ class BlogPost(SourceItem):
|
|||||||
return {k: v for k, v in payload.items() if v}
|
return {k: v for k, v in payload.items() if v}
|
||||||
|
|
||||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||||
images = [
|
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||||
Image.open(settings.FILE_STORAGE_DIR / image) for image in self.images
|
|
||||||
]
|
|
||||||
|
|
||||||
content = cast(str, self.content)
|
|
||||||
full_text = [content.strip(), *images]
|
|
||||||
|
|
||||||
chunks = []
|
class ForumPost(SourceItem):
|
||||||
tokens = chunker.approx_token_count(content)
|
__tablename__ = "forum_post"
|
||||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
|
||||||
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
|
|
||||||
|
|
||||||
all_chunks = [full_text] + chunks
|
id = Column(
|
||||||
return [c for c in all_chunks if c and all(i for i in c)]
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
url = Column(Text, unique=True)
|
||||||
|
title = Column(Text)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
authors = Column(ARRAY(Text), nullable=True)
|
||||||
|
published_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
modified_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
slug = Column(Text, nullable=True)
|
||||||
|
karma = Column(Integer, nullable=True)
|
||||||
|
votes = Column(Integer, nullable=True)
|
||||||
|
comments = Column(Integer, nullable=True)
|
||||||
|
words = Column(Integer, nullable=True)
|
||||||
|
score = Column(Integer, nullable=True)
|
||||||
|
images = Column(ARRAY(Text), nullable=True)
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "forum_post",
|
||||||
|
}
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("forum_post_url_idx", "url"),
|
||||||
|
Index("forum_post_slug_idx", "slug"),
|
||||||
|
Index("forum_post_title_idx", "title"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_payload(self) -> dict:
|
||||||
|
return {
|
||||||
|
"source_id": self.id,
|
||||||
|
"url": self.url,
|
||||||
|
"title": self.title,
|
||||||
|
"description": self.description,
|
||||||
|
"authors": self.authors,
|
||||||
|
"published_at": self.published_at,
|
||||||
|
"slug": self.slug,
|
||||||
|
"karma": self.karma,
|
||||||
|
"votes": self.votes,
|
||||||
|
"score": self.score,
|
||||||
|
"comments": self.comments,
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||||
|
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||||
|
|
||||||
|
|
||||||
class MiscDoc(SourceItem):
|
class MiscDoc(SourceItem):
|
||||||
@ -710,7 +775,7 @@ class ArticleFeed(Base):
|
|||||||
description = Column(Text)
|
description = Column(Text)
|
||||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
check_interval = Column(
|
check_interval = Column(
|
||||||
Integer, nullable=False, server_default="3600", doc="Seconds between checks"
|
Integer, nullable=False, server_default="60", doc="Minutes between checks"
|
||||||
)
|
)
|
||||||
last_checked_at = Column(DateTime(timezone=True))
|
last_checked_at = Column(DateTime(timezone=True))
|
||||||
active = Column(Boolean, nullable=False, server_default="true")
|
active = Column(Boolean, nullable=False, server_default="true")
|
||||||
|
@ -20,7 +20,7 @@ def embed_chunks(
|
|||||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]}")
|
logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]} {len(chunks)}")
|
||||||
vo = voyageai.Client() # type: ignore
|
vo = voyageai.Client() # type: ignore
|
||||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||||
return vo.multimodal_embed(
|
return vo.multimodal_embed(
|
||||||
@ -79,7 +79,7 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
|||||||
if not model_chunks:
|
if not model_chunks:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
vectors = embed_chunks([chunk.content for chunk in model_chunks], model) # type: ignore
|
vectors = embed_chunks([chunk.chunks for chunk in model_chunks], model)
|
||||||
for chunk, vector in zip(model_chunks, vectors):
|
for chunk, vector in zip(model_chunks, vectors):
|
||||||
chunk.vector = vector
|
chunk.vector = vector
|
||||||
return model_chunks
|
return model_chunks
|
||||||
|
@ -86,9 +86,11 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
|||||||
|
|
||||||
# Worker settings
|
# Worker settings
|
||||||
# Intervals are in seconds
|
# Intervals are in seconds
|
||||||
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
|
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 60 * 60))
|
||||||
CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 86400))
|
COMIC_SYNC_INTERVAL = int(os.getenv("COMIC_SYNC_INTERVAL", 60 * 60))
|
||||||
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 3600))
|
ARTICLE_FEED_SYNC_INTERVAL = int(os.getenv("ARTICLE_FEED_SYNC_INTERVAL", 30 * 60))
|
||||||
|
CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 24 * 60 * 60))
|
||||||
|
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
|
||||||
|
|
||||||
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
||||||
|
|
||||||
|
304
src/memory/parsers/lesswrong.py
Normal file
304
src/memory/parsers/lesswrong.py
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Generator, TypedDict, NotRequired
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
from memory.common import settings
|
||||||
|
import requests
|
||||||
|
from markdownify import markdownify
|
||||||
|
|
||||||
|
from memory.parsers.html import parse_date, process_images
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LessWrongPost(TypedDict):
|
||||||
|
"""Represents a post from LessWrong."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
url: str
|
||||||
|
description: str
|
||||||
|
content: str
|
||||||
|
authors: list[str]
|
||||||
|
published_at: datetime | None
|
||||||
|
guid: str | None
|
||||||
|
karma: int
|
||||||
|
votes: int
|
||||||
|
comments: int
|
||||||
|
words: int
|
||||||
|
tags: list[str]
|
||||||
|
af: bool
|
||||||
|
score: int
|
||||||
|
extended_score: int
|
||||||
|
modified_at: NotRequired[str | None]
|
||||||
|
slug: NotRequired[str | None]
|
||||||
|
images: NotRequired[list[str]]
|
||||||
|
|
||||||
|
|
||||||
|
def make_graphql_query(
|
||||||
|
after: datetime, af: bool = False, limit: int = 50, min_karma: int = 10
|
||||||
|
) -> str:
|
||||||
|
"""Create GraphQL query for fetching posts."""
|
||||||
|
return f"""
|
||||||
|
{{
|
||||||
|
posts(input: {{
|
||||||
|
terms: {{
|
||||||
|
excludeEvents: true
|
||||||
|
view: "old"
|
||||||
|
af: {str(af).lower()}
|
||||||
|
limit: {limit}
|
||||||
|
karmaThreshold: {min_karma}
|
||||||
|
after: "{after.isoformat()}Z"
|
||||||
|
filter: "tagged"
|
||||||
|
}}
|
||||||
|
}}) {{
|
||||||
|
totalCount
|
||||||
|
results {{
|
||||||
|
_id
|
||||||
|
title
|
||||||
|
slug
|
||||||
|
pageUrl
|
||||||
|
postedAt
|
||||||
|
modifiedAt
|
||||||
|
score
|
||||||
|
extendedScore
|
||||||
|
baseScore
|
||||||
|
voteCount
|
||||||
|
commentCount
|
||||||
|
wordCount
|
||||||
|
tags {{
|
||||||
|
name
|
||||||
|
}}
|
||||||
|
user {{
|
||||||
|
displayName
|
||||||
|
}}
|
||||||
|
coauthors {{
|
||||||
|
displayName
|
||||||
|
}}
|
||||||
|
af
|
||||||
|
htmlBody
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_posts_from_api(url: str, query: str) -> dict[str, Any]:
|
||||||
|
"""Fetch posts from LessWrong GraphQL API."""
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0"
|
||||||
|
},
|
||||||
|
json={"query": query},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["data"]["posts"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_post(post: dict[str, Any], min_karma: int = 10) -> bool:
|
||||||
|
"""Check if post should be included."""
|
||||||
|
# Must have content
|
||||||
|
if not post.get("htmlBody"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Must meet karma threshold
|
||||||
|
if post.get("baseScore", 0) < min_karma:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def extract_authors(post: dict[str, Any]) -> list[str]:
|
||||||
|
"""Extract authors from post data."""
|
||||||
|
authors = post.get("coauthors", []) or []
|
||||||
|
if post.get("user"):
|
||||||
|
authors = [post["user"]] + authors
|
||||||
|
return [a["displayName"] for a in authors] or ["anonymous"]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_description(body: str) -> str:
|
||||||
|
"""Extract description from post HTML content."""
|
||||||
|
first_paragraph = body.split("\n\n")[0]
|
||||||
|
|
||||||
|
# Truncate if too long
|
||||||
|
if len(first_paragraph) > 300:
|
||||||
|
first_paragraph = first_paragraph[:300] + "..."
|
||||||
|
|
||||||
|
return first_paragraph
|
||||||
|
|
||||||
|
|
||||||
|
def parse_lesswrong_date(date_str: str | None) -> datetime | None:
|
||||||
|
"""Parse ISO date string from LessWrong API to datetime."""
|
||||||
|
if not date_str:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try multiple ISO formats that LessWrong might use
|
||||||
|
formats = [
|
||||||
|
"%Y-%m-%dT%H:%M:%S.%fZ", # 2023-01-15T10:30:00.000Z
|
||||||
|
"%Y-%m-%dT%H:%M:%SZ", # 2023-01-15T10:30:00Z
|
||||||
|
"%Y-%m-%dT%H:%M:%S.%f", # 2023-01-15T10:30:00.000
|
||||||
|
"%Y-%m-%dT%H:%M:%S", # 2023-01-15T10:30:00
|
||||||
|
]
|
||||||
|
|
||||||
|
for fmt in formats:
|
||||||
|
if result := parse_date(date_str, fmt):
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Fallback: try removing 'Z' and using fromisoformat
|
||||||
|
try:
|
||||||
|
clean_date = date_str.rstrip("Z")
|
||||||
|
return datetime.fromisoformat(clean_date)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(f"Could not parse date: {date_str}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_body(post: dict[str, Any]) -> tuple[str, dict[str, PILImage.Image]]:
|
||||||
|
"""Extract body from post data."""
|
||||||
|
if not (body := post.get("htmlBody", "").strip()):
|
||||||
|
return "", {}
|
||||||
|
|
||||||
|
url = post.get("pageUrl", "")
|
||||||
|
image_dir = settings.FILE_STORAGE_DIR / "lesswrong" / url
|
||||||
|
|
||||||
|
soup = BeautifulSoup(body, "html.parser")
|
||||||
|
soup, images = process_images(soup, url, image_dir)
|
||||||
|
body = markdownify(str(soup)).strip()
|
||||||
|
return body, images
|
||||||
|
|
||||||
|
|
||||||
|
def format_post(post: dict[str, Any]) -> LessWrongPost:
|
||||||
|
"""Convert raw API post data to GreaterWrongPost."""
|
||||||
|
body, images = extract_body(post)
|
||||||
|
|
||||||
|
result: LessWrongPost = {
|
||||||
|
"title": post.get("title", "Untitled"),
|
||||||
|
"url": post.get("pageUrl", ""),
|
||||||
|
"description": extract_description(body),
|
||||||
|
"content": body,
|
||||||
|
"authors": extract_authors(post),
|
||||||
|
"published_at": parse_lesswrong_date(post.get("postedAt")),
|
||||||
|
"guid": post.get("_id"),
|
||||||
|
"karma": post.get("baseScore", 0),
|
||||||
|
"votes": post.get("voteCount", 0),
|
||||||
|
"comments": post.get("commentCount", 0),
|
||||||
|
"words": post.get("wordCount", 0),
|
||||||
|
"tags": [tag["name"] for tag in post.get("tags", [])],
|
||||||
|
"af": post.get("af", False),
|
||||||
|
"score": post.get("score", 0),
|
||||||
|
"extended_score": post.get("extendedScore", 0),
|
||||||
|
"modified_at": post.get("modifiedAt"),
|
||||||
|
"slug": post.get("slug"),
|
||||||
|
"images": list(images.keys()),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_lesswrong(
|
||||||
|
url: str,
|
||||||
|
current_date: datetime,
|
||||||
|
af: bool = False,
|
||||||
|
min_karma: int = 10,
|
||||||
|
limit: int = 50,
|
||||||
|
last_url: str | None = None,
|
||||||
|
) -> list[LessWrongPost]:
|
||||||
|
"""
|
||||||
|
Fetch a batch of posts from LessWrong.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(posts, next_date, last_item) where next_date is None if iteration should stop
|
||||||
|
"""
|
||||||
|
query = make_graphql_query(current_date, af, limit, min_karma)
|
||||||
|
api_response = fetch_posts_from_api(url, query)
|
||||||
|
|
||||||
|
if not api_response["results"]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# If we only get the same item we started with, we're done
|
||||||
|
if (
|
||||||
|
len(api_response["results"]) == 1
|
||||||
|
and last_url
|
||||||
|
and api_response["results"][0]["pageUrl"] == last_url
|
||||||
|
):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
format_post(post)
|
||||||
|
for post in api_response["results"]
|
||||||
|
if is_valid_post(post, min_karma)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_lesswrong_posts(
|
||||||
|
since: datetime | None = None,
|
||||||
|
min_karma: int = 10,
|
||||||
|
limit: int = 50,
|
||||||
|
cooldown: float = 0.5,
|
||||||
|
max_items: int = 1000,
|
||||||
|
af: bool = False,
|
||||||
|
url: str = "https://www.lesswrong.com/graphql",
|
||||||
|
) -> Generator[LessWrongPost, None, None]:
|
||||||
|
"""
|
||||||
|
Fetch posts from LessWrong.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: GraphQL endpoint URL
|
||||||
|
af: Whether to fetch Alignment Forum posts
|
||||||
|
min_karma: Minimum karma threshold for posts
|
||||||
|
limit: Number of posts per API request
|
||||||
|
start_year: Default start year if no since date provided
|
||||||
|
since: Start date for fetching posts
|
||||||
|
cooldown: Delay between API requests in seconds
|
||||||
|
max_pages: Maximum number of pages to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of GreaterWrongPost objects
|
||||||
|
"""
|
||||||
|
if not since:
|
||||||
|
since = datetime.now() - timedelta(days=1)
|
||||||
|
|
||||||
|
logger.info(f"Starting from {since}")
|
||||||
|
|
||||||
|
last_url = None
|
||||||
|
next_date = since
|
||||||
|
items_count = 0
|
||||||
|
|
||||||
|
while next_date and items_count < max_items:
|
||||||
|
try:
|
||||||
|
page_posts = fetch_lesswrong(url, next_date, af, min_karma, limit, last_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching posts: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not page_posts or next_date is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
for post in page_posts:
|
||||||
|
yield post
|
||||||
|
|
||||||
|
last_item = page_posts[-1]
|
||||||
|
prev_date = next_date
|
||||||
|
next_date = last_item.get("published_at")
|
||||||
|
|
||||||
|
if not next_date or prev_date == next_date:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not advance through dataset, stopping at {next_date}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# The articles are paged by date (inclusive) so passing the last date as
|
||||||
|
# is will return the same article again.
|
||||||
|
next_date += timedelta(seconds=1)
|
||||||
|
items_count += len(page_posts)
|
||||||
|
last_url = last_item["url"]
|
||||||
|
|
||||||
|
if cooldown > 0:
|
||||||
|
time.sleep(cooldown)
|
||||||
|
|
||||||
|
logger.info(f"Fetched {items_count} items")
|
@ -1,6 +1,14 @@
|
|||||||
from celery import Celery
|
from celery import Celery
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
|
EMAIL_ROOT = "memory.workers.tasks.email"
|
||||||
|
FORUMS_ROOT = "memory.workers.tasks.forums"
|
||||||
|
BLOGS_ROOT = "memory.workers.tasks.blogs"
|
||||||
|
PHOTO_ROOT = "memory.workers.tasks.photo"
|
||||||
|
COMIC_ROOT = "memory.workers.tasks.comic"
|
||||||
|
EBOOK_ROOT = "memory.workers.tasks.ebook"
|
||||||
|
MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
|
||||||
|
|
||||||
|
|
||||||
def rabbit_url() -> str:
|
def rabbit_url() -> str:
|
||||||
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:5672//"
|
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:5672//"
|
||||||
@ -21,12 +29,13 @@ app.conf.update(
|
|||||||
task_reject_on_worker_lost=True,
|
task_reject_on_worker_lost=True,
|
||||||
worker_prefetch_multiplier=1,
|
worker_prefetch_multiplier=1,
|
||||||
task_routes={
|
task_routes={
|
||||||
"memory.workers.tasks.email.*": {"queue": "email"},
|
f"{EMAIL_ROOT}.*": {"queue": "email"},
|
||||||
"memory.workers.tasks.photo.*": {"queue": "photo_embed"},
|
f"{PHOTO_ROOT}.*": {"queue": "photo_embed"},
|
||||||
"memory.workers.tasks.comic.*": {"queue": "comic"},
|
f"{COMIC_ROOT}.*": {"queue": "comic"},
|
||||||
"memory.workers.tasks.ebook.*": {"queue": "ebooks"},
|
f"{EBOOK_ROOT}.*": {"queue": "ebooks"},
|
||||||
"memory.workers.tasks.blogs.*": {"queue": "blogs"},
|
f"{BLOGS_ROOT}.*": {"queue": "blogs"},
|
||||||
"memory.workers.tasks.maintenance.*": {"queue": "maintenance"},
|
f"{FORUMS_ROOT}.*": {"queue": "forums"},
|
||||||
|
f"{MAINTENANCE_ROOT}.*": {"queue": "maintenance"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -8,10 +8,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
app.conf.beat_schedule = {
|
app.conf.beat_schedule = {
|
||||||
"sync-mail-all": {
|
|
||||||
"task": "memory.workers.tasks.email.sync_all_accounts",
|
|
||||||
"schedule": settings.EMAIL_SYNC_INTERVAL,
|
|
||||||
},
|
|
||||||
"clean-all-collections": {
|
"clean-all-collections": {
|
||||||
"task": CLEAN_ALL_COLLECTIONS,
|
"task": CLEAN_ALL_COLLECTIONS,
|
||||||
"schedule": settings.CLEAN_COLLECTION_INTERVAL,
|
"schedule": settings.CLEAN_COLLECTION_INTERVAL,
|
||||||
@ -20,4 +16,16 @@ app.conf.beat_schedule = {
|
|||||||
"task": REINGEST_MISSING_CHUNKS,
|
"task": REINGEST_MISSING_CHUNKS,
|
||||||
"schedule": settings.CHUNK_REINGEST_INTERVAL,
|
"schedule": settings.CHUNK_REINGEST_INTERVAL,
|
||||||
},
|
},
|
||||||
|
"sync-mail-all": {
|
||||||
|
"task": "memory.workers.tasks.email.sync_all_accounts",
|
||||||
|
"schedule": settings.EMAIL_SYNC_INTERVAL,
|
||||||
|
},
|
||||||
|
"sync-all-comics": {
|
||||||
|
"task": "memory.workers.tasks.comic.sync_all_comics",
|
||||||
|
"schedule": settings.COMIC_SYNC_INTERVAL,
|
||||||
|
},
|
||||||
|
"sync-all-article-feeds": {
|
||||||
|
"task": "memory.workers.tasks.blogs.sync_all_article_feeds",
|
||||||
|
"schedule": settings.ARTICLE_FEED_SYNC_INTERVAL,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Import sub-modules so Celery can register their @app.task decorators.
|
Import sub-modules so Celery can register their @app.task decorators.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from memory.workers.tasks import email, comic, blogs, ebook # noqa
|
from memory.workers.tasks import email, comic, blogs, ebook, forums # noqa
|
||||||
from memory.workers.tasks.blogs import (
|
from memory.workers.tasks.blogs import (
|
||||||
SYNC_WEBPAGE,
|
SYNC_WEBPAGE,
|
||||||
SYNC_ARTICLE_FEED,
|
SYNC_ARTICLE_FEED,
|
||||||
@ -12,11 +12,13 @@ from memory.workers.tasks.blogs import (
|
|||||||
from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_SMBC, SYNC_XKCD
|
from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_SMBC, SYNC_XKCD
|
||||||
from memory.workers.tasks.ebook import SYNC_BOOK
|
from memory.workers.tasks.ebook import SYNC_BOOK
|
||||||
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
|
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
|
||||||
|
from memory.workers.tasks.forums import SYNC_LESSWRONG, SYNC_LESSWRONG_POST
|
||||||
from memory.workers.tasks.maintenance import (
|
from memory.workers.tasks.maintenance import (
|
||||||
CLEAN_ALL_COLLECTIONS,
|
CLEAN_ALL_COLLECTIONS,
|
||||||
CLEAN_COLLECTION,
|
CLEAN_COLLECTION,
|
||||||
REINGEST_MISSING_CHUNKS,
|
REINGEST_MISSING_CHUNKS,
|
||||||
REINGEST_CHUNK,
|
REINGEST_CHUNK,
|
||||||
|
REINGEST_ITEM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -25,6 +27,7 @@ __all__ = [
|
|||||||
"comic",
|
"comic",
|
||||||
"blogs",
|
"blogs",
|
||||||
"ebook",
|
"ebook",
|
||||||
|
"forums",
|
||||||
"SYNC_WEBPAGE",
|
"SYNC_WEBPAGE",
|
||||||
"SYNC_ARTICLE_FEED",
|
"SYNC_ARTICLE_FEED",
|
||||||
"SYNC_ALL_ARTICLE_FEEDS",
|
"SYNC_ALL_ARTICLE_FEEDS",
|
||||||
@ -34,10 +37,13 @@ __all__ = [
|
|||||||
"SYNC_XKCD",
|
"SYNC_XKCD",
|
||||||
"SYNC_BOOK",
|
"SYNC_BOOK",
|
||||||
"SYNC_ACCOUNT",
|
"SYNC_ACCOUNT",
|
||||||
|
"SYNC_LESSWRONG",
|
||||||
|
"SYNC_LESSWRONG_POST",
|
||||||
"SYNC_ALL_ACCOUNTS",
|
"SYNC_ALL_ACCOUNTS",
|
||||||
"PROCESS_EMAIL",
|
"PROCESS_EMAIL",
|
||||||
"CLEAN_ALL_COLLECTIONS",
|
"CLEAN_ALL_COLLECTIONS",
|
||||||
"CLEAN_COLLECTION",
|
"CLEAN_COLLECTION",
|
||||||
"REINGEST_MISSING_CHUNKS",
|
"REINGEST_MISSING_CHUNKS",
|
||||||
"REINGEST_CHUNK",
|
"REINGEST_CHUNK",
|
||||||
|
"REINGEST_ITEM",
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Iterable, cast
|
from typing import Iterable, cast
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
@ -7,7 +7,7 @@ from memory.common.db.models import ArticleFeed, BlogPost
|
|||||||
from memory.parsers.blogs import parse_webpage
|
from memory.parsers.blogs import parse_webpage
|
||||||
from memory.parsers.feeds import get_feed_parser
|
from memory.parsers.feeds import get_feed_parser
|
||||||
from memory.parsers.archives import get_archive_fetcher
|
from memory.parsers.archives import get_archive_fetcher
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app, BLOGS_ROOT
|
||||||
from memory.workers.tasks.content_processing import (
|
from memory.workers.tasks.content_processing import (
|
||||||
check_content_exists,
|
check_content_exists,
|
||||||
create_content_hash,
|
create_content_hash,
|
||||||
@ -18,10 +18,10 @@ from memory.workers.tasks.content_processing import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage"
|
SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage"
|
||||||
SYNC_ARTICLE_FEED = "memory.workers.tasks.blogs.sync_article_feed"
|
SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed"
|
||||||
SYNC_ALL_ARTICLE_FEEDS = "memory.workers.tasks.blogs.sync_all_article_feeds"
|
SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
|
||||||
SYNC_WEBSITE_ARCHIVE = "memory.workers.tasks.blogs.sync_website_archive"
|
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_WEBPAGE)
|
@app.task(name=SYNC_WEBPAGE)
|
||||||
@ -71,7 +71,7 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
|
|||||||
logger.info(f"Blog post already exists: {existing_post.title}")
|
logger.info(f"Blog post already exists: {existing_post.title}")
|
||||||
return create_task_result(existing_post, "already_exists", url=article.url)
|
return create_task_result(existing_post, "already_exists", url=article.url)
|
||||||
|
|
||||||
return process_content_item(blog_post, "blog", session, tags)
|
return process_content_item(blog_post, session)
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_ARTICLE_FEED)
|
@app.task(name=SYNC_ARTICLE_FEED)
|
||||||
@ -93,8 +93,8 @@ def sync_article_feed(feed_id: int) -> dict:
|
|||||||
return {"status": "error", "error": "Feed not found or inactive"}
|
return {"status": "error", "error": "Feed not found or inactive"}
|
||||||
|
|
||||||
last_checked_at = cast(datetime | None, feed.last_checked_at)
|
last_checked_at = cast(datetime | None, feed.last_checked_at)
|
||||||
if last_checked_at and datetime.now() - last_checked_at < timedelta(
|
if last_checked_at and datetime.now(timezone.utc) - last_checked_at < timedelta(
|
||||||
seconds=cast(int, feed.check_interval)
|
minutes=cast(int, feed.check_interval)
|
||||||
):
|
):
|
||||||
logger.info(f"Feed {feed_id} checked too recently, skipping")
|
logger.info(f"Feed {feed_id} checked too recently, skipping")
|
||||||
return {"status": "skipped_recent_check", "feed_id": feed_id}
|
return {"status": "skipped_recent_check", "feed_id": feed_id}
|
||||||
@ -129,7 +129,7 @@ def sync_article_feed(feed_id: int) -> dict:
|
|||||||
logger.error(f"Error parsing feed {feed.url}: {e}")
|
logger.error(f"Error parsing feed {feed.url}: {e}")
|
||||||
errors += 1
|
errors += 1
|
||||||
|
|
||||||
feed.last_checked_at = datetime.now() # type: ignore
|
feed.last_checked_at = datetime.now(timezone.utc) # type: ignore
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -9,7 +9,7 @@ from memory.common import settings
|
|||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Comic, clean_filename
|
from memory.common.db.models import Comic, clean_filename
|
||||||
from memory.parsers import comics
|
from memory.parsers import comics
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app, COMIC_ROOT
|
||||||
from memory.workers.tasks.content_processing import (
|
from memory.workers.tasks.content_processing import (
|
||||||
check_content_exists,
|
check_content_exists,
|
||||||
create_content_hash,
|
create_content_hash,
|
||||||
@ -19,10 +19,10 @@ from memory.workers.tasks.content_processing import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics"
|
SYNC_ALL_COMICS = f"{COMIC_ROOT}.sync_all_comics"
|
||||||
SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc"
|
SYNC_SMBC = f"{COMIC_ROOT}.sync_smbc"
|
||||||
SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd"
|
SYNC_XKCD = f"{COMIC_ROOT}.sync_xkcd"
|
||||||
SYNC_COMIC = "memory.workers.tasks.comic.sync_comic"
|
SYNC_COMIC = f"{COMIC_ROOT}.sync_comic"
|
||||||
|
|
||||||
BASE_SMBC_URL = "https://www.smbc-comics.com/"
|
BASE_SMBC_URL = "https://www.smbc-comics.com/"
|
||||||
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
|
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
|
||||||
@ -109,7 +109,7 @@ def sync_comic(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
return process_content_item(comic, "comic", session)
|
return process_content_item(comic, session)
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_SMBC)
|
@app.task(name=SYNC_SMBC)
|
||||||
|
@ -99,6 +99,7 @@ def embed_source_item(source_item: SourceItem) -> int:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
source_item.embed_status = "FAILED" # type: ignore
|
source_item.embed_status = "FAILED" # type: ignore
|
||||||
logger.error(f"Failed to embed {type(source_item).__name__}: {e}")
|
logger.error(f"Failed to embed {type(source_item).__name__}: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@ -156,6 +157,7 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
|||||||
for item in items_to_process:
|
for item in items_to_process:
|
||||||
item.embed_status = "FAILED" # type: ignore
|
item.embed_status = "FAILED" # type: ignore
|
||||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@ -186,9 +188,7 @@ def create_task_result(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def process_content_item(
|
def process_content_item(item: SourceItem, session) -> dict[str, Any]:
|
||||||
item: SourceItem, collection_name: str, session, tags: Iterable[str] = []
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
Execute complete content processing workflow.
|
Execute complete content processing workflow.
|
||||||
|
|
||||||
@ -200,7 +200,6 @@ def process_content_item(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: SourceItem to process
|
item: SourceItem to process
|
||||||
collection_name: Qdrant collection name for vector storage
|
|
||||||
session: Database session for persistence
|
session: Database session for persistence
|
||||||
tags: Optional tags to associate with the item (currently unused)
|
tags: Optional tags to associate with the item (currently unused)
|
||||||
|
|
||||||
@ -223,7 +222,7 @@ def process_content_item(
|
|||||||
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
push_to_qdrant([item], collection_name)
|
push_to_qdrant([item], cast(str, item.modality))
|
||||||
status = "processed"
|
status = "processed"
|
||||||
item.embed_status = "STORED" # type: ignore
|
item.embed_status = "STORED" # type: ignore
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -231,6 +230,7 @@ def process_content_item(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
item.embed_status = "FAILED" # type: ignore
|
item.embed_status = "FAILED" # type: ignore
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
@ -261,10 +261,8 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
|||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
|
||||||
f"Task {func.__name__} failed with traceback:\n{traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
logger.error(f"Task {func.__name__} failed: {e}")
|
logger.error(f"Task {func.__name__} failed: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
return {"status": "error", "error": str(e)}
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -6,7 +6,7 @@ import memory.common.settings as settings
|
|||||||
from memory.parsers.ebook import Ebook, parse_ebook, Section
|
from memory.parsers.ebook import Ebook, parse_ebook, Section
|
||||||
from memory.common.db.models import Book, BookSection
|
from memory.common.db.models import Book, BookSection
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app, EBOOK_ROOT
|
||||||
from memory.workers.tasks.content_processing import (
|
from memory.workers.tasks.content_processing import (
|
||||||
check_content_exists,
|
check_content_exists,
|
||||||
create_content_hash,
|
create_content_hash,
|
||||||
@ -17,7 +17,7 @@ from memory.workers.tasks.content_processing import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SYNC_BOOK = "memory.workers.tasks.ebook.sync_book"
|
SYNC_BOOK = f"{EBOOK_ROOT}.sync_book"
|
||||||
|
|
||||||
# Minimum section length to embed (avoid noise from very short sections)
|
# Minimum section length to embed (avoid noise from very short sections)
|
||||||
MIN_SECTION_LENGTH = 100
|
MIN_SECTION_LENGTH = 100
|
||||||
|
@ -3,7 +3,7 @@ from datetime import datetime
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import EmailAccount, MailMessage
|
from memory.common.db.models import EmailAccount, MailMessage
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app, EMAIL_ROOT
|
||||||
from memory.workers.email import (
|
from memory.workers.email import (
|
||||||
create_mail_message,
|
create_mail_message,
|
||||||
imap_connection,
|
imap_connection,
|
||||||
@ -18,9 +18,9 @@ from memory.workers.tasks.content_processing import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PROCESS_EMAIL = "memory.workers.tasks.email.process_message"
|
PROCESS_EMAIL = f"{EMAIL_ROOT}.process_message"
|
||||||
SYNC_ACCOUNT = "memory.workers.tasks.email.sync_account"
|
SYNC_ACCOUNT = f"{EMAIL_ROOT}.sync_account"
|
||||||
SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
|
SYNC_ALL_ACCOUNTS = f"{EMAIL_ROOT}.sync_all_accounts"
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=PROCESS_EMAIL)
|
@app.task(name=PROCESS_EMAIL)
|
||||||
|
85
src/memory/workers/tasks/forums.py
Normal file
85
src/memory/workers/tasks/forums.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from memory.parsers.lesswrong import fetch_lesswrong_posts, LessWrongPost
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import ForumPost
|
||||||
|
from memory.workers.celery_app import app, FORUMS_ROOT
|
||||||
|
from memory.workers.tasks.content_processing import (
|
||||||
|
check_content_exists,
|
||||||
|
create_content_hash,
|
||||||
|
create_task_result,
|
||||||
|
process_content_item,
|
||||||
|
safe_task_execution,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SYNC_LESSWRONG = f"{FORUMS_ROOT}.sync_lesswrong"
|
||||||
|
SYNC_LESSWRONG_POST = f"{FORUMS_ROOT}.sync_lesswrong_post"
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=SYNC_LESSWRONG_POST)
|
||||||
|
@safe_task_execution
|
||||||
|
def sync_lesswrong_post(
|
||||||
|
post: LessWrongPost,
|
||||||
|
tags: list[str] = [],
|
||||||
|
):
|
||||||
|
logger.info(f"Syncing LessWrong post {post['url']}")
|
||||||
|
sha256 = create_content_hash(post["content"])
|
||||||
|
|
||||||
|
post["tags"] = list(set(post["tags"] + tags))
|
||||||
|
post_obj = ForumPost(
|
||||||
|
embed_status="RAW",
|
||||||
|
size=len(post["content"].encode("utf-8")),
|
||||||
|
modality="forum",
|
||||||
|
mime_type="text/markdown",
|
||||||
|
sha256=sha256,
|
||||||
|
**{k: v for k, v in post.items() if hasattr(ForumPost, k)},
|
||||||
|
)
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
existing_post = check_content_exists(
|
||||||
|
session, ForumPost, url=post_obj.url, sha256=sha256
|
||||||
|
)
|
||||||
|
if existing_post:
|
||||||
|
logger.info(f"LessWrong post already exists: {existing_post.title}")
|
||||||
|
return create_task_result(existing_post, "already_exists", url=post_obj.url)
|
||||||
|
|
||||||
|
return process_content_item(post_obj, session)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=SYNC_LESSWRONG)
|
||||||
|
@safe_task_execution
|
||||||
|
def sync_lesswrong(
|
||||||
|
since: str = (datetime.now() - timedelta(days=30)).isoformat(),
|
||||||
|
min_karma: int = 10,
|
||||||
|
limit: int = 50,
|
||||||
|
cooldown: float = 0.5,
|
||||||
|
max_items: int = 1000,
|
||||||
|
af: bool = False,
|
||||||
|
tags: list[str] = [],
|
||||||
|
):
|
||||||
|
logger.info(f"Syncing LessWrong posts since {since}")
|
||||||
|
start_date = datetime.fromisoformat(since)
|
||||||
|
posts = fetch_lesswrong_posts(start_date, min_karma, limit, cooldown, max_items, af)
|
||||||
|
|
||||||
|
posts_num, new_posts = 0, 0
|
||||||
|
with make_session() as session:
|
||||||
|
for post in posts:
|
||||||
|
if not check_content_exists(session, ForumPost, url=post["url"]):
|
||||||
|
new_posts += 1
|
||||||
|
sync_lesswrong_post.delay(post, tags)
|
||||||
|
|
||||||
|
if posts_num >= max_items:
|
||||||
|
break
|
||||||
|
posts_num += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"posts_num": posts_num,
|
||||||
|
"new_posts": new_posts,
|
||||||
|
"since": since,
|
||||||
|
"min_karma": min_karma,
|
||||||
|
"max_items": max_items,
|
||||||
|
"af": af,
|
||||||
|
}
|
@ -3,21 +3,24 @@ from collections import defaultdict
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
|
from memory.workers.tasks.content_processing import process_content_item
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import contains_eager
|
from sqlalchemy.orm import contains_eager
|
||||||
|
|
||||||
from memory.common import collections, embedding, qdrant, settings
|
from memory.common import collections, embedding, qdrant, settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app, MAINTENANCE_ROOT
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
CLEAN_ALL_COLLECTIONS = "memory.workers.tasks.maintenance.clean_all_collections"
|
CLEAN_ALL_COLLECTIONS = f"{MAINTENANCE_ROOT}.clean_all_collections"
|
||||||
CLEAN_COLLECTION = "memory.workers.tasks.maintenance.clean_collection"
|
CLEAN_COLLECTION = f"{MAINTENANCE_ROOT}.clean_collection"
|
||||||
REINGEST_MISSING_CHUNKS = "memory.workers.tasks.maintenance.reingest_missing_chunks"
|
REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks"
|
||||||
REINGEST_CHUNK = "memory.workers.tasks.maintenance.reingest_chunk"
|
REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk"
|
||||||
|
REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item"
|
||||||
|
REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items"
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=CLEAN_COLLECTION)
|
@app.task(name=CLEAN_COLLECTION)
|
||||||
@ -87,6 +90,61 @@ def reingest_chunk(chunk_id: str, collection: str):
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def get_item_class(item_type: str):
|
||||||
|
class_ = SourceItem.registry._class_registry.get(item_type)
|
||||||
|
if not class_:
|
||||||
|
available_types = ", ".join(sorted(SourceItem.registry._class_registry.keys()))
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported item type {item_type}. Available types: {available_types}"
|
||||||
|
)
|
||||||
|
return class_
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=REINGEST_ITEM)
|
||||||
|
def reingest_item(item_id: str, item_type: str):
|
||||||
|
logger.info(f"Reingesting {item_type} {item_id}")
|
||||||
|
try:
|
||||||
|
class_ = get_item_class(item_type)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error getting item class: {e}")
|
||||||
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
item = session.query(class_).get(item_id)
|
||||||
|
if not item:
|
||||||
|
return {"status": "error", "error": f"Item {item_id} not found"}
|
||||||
|
|
||||||
|
chunk_ids = [str(c.id) for c in item.chunks if c.id]
|
||||||
|
if chunk_ids:
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
qdrant.delete_points(client, item.modality, chunk_ids)
|
||||||
|
|
||||||
|
for chunk in item.chunks:
|
||||||
|
session.delete(chunk)
|
||||||
|
|
||||||
|
return process_content_item(item, session)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=REINGEST_EMPTY_SOURCE_ITEMS)
|
||||||
|
def reingest_empty_source_items(item_type: str):
|
||||||
|
logger.info("Reingesting empty source items")
|
||||||
|
try:
|
||||||
|
class_ = get_item_class(item_type)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error getting item class: {e}")
|
||||||
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
item_ids = session.query(class_.id).filter(~class_.chunks.any()).all()
|
||||||
|
|
||||||
|
logger.info(f"Found {len(item_ids)} items to reingest")
|
||||||
|
|
||||||
|
for item_id in item_ids:
|
||||||
|
reingest_item.delay(item_id.id, item_type) # type: ignore
|
||||||
|
|
||||||
|
return {"status": "success", "items": len(item_ids)}
|
||||||
|
|
||||||
|
|
||||||
def check_batch(batch: Sequence[Chunk]) -> dict:
|
def check_batch(batch: Sequence[Chunk]) -> dict:
|
||||||
client = qdrant.get_qdrant_client()
|
client = qdrant.get_qdrant_client()
|
||||||
by_collection = defaultdict(list)
|
by_collection = defaultdict(list)
|
||||||
@ -116,14 +174,19 @@ def check_batch(batch: Sequence[Chunk]) -> dict:
|
|||||||
|
|
||||||
@app.task(name=REINGEST_MISSING_CHUNKS)
|
@app.task(name=REINGEST_MISSING_CHUNKS)
|
||||||
def reingest_missing_chunks(
|
def reingest_missing_chunks(
|
||||||
batch_size: int = 1000, minutes_ago: int = settings.CHUNK_REINGEST_SINCE_MINUTES
|
batch_size: int = 1000,
|
||||||
|
collection: str | None = None,
|
||||||
|
minutes_ago: int = settings.CHUNK_REINGEST_SINCE_MINUTES,
|
||||||
):
|
):
|
||||||
logger.info("Reingesting missing chunks")
|
logger.info("Reingesting missing chunks")
|
||||||
total_stats = defaultdict(lambda: {"missing": 0, "correct": 0, "total": 0})
|
total_stats = defaultdict(lambda: {"missing": 0, "correct": 0, "total": 0})
|
||||||
since = datetime.now() - timedelta(minutes=minutes_ago)
|
since = datetime.now() - timedelta(minutes=minutes_ago)
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
total_count = session.query(Chunk).filter(Chunk.checked_at < since).count()
|
query = session.query(Chunk).filter(Chunk.checked_at < since)
|
||||||
|
if collection:
|
||||||
|
query = query.filter(Chunk.source.has(SourceItem.modality == collection))
|
||||||
|
total_count = query.count()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Found {total_count} chunks to check, processing in batches of {batch_size}"
|
f"Found {total_count} chunks to check, processing in batches of {batch_size}"
|
||||||
|
542
tests/memory/parsers/test_lesswrong.py
Normal file
542
tests/memory/parsers/test_lesswrong.py
Normal file
@ -0,0 +1,542 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import MagicMock, patch, Mock
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from memory.parsers.lesswrong import (
|
||||||
|
LessWrongPost,
|
||||||
|
make_graphql_query,
|
||||||
|
fetch_posts_from_api,
|
||||||
|
is_valid_post,
|
||||||
|
extract_authors,
|
||||||
|
extract_description,
|
||||||
|
parse_lesswrong_date,
|
||||||
|
extract_body,
|
||||||
|
format_post,
|
||||||
|
fetch_lesswrong,
|
||||||
|
fetch_lesswrong_posts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"after, af, limit, min_karma, expected_contains",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
datetime(2023, 1, 15, 10, 30),
|
||||||
|
False,
|
||||||
|
50,
|
||||||
|
10,
|
||||||
|
[
|
||||||
|
"af: false",
|
||||||
|
"limit: 50",
|
||||||
|
"karmaThreshold: 10",
|
||||||
|
'after: "2023-01-15T10:30:00Z"',
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
datetime(2023, 2, 20),
|
||||||
|
True,
|
||||||
|
25,
|
||||||
|
5,
|
||||||
|
[
|
||||||
|
"af: true",
|
||||||
|
"limit: 25",
|
||||||
|
"karmaThreshold: 5",
|
||||||
|
'after: "2023-02-20T00:00:00Z"',
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_make_graphql_query(after, af, limit, min_karma, expected_contains):
|
||||||
|
query = make_graphql_query(after, af, limit, min_karma)
|
||||||
|
|
||||||
|
for expected in expected_contains:
|
||||||
|
assert expected in query
|
||||||
|
|
||||||
|
# Check that all required fields are present
|
||||||
|
required_fields = [
|
||||||
|
"_id",
|
||||||
|
"title",
|
||||||
|
"slug",
|
||||||
|
"pageUrl",
|
||||||
|
"postedAt",
|
||||||
|
"modifiedAt",
|
||||||
|
"score",
|
||||||
|
"extendedScore",
|
||||||
|
"baseScore",
|
||||||
|
"voteCount",
|
||||||
|
"commentCount",
|
||||||
|
"wordCount",
|
||||||
|
"tags",
|
||||||
|
"user",
|
||||||
|
"coauthors",
|
||||||
|
"af",
|
||||||
|
"htmlBody",
|
||||||
|
]
|
||||||
|
for field in required_fields:
|
||||||
|
assert field in query
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.requests.post")
|
||||||
|
def test_fetch_posts_from_api_success(mock_post):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"data": {
|
||||||
|
"posts": {
|
||||||
|
"totalCount": 2,
|
||||||
|
"results": [
|
||||||
|
{"_id": "1", "title": "Post 1"},
|
||||||
|
{"_id": "2", "title": "Post 2"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
url = "https://www.lesswrong.com/graphql"
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
result = fetch_posts_from_api(url, query)
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0"
|
||||||
|
},
|
||||||
|
json={"query": query},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"totalCount": 2,
|
||||||
|
"results": [
|
||||||
|
{"_id": "1", "title": "Post 1"},
|
||||||
|
{"_id": "2", "title": "Post 2"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.requests.post")
|
||||||
|
def test_fetch_posts_from_api_http_error(mock_post):
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.raise_for_status.side_effect = Exception("HTTP Error")
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="HTTP Error"):
|
||||||
|
fetch_posts_from_api("https://example.com", "query")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"post_data, min_karma, expected",
|
||||||
|
[
|
||||||
|
# Valid post with content and karma
|
||||||
|
({"htmlBody": "<p>Content</p>", "baseScore": 15}, 10, True),
|
||||||
|
# Valid post at karma threshold
|
||||||
|
({"htmlBody": "<p>Content</p>", "baseScore": 10}, 10, True),
|
||||||
|
# Invalid: no content
|
||||||
|
({"htmlBody": "", "baseScore": 15}, 10, False),
|
||||||
|
({"htmlBody": None, "baseScore": 15}, 10, False),
|
||||||
|
({}, 10, False),
|
||||||
|
# Invalid: below karma threshold
|
||||||
|
({"htmlBody": "<p>Content</p>", "baseScore": 5}, 10, False),
|
||||||
|
({"htmlBody": "<p>Content</p>"}, 10, False), # No baseScore
|
||||||
|
# Edge cases
|
||||||
|
(
|
||||||
|
{"htmlBody": " ", "baseScore": 15},
|
||||||
|
10,
|
||||||
|
True,
|
||||||
|
), # Whitespace only - actually valid
|
||||||
|
({"htmlBody": "<p>Content</p>", "baseScore": 0}, 0, True), # Zero threshold
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_valid_post(post_data, min_karma, expected):
|
||||||
|
assert is_valid_post(post_data, min_karma) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"post_data, expected",
|
||||||
|
[
|
||||||
|
# User only
|
||||||
|
(
|
||||||
|
{"user": {"displayName": "Alice"}},
|
||||||
|
["Alice"],
|
||||||
|
),
|
||||||
|
# User with coauthors
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"user": {"displayName": "Alice"},
|
||||||
|
"coauthors": [{"displayName": "Bob"}, {"displayName": "Charlie"}],
|
||||||
|
},
|
||||||
|
["Alice", "Bob", "Charlie"],
|
||||||
|
),
|
||||||
|
# Coauthors only (no user)
|
||||||
|
(
|
||||||
|
{"coauthors": [{"displayName": "Bob"}]},
|
||||||
|
["Bob"],
|
||||||
|
),
|
||||||
|
# Empty coauthors list
|
||||||
|
(
|
||||||
|
{"user": {"displayName": "Alice"}, "coauthors": []},
|
||||||
|
["Alice"],
|
||||||
|
),
|
||||||
|
# No authors at all
|
||||||
|
({}, ["anonymous"]),
|
||||||
|
({"user": None, "coauthors": None}, ["anonymous"]),
|
||||||
|
({"user": None, "coauthors": []}, ["anonymous"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_authors(post_data, expected):
|
||||||
|
assert extract_authors(post_data) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"body, expected",
|
||||||
|
[
|
||||||
|
# Short content
|
||||||
|
("This is a short paragraph.", "This is a short paragraph."),
|
||||||
|
# Multiple paragraphs - only first
|
||||||
|
("First paragraph.\n\nSecond paragraph.", "First paragraph."),
|
||||||
|
# Long content - truncated
|
||||||
|
(
|
||||||
|
"A" * 350,
|
||||||
|
"A" * 300 + "...",
|
||||||
|
),
|
||||||
|
# Empty content
|
||||||
|
("", ""),
|
||||||
|
# Whitespace only
|
||||||
|
(" \n\n ", " "),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_description(body, expected):
|
||||||
|
assert extract_description(body) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"date_str, expected",
|
||||||
|
[
|
||||||
|
# Standard ISO formats
|
||||||
|
("2023-01-15T10:30:00.000Z", datetime(2023, 1, 15, 10, 30, 0)),
|
||||||
|
("2023-01-15T10:30:00Z", datetime(2023, 1, 15, 10, 30, 0)),
|
||||||
|
("2023-01-15T10:30:00.000", datetime(2023, 1, 15, 10, 30, 0)),
|
||||||
|
("2023-01-15T10:30:00", datetime(2023, 1, 15, 10, 30, 0)),
|
||||||
|
# Fallback to fromisoformat
|
||||||
|
("2023-01-15T10:30:00.123456", datetime(2023, 1, 15, 10, 30, 0, 123456)),
|
||||||
|
# Invalid dates
|
||||||
|
("invalid-date", None),
|
||||||
|
("", None),
|
||||||
|
(None, None),
|
||||||
|
("2023-13-45T25:70:70Z", None), # Invalid date components
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_lesswrong_date(date_str, expected):
|
||||||
|
assert parse_lesswrong_date(date_str) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.process_images")
|
||||||
|
@patch("memory.parsers.lesswrong.markdownify")
|
||||||
|
@patch("memory.parsers.lesswrong.settings")
|
||||||
|
def test_extract_body(mock_settings, mock_markdownify, mock_process_images):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
mock_settings.FILE_STORAGE_DIR = Path("/tmp")
|
||||||
|
mock_markdownify.return_value = "# Markdown content"
|
||||||
|
mock_images = {"image1.jpg": Mock(spec=PILImage.Image)}
|
||||||
|
mock_process_images.return_value = (Mock(), mock_images)
|
||||||
|
|
||||||
|
post = {
|
||||||
|
"htmlBody": "<h1>HTML content</h1>",
|
||||||
|
"pageUrl": "https://lesswrong.com/posts/abc123/test-post",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, images = extract_body(post)
|
||||||
|
|
||||||
|
assert body == "# Markdown content"
|
||||||
|
assert images == mock_images
|
||||||
|
mock_process_images.assert_called_once()
|
||||||
|
mock_markdownify.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.process_images")
|
||||||
|
def test_extract_body_empty_content(mock_process_images):
|
||||||
|
post = {"htmlBody": ""}
|
||||||
|
body, images = extract_body(post)
|
||||||
|
|
||||||
|
assert body == ""
|
||||||
|
assert images == {}
|
||||||
|
mock_process_images.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.extract_body")
|
||||||
|
def test_format_post(mock_extract_body):
|
||||||
|
mock_extract_body.return_value = ("Markdown body", {"img1.jpg": Mock()})
|
||||||
|
|
||||||
|
post_data = {
|
||||||
|
"_id": "abc123",
|
||||||
|
"title": "Test Post",
|
||||||
|
"slug": "test-post",
|
||||||
|
"pageUrl": "https://lesswrong.com/posts/abc123/test-post",
|
||||||
|
"postedAt": "2023-01-15T10:30:00Z",
|
||||||
|
"modifiedAt": "2023-01-16T11:00:00Z",
|
||||||
|
"score": 25,
|
||||||
|
"extendedScore": 30,
|
||||||
|
"baseScore": 20,
|
||||||
|
"voteCount": 15,
|
||||||
|
"commentCount": 5,
|
||||||
|
"wordCount": 1000,
|
||||||
|
"tags": [{"name": "AI"}, {"name": "Rationality"}],
|
||||||
|
"user": {"displayName": "Author"},
|
||||||
|
"coauthors": [{"displayName": "Coauthor"}],
|
||||||
|
"af": True,
|
||||||
|
"htmlBody": "<p>HTML content</p>",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = format_post(post_data)
|
||||||
|
|
||||||
|
expected: LessWrongPost = {
|
||||||
|
"title": "Test Post",
|
||||||
|
"url": "https://lesswrong.com/posts/abc123/test-post",
|
||||||
|
"description": "Markdown body",
|
||||||
|
"content": "Markdown body",
|
||||||
|
"authors": ["Author", "Coauthor"],
|
||||||
|
"published_at": datetime(2023, 1, 15, 10, 30, 0),
|
||||||
|
"guid": "abc123",
|
||||||
|
"karma": 20,
|
||||||
|
"votes": 15,
|
||||||
|
"comments": 5,
|
||||||
|
"words": 1000,
|
||||||
|
"tags": ["AI", "Rationality"],
|
||||||
|
"af": True,
|
||||||
|
"score": 25,
|
||||||
|
"extended_score": 30,
|
||||||
|
"modified_at": "2023-01-16T11:00:00Z",
|
||||||
|
"slug": "test-post",
|
||||||
|
"images": ["img1.jpg"],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.extract_body")
|
||||||
|
def test_format_post_minimal_data(mock_extract_body):
|
||||||
|
mock_extract_body.return_value = ("", {})
|
||||||
|
|
||||||
|
post_data = {}
|
||||||
|
|
||||||
|
result = format_post(post_data)
|
||||||
|
|
||||||
|
expected: LessWrongPost = {
|
||||||
|
"title": "Untitled",
|
||||||
|
"url": "",
|
||||||
|
"description": "",
|
||||||
|
"content": "",
|
||||||
|
"authors": ["anonymous"],
|
||||||
|
"published_at": None,
|
||||||
|
"guid": None,
|
||||||
|
"karma": 0,
|
||||||
|
"votes": 0,
|
||||||
|
"comments": 0,
|
||||||
|
"words": 0,
|
||||||
|
"tags": [],
|
||||||
|
"af": False,
|
||||||
|
"score": 0,
|
||||||
|
"extended_score": 0,
|
||||||
|
"modified_at": None,
|
||||||
|
"slug": None,
|
||||||
|
"images": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.fetch_posts_from_api")
|
||||||
|
@patch("memory.parsers.lesswrong.make_graphql_query")
|
||||||
|
@patch("memory.parsers.lesswrong.format_post")
|
||||||
|
@patch("memory.parsers.lesswrong.is_valid_post")
|
||||||
|
def test_fetch_lesswrong_success(
|
||||||
|
mock_is_valid, mock_format, mock_query, mock_fetch_api
|
||||||
|
):
|
||||||
|
mock_query.return_value = "test query"
|
||||||
|
mock_fetch_api.return_value = {
|
||||||
|
"results": [
|
||||||
|
{"_id": "1", "title": "Post 1"},
|
||||||
|
{"_id": "2", "title": "Post 2"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_is_valid.side_effect = [True, False] # First valid, second invalid
|
||||||
|
mock_format.return_value = {"title": "Formatted Post"}
|
||||||
|
|
||||||
|
url = "https://lesswrong.com/graphql"
|
||||||
|
current_date = datetime(2023, 1, 15)
|
||||||
|
|
||||||
|
result = fetch_lesswrong(url, current_date, af=True, min_karma=5, limit=25)
|
||||||
|
|
||||||
|
mock_query.assert_called_once_with(current_date, True, 25, 5)
|
||||||
|
mock_fetch_api.assert_called_once_with(url, "test query")
|
||||||
|
assert mock_is_valid.call_count == 2
|
||||||
|
mock_format.assert_called_once_with({"_id": "1", "title": "Post 1"})
|
||||||
|
assert result == [{"title": "Formatted Post"}]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.fetch_posts_from_api")
|
||||||
|
def test_fetch_lesswrong_empty_results(mock_fetch_api):
|
||||||
|
mock_fetch_api.return_value = {"results": []}
|
||||||
|
|
||||||
|
result = fetch_lesswrong("url", datetime.now())
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.fetch_posts_from_api")
|
||||||
|
def test_fetch_lesswrong_same_item_as_last(mock_fetch_api):
|
||||||
|
mock_fetch_api.return_value = {
|
||||||
|
"results": [{"pageUrl": "https://lesswrong.com/posts/same"}]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = fetch_lesswrong(
|
||||||
|
"url", datetime.now(), last_url="https://lesswrong.com/posts/same"
|
||||||
|
)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||||
|
@patch("memory.parsers.lesswrong.time.sleep")
|
||||||
|
def test_fetch_lesswrong_posts_success(mock_sleep, mock_fetch):
|
||||||
|
since = datetime(2023, 1, 15)
|
||||||
|
|
||||||
|
# Mock three batches of posts
|
||||||
|
mock_fetch.side_effect = [
|
||||||
|
[
|
||||||
|
{"published_at": datetime(2023, 1, 14), "url": "post1"},
|
||||||
|
{"published_at": datetime(2023, 1, 13), "url": "post2"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"published_at": datetime(2023, 1, 12), "url": "post3"},
|
||||||
|
],
|
||||||
|
[], # Empty result to stop iteration
|
||||||
|
]
|
||||||
|
|
||||||
|
posts = list(
|
||||||
|
fetch_lesswrong_posts(
|
||||||
|
since=since,
|
||||||
|
min_karma=10,
|
||||||
|
limit=50,
|
||||||
|
cooldown=0.1,
|
||||||
|
max_items=100,
|
||||||
|
af=False,
|
||||||
|
url="https://lesswrong.com/graphql",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(posts) == 3
|
||||||
|
assert posts[0]["url"] == "post1"
|
||||||
|
assert posts[1]["url"] == "post2"
|
||||||
|
assert posts[2]["url"] == "post3"
|
||||||
|
|
||||||
|
# Should have called sleep twice (after first two batches)
|
||||||
|
assert mock_sleep.call_count == 2
|
||||||
|
mock_sleep.assert_called_with(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||||
|
def test_fetch_lesswrong_posts_default_since(mock_fetch):
|
||||||
|
mock_fetch.return_value = []
|
||||||
|
|
||||||
|
with patch("memory.parsers.lesswrong.datetime") as mock_datetime:
|
||||||
|
mock_now = datetime(2023, 1, 15, 12, 0, 0)
|
||||||
|
mock_datetime.now.return_value = mock_now
|
||||||
|
|
||||||
|
list(fetch_lesswrong_posts())
|
||||||
|
|
||||||
|
# Should use yesterday as default
|
||||||
|
expected_since = mock_now - timedelta(days=1)
|
||||||
|
mock_fetch.assert_called_with(
|
||||||
|
"https://www.lesswrong.com/graphql", expected_since, False, 10, 50, None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||||
|
def test_fetch_lesswrong_posts_max_items_limit(mock_fetch):
|
||||||
|
# Return posts that would exceed max_items
|
||||||
|
mock_fetch.side_effect = [
|
||||||
|
[{"published_at": datetime(2023, 1, 14), "url": f"post{i}"} for i in range(8)],
|
||||||
|
[
|
||||||
|
{"published_at": datetime(2023, 1, 13), "url": f"post{i}"}
|
||||||
|
for i in range(8, 16)
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
posts = list(
|
||||||
|
fetch_lesswrong_posts(
|
||||||
|
since=datetime(2023, 1, 15),
|
||||||
|
max_items=7, # Should stop after 7 items
|
||||||
|
cooldown=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# The logic checks items_count < max_items before fetching, so it will fetch the first batch
|
||||||
|
# Since items_count (8) >= max_items (7) after first batch, it won't fetch the second batch
|
||||||
|
assert len(posts) == 8
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||||
|
def test_fetch_lesswrong_posts_api_error(mock_fetch):
|
||||||
|
mock_fetch.side_effect = Exception("API Error")
|
||||||
|
|
||||||
|
posts = list(fetch_lesswrong_posts(since=datetime(2023, 1, 15)))
|
||||||
|
assert posts == []
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||||
|
def test_fetch_lesswrong_posts_no_date_progression(mock_fetch):
|
||||||
|
# Mock posts with same date to trigger stopping condition
|
||||||
|
same_date = datetime(2023, 1, 15)
|
||||||
|
mock_fetch.side_effect = [
|
||||||
|
[{"published_at": same_date, "url": "post1"}],
|
||||||
|
[{"published_at": same_date, "url": "post2"}], # Same date, should stop
|
||||||
|
]
|
||||||
|
|
||||||
|
posts = list(fetch_lesswrong_posts(since=same_date, cooldown=0))
|
||||||
|
|
||||||
|
assert len(posts) == 1 # Should stop after first batch
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||||
|
def test_fetch_lesswrong_posts_none_date(mock_fetch):
|
||||||
|
# Mock posts with None date to trigger stopping condition
|
||||||
|
mock_fetch.side_effect = [
|
||||||
|
[{"published_at": None, "url": "post1"}],
|
||||||
|
]
|
||||||
|
|
||||||
|
posts = list(fetch_lesswrong_posts(since=datetime(2023, 1, 15), cooldown=0))
|
||||||
|
|
||||||
|
assert len(posts) == 1 # Should stop after first batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_lesswrong_post_type():
|
||||||
|
"""Test that LessWrongPost TypedDict has correct structure."""
|
||||||
|
# This is more of a documentation test to ensure the type is correct
|
||||||
|
post: LessWrongPost = {
|
||||||
|
"title": "Test",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"description": "Description",
|
||||||
|
"content": "Content",
|
||||||
|
"authors": ["Author"],
|
||||||
|
"published_at": datetime.now(),
|
||||||
|
"guid": "123",
|
||||||
|
"karma": 10,
|
||||||
|
"votes": 5,
|
||||||
|
"comments": 2,
|
||||||
|
"words": 100,
|
||||||
|
"tags": ["tag"],
|
||||||
|
"af": False,
|
||||||
|
"score": 15,
|
||||||
|
"extended_score": 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Optional fields
|
||||||
|
post["modified_at"] = "2023-01-15T10:30:00Z"
|
||||||
|
post["slug"] = "test-slug"
|
||||||
|
post["images"] = ["image.jpg"]
|
||||||
|
|
||||||
|
# Should not raise any type errors
|
||||||
|
assert post["title"] == "Test"
|
@ -1,5 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from memory.common.db.models import ArticleFeed, BlogPost
|
from memory.common.db.models import ArticleFeed, BlogPost
|
||||||
@ -64,37 +64,6 @@ def inactive_article_feed(db_session):
|
|||||||
return feed
|
return feed
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def recently_checked_feed(db_session):
|
|
||||||
"""Create a recently checked ArticleFeed."""
|
|
||||||
from sqlalchemy import text
|
|
||||||
|
|
||||||
# Use a very recent timestamp that will trigger the "recently checked" condition
|
|
||||||
# The check_interval is 3600 seconds, so 30 seconds ago should be "recent"
|
|
||||||
recent_time = datetime.now() - timedelta(seconds=30)
|
|
||||||
|
|
||||||
feed = ArticleFeed(
|
|
||||||
url="https://example.com/recent.xml",
|
|
||||||
title="Recent Feed",
|
|
||||||
description="A recently checked feed",
|
|
||||||
tags=["test"],
|
|
||||||
check_interval=3600,
|
|
||||||
active=True,
|
|
||||||
)
|
|
||||||
db_session.add(feed)
|
|
||||||
db_session.flush() # Get the ID
|
|
||||||
|
|
||||||
# Manually set the last_checked_at to avoid timezone issues
|
|
||||||
db_session.execute(
|
|
||||||
text(
|
|
||||||
"UPDATE article_feeds SET last_checked_at = :timestamp WHERE id = :feed_id"
|
|
||||||
),
|
|
||||||
{"timestamp": recent_time, "feed_id": feed.id},
|
|
||||||
)
|
|
||||||
db_session.commit()
|
|
||||||
return feed
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_feed_item():
|
def mock_feed_item():
|
||||||
"""Mock feed item for testing."""
|
"""Mock feed item for testing."""
|
||||||
@ -564,3 +533,65 @@ def test_sync_website_archive_empty_results(
|
|||||||
assert result["articles_found"] == 0
|
assert result["articles_found"] == 0
|
||||||
assert result["new_articles"] == 0
|
assert result["new_articles"] == 0
|
||||||
assert result["task_ids"] == []
|
assert result["task_ids"] == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"check_interval_minutes,seconds_since_check,should_skip",
|
||||||
|
[
|
||||||
|
(60, 30, True), # 60min interval, checked 30s ago -> skip
|
||||||
|
(60, 3000, True), # 60min interval, checked 50min ago -> skip
|
||||||
|
(60, 4000, False), # 60min interval, checked 66min ago -> don't skip
|
||||||
|
(30, 1000, True), # 30min interval, checked 16min ago -> skip
|
||||||
|
(30, 2000, False), # 30min interval, checked 33min ago -> don't skip
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.workers.tasks.blogs.get_feed_parser")
|
||||||
|
def test_sync_article_feed_check_interval(
|
||||||
|
mock_get_parser,
|
||||||
|
check_interval_minutes,
|
||||||
|
seconds_since_check,
|
||||||
|
should_skip,
|
||||||
|
db_session,
|
||||||
|
):
|
||||||
|
"""Test sync respects check interval with various timing scenarios."""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
# Mock parser to return None (no parser available) for non-skipped cases
|
||||||
|
mock_get_parser.return_value = None
|
||||||
|
|
||||||
|
# Create feed with specific check interval
|
||||||
|
feed = ArticleFeed(
|
||||||
|
url="https://example.com/interval-test.xml",
|
||||||
|
title="Interval Test Feed",
|
||||||
|
description="Feed for testing check intervals",
|
||||||
|
tags=["test"],
|
||||||
|
check_interval=check_interval_minutes,
|
||||||
|
active=True,
|
||||||
|
)
|
||||||
|
db_session.add(feed)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Set last_checked_at to specific time in the past
|
||||||
|
last_checked_time = datetime.now(timezone.utc) - timedelta(
|
||||||
|
seconds=seconds_since_check
|
||||||
|
)
|
||||||
|
db_session.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE article_feeds SET last_checked_at = :timestamp WHERE id = :feed_id"
|
||||||
|
),
|
||||||
|
{"timestamp": last_checked_time, "feed_id": feed.id},
|
||||||
|
)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
result = blogs.sync_article_feed(feed.id)
|
||||||
|
|
||||||
|
if should_skip:
|
||||||
|
assert result == {"status": "skipped_recent_check", "feed_id": feed.id}
|
||||||
|
# get_feed_parser should not be called when skipping
|
||||||
|
mock_get_parser.assert_not_called()
|
||||||
|
else:
|
||||||
|
# Should proceed with sync, but will fail due to no parser - that's expected
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert result["error"] == "No parser available for feed"
|
||||||
|
# get_feed_parser should be called when not skipping
|
||||||
|
mock_get_parser.assert_called_once()
|
||||||
|
@ -471,11 +471,9 @@ def test_process_content_item(
|
|||||||
"memory.workers.tasks.content_processing.push_to_qdrant",
|
"memory.workers.tasks.content_processing.push_to_qdrant",
|
||||||
side_effect=Exception("Qdrant error"),
|
side_effect=Exception("Qdrant error"),
|
||||||
):
|
):
|
||||||
result = process_content_item(
|
result = process_content_item(mail_message, db_session)
|
||||||
mail_message, "mail", db_session, ["tag1"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
result = process_content_item(mail_message, "mail", db_session, ["tag1"])
|
result = process_content_item(mail_message, db_session)
|
||||||
|
|
||||||
assert result["status"] == expected_status
|
assert result["status"] == expected_status
|
||||||
assert result["embed_status"] == expected_embed_status
|
assert result["embed_status"] == expected_embed_status
|
||||||
|
@ -326,7 +326,7 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
|||||||
# Check that the full content was passed to the embedding function
|
# Check that the full content was passed to the embedding function
|
||||||
texts = mock_voyage_client.embed.call_args[0][0]
|
texts = mock_voyage_client.embed.call_args[0][0]
|
||||||
assert texts == [
|
assert texts == [
|
||||||
large_page_1.strip(),
|
[large_page_1.strip()],
|
||||||
large_page_2.strip(),
|
[large_page_2.strip()],
|
||||||
large_section_content.strip(),
|
[large_section_content.strip()],
|
||||||
]
|
]
|
||||||
|
562
tests/memory/workers/tasks/test_forums_tasks.py
Normal file
562
tests/memory/workers/tasks/test_forums_tasks.py
Normal file
@ -0,0 +1,562 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from memory.common.db.models import ForumPost
|
||||||
|
from memory.workers.tasks import forums
|
||||||
|
from memory.parsers.lesswrong import LessWrongPost
|
||||||
|
from memory.workers.tasks.content_processing import create_content_hash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_lesswrong_post():
|
||||||
|
"""Mock LessWrong post data for testing."""
|
||||||
|
return LessWrongPost(
|
||||||
|
title="Test LessWrong Post",
|
||||||
|
url="https://www.lesswrong.com/posts/test123/test-post",
|
||||||
|
description="This is a test post description",
|
||||||
|
content="This is test post content with enough text to be processed and embedded.",
|
||||||
|
authors=["Test Author"],
|
||||||
|
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||||
|
guid="test123",
|
||||||
|
karma=25,
|
||||||
|
votes=10,
|
||||||
|
comments=5,
|
||||||
|
words=100,
|
||||||
|
tags=["rationality", "ai"],
|
||||||
|
af=False,
|
||||||
|
score=25,
|
||||||
|
extended_score=30,
|
||||||
|
modified_at="2024-01-01T12:30:00Z",
|
||||||
|
slug="test-post",
|
||||||
|
images=[], # Empty images to avoid file path issues
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_empty_lesswrong_post():
|
||||||
|
"""Mock LessWrong post with empty content."""
|
||||||
|
return LessWrongPost(
|
||||||
|
title="Empty Post",
|
||||||
|
url="https://www.lesswrong.com/posts/empty123/empty-post",
|
||||||
|
description="",
|
||||||
|
content="",
|
||||||
|
authors=["Empty Author"],
|
||||||
|
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||||
|
guid="empty123",
|
||||||
|
karma=5,
|
||||||
|
votes=2,
|
||||||
|
comments=0,
|
||||||
|
words=0,
|
||||||
|
tags=[],
|
||||||
|
af=False,
|
||||||
|
score=5,
|
||||||
|
extended_score=5,
|
||||||
|
slug="empty-post",
|
||||||
|
images=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_af_post():
|
||||||
|
"""Mock Alignment Forum post."""
|
||||||
|
return LessWrongPost(
|
||||||
|
title="AI Safety Research",
|
||||||
|
url="https://www.lesswrong.com/posts/af123/ai-safety-research",
|
||||||
|
description="Important AI safety research",
|
||||||
|
content="This is important AI safety research content that should be processed.",
|
||||||
|
authors=["AI Researcher"],
|
||||||
|
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||||
|
guid="af123",
|
||||||
|
karma=50,
|
||||||
|
votes=20,
|
||||||
|
comments=15,
|
||||||
|
words=200,
|
||||||
|
tags=["ai-safety", "alignment"],
|
||||||
|
af=True,
|
||||||
|
score=50,
|
||||||
|
extended_score=60,
|
||||||
|
slug="ai-safety-research",
|
||||||
|
images=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_lesswrong_post_success(mock_lesswrong_post, db_session, qdrant):
|
||||||
|
"""Test successful LessWrong post synchronization."""
|
||||||
|
result = forums.sync_lesswrong_post(mock_lesswrong_post, ["test", "forum"])
|
||||||
|
|
||||||
|
# Verify the ForumPost was created in the database
|
||||||
|
forum_post = (
|
||||||
|
db_session.query(ForumPost)
|
||||||
|
.filter_by(url="https://www.lesswrong.com/posts/test123/test-post")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert forum_post is not None
|
||||||
|
assert forum_post.title == "Test LessWrong Post"
|
||||||
|
assert (
|
||||||
|
forum_post.content
|
||||||
|
== "This is test post content with enough text to be processed and embedded."
|
||||||
|
)
|
||||||
|
assert forum_post.modality == "forum"
|
||||||
|
assert forum_post.mime_type == "text/markdown"
|
||||||
|
assert forum_post.authors == ["Test Author"]
|
||||||
|
assert forum_post.karma == 25
|
||||||
|
assert forum_post.votes == 10
|
||||||
|
assert forum_post.comments == 5
|
||||||
|
assert forum_post.words == 100
|
||||||
|
assert forum_post.score == 25
|
||||||
|
assert forum_post.slug == "test-post"
|
||||||
|
assert forum_post.images == []
|
||||||
|
assert "test" in forum_post.tags
|
||||||
|
assert "forum" in forum_post.tags
|
||||||
|
assert "rationality" in forum_post.tags
|
||||||
|
assert "ai" in forum_post.tags
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
assert result["forumpost_id"] == forum_post.id
|
||||||
|
assert result["title"] == "Test LessWrong Post"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_lesswrong_post_empty_content(mock_empty_lesswrong_post, db_session):
|
||||||
|
"""Test LessWrong post sync with empty content."""
|
||||||
|
result = forums.sync_lesswrong_post(mock_empty_lesswrong_post)
|
||||||
|
|
||||||
|
# Should still create the post but with failed status due to no chunks
|
||||||
|
forum_post = (
|
||||||
|
db_session.query(ForumPost)
|
||||||
|
.filter_by(url="https://www.lesswrong.com/posts/empty123/empty-post")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert forum_post is not None
|
||||||
|
assert forum_post.title == "Empty Post"
|
||||||
|
assert forum_post.content == ""
|
||||||
|
assert result["status"] == "failed" # No chunks generated for empty content
|
||||||
|
assert result["chunks_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_lesswrong_post_already_exists(mock_lesswrong_post, db_session):
|
||||||
|
"""Test LessWrong post sync when content already exists."""
|
||||||
|
# Add existing forum post with same content hash
|
||||||
|
existing_post = ForumPost(
|
||||||
|
url="https://www.lesswrong.com/posts/test123/test-post",
|
||||||
|
title="Test LessWrong Post",
|
||||||
|
content="This is test post content with enough text to be processed and embedded.",
|
||||||
|
sha256=create_content_hash(
|
||||||
|
"This is test post content with enough text to be processed and embedded."
|
||||||
|
),
|
||||||
|
modality="forum",
|
||||||
|
tags=["existing"],
|
||||||
|
mime_type="text/markdown",
|
||||||
|
size=77,
|
||||||
|
authors=["Test Author"],
|
||||||
|
karma=25,
|
||||||
|
)
|
||||||
|
db_session.add(existing_post)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
result = forums.sync_lesswrong_post(mock_lesswrong_post, ["test"])
|
||||||
|
|
||||||
|
assert result["status"] == "already_exists"
|
||||||
|
assert result["forumpost_id"] == existing_post.id
|
||||||
|
|
||||||
|
# Verify no duplicate was created
|
||||||
|
forum_posts = (
|
||||||
|
db_session.query(ForumPost)
|
||||||
|
.filter_by(url="https://www.lesswrong.com/posts/test123/test-post")
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(forum_posts) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_lesswrong_post_with_custom_tags(mock_lesswrong_post, db_session, qdrant):
|
||||||
|
"""Test LessWrong post sync with custom tags."""
|
||||||
|
result = forums.sync_lesswrong_post(mock_lesswrong_post, ["custom", "tags"])
|
||||||
|
|
||||||
|
# Verify the ForumPost was created with custom tags merged with post tags
|
||||||
|
forum_post = (
|
||||||
|
db_session.query(ForumPost)
|
||||||
|
.filter_by(url="https://www.lesswrong.com/posts/test123/test-post")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert forum_post is not None
|
||||||
|
assert "custom" in forum_post.tags
|
||||||
|
assert "tags" in forum_post.tags
|
||||||
|
assert "rationality" in forum_post.tags # Original post tags
|
||||||
|
assert "ai" in forum_post.tags
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_lesswrong_post_af_post(mock_af_post, db_session, qdrant):
|
||||||
|
"""Test syncing an Alignment Forum post."""
|
||||||
|
result = forums.sync_lesswrong_post(mock_af_post, ["alignment-forum"])
|
||||||
|
|
||||||
|
forum_post = (
|
||||||
|
db_session.query(ForumPost)
|
||||||
|
.filter_by(url="https://www.lesswrong.com/posts/af123/ai-safety-research")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert forum_post is not None
|
||||||
|
assert forum_post.title == "AI Safety Research"
|
||||||
|
assert forum_post.karma == 50
|
||||||
|
assert "ai-safety" in forum_post.tags
|
||||||
|
assert "alignment" in forum_post.tags
|
||||||
|
assert "alignment-forum" in forum_post.tags
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||||
|
def test_sync_lesswrong_success(mock_fetch, mock_lesswrong_post, db_session):
|
||||||
|
"""Test successful LessWrong synchronization."""
|
||||||
|
mock_fetch.return_value = [mock_lesswrong_post]
|
||||||
|
|
||||||
|
with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post:
|
||||||
|
mock_sync_post.delay.return_value = Mock(id="task-123")
|
||||||
|
|
||||||
|
result = forums.sync_lesswrong(
|
||||||
|
since="2024-01-01T00:00:00",
|
||||||
|
min_karma=10,
|
||||||
|
limit=50,
|
||||||
|
cooldown=0.1,
|
||||||
|
max_items=100,
|
||||||
|
af=False,
|
||||||
|
tags=["test"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["posts_num"] == 1
|
||||||
|
assert result["new_posts"] == 1
|
||||||
|
assert result["since"] == "2024-01-01T00:00:00"
|
||||||
|
assert result["min_karma"] == 10
|
||||||
|
assert result["max_items"] == 100
|
||||||
|
assert result["af"] == False
|
||||||
|
|
||||||
|
# Verify fetch_lesswrong_posts was called with correct arguments
|
||||||
|
mock_fetch.assert_called_once_with(
|
||||||
|
datetime.fromisoformat("2024-01-01T00:00:00"),
|
||||||
|
10, # min_karma
|
||||||
|
50, # limit
|
||||||
|
0.1, # cooldown
|
||||||
|
100, # max_items
|
||||||
|
False, # af
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify sync_lesswrong_post was called for the new post
|
||||||
|
mock_sync_post.delay.assert_called_once_with(mock_lesswrong_post, ["test"])
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||||
|
def test_sync_lesswrong_with_existing_posts(
|
||||||
|
mock_fetch, mock_lesswrong_post, db_session
|
||||||
|
):
|
||||||
|
"""Test sync when some posts already exist."""
|
||||||
|
# Create existing forum post
|
||||||
|
existing_post = ForumPost(
|
||||||
|
url="https://www.lesswrong.com/posts/test123/test-post",
|
||||||
|
title="Existing Post",
|
||||||
|
content="Existing content",
|
||||||
|
sha256=b"existing_hash" + bytes(24),
|
||||||
|
modality="forum",
|
||||||
|
tags=["existing"],
|
||||||
|
mime_type="text/markdown",
|
||||||
|
size=100,
|
||||||
|
authors=["Test Author"],
|
||||||
|
)
|
||||||
|
db_session.add(existing_post)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Mock fetch to return existing post and a new one
|
||||||
|
new_post = mock_lesswrong_post.copy()
|
||||||
|
new_post["url"] = "https://www.lesswrong.com/posts/new123/new-post"
|
||||||
|
new_post["title"] = "New Post"
|
||||||
|
|
||||||
|
mock_fetch.return_value = [mock_lesswrong_post, new_post]
|
||||||
|
|
||||||
|
with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post:
|
||||||
|
mock_sync_post.delay.return_value = Mock(id="task-456")
|
||||||
|
|
||||||
|
result = forums.sync_lesswrong(max_items=100)
|
||||||
|
|
||||||
|
assert result["posts_num"] == 2
|
||||||
|
assert result["new_posts"] == 1 # Only one new post
|
||||||
|
|
||||||
|
# Verify sync_lesswrong_post was only called for the new post
|
||||||
|
mock_sync_post.delay.assert_called_once_with(new_post, [])
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||||
|
def test_sync_lesswrong_no_posts(mock_fetch, db_session):
|
||||||
|
"""Test sync when no posts are returned."""
|
||||||
|
mock_fetch.return_value = []
|
||||||
|
|
||||||
|
result = forums.sync_lesswrong()
|
||||||
|
|
||||||
|
assert result["posts_num"] == 0
|
||||||
|
assert result["new_posts"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||||
|
def test_sync_lesswrong_max_items_limit(mock_fetch, db_session):
|
||||||
|
"""Test that max_items limit is respected."""
|
||||||
|
# Create multiple mock posts
|
||||||
|
posts = []
|
||||||
|
for i in range(5):
|
||||||
|
post = LessWrongPost(
|
||||||
|
title=f"Post {i}",
|
||||||
|
url=f"https://www.lesswrong.com/posts/test{i}/post-{i}",
|
||||||
|
description=f"Description {i}",
|
||||||
|
content=f"Content {i}",
|
||||||
|
authors=[f"Author {i}"],
|
||||||
|
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||||
|
guid=f"test{i}",
|
||||||
|
karma=10,
|
||||||
|
votes=5,
|
||||||
|
comments=2,
|
||||||
|
words=50,
|
||||||
|
tags=[],
|
||||||
|
af=False,
|
||||||
|
score=10,
|
||||||
|
extended_score=10,
|
||||||
|
slug=f"post-{i}",
|
||||||
|
images=[],
|
||||||
|
)
|
||||||
|
posts.append(post)
|
||||||
|
|
||||||
|
mock_fetch.return_value = posts
|
||||||
|
|
||||||
|
with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post:
|
||||||
|
mock_sync_post.delay.return_value = Mock(id="task-123")
|
||||||
|
|
||||||
|
result = forums.sync_lesswrong(max_items=3)
|
||||||
|
|
||||||
|
# Should stop at max_items, but new_posts can be higher because
|
||||||
|
# the check happens after incrementing new_posts but before incrementing posts_num
|
||||||
|
assert result["posts_num"] == 3
|
||||||
|
assert result["new_posts"] == 4 # One more than posts_num due to timing
|
||||||
|
assert result["max_items"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"since_str,expected_days_ago",
|
||||||
|
[
|
||||||
|
("2024-01-01T00:00:00", None), # Specific date
|
||||||
|
(None, 30), # Default should be 30 days ago
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||||
|
def test_sync_lesswrong_since_parameter(
|
||||||
|
mock_fetch, since_str, expected_days_ago, db_session
|
||||||
|
):
|
||||||
|
"""Test that since parameter is handled correctly."""
|
||||||
|
mock_fetch.return_value = []
|
||||||
|
|
||||||
|
if since_str:
|
||||||
|
forums.sync_lesswrong(since=since_str)
|
||||||
|
expected_since = datetime.fromisoformat(since_str)
|
||||||
|
else:
|
||||||
|
forums.sync_lesswrong()
|
||||||
|
# Should default to 30 days ago
|
||||||
|
expected_since = datetime.now() - timedelta(days=30)
|
||||||
|
|
||||||
|
# Verify fetch was called with correct since date
|
||||||
|
call_args = mock_fetch.call_args[0]
|
||||||
|
actual_since = call_args[0]
|
||||||
|
|
||||||
|
if expected_days_ago:
|
||||||
|
# For default case, check it's approximately 30 days ago
|
||||||
|
time_diff = abs((actual_since - expected_since).total_seconds())
|
||||||
|
assert time_diff < 120 # Within 2 minute tolerance for slow tests
|
||||||
|
else:
|
||||||
|
assert actual_since == expected_since
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"af_value,min_karma,limit,cooldown",
|
||||||
|
[
|
||||||
|
(True, 20, 25, 1.0),
|
||||||
|
(False, 5, 100, 0.0),
|
||||||
|
(True, 50, 10, 0.5),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||||
|
def test_sync_lesswrong_parameters(
|
||||||
|
mock_fetch, af_value, min_karma, limit, cooldown, db_session
|
||||||
|
):
|
||||||
|
"""Test that all parameters are passed correctly to fetch function."""
|
||||||
|
mock_fetch.return_value = []
|
||||||
|
|
||||||
|
result = forums.sync_lesswrong(
|
||||||
|
af=af_value,
|
||||||
|
min_karma=min_karma,
|
||||||
|
limit=limit,
|
||||||
|
cooldown=cooldown,
|
||||||
|
max_items=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify fetch was called with correct parameters
|
||||||
|
call_args = mock_fetch.call_args[0]
|
||||||
|
|
||||||
|
assert call_args[1] == min_karma # min_karma
|
||||||
|
assert call_args[2] == limit # limit
|
||||||
|
assert call_args[3] == cooldown # cooldown
|
||||||
|
assert call_args[4] == 500 # max_items
|
||||||
|
assert call_args[5] == af_value # af
|
||||||
|
|
||||||
|
assert result["min_karma"] == min_karma
|
||||||
|
assert result["af"] == af_value
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||||
|
def test_sync_lesswrong_fetch_error(mock_fetch, db_session):
|
||||||
|
"""Test sync when fetch_lesswrong_posts raises an exception."""
|
||||||
|
mock_fetch.side_effect = Exception("API error")
|
||||||
|
|
||||||
|
# The safe_task_execution decorator should catch this
|
||||||
|
result = forums.sync_lesswrong()
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "API error" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_lesswrong_post_error_handling(db_session):
|
||||||
|
"""Test error handling in sync_lesswrong_post."""
|
||||||
|
# Create invalid post data that will cause an error
|
||||||
|
invalid_post = {
|
||||||
|
"title": "Test",
|
||||||
|
"url": "invalid-url",
|
||||||
|
"content": "test content",
|
||||||
|
# Missing required fields
|
||||||
|
}
|
||||||
|
|
||||||
|
# The safe_task_execution decorator should catch this
|
||||||
|
result = forums.sync_lesswrong_post(invalid_post)
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"post_tags,additional_tags,expected_tags",
|
||||||
|
[
|
||||||
|
(["original"], ["new"], ["original", "new"]),
|
||||||
|
([], ["tag1", "tag2"], ["tag1", "tag2"]),
|
||||||
|
(["existing"], [], ["existing"]),
|
||||||
|
(["dup", "tag"], ["tag", "new"], ["dup", "tag", "new"]), # Duplicates removed
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sync_lesswrong_post_tag_merging(
|
||||||
|
post_tags, additional_tags, expected_tags, db_session, qdrant
|
||||||
|
):
|
||||||
|
"""Test that post tags and additional tags are properly merged."""
|
||||||
|
post = LessWrongPost(
|
||||||
|
title="Tag Test Post",
|
||||||
|
url="https://www.lesswrong.com/posts/tag123/tag-test",
|
||||||
|
description="Test description",
|
||||||
|
content="Test content for tag merging",
|
||||||
|
authors=["Tag Author"],
|
||||||
|
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||||
|
guid="tag123",
|
||||||
|
karma=15,
|
||||||
|
votes=5,
|
||||||
|
comments=2,
|
||||||
|
words=50,
|
||||||
|
tags=post_tags,
|
||||||
|
af=False,
|
||||||
|
score=15,
|
||||||
|
extended_score=15,
|
||||||
|
slug="tag-test",
|
||||||
|
images=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
forums.sync_lesswrong_post(post, additional_tags)
|
||||||
|
|
||||||
|
forum_post = (
|
||||||
|
db_session.query(ForumPost)
|
||||||
|
.filter_by(url="https://www.lesswrong.com/posts/tag123/tag-test")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert forum_post is not None
|
||||||
|
|
||||||
|
# Check that all expected tags are present (order doesn't matter)
|
||||||
|
for tag in expected_tags:
|
||||||
|
assert tag in forum_post.tags
|
||||||
|
|
||||||
|
# Check that no unexpected tags are present
|
||||||
|
assert len(forum_post.tags) == len(set(expected_tags))
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_lesswrong_post_datetime_handling(db_session, qdrant):
|
||||||
|
"""Test that datetime fields are properly handled."""
|
||||||
|
post = LessWrongPost(
|
||||||
|
title="DateTime Test",
|
||||||
|
url="https://www.lesswrong.com/posts/dt123/datetime-test",
|
||||||
|
description="Test description",
|
||||||
|
content="Test content",
|
||||||
|
authors=["DateTime Author"],
|
||||||
|
published_at=datetime(2024, 6, 15, 14, 30, 45, tzinfo=timezone.utc),
|
||||||
|
guid="dt123",
|
||||||
|
karma=20,
|
||||||
|
votes=8,
|
||||||
|
comments=3,
|
||||||
|
words=75,
|
||||||
|
tags=["datetime"],
|
||||||
|
af=False,
|
||||||
|
score=20,
|
||||||
|
extended_score=25,
|
||||||
|
modified_at="2024-06-15T15:00:00Z",
|
||||||
|
slug="datetime-test",
|
||||||
|
images=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = forums.sync_lesswrong_post(post)
|
||||||
|
|
||||||
|
forum_post = (
|
||||||
|
db_session.query(ForumPost)
|
||||||
|
.filter_by(url="https://www.lesswrong.com/posts/dt123/datetime-test")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert forum_post is not None
|
||||||
|
assert forum_post.published_at == datetime(
|
||||||
|
2024, 6, 15, 14, 30, 45, tzinfo=timezone.utc
|
||||||
|
)
|
||||||
|
# modified_at should be stored as string in the post data
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_lesswrong_post_content_hash_consistency(db_session):
|
||||||
|
"""Test that content hash is calculated consistently."""
|
||||||
|
post = LessWrongPost(
|
||||||
|
title="Hash Test",
|
||||||
|
url="https://www.lesswrong.com/posts/hash123/hash-test",
|
||||||
|
description="Test description",
|
||||||
|
content="Consistent content for hashing",
|
||||||
|
authors=["Hash Author"],
|
||||||
|
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||||
|
guid="hash123",
|
||||||
|
karma=15,
|
||||||
|
votes=5,
|
||||||
|
comments=2,
|
||||||
|
words=50,
|
||||||
|
tags=["hash"],
|
||||||
|
af=False,
|
||||||
|
score=15,
|
||||||
|
extended_score=15,
|
||||||
|
slug="hash-test",
|
||||||
|
images=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sync the same post twice
|
||||||
|
result1 = forums.sync_lesswrong_post(post)
|
||||||
|
result2 = forums.sync_lesswrong_post(post)
|
||||||
|
|
||||||
|
# First should succeed, second should detect existing
|
||||||
|
assert result1["status"] == "processed"
|
||||||
|
assert result2["status"] == "already_exists"
|
||||||
|
assert result1["forumpost_id"] == result2["forumpost_id"]
|
||||||
|
|
||||||
|
# Verify only one post exists in database
|
||||||
|
forum_posts = (
|
||||||
|
db_session.query(ForumPost)
|
||||||
|
.filter_by(url="https://www.lesswrong.com/posts/hash123/hash-test")
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(forum_posts) == 1
|
@ -7,12 +7,14 @@ from PIL import Image
|
|||||||
|
|
||||||
from memory.common import qdrant as qd
|
from memory.common import qdrant as qd
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk, SourceItem, MailMessage, BlogPost
|
||||||
from memory.workers.tasks.maintenance import (
|
from memory.workers.tasks.maintenance import (
|
||||||
clean_collection,
|
clean_collection,
|
||||||
reingest_chunk,
|
reingest_chunk,
|
||||||
check_batch,
|
check_batch,
|
||||||
reingest_missing_chunks,
|
reingest_missing_chunks,
|
||||||
|
reingest_item,
|
||||||
|
reingest_empty_source_items,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -302,5 +304,295 @@ def test_reingest_missing_chunks(db_session, qdrant, batch_size):
|
|||||||
assert set(qdrant_ids) == set(db_ids)
|
assert set(qdrant_ids) == set(db_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
|
||||||
|
def test_reingest_item_success(db_session, qdrant, item_type):
|
||||||
|
"""Test successful reingestion of an item with existing chunks."""
|
||||||
|
if item_type == "MailMessage":
|
||||||
|
item = MailMessage(
|
||||||
|
sha256=b"test_hash" + bytes(24),
|
||||||
|
tags=["test"],
|
||||||
|
size=100,
|
||||||
|
mime_type="message/rfc822",
|
||||||
|
embed_status="STORED",
|
||||||
|
message_id="<test@example.com>",
|
||||||
|
subject="Test Subject",
|
||||||
|
sender="sender@example.com",
|
||||||
|
recipients=["recipient@example.com"],
|
||||||
|
content="Test content for reingestion",
|
||||||
|
folder="INBOX",
|
||||||
|
modality="mail",
|
||||||
|
)
|
||||||
|
else: # blog_post
|
||||||
|
item = BlogPost(
|
||||||
|
sha256=b"test_hash" + bytes(24),
|
||||||
|
tags=["test"],
|
||||||
|
size=100,
|
||||||
|
mime_type="text/html",
|
||||||
|
embed_status="STORED",
|
||||||
|
url="https://example.com/post",
|
||||||
|
title="Test Blog Post",
|
||||||
|
author="Author Name",
|
||||||
|
content="Test blog content for reingestion",
|
||||||
|
modality="blog",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(item)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Add some chunks to the item
|
||||||
|
chunk_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||||
|
chunks = [
|
||||||
|
Chunk(
|
||||||
|
id=chunk_id,
|
||||||
|
source=item,
|
||||||
|
content=f"Test chunk content {i}",
|
||||||
|
embedding_model="test-model",
|
||||||
|
)
|
||||||
|
for i, chunk_id in enumerate(chunk_ids)
|
||||||
|
]
|
||||||
|
db_session.add_all(chunks)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Add vectors to Qdrant
|
||||||
|
modality = "mail" if item_type == "MailMessage" else "blog"
|
||||||
|
qd.ensure_collection_exists(qdrant, modality, 1024)
|
||||||
|
qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids))
|
||||||
|
|
||||||
|
# Verify chunks exist in Qdrant before reingestion
|
||||||
|
qdrant_ids_before = {
|
||||||
|
str(i) for batch in qd.batch_ids(qdrant, modality) for i in batch
|
||||||
|
}
|
||||||
|
assert set(chunk_ids).issubset(qdrant_ids_before)
|
||||||
|
|
||||||
|
# Mock the embedding function to return chunks
|
||||||
|
with patch("memory.common.embedding.embed_source_item") as mock_embed:
|
||||||
|
mock_embed.return_value = [
|
||||||
|
Chunk(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
content="New chunk content 1",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.1] * 1024,
|
||||||
|
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
content="New chunk content 2",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.2] * 1024,
|
||||||
|
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = reingest_item(str(item.id), item_type)
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
assert result[f"{item_type.lower()}_id"] == item.id
|
||||||
|
assert result["chunks_count"] == 2
|
||||||
|
assert result["embed_status"] == "STORED"
|
||||||
|
|
||||||
|
# Verify old chunks were deleted from database
|
||||||
|
db_session.refresh(item)
|
||||||
|
remaining_chunks = db_session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
|
||||||
|
assert len(remaining_chunks) == 0
|
||||||
|
|
||||||
|
# Verify old vectors were deleted from Qdrant
|
||||||
|
qdrant_ids_after = {
|
||||||
|
str(i) for batch in qd.batch_ids(qdrant, modality) for i in batch
|
||||||
|
}
|
||||||
|
assert not set(chunk_ids).intersection(qdrant_ids_after)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reingest_item_not_found(db_session):
|
||||||
|
"""Test reingesting a non-existent item."""
|
||||||
|
non_existent_id = "999"
|
||||||
|
result = reingest_item(non_existent_id, "MailMessage")
|
||||||
|
|
||||||
|
assert result == {"status": "error", "error": f"Item {non_existent_id} not found"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_reingest_item_invalid_type(db_session):
|
||||||
|
"""Test reingesting with an invalid item type."""
|
||||||
|
result = reingest_item("1", "invalid_type")
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "Unsupported item type invalid_type" in result["error"]
|
||||||
|
assert "Available types:" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_reingest_item_no_chunks(db_session, qdrant):
|
||||||
|
"""Test reingesting an item that has no chunks."""
|
||||||
|
item = MailMessage(
|
||||||
|
sha256=b"test_hash" + bytes(24),
|
||||||
|
tags=["test"],
|
||||||
|
size=100,
|
||||||
|
mime_type="message/rfc822",
|
||||||
|
embed_status="RAW",
|
||||||
|
message_id="<test@example.com>",
|
||||||
|
subject="Test Subject",
|
||||||
|
sender="sender@example.com",
|
||||||
|
recipients=["recipient@example.com"],
|
||||||
|
content="Test content",
|
||||||
|
folder="INBOX",
|
||||||
|
modality="mail",
|
||||||
|
)
|
||||||
|
db_session.add(item)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Mock the embedding function to return a chunk
|
||||||
|
with patch("memory.common.embedding.embed_source_item") as mock_embed:
|
||||||
|
mock_embed.return_value = [
|
||||||
|
Chunk(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
content="New chunk content",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.1] * 1024,
|
||||||
|
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = reingest_item(str(item.id), "MailMessage")
|
||||||
|
|
||||||
|
assert result["status"] == "processed"
|
||||||
|
assert result["mailmessage_id"] == item.id
|
||||||
|
assert result["chunks_count"] == 1
|
||||||
|
assert result["embed_status"] == "STORED"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
|
||||||
|
def test_reingest_empty_source_items_success(db_session, item_type):
|
||||||
|
"""Test reingesting empty source items."""
|
||||||
|
# Create items with and without chunks
|
||||||
|
if item_type == "MailMessage":
|
||||||
|
empty_items = [
|
||||||
|
MailMessage(
|
||||||
|
sha256=f"empty_hash_{i}".encode() + bytes(32 - len(f"empty_hash_{i}")),
|
||||||
|
tags=["test"],
|
||||||
|
size=100,
|
||||||
|
mime_type="message/rfc822",
|
||||||
|
embed_status="RAW",
|
||||||
|
message_id=f"<empty{i}@example.com>",
|
||||||
|
subject=f"Empty Subject {i}",
|
||||||
|
sender="sender@example.com",
|
||||||
|
recipients=["recipient@example.com"],
|
||||||
|
content=f"Empty content {i}",
|
||||||
|
folder="INBOX",
|
||||||
|
modality="mail",
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
item_with_chunks = MailMessage(
|
||||||
|
sha256=b"with_chunks_hash" + bytes(16),
|
||||||
|
tags=["test"],
|
||||||
|
size=100,
|
||||||
|
mime_type="message/rfc822",
|
||||||
|
embed_status="STORED",
|
||||||
|
message_id="<with_chunks@example.com>",
|
||||||
|
subject="With Chunks Subject",
|
||||||
|
sender="sender@example.com",
|
||||||
|
recipients=["recipient@example.com"],
|
||||||
|
content="Content with chunks",
|
||||||
|
folder="INBOX",
|
||||||
|
modality="mail",
|
||||||
|
)
|
||||||
|
else: # blog_post
|
||||||
|
empty_items = [
|
||||||
|
BlogPost(
|
||||||
|
sha256=f"empty_hash_{i}".encode() + bytes(32 - len(f"empty_hash_{i}")),
|
||||||
|
tags=["test"],
|
||||||
|
size=100,
|
||||||
|
mime_type="text/html",
|
||||||
|
embed_status="RAW",
|
||||||
|
url=f"https://example.com/empty{i}",
|
||||||
|
title=f"Empty Post {i}",
|
||||||
|
author="Author Name",
|
||||||
|
content=f"Empty blog content {i}",
|
||||||
|
modality="blog",
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
item_with_chunks = BlogPost(
|
||||||
|
sha256=b"with_chunks_hash" + bytes(16),
|
||||||
|
tags=["test"],
|
||||||
|
size=100,
|
||||||
|
mime_type="text/html",
|
||||||
|
embed_status="STORED",
|
||||||
|
url="https://example.com/with_chunks",
|
||||||
|
title="With Chunks Post",
|
||||||
|
author="Author Name",
|
||||||
|
content="Blog content with chunks",
|
||||||
|
modality="blog",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all(empty_items + [item_with_chunks])
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Add a chunk to the item_with_chunks
|
||||||
|
chunk = Chunk(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
source=item_with_chunks,
|
||||||
|
content="Test chunk content",
|
||||||
|
embedding_model="test-model",
|
||||||
|
)
|
||||||
|
db_session.add(chunk)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
with patch.object(reingest_item, "delay") as mock_reingest:
|
||||||
|
result = reingest_empty_source_items(item_type)
|
||||||
|
|
||||||
|
assert result == {"status": "success", "items": 3}
|
||||||
|
|
||||||
|
# Verify that reingest_item.delay was called for each empty item
|
||||||
|
assert mock_reingest.call_count == 3
|
||||||
|
expected_calls = [call(item.id, item_type) for item in empty_items]
|
||||||
|
mock_reingest.assert_has_calls(expected_calls, any_order=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reingest_empty_source_items_no_empty_items(db_session):
|
||||||
|
"""Test when there are no empty source items."""
|
||||||
|
# Create an item with chunks
|
||||||
|
item = MailMessage(
|
||||||
|
sha256=b"with_chunks_hash" + bytes(16),
|
||||||
|
tags=["test"],
|
||||||
|
size=100,
|
||||||
|
mime_type="message/rfc822",
|
||||||
|
embed_status="STORED",
|
||||||
|
message_id="<with_chunks@example.com>",
|
||||||
|
subject="With Chunks Subject",
|
||||||
|
sender="sender@example.com",
|
||||||
|
recipients=["recipient@example.com"],
|
||||||
|
content="Content with chunks",
|
||||||
|
folder="INBOX",
|
||||||
|
modality="mail",
|
||||||
|
)
|
||||||
|
db_session.add(item)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
chunk = Chunk(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
source=item,
|
||||||
|
content="Test chunk content",
|
||||||
|
embedding_model="test-model",
|
||||||
|
)
|
||||||
|
db_session.add(chunk)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
with patch.object(reingest_item, "delay") as mock_reingest:
|
||||||
|
result = reingest_empty_source_items("MailMessage")
|
||||||
|
|
||||||
|
assert result == {"status": "success", "items": 0}
|
||||||
|
mock_reingest.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_reingest_empty_source_items_invalid_type(db_session):
|
||||||
|
"""Test reingesting empty source items with invalid type."""
|
||||||
|
result = reingest_empty_source_items("invalid_type")
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "Unsupported item type invalid_type" in result["error"]
|
||||||
|
assert "Available types:" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
def test_reingest_missing_chunks_no_chunks(db_session):
|
def test_reingest_missing_chunks_no_chunks(db_session):
|
||||||
assert reingest_missing_chunks() == {}
|
assert reingest_missing_chunks() == {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user