mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
Add less wrong tasks + reindexer
This commit is contained in:
parent
ab87bced81
commit
ed8033bdd3
52
db/migrations/versions/20250528_012300_add_forum_posts.py
Normal file
52
db/migrations/versions/20250528_012300_add_forum_posts.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""Add forum posts
|
||||
|
||||
Revision ID: 2524646f56f6
|
||||
Revises: 1b535e1b044e
|
||||
Create Date: 2025-05-28 01:23:00.079366
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2524646f56f6"
|
||||
down_revision: Union[str, None] = "1b535e1b044e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"forum_post",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("url", sa.Text(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("authors", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("modified_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("slug", sa.Text(), nullable=True),
|
||||
sa.Column("karma", sa.Integer(), nullable=True),
|
||||
sa.Column("votes", sa.Integer(), nullable=True),
|
||||
sa.Column("comments", sa.Integer(), nullable=True),
|
||||
sa.Column("words", sa.Integer(), nullable=True),
|
||||
sa.Column("score", sa.Integer(), nullable=True),
|
||||
sa.Column("images", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("url"),
|
||||
)
|
||||
op.create_index("forum_post_slug_idx", "forum_post", ["slug"], unique=False)
|
||||
op.create_index("forum_post_title_idx", "forum_post", ["title"], unique=False)
|
||||
op.create_index("forum_post_url_idx", "forum_post", ["url"], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("forum_post_url_idx", table_name="forum_post")
|
||||
op.drop_index("forum_post_title_idx", table_name="forum_post")
|
||||
op.drop_index("forum_post_slug_idx", table_name="forum_post")
|
||||
op.drop_table("forum_post")
|
@ -207,6 +207,12 @@ services:
|
||||
<<: *worker-env
|
||||
QUEUES: "blogs"
|
||||
|
||||
worker-forums:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "forums"
|
||||
|
||||
worker-photo:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
|
@ -33,7 +33,7 @@ RUN mkdir -p /var/log/supervisor /var/run/supervisor
|
||||
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files
|
||||
USER kb
|
||||
|
||||
ENV QUEUES="docs,email,maintenance"
|
||||
ENV QUEUES="maintenance"
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]
|
@ -39,7 +39,7 @@ RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||
USER kb
|
||||
|
||||
# Default queues to process
|
||||
ENV QUEUES="ebooks,email,comic,blogs,photo_embed,maintenance"
|
||||
ENV QUEUES="ebooks,email,comic,blogs,forums,photo_embed,maintenance"
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
ENTRYPOINT ["./entry.sh"]
|
@ -3,4 +3,5 @@ pytest-cov==4.1.0
|
||||
black==23.12.1
|
||||
mypy==1.8.0
|
||||
isort==5.13.2
|
||||
testcontainers[qdrant]==4.10.0
|
||||
testcontainers[qdrant]==4.10.0
|
||||
click==8.1.7
|
@ -17,6 +17,7 @@ from memory.common.db.models import (
|
||||
MiscDoc,
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
ForumPost,
|
||||
)
|
||||
|
||||
|
||||
@ -33,7 +34,11 @@ DEFAULT_COLUMNS = (
|
||||
|
||||
|
||||
def source_columns(model: type[SourceItem], *columns: str):
|
||||
return [getattr(model, c) for c in columns + DEFAULT_COLUMNS if hasattr(model, c)]
|
||||
return [
|
||||
getattr(model, c)
|
||||
for c in ("id",) + columns + DEFAULT_COLUMNS
|
||||
if hasattr(model, c)
|
||||
]
|
||||
|
||||
|
||||
# Create admin views for all models
|
||||
@ -86,6 +91,21 @@ class BlogPostAdmin(ModelView, model=BlogPost):
|
||||
column_searchable_list = ["title", "author", "domain"]
|
||||
|
||||
|
||||
class ForumPostAdmin(ModelView, model=ForumPost):
|
||||
column_list = source_columns(
|
||||
ForumPost,
|
||||
"title",
|
||||
"authors",
|
||||
"published_at",
|
||||
"url",
|
||||
"karma",
|
||||
"votes",
|
||||
"comments",
|
||||
"score",
|
||||
)
|
||||
column_searchable_list = ["title", "authors"]
|
||||
|
||||
|
||||
class PhotoAdmin(ModelView, model=Photo):
|
||||
column_list = source_columns(Photo, "exif_taken_at", "camera")
|
||||
|
||||
@ -166,5 +186,6 @@ def setup_admin(admin: Admin):
|
||||
admin.add_view(MiscDocAdmin)
|
||||
admin.add_view(ArticleFeedAdmin)
|
||||
admin.add_view(BlogPostAdmin)
|
||||
admin.add_view(ForumPostAdmin)
|
||||
admin.add_view(ComicAdmin)
|
||||
admin.add_view(PhotoAdmin)
|
||||
|
@ -43,7 +43,12 @@ ALL_COLLECTIONS: dict[str, Collection] = {
|
||||
"blog": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
"forum": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
"text": {
|
||||
"dimension": 1024,
|
||||
|
@ -105,6 +105,21 @@ def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalCh
|
||||
]
|
||||
|
||||
|
||||
def chunk_mixed(
|
||||
content: str, image_paths: Sequence[str]
|
||||
) -> list[list[extract.MulitmodalChunk]]:
|
||||
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
||||
full_text: list[extract.MulitmodalChunk] = [content.strip(), *images]
|
||||
|
||||
chunks = []
|
||||
tokens = chunker.approx_token_count(content)
|
||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
|
||||
|
||||
all_chunks = [full_text] + chunks
|
||||
return [c for c in all_chunks if c and all(i for i in c)]
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
"""Stores content chunks with their vector embeddings."""
|
||||
|
||||
@ -134,6 +149,19 @@ class Chunk(Base):
|
||||
Index("chunk_source_idx", "source_id"),
|
||||
)
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[extract.MulitmodalChunk]:
|
||||
chunks: list[extract.MulitmodalChunk] = []
|
||||
if cast(str | None, self.content):
|
||||
chunks = [cast(str, self.content)]
|
||||
if self.images:
|
||||
chunks += self.images
|
||||
elif cast(Sequence[str] | None, self.file_paths):
|
||||
chunks += [
|
||||
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
|
||||
]
|
||||
return chunks
|
||||
|
||||
@property
|
||||
def data(self) -> list[bytes | str | Image.Image]:
|
||||
if self.file_paths is None:
|
||||
@ -638,20 +666,57 @@ class BlogPost(SourceItem):
|
||||
return {k: v for k, v in payload.items() if v}
|
||||
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
images = [
|
||||
Image.open(settings.FILE_STORAGE_DIR / image) for image in self.images
|
||||
]
|
||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||
|
||||
content = cast(str, self.content)
|
||||
full_text = [content.strip(), *images]
|
||||
|
||||
chunks = []
|
||||
tokens = chunker.approx_token_count(content)
|
||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
|
||||
class ForumPost(SourceItem):
|
||||
__tablename__ = "forum_post"
|
||||
|
||||
all_chunks = [full_text] + chunks
|
||||
return [c for c in all_chunks if c and all(i for i in c)]
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
url = Column(Text, unique=True)
|
||||
title = Column(Text)
|
||||
description = Column(Text, nullable=True)
|
||||
authors = Column(ARRAY(Text), nullable=True)
|
||||
published_at = Column(DateTime(timezone=True), nullable=True)
|
||||
modified_at = Column(DateTime(timezone=True), nullable=True)
|
||||
slug = Column(Text, nullable=True)
|
||||
karma = Column(Integer, nullable=True)
|
||||
votes = Column(Integer, nullable=True)
|
||||
comments = Column(Integer, nullable=True)
|
||||
words = Column(Integer, nullable=True)
|
||||
score = Column(Integer, nullable=True)
|
||||
images = Column(ARRAY(Text), nullable=True)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "forum_post",
|
||||
}
|
||||
|
||||
__table_args__ = (
|
||||
Index("forum_post_url_idx", "url"),
|
||||
Index("forum_post_slug_idx", "slug"),
|
||||
Index("forum_post_title_idx", "title"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
"url": self.url,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"authors": self.authors,
|
||||
"published_at": self.published_at,
|
||||
"slug": self.slug,
|
||||
"karma": self.karma,
|
||||
"votes": self.votes,
|
||||
"score": self.score,
|
||||
"comments": self.comments,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||
|
||||
|
||||
class MiscDoc(SourceItem):
|
||||
@ -710,7 +775,7 @@ class ArticleFeed(Base):
|
||||
description = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
check_interval = Column(
|
||||
Integer, nullable=False, server_default="3600", doc="Seconds between checks"
|
||||
Integer, nullable=False, server_default="60", doc="Minutes between checks"
|
||||
)
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
|
@ -20,7 +20,7 @@ def embed_chunks(
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]}")
|
||||
logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]} {len(chunks)}")
|
||||
vo = voyageai.Client() # type: ignore
|
||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||
return vo.multimodal_embed(
|
||||
@ -79,7 +79,7 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
||||
if not model_chunks:
|
||||
return []
|
||||
|
||||
vectors = embed_chunks([chunk.content for chunk in model_chunks], model) # type: ignore
|
||||
vectors = embed_chunks([chunk.chunks for chunk in model_chunks], model)
|
||||
for chunk, vector in zip(model_chunks, vectors):
|
||||
chunk.vector = vector
|
||||
return model_chunks
|
||||
|
@ -86,9 +86,11 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
||||
|
||||
# Worker settings
|
||||
# Intervals are in seconds
|
||||
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
|
||||
CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 86400))
|
||||
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 3600))
|
||||
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 60 * 60))
|
||||
COMIC_SYNC_INTERVAL = int(os.getenv("COMIC_SYNC_INTERVAL", 60 * 60))
|
||||
ARTICLE_FEED_SYNC_INTERVAL = int(os.getenv("ARTICLE_FEED_SYNC_INTERVAL", 30 * 60))
|
||||
CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 24 * 60 * 60))
|
||||
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 60 * 60))
|
||||
|
||||
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
||||
|
||||
|
304
src/memory/parsers/lesswrong.py
Normal file
304
src/memory/parsers/lesswrong.py
Normal file
@ -0,0 +1,304 @@
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Generator, TypedDict, NotRequired
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from PIL import Image as PILImage
|
||||
from memory.common import settings
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
|
||||
from memory.parsers.html import parse_date, process_images
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LessWrongPost(TypedDict):
|
||||
"""Represents a post from LessWrong."""
|
||||
|
||||
title: str
|
||||
url: str
|
||||
description: str
|
||||
content: str
|
||||
authors: list[str]
|
||||
published_at: datetime | None
|
||||
guid: str | None
|
||||
karma: int
|
||||
votes: int
|
||||
comments: int
|
||||
words: int
|
||||
tags: list[str]
|
||||
af: bool
|
||||
score: int
|
||||
extended_score: int
|
||||
modified_at: NotRequired[str | None]
|
||||
slug: NotRequired[str | None]
|
||||
images: NotRequired[list[str]]
|
||||
|
||||
|
||||
def make_graphql_query(
|
||||
after: datetime, af: bool = False, limit: int = 50, min_karma: int = 10
|
||||
) -> str:
|
||||
"""Create GraphQL query for fetching posts."""
|
||||
return f"""
|
||||
{{
|
||||
posts(input: {{
|
||||
terms: {{
|
||||
excludeEvents: true
|
||||
view: "old"
|
||||
af: {str(af).lower()}
|
||||
limit: {limit}
|
||||
karmaThreshold: {min_karma}
|
||||
after: "{after.isoformat()}Z"
|
||||
filter: "tagged"
|
||||
}}
|
||||
}}) {{
|
||||
totalCount
|
||||
results {{
|
||||
_id
|
||||
title
|
||||
slug
|
||||
pageUrl
|
||||
postedAt
|
||||
modifiedAt
|
||||
score
|
||||
extendedScore
|
||||
baseScore
|
||||
voteCount
|
||||
commentCount
|
||||
wordCount
|
||||
tags {{
|
||||
name
|
||||
}}
|
||||
user {{
|
||||
displayName
|
||||
}}
|
||||
coauthors {{
|
||||
displayName
|
||||
}}
|
||||
af
|
||||
htmlBody
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
def fetch_posts_from_api(url: str, query: str) -> dict[str, Any]:
|
||||
"""Fetch posts from LessWrong GraphQL API."""
|
||||
response = requests.post(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0"
|
||||
},
|
||||
json={"query": query},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]["posts"]
|
||||
|
||||
|
||||
def is_valid_post(post: dict[str, Any], min_karma: int = 10) -> bool:
|
||||
"""Check if post should be included."""
|
||||
# Must have content
|
||||
if not post.get("htmlBody"):
|
||||
return False
|
||||
|
||||
# Must meet karma threshold
|
||||
if post.get("baseScore", 0) < min_karma:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def extract_authors(post: dict[str, Any]) -> list[str]:
|
||||
"""Extract authors from post data."""
|
||||
authors = post.get("coauthors", []) or []
|
||||
if post.get("user"):
|
||||
authors = [post["user"]] + authors
|
||||
return [a["displayName"] for a in authors] or ["anonymous"]
|
||||
|
||||
|
||||
def extract_description(body: str) -> str:
|
||||
"""Extract description from post HTML content."""
|
||||
first_paragraph = body.split("\n\n")[0]
|
||||
|
||||
# Truncate if too long
|
||||
if len(first_paragraph) > 300:
|
||||
first_paragraph = first_paragraph[:300] + "..."
|
||||
|
||||
return first_paragraph
|
||||
|
||||
|
||||
def parse_lesswrong_date(date_str: str | None) -> datetime | None:
|
||||
"""Parse ISO date string from LessWrong API to datetime."""
|
||||
if not date_str:
|
||||
return None
|
||||
|
||||
# Try multiple ISO formats that LessWrong might use
|
||||
formats = [
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ", # 2023-01-15T10:30:00.000Z
|
||||
"%Y-%m-%dT%H:%M:%SZ", # 2023-01-15T10:30:00Z
|
||||
"%Y-%m-%dT%H:%M:%S.%f", # 2023-01-15T10:30:00.000
|
||||
"%Y-%m-%dT%H:%M:%S", # 2023-01-15T10:30:00
|
||||
]
|
||||
|
||||
for fmt in formats:
|
||||
if result := parse_date(date_str, fmt):
|
||||
return result
|
||||
|
||||
# Fallback: try removing 'Z' and using fromisoformat
|
||||
try:
|
||||
clean_date = date_str.rstrip("Z")
|
||||
return datetime.fromisoformat(clean_date)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Could not parse date: {date_str}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_body(post: dict[str, Any]) -> tuple[str, dict[str, PILImage.Image]]:
|
||||
"""Extract body from post data."""
|
||||
if not (body := post.get("htmlBody", "").strip()):
|
||||
return "", {}
|
||||
|
||||
url = post.get("pageUrl", "")
|
||||
image_dir = settings.FILE_STORAGE_DIR / "lesswrong" / url
|
||||
|
||||
soup = BeautifulSoup(body, "html.parser")
|
||||
soup, images = process_images(soup, url, image_dir)
|
||||
body = markdownify(str(soup)).strip()
|
||||
return body, images
|
||||
|
||||
|
||||
def format_post(post: dict[str, Any]) -> LessWrongPost:
|
||||
"""Convert raw API post data to GreaterWrongPost."""
|
||||
body, images = extract_body(post)
|
||||
|
||||
result: LessWrongPost = {
|
||||
"title": post.get("title", "Untitled"),
|
||||
"url": post.get("pageUrl", ""),
|
||||
"description": extract_description(body),
|
||||
"content": body,
|
||||
"authors": extract_authors(post),
|
||||
"published_at": parse_lesswrong_date(post.get("postedAt")),
|
||||
"guid": post.get("_id"),
|
||||
"karma": post.get("baseScore", 0),
|
||||
"votes": post.get("voteCount", 0),
|
||||
"comments": post.get("commentCount", 0),
|
||||
"words": post.get("wordCount", 0),
|
||||
"tags": [tag["name"] for tag in post.get("tags", [])],
|
||||
"af": post.get("af", False),
|
||||
"score": post.get("score", 0),
|
||||
"extended_score": post.get("extendedScore", 0),
|
||||
"modified_at": post.get("modifiedAt"),
|
||||
"slug": post.get("slug"),
|
||||
"images": list(images.keys()),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_lesswrong(
|
||||
url: str,
|
||||
current_date: datetime,
|
||||
af: bool = False,
|
||||
min_karma: int = 10,
|
||||
limit: int = 50,
|
||||
last_url: str | None = None,
|
||||
) -> list[LessWrongPost]:
|
||||
"""
|
||||
Fetch a batch of posts from LessWrong.
|
||||
|
||||
Returns:
|
||||
(posts, next_date, last_item) where next_date is None if iteration should stop
|
||||
"""
|
||||
query = make_graphql_query(current_date, af, limit, min_karma)
|
||||
api_response = fetch_posts_from_api(url, query)
|
||||
|
||||
if not api_response["results"]:
|
||||
return []
|
||||
|
||||
# If we only get the same item we started with, we're done
|
||||
if (
|
||||
len(api_response["results"]) == 1
|
||||
and last_url
|
||||
and api_response["results"][0]["pageUrl"] == last_url
|
||||
):
|
||||
return []
|
||||
|
||||
return [
|
||||
format_post(post)
|
||||
for post in api_response["results"]
|
||||
if is_valid_post(post, min_karma)
|
||||
]
|
||||
|
||||
|
||||
def fetch_lesswrong_posts(
|
||||
since: datetime | None = None,
|
||||
min_karma: int = 10,
|
||||
limit: int = 50,
|
||||
cooldown: float = 0.5,
|
||||
max_items: int = 1000,
|
||||
af: bool = False,
|
||||
url: str = "https://www.lesswrong.com/graphql",
|
||||
) -> Generator[LessWrongPost, None, None]:
|
||||
"""
|
||||
Fetch posts from LessWrong.
|
||||
|
||||
Args:
|
||||
url: GraphQL endpoint URL
|
||||
af: Whether to fetch Alignment Forum posts
|
||||
min_karma: Minimum karma threshold for posts
|
||||
limit: Number of posts per API request
|
||||
start_year: Default start year if no since date provided
|
||||
since: Start date for fetching posts
|
||||
cooldown: Delay between API requests in seconds
|
||||
max_pages: Maximum number of pages to fetch
|
||||
|
||||
Returns:
|
||||
List of GreaterWrongPost objects
|
||||
"""
|
||||
if not since:
|
||||
since = datetime.now() - timedelta(days=1)
|
||||
|
||||
logger.info(f"Starting from {since}")
|
||||
|
||||
last_url = None
|
||||
next_date = since
|
||||
items_count = 0
|
||||
|
||||
while next_date and items_count < max_items:
|
||||
try:
|
||||
page_posts = fetch_lesswrong(url, next_date, af, min_karma, limit, last_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching posts: {e}")
|
||||
break
|
||||
|
||||
if not page_posts or next_date is None:
|
||||
break
|
||||
|
||||
for post in page_posts:
|
||||
yield post
|
||||
|
||||
last_item = page_posts[-1]
|
||||
prev_date = next_date
|
||||
next_date = last_item.get("published_at")
|
||||
|
||||
if not next_date or prev_date == next_date:
|
||||
logger.warning(
|
||||
f"Could not advance through dataset, stopping at {next_date}"
|
||||
)
|
||||
break
|
||||
|
||||
# The articles are paged by date (inclusive) so passing the last date as
|
||||
# is will return the same article again.
|
||||
next_date += timedelta(seconds=1)
|
||||
items_count += len(page_posts)
|
||||
last_url = last_item["url"]
|
||||
|
||||
if cooldown > 0:
|
||||
time.sleep(cooldown)
|
||||
|
||||
logger.info(f"Fetched {items_count} items")
|
@ -1,6 +1,14 @@
|
||||
from celery import Celery
|
||||
from memory.common import settings
|
||||
|
||||
EMAIL_ROOT = "memory.workers.tasks.email"
|
||||
FORUMS_ROOT = "memory.workers.tasks.forums"
|
||||
BLOGS_ROOT = "memory.workers.tasks.blogs"
|
||||
PHOTO_ROOT = "memory.workers.tasks.photo"
|
||||
COMIC_ROOT = "memory.workers.tasks.comic"
|
||||
EBOOK_ROOT = "memory.workers.tasks.ebook"
|
||||
MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
|
||||
|
||||
|
||||
def rabbit_url() -> str:
|
||||
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:5672//"
|
||||
@ -21,12 +29,13 @@ app.conf.update(
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
task_routes={
|
||||
"memory.workers.tasks.email.*": {"queue": "email"},
|
||||
"memory.workers.tasks.photo.*": {"queue": "photo_embed"},
|
||||
"memory.workers.tasks.comic.*": {"queue": "comic"},
|
||||
"memory.workers.tasks.ebook.*": {"queue": "ebooks"},
|
||||
"memory.workers.tasks.blogs.*": {"queue": "blogs"},
|
||||
"memory.workers.tasks.maintenance.*": {"queue": "maintenance"},
|
||||
f"{EMAIL_ROOT}.*": {"queue": "email"},
|
||||
f"{PHOTO_ROOT}.*": {"queue": "photo_embed"},
|
||||
f"{COMIC_ROOT}.*": {"queue": "comic"},
|
||||
f"{EBOOK_ROOT}.*": {"queue": "ebooks"},
|
||||
f"{BLOGS_ROOT}.*": {"queue": "blogs"},
|
||||
f"{FORUMS_ROOT}.*": {"queue": "forums"},
|
||||
f"{MAINTENANCE_ROOT}.*": {"queue": "maintenance"},
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -8,10 +8,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
app.conf.beat_schedule = {
|
||||
"sync-mail-all": {
|
||||
"task": "memory.workers.tasks.email.sync_all_accounts",
|
||||
"schedule": settings.EMAIL_SYNC_INTERVAL,
|
||||
},
|
||||
"clean-all-collections": {
|
||||
"task": CLEAN_ALL_COLLECTIONS,
|
||||
"schedule": settings.CLEAN_COLLECTION_INTERVAL,
|
||||
@ -20,4 +16,16 @@ app.conf.beat_schedule = {
|
||||
"task": REINGEST_MISSING_CHUNKS,
|
||||
"schedule": settings.CHUNK_REINGEST_INTERVAL,
|
||||
},
|
||||
"sync-mail-all": {
|
||||
"task": "memory.workers.tasks.email.sync_all_accounts",
|
||||
"schedule": settings.EMAIL_SYNC_INTERVAL,
|
||||
},
|
||||
"sync-all-comics": {
|
||||
"task": "memory.workers.tasks.comic.sync_all_comics",
|
||||
"schedule": settings.COMIC_SYNC_INTERVAL,
|
||||
},
|
||||
"sync-all-article-feeds": {
|
||||
"task": "memory.workers.tasks.blogs.sync_all_article_feeds",
|
||||
"schedule": settings.ARTICLE_FEED_SYNC_INTERVAL,
|
||||
},
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
Import sub-modules so Celery can register their @app.task decorators.
|
||||
"""
|
||||
|
||||
from memory.workers.tasks import email, comic, blogs, ebook # noqa
|
||||
from memory.workers.tasks import email, comic, blogs, ebook, forums # noqa
|
||||
from memory.workers.tasks.blogs import (
|
||||
SYNC_WEBPAGE,
|
||||
SYNC_ARTICLE_FEED,
|
||||
@ -12,11 +12,13 @@ from memory.workers.tasks.blogs import (
|
||||
from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_SMBC, SYNC_XKCD
|
||||
from memory.workers.tasks.ebook import SYNC_BOOK
|
||||
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
|
||||
from memory.workers.tasks.forums import SYNC_LESSWRONG, SYNC_LESSWRONG_POST
|
||||
from memory.workers.tasks.maintenance import (
|
||||
CLEAN_ALL_COLLECTIONS,
|
||||
CLEAN_COLLECTION,
|
||||
REINGEST_MISSING_CHUNKS,
|
||||
REINGEST_CHUNK,
|
||||
REINGEST_ITEM,
|
||||
)
|
||||
|
||||
|
||||
@ -25,6 +27,7 @@ __all__ = [
|
||||
"comic",
|
||||
"blogs",
|
||||
"ebook",
|
||||
"forums",
|
||||
"SYNC_WEBPAGE",
|
||||
"SYNC_ARTICLE_FEED",
|
||||
"SYNC_ALL_ARTICLE_FEEDS",
|
||||
@ -34,10 +37,13 @@ __all__ = [
|
||||
"SYNC_XKCD",
|
||||
"SYNC_BOOK",
|
||||
"SYNC_ACCOUNT",
|
||||
"SYNC_LESSWRONG",
|
||||
"SYNC_LESSWRONG_POST",
|
||||
"SYNC_ALL_ACCOUNTS",
|
||||
"PROCESS_EMAIL",
|
||||
"CLEAN_ALL_COLLECTIONS",
|
||||
"CLEAN_COLLECTION",
|
||||
"REINGEST_MISSING_CHUNKS",
|
||||
"REINGEST_CHUNK",
|
||||
"REINGEST_ITEM",
|
||||
]
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Iterable, cast
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
@ -7,7 +7,7 @@ from memory.common.db.models import ArticleFeed, BlogPost
|
||||
from memory.parsers.blogs import parse_webpage
|
||||
from memory.parsers.feeds import get_feed_parser
|
||||
from memory.parsers.archives import get_archive_fetcher
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.celery_app import app, BLOGS_ROOT
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
create_content_hash,
|
||||
@ -18,10 +18,10 @@ from memory.workers.tasks.content_processing import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage"
|
||||
SYNC_ARTICLE_FEED = "memory.workers.tasks.blogs.sync_article_feed"
|
||||
SYNC_ALL_ARTICLE_FEEDS = "memory.workers.tasks.blogs.sync_all_article_feeds"
|
||||
SYNC_WEBSITE_ARCHIVE = "memory.workers.tasks.blogs.sync_website_archive"
|
||||
SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage"
|
||||
SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed"
|
||||
SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
|
||||
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
|
||||
|
||||
|
||||
@app.task(name=SYNC_WEBPAGE)
|
||||
@ -71,7 +71,7 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
|
||||
logger.info(f"Blog post already exists: {existing_post.title}")
|
||||
return create_task_result(existing_post, "already_exists", url=article.url)
|
||||
|
||||
return process_content_item(blog_post, "blog", session, tags)
|
||||
return process_content_item(blog_post, session)
|
||||
|
||||
|
||||
@app.task(name=SYNC_ARTICLE_FEED)
|
||||
@ -93,8 +93,8 @@ def sync_article_feed(feed_id: int) -> dict:
|
||||
return {"status": "error", "error": "Feed not found or inactive"}
|
||||
|
||||
last_checked_at = cast(datetime | None, feed.last_checked_at)
|
||||
if last_checked_at and datetime.now() - last_checked_at < timedelta(
|
||||
seconds=cast(int, feed.check_interval)
|
||||
if last_checked_at and datetime.now(timezone.utc) - last_checked_at < timedelta(
|
||||
minutes=cast(int, feed.check_interval)
|
||||
):
|
||||
logger.info(f"Feed {feed_id} checked too recently, skipping")
|
||||
return {"status": "skipped_recent_check", "feed_id": feed_id}
|
||||
@ -129,7 +129,7 @@ def sync_article_feed(feed_id: int) -> dict:
|
||||
logger.error(f"Error parsing feed {feed.url}: {e}")
|
||||
errors += 1
|
||||
|
||||
feed.last_checked_at = datetime.now() # type: ignore
|
||||
feed.last_checked_at = datetime.now(timezone.utc) # type: ignore
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
|
@ -9,7 +9,7 @@ from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Comic, clean_filename
|
||||
from memory.parsers import comics
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.celery_app import app, COMIC_ROOT
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
create_content_hash,
|
||||
@ -19,10 +19,10 @@ from memory.workers.tasks.content_processing import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics"
|
||||
SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc"
|
||||
SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd"
|
||||
SYNC_COMIC = "memory.workers.tasks.comic.sync_comic"
|
||||
SYNC_ALL_COMICS = f"{COMIC_ROOT}.sync_all_comics"
|
||||
SYNC_SMBC = f"{COMIC_ROOT}.sync_smbc"
|
||||
SYNC_XKCD = f"{COMIC_ROOT}.sync_xkcd"
|
||||
SYNC_COMIC = f"{COMIC_ROOT}.sync_comic"
|
||||
|
||||
BASE_SMBC_URL = "https://www.smbc-comics.com/"
|
||||
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
|
||||
@ -109,7 +109,7 @@ def sync_comic(
|
||||
)
|
||||
|
||||
with make_session() as session:
|
||||
return process_content_item(comic, "comic", session)
|
||||
return process_content_item(comic, session)
|
||||
|
||||
|
||||
@app.task(name=SYNC_SMBC)
|
||||
|
@ -99,6 +99,7 @@ def embed_source_item(source_item: SourceItem) -> int:
|
||||
except Exception as e:
|
||||
source_item.embed_status = "FAILED" # type: ignore
|
||||
logger.error(f"Failed to embed {type(source_item).__name__}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return 0
|
||||
|
||||
|
||||
@ -156,6 +157,7 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
||||
for item in items_to_process:
|
||||
item.embed_status = "FAILED" # type: ignore
|
||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
|
||||
@ -186,9 +188,7 @@ def create_task_result(
|
||||
}
|
||||
|
||||
|
||||
def process_content_item(
|
||||
item: SourceItem, collection_name: str, session, tags: Iterable[str] = []
|
||||
) -> dict[str, Any]:
|
||||
def process_content_item(item: SourceItem, session) -> dict[str, Any]:
|
||||
"""
|
||||
Execute complete content processing workflow.
|
||||
|
||||
@ -200,7 +200,6 @@ def process_content_item(
|
||||
|
||||
Args:
|
||||
item: SourceItem to process
|
||||
collection_name: Qdrant collection name for vector storage
|
||||
session: Database session for persistence
|
||||
tags: Optional tags to associate with the item (currently unused)
|
||||
|
||||
@ -223,7 +222,7 @@ def process_content_item(
|
||||
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
||||
|
||||
try:
|
||||
push_to_qdrant([item], collection_name)
|
||||
push_to_qdrant([item], cast(str, item.modality))
|
||||
status = "processed"
|
||||
item.embed_status = "STORED" # type: ignore
|
||||
logger.info(
|
||||
@ -231,6 +230,7 @@ def process_content_item(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
item.embed_status = "FAILED" # type: ignore
|
||||
session.commit()
|
||||
|
||||
@ -261,10 +261,8 @@ def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Task {func.__name__} failed with traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.error(f"Task {func.__name__} failed: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
return wrapper
|
||||
|
@ -6,7 +6,7 @@ import memory.common.settings as settings
|
||||
from memory.parsers.ebook import Ebook, parse_ebook, Section
|
||||
from memory.common.db.models import Book, BookSection
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.celery_app import app, EBOOK_ROOT
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
create_content_hash,
|
||||
@ -17,7 +17,7 @@ from memory.workers.tasks.content_processing import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_BOOK = "memory.workers.tasks.ebook.sync_book"
|
||||
SYNC_BOOK = f"{EBOOK_ROOT}.sync_book"
|
||||
|
||||
# Minimum section length to embed (avoid noise from very short sections)
|
||||
MIN_SECTION_LENGTH = 100
|
||||
|
@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from typing import cast
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import EmailAccount, MailMessage
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.celery_app import app, EMAIL_ROOT
|
||||
from memory.workers.email import (
|
||||
create_mail_message,
|
||||
imap_connection,
|
||||
@ -18,9 +18,9 @@ from memory.workers.tasks.content_processing import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROCESS_EMAIL = "memory.workers.tasks.email.process_message"
|
||||
SYNC_ACCOUNT = "memory.workers.tasks.email.sync_account"
|
||||
SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
|
||||
PROCESS_EMAIL = f"{EMAIL_ROOT}.process_message"
|
||||
SYNC_ACCOUNT = f"{EMAIL_ROOT}.sync_account"
|
||||
SYNC_ALL_ACCOUNTS = f"{EMAIL_ROOT}.sync_all_accounts"
|
||||
|
||||
|
||||
@app.task(name=PROCESS_EMAIL)
|
||||
|
85
src/memory/workers/tasks/forums.py
Normal file
85
src/memory/workers/tasks/forums.py
Normal file
@ -0,0 +1,85 @@
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
from memory.parsers.lesswrong import fetch_lesswrong_posts, LessWrongPost
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import ForumPost
|
||||
from memory.workers.celery_app import app, FORUMS_ROOT
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
create_content_hash,
|
||||
create_task_result,
|
||||
process_content_item,
|
||||
safe_task_execution,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_LESSWRONG = f"{FORUMS_ROOT}.sync_lesswrong"
|
||||
SYNC_LESSWRONG_POST = f"{FORUMS_ROOT}.sync_lesswrong_post"
|
||||
|
||||
|
||||
@app.task(name=SYNC_LESSWRONG_POST)
|
||||
@safe_task_execution
|
||||
def sync_lesswrong_post(
|
||||
post: LessWrongPost,
|
||||
tags: list[str] = [],
|
||||
):
|
||||
logger.info(f"Syncing LessWrong post {post['url']}")
|
||||
sha256 = create_content_hash(post["content"])
|
||||
|
||||
post["tags"] = list(set(post["tags"] + tags))
|
||||
post_obj = ForumPost(
|
||||
embed_status="RAW",
|
||||
size=len(post["content"].encode("utf-8")),
|
||||
modality="forum",
|
||||
mime_type="text/markdown",
|
||||
sha256=sha256,
|
||||
**{k: v for k, v in post.items() if hasattr(ForumPost, k)},
|
||||
)
|
||||
|
||||
with make_session() as session:
|
||||
existing_post = check_content_exists(
|
||||
session, ForumPost, url=post_obj.url, sha256=sha256
|
||||
)
|
||||
if existing_post:
|
||||
logger.info(f"LessWrong post already exists: {existing_post.title}")
|
||||
return create_task_result(existing_post, "already_exists", url=post_obj.url)
|
||||
|
||||
return process_content_item(post_obj, session)
|
||||
|
||||
|
||||
@app.task(name=SYNC_LESSWRONG)
|
||||
@safe_task_execution
|
||||
def sync_lesswrong(
|
||||
since: str = (datetime.now() - timedelta(days=30)).isoformat(),
|
||||
min_karma: int = 10,
|
||||
limit: int = 50,
|
||||
cooldown: float = 0.5,
|
||||
max_items: int = 1000,
|
||||
af: bool = False,
|
||||
tags: list[str] = [],
|
||||
):
|
||||
logger.info(f"Syncing LessWrong posts since {since}")
|
||||
start_date = datetime.fromisoformat(since)
|
||||
posts = fetch_lesswrong_posts(start_date, min_karma, limit, cooldown, max_items, af)
|
||||
|
||||
posts_num, new_posts = 0, 0
|
||||
with make_session() as session:
|
||||
for post in posts:
|
||||
if not check_content_exists(session, ForumPost, url=post["url"]):
|
||||
new_posts += 1
|
||||
sync_lesswrong_post.delay(post, tags)
|
||||
|
||||
if posts_num >= max_items:
|
||||
break
|
||||
posts_num += 1
|
||||
|
||||
return {
|
||||
"posts_num": posts_num,
|
||||
"new_posts": new_posts,
|
||||
"since": since,
|
||||
"min_karma": min_karma,
|
||||
"max_items": max_items,
|
||||
"af": af,
|
||||
}
|
@ -3,21 +3,24 @@ from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Sequence
|
||||
|
||||
from memory.workers.tasks.content_processing import process_content_item
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import contains_eager
|
||||
|
||||
from memory.common import collections, embedding, qdrant, settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.celery_app import app, MAINTENANCE_ROOT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CLEAN_ALL_COLLECTIONS = "memory.workers.tasks.maintenance.clean_all_collections"
|
||||
CLEAN_COLLECTION = "memory.workers.tasks.maintenance.clean_collection"
|
||||
REINGEST_MISSING_CHUNKS = "memory.workers.tasks.maintenance.reingest_missing_chunks"
|
||||
REINGEST_CHUNK = "memory.workers.tasks.maintenance.reingest_chunk"
|
||||
CLEAN_ALL_COLLECTIONS = f"{MAINTENANCE_ROOT}.clean_all_collections"
|
||||
CLEAN_COLLECTION = f"{MAINTENANCE_ROOT}.clean_collection"
|
||||
REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks"
|
||||
REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk"
|
||||
REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item"
|
||||
REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items"
|
||||
|
||||
|
||||
@app.task(name=CLEAN_COLLECTION)
|
||||
@ -87,6 +90,61 @@ def reingest_chunk(chunk_id: str, collection: str):
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_item_class(item_type: str):
|
||||
class_ = SourceItem.registry._class_registry.get(item_type)
|
||||
if not class_:
|
||||
available_types = ", ".join(sorted(SourceItem.registry._class_registry.keys()))
|
||||
raise ValueError(
|
||||
f"Unsupported item type {item_type}. Available types: {available_types}"
|
||||
)
|
||||
return class_
|
||||
|
||||
|
||||
@app.task(name=REINGEST_ITEM)
|
||||
def reingest_item(item_id: str, item_type: str):
|
||||
logger.info(f"Reingesting {item_type} {item_id}")
|
||||
try:
|
||||
class_ = get_item_class(item_type)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error getting item class: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
with make_session() as session:
|
||||
item = session.query(class_).get(item_id)
|
||||
if not item:
|
||||
return {"status": "error", "error": f"Item {item_id} not found"}
|
||||
|
||||
chunk_ids = [str(c.id) for c in item.chunks if c.id]
|
||||
if chunk_ids:
|
||||
client = qdrant.get_qdrant_client()
|
||||
qdrant.delete_points(client, item.modality, chunk_ids)
|
||||
|
||||
for chunk in item.chunks:
|
||||
session.delete(chunk)
|
||||
|
||||
return process_content_item(item, session)
|
||||
|
||||
|
||||
@app.task(name=REINGEST_EMPTY_SOURCE_ITEMS)
|
||||
def reingest_empty_source_items(item_type: str):
|
||||
logger.info("Reingesting empty source items")
|
||||
try:
|
||||
class_ = get_item_class(item_type)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error getting item class: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
with make_session() as session:
|
||||
item_ids = session.query(class_.id).filter(~class_.chunks.any()).all()
|
||||
|
||||
logger.info(f"Found {len(item_ids)} items to reingest")
|
||||
|
||||
for item_id in item_ids:
|
||||
reingest_item.delay(item_id.id, item_type) # type: ignore
|
||||
|
||||
return {"status": "success", "items": len(item_ids)}
|
||||
|
||||
|
||||
def check_batch(batch: Sequence[Chunk]) -> dict:
|
||||
client = qdrant.get_qdrant_client()
|
||||
by_collection = defaultdict(list)
|
||||
@ -116,14 +174,19 @@ def check_batch(batch: Sequence[Chunk]) -> dict:
|
||||
|
||||
@app.task(name=REINGEST_MISSING_CHUNKS)
|
||||
def reingest_missing_chunks(
|
||||
batch_size: int = 1000, minutes_ago: int = settings.CHUNK_REINGEST_SINCE_MINUTES
|
||||
batch_size: int = 1000,
|
||||
collection: str | None = None,
|
||||
minutes_ago: int = settings.CHUNK_REINGEST_SINCE_MINUTES,
|
||||
):
|
||||
logger.info("Reingesting missing chunks")
|
||||
total_stats = defaultdict(lambda: {"missing": 0, "correct": 0, "total": 0})
|
||||
since = datetime.now() - timedelta(minutes=minutes_ago)
|
||||
|
||||
with make_session() as session:
|
||||
total_count = session.query(Chunk).filter(Chunk.checked_at < since).count()
|
||||
query = session.query(Chunk).filter(Chunk.checked_at < since)
|
||||
if collection:
|
||||
query = query.filter(Chunk.source.has(SourceItem.modality == collection))
|
||||
total_count = query.count()
|
||||
|
||||
logger.info(
|
||||
f"Found {total_count} chunks to check, processing in batches of {batch_size}"
|
||||
|
542
tests/memory/parsers/test_lesswrong.py
Normal file
542
tests/memory/parsers/test_lesswrong.py
Normal file
@ -0,0 +1,542 @@
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch, Mock
|
||||
import json
|
||||
import pytest
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from memory.parsers.lesswrong import (
|
||||
LessWrongPost,
|
||||
make_graphql_query,
|
||||
fetch_posts_from_api,
|
||||
is_valid_post,
|
||||
extract_authors,
|
||||
extract_description,
|
||||
parse_lesswrong_date,
|
||||
extract_body,
|
||||
format_post,
|
||||
fetch_lesswrong,
|
||||
fetch_lesswrong_posts,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"after, af, limit, min_karma, expected_contains",
|
||||
[
|
||||
(
|
||||
datetime(2023, 1, 15, 10, 30),
|
||||
False,
|
||||
50,
|
||||
10,
|
||||
[
|
||||
"af: false",
|
||||
"limit: 50",
|
||||
"karmaThreshold: 10",
|
||||
'after: "2023-01-15T10:30:00Z"',
|
||||
],
|
||||
),
|
||||
(
|
||||
datetime(2023, 2, 20),
|
||||
True,
|
||||
25,
|
||||
5,
|
||||
[
|
||||
"af: true",
|
||||
"limit: 25",
|
||||
"karmaThreshold: 5",
|
||||
'after: "2023-02-20T00:00:00Z"',
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_make_graphql_query(after, af, limit, min_karma, expected_contains):
|
||||
query = make_graphql_query(after, af, limit, min_karma)
|
||||
|
||||
for expected in expected_contains:
|
||||
assert expected in query
|
||||
|
||||
# Check that all required fields are present
|
||||
required_fields = [
|
||||
"_id",
|
||||
"title",
|
||||
"slug",
|
||||
"pageUrl",
|
||||
"postedAt",
|
||||
"modifiedAt",
|
||||
"score",
|
||||
"extendedScore",
|
||||
"baseScore",
|
||||
"voteCount",
|
||||
"commentCount",
|
||||
"wordCount",
|
||||
"tags",
|
||||
"user",
|
||||
"coauthors",
|
||||
"af",
|
||||
"htmlBody",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in query
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.requests.post")
|
||||
def test_fetch_posts_from_api_success(mock_post):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"data": {
|
||||
"posts": {
|
||||
"totalCount": 2,
|
||||
"results": [
|
||||
{"_id": "1", "title": "Post 1"},
|
||||
{"_id": "2", "title": "Post 2"},
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
url = "https://www.lesswrong.com/graphql"
|
||||
query = "test query"
|
||||
|
||||
result = fetch_posts_from_api(url, query)
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0"
|
||||
},
|
||||
json={"query": query},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"totalCount": 2,
|
||||
"results": [
|
||||
{"_id": "1", "title": "Post 1"},
|
||||
{"_id": "2", "title": "Post 2"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.requests.post")
|
||||
def test_fetch_posts_from_api_http_error(mock_post):
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = Exception("HTTP Error")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception, match="HTTP Error"):
|
||||
fetch_posts_from_api("https://example.com", "query")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"post_data, min_karma, expected",
|
||||
[
|
||||
# Valid post with content and karma
|
||||
({"htmlBody": "<p>Content</p>", "baseScore": 15}, 10, True),
|
||||
# Valid post at karma threshold
|
||||
({"htmlBody": "<p>Content</p>", "baseScore": 10}, 10, True),
|
||||
# Invalid: no content
|
||||
({"htmlBody": "", "baseScore": 15}, 10, False),
|
||||
({"htmlBody": None, "baseScore": 15}, 10, False),
|
||||
({}, 10, False),
|
||||
# Invalid: below karma threshold
|
||||
({"htmlBody": "<p>Content</p>", "baseScore": 5}, 10, False),
|
||||
({"htmlBody": "<p>Content</p>"}, 10, False), # No baseScore
|
||||
# Edge cases
|
||||
(
|
||||
{"htmlBody": " ", "baseScore": 15},
|
||||
10,
|
||||
True,
|
||||
), # Whitespace only - actually valid
|
||||
({"htmlBody": "<p>Content</p>", "baseScore": 0}, 0, True), # Zero threshold
|
||||
],
|
||||
)
|
||||
def test_is_valid_post(post_data, min_karma, expected):
|
||||
assert is_valid_post(post_data, min_karma) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"post_data, expected",
|
||||
[
|
||||
# User only
|
||||
(
|
||||
{"user": {"displayName": "Alice"}},
|
||||
["Alice"],
|
||||
),
|
||||
# User with coauthors
|
||||
(
|
||||
{
|
||||
"user": {"displayName": "Alice"},
|
||||
"coauthors": [{"displayName": "Bob"}, {"displayName": "Charlie"}],
|
||||
},
|
||||
["Alice", "Bob", "Charlie"],
|
||||
),
|
||||
# Coauthors only (no user)
|
||||
(
|
||||
{"coauthors": [{"displayName": "Bob"}]},
|
||||
["Bob"],
|
||||
),
|
||||
# Empty coauthors list
|
||||
(
|
||||
{"user": {"displayName": "Alice"}, "coauthors": []},
|
||||
["Alice"],
|
||||
),
|
||||
# No authors at all
|
||||
({}, ["anonymous"]),
|
||||
({"user": None, "coauthors": None}, ["anonymous"]),
|
||||
({"user": None, "coauthors": []}, ["anonymous"]),
|
||||
],
|
||||
)
|
||||
def test_extract_authors(post_data, expected):
|
||||
assert extract_authors(post_data) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"body, expected",
|
||||
[
|
||||
# Short content
|
||||
("This is a short paragraph.", "This is a short paragraph."),
|
||||
# Multiple paragraphs - only first
|
||||
("First paragraph.\n\nSecond paragraph.", "First paragraph."),
|
||||
# Long content - truncated
|
||||
(
|
||||
"A" * 350,
|
||||
"A" * 300 + "...",
|
||||
),
|
||||
# Empty content
|
||||
("", ""),
|
||||
# Whitespace only
|
||||
(" \n\n ", " "),
|
||||
],
|
||||
)
|
||||
def test_extract_description(body, expected):
|
||||
assert extract_description(body) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"date_str, expected",
|
||||
[
|
||||
# Standard ISO formats
|
||||
("2023-01-15T10:30:00.000Z", datetime(2023, 1, 15, 10, 30, 0)),
|
||||
("2023-01-15T10:30:00Z", datetime(2023, 1, 15, 10, 30, 0)),
|
||||
("2023-01-15T10:30:00.000", datetime(2023, 1, 15, 10, 30, 0)),
|
||||
("2023-01-15T10:30:00", datetime(2023, 1, 15, 10, 30, 0)),
|
||||
# Fallback to fromisoformat
|
||||
("2023-01-15T10:30:00.123456", datetime(2023, 1, 15, 10, 30, 0, 123456)),
|
||||
# Invalid dates
|
||||
("invalid-date", None),
|
||||
("", None),
|
||||
(None, None),
|
||||
("2023-13-45T25:70:70Z", None), # Invalid date components
|
||||
],
|
||||
)
|
||||
def test_parse_lesswrong_date(date_str, expected):
|
||||
assert parse_lesswrong_date(date_str) == expected
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.process_images")
|
||||
@patch("memory.parsers.lesswrong.markdownify")
|
||||
@patch("memory.parsers.lesswrong.settings")
|
||||
def test_extract_body(mock_settings, mock_markdownify, mock_process_images):
|
||||
from pathlib import Path
|
||||
|
||||
mock_settings.FILE_STORAGE_DIR = Path("/tmp")
|
||||
mock_markdownify.return_value = "# Markdown content"
|
||||
mock_images = {"image1.jpg": Mock(spec=PILImage.Image)}
|
||||
mock_process_images.return_value = (Mock(), mock_images)
|
||||
|
||||
post = {
|
||||
"htmlBody": "<h1>HTML content</h1>",
|
||||
"pageUrl": "https://lesswrong.com/posts/abc123/test-post",
|
||||
}
|
||||
|
||||
body, images = extract_body(post)
|
||||
|
||||
assert body == "# Markdown content"
|
||||
assert images == mock_images
|
||||
mock_process_images.assert_called_once()
|
||||
mock_markdownify.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.process_images")
|
||||
def test_extract_body_empty_content(mock_process_images):
|
||||
post = {"htmlBody": ""}
|
||||
body, images = extract_body(post)
|
||||
|
||||
assert body == ""
|
||||
assert images == {}
|
||||
mock_process_images.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.extract_body")
|
||||
def test_format_post(mock_extract_body):
|
||||
mock_extract_body.return_value = ("Markdown body", {"img1.jpg": Mock()})
|
||||
|
||||
post_data = {
|
||||
"_id": "abc123",
|
||||
"title": "Test Post",
|
||||
"slug": "test-post",
|
||||
"pageUrl": "https://lesswrong.com/posts/abc123/test-post",
|
||||
"postedAt": "2023-01-15T10:30:00Z",
|
||||
"modifiedAt": "2023-01-16T11:00:00Z",
|
||||
"score": 25,
|
||||
"extendedScore": 30,
|
||||
"baseScore": 20,
|
||||
"voteCount": 15,
|
||||
"commentCount": 5,
|
||||
"wordCount": 1000,
|
||||
"tags": [{"name": "AI"}, {"name": "Rationality"}],
|
||||
"user": {"displayName": "Author"},
|
||||
"coauthors": [{"displayName": "Coauthor"}],
|
||||
"af": True,
|
||||
"htmlBody": "<p>HTML content</p>",
|
||||
}
|
||||
|
||||
result = format_post(post_data)
|
||||
|
||||
expected: LessWrongPost = {
|
||||
"title": "Test Post",
|
||||
"url": "https://lesswrong.com/posts/abc123/test-post",
|
||||
"description": "Markdown body",
|
||||
"content": "Markdown body",
|
||||
"authors": ["Author", "Coauthor"],
|
||||
"published_at": datetime(2023, 1, 15, 10, 30, 0),
|
||||
"guid": "abc123",
|
||||
"karma": 20,
|
||||
"votes": 15,
|
||||
"comments": 5,
|
||||
"words": 1000,
|
||||
"tags": ["AI", "Rationality"],
|
||||
"af": True,
|
||||
"score": 25,
|
||||
"extended_score": 30,
|
||||
"modified_at": "2023-01-16T11:00:00Z",
|
||||
"slug": "test-post",
|
||||
"images": ["img1.jpg"],
|
||||
}
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.extract_body")
|
||||
def test_format_post_minimal_data(mock_extract_body):
|
||||
mock_extract_body.return_value = ("", {})
|
||||
|
||||
post_data = {}
|
||||
|
||||
result = format_post(post_data)
|
||||
|
||||
expected: LessWrongPost = {
|
||||
"title": "Untitled",
|
||||
"url": "",
|
||||
"description": "",
|
||||
"content": "",
|
||||
"authors": ["anonymous"],
|
||||
"published_at": None,
|
||||
"guid": None,
|
||||
"karma": 0,
|
||||
"votes": 0,
|
||||
"comments": 0,
|
||||
"words": 0,
|
||||
"tags": [],
|
||||
"af": False,
|
||||
"score": 0,
|
||||
"extended_score": 0,
|
||||
"modified_at": None,
|
||||
"slug": None,
|
||||
"images": [],
|
||||
}
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.fetch_posts_from_api")
|
||||
@patch("memory.parsers.lesswrong.make_graphql_query")
|
||||
@patch("memory.parsers.lesswrong.format_post")
|
||||
@patch("memory.parsers.lesswrong.is_valid_post")
|
||||
def test_fetch_lesswrong_success(
|
||||
mock_is_valid, mock_format, mock_query, mock_fetch_api
|
||||
):
|
||||
mock_query.return_value = "test query"
|
||||
mock_fetch_api.return_value = {
|
||||
"results": [
|
||||
{"_id": "1", "title": "Post 1"},
|
||||
{"_id": "2", "title": "Post 2"},
|
||||
]
|
||||
}
|
||||
mock_is_valid.side_effect = [True, False] # First valid, second invalid
|
||||
mock_format.return_value = {"title": "Formatted Post"}
|
||||
|
||||
url = "https://lesswrong.com/graphql"
|
||||
current_date = datetime(2023, 1, 15)
|
||||
|
||||
result = fetch_lesswrong(url, current_date, af=True, min_karma=5, limit=25)
|
||||
|
||||
mock_query.assert_called_once_with(current_date, True, 25, 5)
|
||||
mock_fetch_api.assert_called_once_with(url, "test query")
|
||||
assert mock_is_valid.call_count == 2
|
||||
mock_format.assert_called_once_with({"_id": "1", "title": "Post 1"})
|
||||
assert result == [{"title": "Formatted Post"}]
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.fetch_posts_from_api")
|
||||
def test_fetch_lesswrong_empty_results(mock_fetch_api):
|
||||
mock_fetch_api.return_value = {"results": []}
|
||||
|
||||
result = fetch_lesswrong("url", datetime.now())
|
||||
assert result == []
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.fetch_posts_from_api")
|
||||
def test_fetch_lesswrong_same_item_as_last(mock_fetch_api):
|
||||
mock_fetch_api.return_value = {
|
||||
"results": [{"pageUrl": "https://lesswrong.com/posts/same"}]
|
||||
}
|
||||
|
||||
result = fetch_lesswrong(
|
||||
"url", datetime.now(), last_url="https://lesswrong.com/posts/same"
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||
@patch("memory.parsers.lesswrong.time.sleep")
|
||||
def test_fetch_lesswrong_posts_success(mock_sleep, mock_fetch):
|
||||
since = datetime(2023, 1, 15)
|
||||
|
||||
# Mock three batches of posts
|
||||
mock_fetch.side_effect = [
|
||||
[
|
||||
{"published_at": datetime(2023, 1, 14), "url": "post1"},
|
||||
{"published_at": datetime(2023, 1, 13), "url": "post2"},
|
||||
],
|
||||
[
|
||||
{"published_at": datetime(2023, 1, 12), "url": "post3"},
|
||||
],
|
||||
[], # Empty result to stop iteration
|
||||
]
|
||||
|
||||
posts = list(
|
||||
fetch_lesswrong_posts(
|
||||
since=since,
|
||||
min_karma=10,
|
||||
limit=50,
|
||||
cooldown=0.1,
|
||||
max_items=100,
|
||||
af=False,
|
||||
url="https://lesswrong.com/graphql",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(posts) == 3
|
||||
assert posts[0]["url"] == "post1"
|
||||
assert posts[1]["url"] == "post2"
|
||||
assert posts[2]["url"] == "post3"
|
||||
|
||||
# Should have called sleep twice (after first two batches)
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_called_with(0.1)
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||
def test_fetch_lesswrong_posts_default_since(mock_fetch):
|
||||
mock_fetch.return_value = []
|
||||
|
||||
with patch("memory.parsers.lesswrong.datetime") as mock_datetime:
|
||||
mock_now = datetime(2023, 1, 15, 12, 0, 0)
|
||||
mock_datetime.now.return_value = mock_now
|
||||
|
||||
list(fetch_lesswrong_posts())
|
||||
|
||||
# Should use yesterday as default
|
||||
expected_since = mock_now - timedelta(days=1)
|
||||
mock_fetch.assert_called_with(
|
||||
"https://www.lesswrong.com/graphql", expected_since, False, 10, 50, None
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||
def test_fetch_lesswrong_posts_max_items_limit(mock_fetch):
|
||||
# Return posts that would exceed max_items
|
||||
mock_fetch.side_effect = [
|
||||
[{"published_at": datetime(2023, 1, 14), "url": f"post{i}"} for i in range(8)],
|
||||
[
|
||||
{"published_at": datetime(2023, 1, 13), "url": f"post{i}"}
|
||||
for i in range(8, 16)
|
||||
],
|
||||
]
|
||||
|
||||
posts = list(
|
||||
fetch_lesswrong_posts(
|
||||
since=datetime(2023, 1, 15),
|
||||
max_items=7, # Should stop after 7 items
|
||||
cooldown=0,
|
||||
)
|
||||
)
|
||||
|
||||
# The logic checks items_count < max_items before fetching, so it will fetch the first batch
|
||||
# Since items_count (8) >= max_items (7) after first batch, it won't fetch the second batch
|
||||
assert len(posts) == 8
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||
def test_fetch_lesswrong_posts_api_error(mock_fetch):
|
||||
mock_fetch.side_effect = Exception("API Error")
|
||||
|
||||
posts = list(fetch_lesswrong_posts(since=datetime(2023, 1, 15)))
|
||||
assert posts == []
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||
def test_fetch_lesswrong_posts_no_date_progression(mock_fetch):
|
||||
# Mock posts with same date to trigger stopping condition
|
||||
same_date = datetime(2023, 1, 15)
|
||||
mock_fetch.side_effect = [
|
||||
[{"published_at": same_date, "url": "post1"}],
|
||||
[{"published_at": same_date, "url": "post2"}], # Same date, should stop
|
||||
]
|
||||
|
||||
posts = list(fetch_lesswrong_posts(since=same_date, cooldown=0))
|
||||
|
||||
assert len(posts) == 1 # Should stop after first batch
|
||||
|
||||
|
||||
@patch("memory.parsers.lesswrong.fetch_lesswrong")
|
||||
def test_fetch_lesswrong_posts_none_date(mock_fetch):
|
||||
# Mock posts with None date to trigger stopping condition
|
||||
mock_fetch.side_effect = [
|
||||
[{"published_at": None, "url": "post1"}],
|
||||
]
|
||||
|
||||
posts = list(fetch_lesswrong_posts(since=datetime(2023, 1, 15), cooldown=0))
|
||||
|
||||
assert len(posts) == 1 # Should stop after first batch
|
||||
|
||||
|
||||
def test_lesswrong_post_type():
|
||||
"""Test that LessWrongPost TypedDict has correct structure."""
|
||||
# This is more of a documentation test to ensure the type is correct
|
||||
post: LessWrongPost = {
|
||||
"title": "Test",
|
||||
"url": "https://example.com",
|
||||
"description": "Description",
|
||||
"content": "Content",
|
||||
"authors": ["Author"],
|
||||
"published_at": datetime.now(),
|
||||
"guid": "123",
|
||||
"karma": 10,
|
||||
"votes": 5,
|
||||
"comments": 2,
|
||||
"words": 100,
|
||||
"tags": ["tag"],
|
||||
"af": False,
|
||||
"score": 15,
|
||||
"extended_score": 20,
|
||||
}
|
||||
|
||||
# Optional fields
|
||||
post["modified_at"] = "2023-01-15T10:30:00Z"
|
||||
post["slug"] = "test-slug"
|
||||
post["images"] = ["image.jpg"]
|
||||
|
||||
# Should not raise any type errors
|
||||
assert post["title"] == "Test"
|
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from memory.common.db.models import ArticleFeed, BlogPost
|
||||
@ -64,37 +64,6 @@ def inactive_article_feed(db_session):
|
||||
return feed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recently_checked_feed(db_session):
|
||||
"""Create a recently checked ArticleFeed."""
|
||||
from sqlalchemy import text
|
||||
|
||||
# Use a very recent timestamp that will trigger the "recently checked" condition
|
||||
# The check_interval is 3600 seconds, so 30 seconds ago should be "recent"
|
||||
recent_time = datetime.now() - timedelta(seconds=30)
|
||||
|
||||
feed = ArticleFeed(
|
||||
url="https://example.com/recent.xml",
|
||||
title="Recent Feed",
|
||||
description="A recently checked feed",
|
||||
tags=["test"],
|
||||
check_interval=3600,
|
||||
active=True,
|
||||
)
|
||||
db_session.add(feed)
|
||||
db_session.flush() # Get the ID
|
||||
|
||||
# Manually set the last_checked_at to avoid timezone issues
|
||||
db_session.execute(
|
||||
text(
|
||||
"UPDATE article_feeds SET last_checked_at = :timestamp WHERE id = :feed_id"
|
||||
),
|
||||
{"timestamp": recent_time, "feed_id": feed.id},
|
||||
)
|
||||
db_session.commit()
|
||||
return feed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feed_item():
|
||||
"""Mock feed item for testing."""
|
||||
@ -564,3 +533,65 @@ def test_sync_website_archive_empty_results(
|
||||
assert result["articles_found"] == 0
|
||||
assert result["new_articles"] == 0
|
||||
assert result["task_ids"] == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"check_interval_minutes,seconds_since_check,should_skip",
|
||||
[
|
||||
(60, 30, True), # 60min interval, checked 30s ago -> skip
|
||||
(60, 3000, True), # 60min interval, checked 50min ago -> skip
|
||||
(60, 4000, False), # 60min interval, checked 66min ago -> don't skip
|
||||
(30, 1000, True), # 30min interval, checked 16min ago -> skip
|
||||
(30, 2000, False), # 30min interval, checked 33min ago -> don't skip
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.blogs.get_feed_parser")
|
||||
def test_sync_article_feed_check_interval(
|
||||
mock_get_parser,
|
||||
check_interval_minutes,
|
||||
seconds_since_check,
|
||||
should_skip,
|
||||
db_session,
|
||||
):
|
||||
"""Test sync respects check interval with various timing scenarios."""
|
||||
from sqlalchemy import text
|
||||
|
||||
# Mock parser to return None (no parser available) for non-skipped cases
|
||||
mock_get_parser.return_value = None
|
||||
|
||||
# Create feed with specific check interval
|
||||
feed = ArticleFeed(
|
||||
url="https://example.com/interval-test.xml",
|
||||
title="Interval Test Feed",
|
||||
description="Feed for testing check intervals",
|
||||
tags=["test"],
|
||||
check_interval=check_interval_minutes,
|
||||
active=True,
|
||||
)
|
||||
db_session.add(feed)
|
||||
db_session.flush()
|
||||
|
||||
# Set last_checked_at to specific time in the past
|
||||
last_checked_time = datetime.now(timezone.utc) - timedelta(
|
||||
seconds=seconds_since_check
|
||||
)
|
||||
db_session.execute(
|
||||
text(
|
||||
"UPDATE article_feeds SET last_checked_at = :timestamp WHERE id = :feed_id"
|
||||
),
|
||||
{"timestamp": last_checked_time, "feed_id": feed.id},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
result = blogs.sync_article_feed(feed.id)
|
||||
|
||||
if should_skip:
|
||||
assert result == {"status": "skipped_recent_check", "feed_id": feed.id}
|
||||
# get_feed_parser should not be called when skipping
|
||||
mock_get_parser.assert_not_called()
|
||||
else:
|
||||
# Should proceed with sync, but will fail due to no parser - that's expected
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "No parser available for feed"
|
||||
# get_feed_parser should be called when not skipping
|
||||
mock_get_parser.assert_called_once()
|
||||
|
@ -471,11 +471,9 @@ def test_process_content_item(
|
||||
"memory.workers.tasks.content_processing.push_to_qdrant",
|
||||
side_effect=Exception("Qdrant error"),
|
||||
):
|
||||
result = process_content_item(
|
||||
mail_message, "mail", db_session, ["tag1"]
|
||||
)
|
||||
result = process_content_item(mail_message, db_session)
|
||||
else:
|
||||
result = process_content_item(mail_message, "mail", db_session, ["tag1"])
|
||||
result = process_content_item(mail_message, db_session)
|
||||
|
||||
assert result["status"] == expected_status
|
||||
assert result["embed_status"] == expected_embed_status
|
||||
|
@ -326,7 +326,7 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
||||
# Check that the full content was passed to the embedding function
|
||||
texts = mock_voyage_client.embed.call_args[0][0]
|
||||
assert texts == [
|
||||
large_page_1.strip(),
|
||||
large_page_2.strip(),
|
||||
large_section_content.strip(),
|
||||
[large_page_1.strip()],
|
||||
[large_page_2.strip()],
|
||||
[large_section_content.strip()],
|
||||
]
|
||||
|
562
tests/memory/workers/tasks/test_forums_tasks.py
Normal file
562
tests/memory/workers/tasks/test_forums_tasks.py
Normal file
@ -0,0 +1,562 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from memory.common.db.models import ForumPost
|
||||
from memory.workers.tasks import forums
|
||||
from memory.parsers.lesswrong import LessWrongPost
|
||||
from memory.workers.tasks.content_processing import create_content_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lesswrong_post():
|
||||
"""Mock LessWrong post data for testing."""
|
||||
return LessWrongPost(
|
||||
title="Test LessWrong Post",
|
||||
url="https://www.lesswrong.com/posts/test123/test-post",
|
||||
description="This is a test post description",
|
||||
content="This is test post content with enough text to be processed and embedded.",
|
||||
authors=["Test Author"],
|
||||
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
guid="test123",
|
||||
karma=25,
|
||||
votes=10,
|
||||
comments=5,
|
||||
words=100,
|
||||
tags=["rationality", "ai"],
|
||||
af=False,
|
||||
score=25,
|
||||
extended_score=30,
|
||||
modified_at="2024-01-01T12:30:00Z",
|
||||
slug="test-post",
|
||||
images=[], # Empty images to avoid file path issues
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_empty_lesswrong_post():
|
||||
"""Mock LessWrong post with empty content."""
|
||||
return LessWrongPost(
|
||||
title="Empty Post",
|
||||
url="https://www.lesswrong.com/posts/empty123/empty-post",
|
||||
description="",
|
||||
content="",
|
||||
authors=["Empty Author"],
|
||||
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
guid="empty123",
|
||||
karma=5,
|
||||
votes=2,
|
||||
comments=0,
|
||||
words=0,
|
||||
tags=[],
|
||||
af=False,
|
||||
score=5,
|
||||
extended_score=5,
|
||||
slug="empty-post",
|
||||
images=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_af_post():
|
||||
"""Mock Alignment Forum post."""
|
||||
return LessWrongPost(
|
||||
title="AI Safety Research",
|
||||
url="https://www.lesswrong.com/posts/af123/ai-safety-research",
|
||||
description="Important AI safety research",
|
||||
content="This is important AI safety research content that should be processed.",
|
||||
authors=["AI Researcher"],
|
||||
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
guid="af123",
|
||||
karma=50,
|
||||
votes=20,
|
||||
comments=15,
|
||||
words=200,
|
||||
tags=["ai-safety", "alignment"],
|
||||
af=True,
|
||||
score=50,
|
||||
extended_score=60,
|
||||
slug="ai-safety-research",
|
||||
images=[],
|
||||
)
|
||||
|
||||
|
||||
def test_sync_lesswrong_post_success(mock_lesswrong_post, db_session, qdrant):
|
||||
"""Test successful LessWrong post synchronization."""
|
||||
result = forums.sync_lesswrong_post(mock_lesswrong_post, ["test", "forum"])
|
||||
|
||||
# Verify the ForumPost was created in the database
|
||||
forum_post = (
|
||||
db_session.query(ForumPost)
|
||||
.filter_by(url="https://www.lesswrong.com/posts/test123/test-post")
|
||||
.first()
|
||||
)
|
||||
assert forum_post is not None
|
||||
assert forum_post.title == "Test LessWrong Post"
|
||||
assert (
|
||||
forum_post.content
|
||||
== "This is test post content with enough text to be processed and embedded."
|
||||
)
|
||||
assert forum_post.modality == "forum"
|
||||
assert forum_post.mime_type == "text/markdown"
|
||||
assert forum_post.authors == ["Test Author"]
|
||||
assert forum_post.karma == 25
|
||||
assert forum_post.votes == 10
|
||||
assert forum_post.comments == 5
|
||||
assert forum_post.words == 100
|
||||
assert forum_post.score == 25
|
||||
assert forum_post.slug == "test-post"
|
||||
assert forum_post.images == []
|
||||
assert "test" in forum_post.tags
|
||||
assert "forum" in forum_post.tags
|
||||
assert "rationality" in forum_post.tags
|
||||
assert "ai" in forum_post.tags
|
||||
|
||||
# Verify the result
|
||||
assert result["status"] == "processed"
|
||||
assert result["forumpost_id"] == forum_post.id
|
||||
assert result["title"] == "Test LessWrong Post"
|
||||
|
||||
|
||||
def test_sync_lesswrong_post_empty_content(mock_empty_lesswrong_post, db_session):
|
||||
"""Test LessWrong post sync with empty content."""
|
||||
result = forums.sync_lesswrong_post(mock_empty_lesswrong_post)
|
||||
|
||||
# Should still create the post but with failed status due to no chunks
|
||||
forum_post = (
|
||||
db_session.query(ForumPost)
|
||||
.filter_by(url="https://www.lesswrong.com/posts/empty123/empty-post")
|
||||
.first()
|
||||
)
|
||||
assert forum_post is not None
|
||||
assert forum_post.title == "Empty Post"
|
||||
assert forum_post.content == ""
|
||||
assert result["status"] == "failed" # No chunks generated for empty content
|
||||
assert result["chunks_count"] == 0
|
||||
|
||||
|
||||
def test_sync_lesswrong_post_already_exists(mock_lesswrong_post, db_session):
|
||||
"""Test LessWrong post sync when content already exists."""
|
||||
# Add existing forum post with same content hash
|
||||
existing_post = ForumPost(
|
||||
url="https://www.lesswrong.com/posts/test123/test-post",
|
||||
title="Test LessWrong Post",
|
||||
content="This is test post content with enough text to be processed and embedded.",
|
||||
sha256=create_content_hash(
|
||||
"This is test post content with enough text to be processed and embedded."
|
||||
),
|
||||
modality="forum",
|
||||
tags=["existing"],
|
||||
mime_type="text/markdown",
|
||||
size=77,
|
||||
authors=["Test Author"],
|
||||
karma=25,
|
||||
)
|
||||
db_session.add(existing_post)
|
||||
db_session.commit()
|
||||
|
||||
result = forums.sync_lesswrong_post(mock_lesswrong_post, ["test"])
|
||||
|
||||
assert result["status"] == "already_exists"
|
||||
assert result["forumpost_id"] == existing_post.id
|
||||
|
||||
# Verify no duplicate was created
|
||||
forum_posts = (
|
||||
db_session.query(ForumPost)
|
||||
.filter_by(url="https://www.lesswrong.com/posts/test123/test-post")
|
||||
.all()
|
||||
)
|
||||
assert len(forum_posts) == 1
|
||||
|
||||
|
||||
def test_sync_lesswrong_post_with_custom_tags(mock_lesswrong_post, db_session, qdrant):
|
||||
"""Test LessWrong post sync with custom tags."""
|
||||
result = forums.sync_lesswrong_post(mock_lesswrong_post, ["custom", "tags"])
|
||||
|
||||
# Verify the ForumPost was created with custom tags merged with post tags
|
||||
forum_post = (
|
||||
db_session.query(ForumPost)
|
||||
.filter_by(url="https://www.lesswrong.com/posts/test123/test-post")
|
||||
.first()
|
||||
)
|
||||
assert forum_post is not None
|
||||
assert "custom" in forum_post.tags
|
||||
assert "tags" in forum_post.tags
|
||||
assert "rationality" in forum_post.tags # Original post tags
|
||||
assert "ai" in forum_post.tags
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
def test_sync_lesswrong_post_af_post(mock_af_post, db_session, qdrant):
|
||||
"""Test syncing an Alignment Forum post."""
|
||||
result = forums.sync_lesswrong_post(mock_af_post, ["alignment-forum"])
|
||||
|
||||
forum_post = (
|
||||
db_session.query(ForumPost)
|
||||
.filter_by(url="https://www.lesswrong.com/posts/af123/ai-safety-research")
|
||||
.first()
|
||||
)
|
||||
assert forum_post is not None
|
||||
assert forum_post.title == "AI Safety Research"
|
||||
assert forum_post.karma == 50
|
||||
assert "ai-safety" in forum_post.tags
|
||||
assert "alignment" in forum_post.tags
|
||||
assert "alignment-forum" in forum_post.tags
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||
def test_sync_lesswrong_success(mock_fetch, mock_lesswrong_post, db_session):
|
||||
"""Test successful LessWrong synchronization."""
|
||||
mock_fetch.return_value = [mock_lesswrong_post]
|
||||
|
||||
with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post:
|
||||
mock_sync_post.delay.return_value = Mock(id="task-123")
|
||||
|
||||
result = forums.sync_lesswrong(
|
||||
since="2024-01-01T00:00:00",
|
||||
min_karma=10,
|
||||
limit=50,
|
||||
cooldown=0.1,
|
||||
max_items=100,
|
||||
af=False,
|
||||
tags=["test"],
|
||||
)
|
||||
|
||||
assert result["posts_num"] == 1
|
||||
assert result["new_posts"] == 1
|
||||
assert result["since"] == "2024-01-01T00:00:00"
|
||||
assert result["min_karma"] == 10
|
||||
assert result["max_items"] == 100
|
||||
assert result["af"] == False
|
||||
|
||||
# Verify fetch_lesswrong_posts was called with correct arguments
|
||||
mock_fetch.assert_called_once_with(
|
||||
datetime.fromisoformat("2024-01-01T00:00:00"),
|
||||
10, # min_karma
|
||||
50, # limit
|
||||
0.1, # cooldown
|
||||
100, # max_items
|
||||
False, # af
|
||||
)
|
||||
|
||||
# Verify sync_lesswrong_post was called for the new post
|
||||
mock_sync_post.delay.assert_called_once_with(mock_lesswrong_post, ["test"])
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||
def test_sync_lesswrong_with_existing_posts(
|
||||
mock_fetch, mock_lesswrong_post, db_session
|
||||
):
|
||||
"""Test sync when some posts already exist."""
|
||||
# Create existing forum post
|
||||
existing_post = ForumPost(
|
||||
url="https://www.lesswrong.com/posts/test123/test-post",
|
||||
title="Existing Post",
|
||||
content="Existing content",
|
||||
sha256=b"existing_hash" + bytes(24),
|
||||
modality="forum",
|
||||
tags=["existing"],
|
||||
mime_type="text/markdown",
|
||||
size=100,
|
||||
authors=["Test Author"],
|
||||
)
|
||||
db_session.add(existing_post)
|
||||
db_session.commit()
|
||||
|
||||
# Mock fetch to return existing post and a new one
|
||||
new_post = mock_lesswrong_post.copy()
|
||||
new_post["url"] = "https://www.lesswrong.com/posts/new123/new-post"
|
||||
new_post["title"] = "New Post"
|
||||
|
||||
mock_fetch.return_value = [mock_lesswrong_post, new_post]
|
||||
|
||||
with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post:
|
||||
mock_sync_post.delay.return_value = Mock(id="task-456")
|
||||
|
||||
result = forums.sync_lesswrong(max_items=100)
|
||||
|
||||
assert result["posts_num"] == 2
|
||||
assert result["new_posts"] == 1 # Only one new post
|
||||
|
||||
# Verify sync_lesswrong_post was only called for the new post
|
||||
mock_sync_post.delay.assert_called_once_with(new_post, [])
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||
def test_sync_lesswrong_no_posts(mock_fetch, db_session):
|
||||
"""Test sync when no posts are returned."""
|
||||
mock_fetch.return_value = []
|
||||
|
||||
result = forums.sync_lesswrong()
|
||||
|
||||
assert result["posts_num"] == 0
|
||||
assert result["new_posts"] == 0
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||
def test_sync_lesswrong_max_items_limit(mock_fetch, db_session):
|
||||
"""Test that max_items limit is respected."""
|
||||
# Create multiple mock posts
|
||||
posts = []
|
||||
for i in range(5):
|
||||
post = LessWrongPost(
|
||||
title=f"Post {i}",
|
||||
url=f"https://www.lesswrong.com/posts/test{i}/post-{i}",
|
||||
description=f"Description {i}",
|
||||
content=f"Content {i}",
|
||||
authors=[f"Author {i}"],
|
||||
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
guid=f"test{i}",
|
||||
karma=10,
|
||||
votes=5,
|
||||
comments=2,
|
||||
words=50,
|
||||
tags=[],
|
||||
af=False,
|
||||
score=10,
|
||||
extended_score=10,
|
||||
slug=f"post-{i}",
|
||||
images=[],
|
||||
)
|
||||
posts.append(post)
|
||||
|
||||
mock_fetch.return_value = posts
|
||||
|
||||
with patch("memory.workers.tasks.forums.sync_lesswrong_post") as mock_sync_post:
|
||||
mock_sync_post.delay.return_value = Mock(id="task-123")
|
||||
|
||||
result = forums.sync_lesswrong(max_items=3)
|
||||
|
||||
# Should stop at max_items, but new_posts can be higher because
|
||||
# the check happens after incrementing new_posts but before incrementing posts_num
|
||||
assert result["posts_num"] == 3
|
||||
assert result["new_posts"] == 4 # One more than posts_num due to timing
|
||||
assert result["max_items"] == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"since_str,expected_days_ago",
|
||||
[
|
||||
("2024-01-01T00:00:00", None), # Specific date
|
||||
(None, 30), # Default should be 30 days ago
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||
def test_sync_lesswrong_since_parameter(
|
||||
mock_fetch, since_str, expected_days_ago, db_session
|
||||
):
|
||||
"""Test that since parameter is handled correctly."""
|
||||
mock_fetch.return_value = []
|
||||
|
||||
if since_str:
|
||||
forums.sync_lesswrong(since=since_str)
|
||||
expected_since = datetime.fromisoformat(since_str)
|
||||
else:
|
||||
forums.sync_lesswrong()
|
||||
# Should default to 30 days ago
|
||||
expected_since = datetime.now() - timedelta(days=30)
|
||||
|
||||
# Verify fetch was called with correct since date
|
||||
call_args = mock_fetch.call_args[0]
|
||||
actual_since = call_args[0]
|
||||
|
||||
if expected_days_ago:
|
||||
# For default case, check it's approximately 30 days ago
|
||||
time_diff = abs((actual_since - expected_since).total_seconds())
|
||||
assert time_diff < 120 # Within 2 minute tolerance for slow tests
|
||||
else:
|
||||
assert actual_since == expected_since
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"af_value,min_karma,limit,cooldown",
|
||||
[
|
||||
(True, 20, 25, 1.0),
|
||||
(False, 5, 100, 0.0),
|
||||
(True, 50, 10, 0.5),
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||
def test_sync_lesswrong_parameters(
|
||||
mock_fetch, af_value, min_karma, limit, cooldown, db_session
|
||||
):
|
||||
"""Test that all parameters are passed correctly to fetch function."""
|
||||
mock_fetch.return_value = []
|
||||
|
||||
result = forums.sync_lesswrong(
|
||||
af=af_value,
|
||||
min_karma=min_karma,
|
||||
limit=limit,
|
||||
cooldown=cooldown,
|
||||
max_items=500,
|
||||
)
|
||||
|
||||
# Verify fetch was called with correct parameters
|
||||
call_args = mock_fetch.call_args[0]
|
||||
|
||||
assert call_args[1] == min_karma # min_karma
|
||||
assert call_args[2] == limit # limit
|
||||
assert call_args[3] == cooldown # cooldown
|
||||
assert call_args[4] == 500 # max_items
|
||||
assert call_args[5] == af_value # af
|
||||
|
||||
assert result["min_karma"] == min_karma
|
||||
assert result["af"] == af_value
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||
def test_sync_lesswrong_fetch_error(mock_fetch, db_session):
|
||||
"""Test sync when fetch_lesswrong_posts raises an exception."""
|
||||
mock_fetch.side_effect = Exception("API error")
|
||||
|
||||
# The safe_task_execution decorator should catch this
|
||||
result = forums.sync_lesswrong()
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "API error" in result["error"]
|
||||
|
||||
|
||||
def test_sync_lesswrong_post_error_handling(db_session):
|
||||
"""Test error handling in sync_lesswrong_post."""
|
||||
# Create invalid post data that will cause an error
|
||||
invalid_post = {
|
||||
"title": "Test",
|
||||
"url": "invalid-url",
|
||||
"content": "test content",
|
||||
# Missing required fields
|
||||
}
|
||||
|
||||
# The safe_task_execution decorator should catch this
|
||||
result = forums.sync_lesswrong_post(invalid_post)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "error" in result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"post_tags,additional_tags,expected_tags",
|
||||
[
|
||||
(["original"], ["new"], ["original", "new"]),
|
||||
([], ["tag1", "tag2"], ["tag1", "tag2"]),
|
||||
(["existing"], [], ["existing"]),
|
||||
(["dup", "tag"], ["tag", "new"], ["dup", "tag", "new"]), # Duplicates removed
|
||||
],
|
||||
)
|
||||
def test_sync_lesswrong_post_tag_merging(
|
||||
post_tags, additional_tags, expected_tags, db_session, qdrant
|
||||
):
|
||||
"""Test that post tags and additional tags are properly merged."""
|
||||
post = LessWrongPost(
|
||||
title="Tag Test Post",
|
||||
url="https://www.lesswrong.com/posts/tag123/tag-test",
|
||||
description="Test description",
|
||||
content="Test content for tag merging",
|
||||
authors=["Tag Author"],
|
||||
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
guid="tag123",
|
||||
karma=15,
|
||||
votes=5,
|
||||
comments=2,
|
||||
words=50,
|
||||
tags=post_tags,
|
||||
af=False,
|
||||
score=15,
|
||||
extended_score=15,
|
||||
slug="tag-test",
|
||||
images=[],
|
||||
)
|
||||
|
||||
forums.sync_lesswrong_post(post, additional_tags)
|
||||
|
||||
forum_post = (
|
||||
db_session.query(ForumPost)
|
||||
.filter_by(url="https://www.lesswrong.com/posts/tag123/tag-test")
|
||||
.first()
|
||||
)
|
||||
assert forum_post is not None
|
||||
|
||||
# Check that all expected tags are present (order doesn't matter)
|
||||
for tag in expected_tags:
|
||||
assert tag in forum_post.tags
|
||||
|
||||
# Check that no unexpected tags are present
|
||||
assert len(forum_post.tags) == len(set(expected_tags))
|
||||
|
||||
|
||||
def test_sync_lesswrong_post_datetime_handling(db_session, qdrant):
|
||||
"""Test that datetime fields are properly handled."""
|
||||
post = LessWrongPost(
|
||||
title="DateTime Test",
|
||||
url="https://www.lesswrong.com/posts/dt123/datetime-test",
|
||||
description="Test description",
|
||||
content="Test content",
|
||||
authors=["DateTime Author"],
|
||||
published_at=datetime(2024, 6, 15, 14, 30, 45, tzinfo=timezone.utc),
|
||||
guid="dt123",
|
||||
karma=20,
|
||||
votes=8,
|
||||
comments=3,
|
||||
words=75,
|
||||
tags=["datetime"],
|
||||
af=False,
|
||||
score=20,
|
||||
extended_score=25,
|
||||
modified_at="2024-06-15T15:00:00Z",
|
||||
slug="datetime-test",
|
||||
images=[],
|
||||
)
|
||||
|
||||
result = forums.sync_lesswrong_post(post)
|
||||
|
||||
forum_post = (
|
||||
db_session.query(ForumPost)
|
||||
.filter_by(url="https://www.lesswrong.com/posts/dt123/datetime-test")
|
||||
.first()
|
||||
)
|
||||
assert forum_post is not None
|
||||
assert forum_post.published_at == datetime(
|
||||
2024, 6, 15, 14, 30, 45, tzinfo=timezone.utc
|
||||
)
|
||||
# modified_at should be stored as string in the post data
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
def test_sync_lesswrong_post_content_hash_consistency(db_session):
|
||||
"""Test that content hash is calculated consistently."""
|
||||
post = LessWrongPost(
|
||||
title="Hash Test",
|
||||
url="https://www.lesswrong.com/posts/hash123/hash-test",
|
||||
description="Test description",
|
||||
content="Consistent content for hashing",
|
||||
authors=["Hash Author"],
|
||||
published_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
guid="hash123",
|
||||
karma=15,
|
||||
votes=5,
|
||||
comments=2,
|
||||
words=50,
|
||||
tags=["hash"],
|
||||
af=False,
|
||||
score=15,
|
||||
extended_score=15,
|
||||
slug="hash-test",
|
||||
images=[],
|
||||
)
|
||||
|
||||
# Sync the same post twice
|
||||
result1 = forums.sync_lesswrong_post(post)
|
||||
result2 = forums.sync_lesswrong_post(post)
|
||||
|
||||
# First should succeed, second should detect existing
|
||||
assert result1["status"] == "processed"
|
||||
assert result2["status"] == "already_exists"
|
||||
assert result1["forumpost_id"] == result2["forumpost_id"]
|
||||
|
||||
# Verify only one post exists in database
|
||||
forum_posts = (
|
||||
db_session.query(ForumPost)
|
||||
.filter_by(url="https://www.lesswrong.com/posts/hash123/hash-test")
|
||||
.all()
|
||||
)
|
||||
assert len(forum_posts) == 1
|
@ -7,12 +7,14 @@ from PIL import Image
|
||||
|
||||
from memory.common import qdrant as qd
|
||||
from memory.common import settings
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common.db.models import Chunk, SourceItem, MailMessage, BlogPost
|
||||
from memory.workers.tasks.maintenance import (
|
||||
clean_collection,
|
||||
reingest_chunk,
|
||||
check_batch,
|
||||
reingest_missing_chunks,
|
||||
reingest_item,
|
||||
reingest_empty_source_items,
|
||||
)
|
||||
|
||||
|
||||
@ -302,5 +304,295 @@ def test_reingest_missing_chunks(db_session, qdrant, batch_size):
|
||||
assert set(qdrant_ids) == set(db_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
|
||||
def test_reingest_item_success(db_session, qdrant, item_type):
|
||||
"""Test successful reingestion of an item with existing chunks."""
|
||||
if item_type == "MailMessage":
|
||||
item = MailMessage(
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="STORED",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content="Test content for reingestion",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
else: # blog_post
|
||||
item = BlogPost(
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="text/html",
|
||||
embed_status="STORED",
|
||||
url="https://example.com/post",
|
||||
title="Test Blog Post",
|
||||
author="Author Name",
|
||||
content="Test blog content for reingestion",
|
||||
modality="blog",
|
||||
)
|
||||
|
||||
db_session.add(item)
|
||||
db_session.flush()
|
||||
|
||||
# Add some chunks to the item
|
||||
chunk_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
chunks = [
|
||||
Chunk(
|
||||
id=chunk_id,
|
||||
source=item,
|
||||
content=f"Test chunk content {i}",
|
||||
embedding_model="test-model",
|
||||
)
|
||||
for i, chunk_id in enumerate(chunk_ids)
|
||||
]
|
||||
db_session.add_all(chunks)
|
||||
db_session.commit()
|
||||
|
||||
# Add vectors to Qdrant
|
||||
modality = "mail" if item_type == "MailMessage" else "blog"
|
||||
qd.ensure_collection_exists(qdrant, modality, 1024)
|
||||
qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids))
|
||||
|
||||
# Verify chunks exist in Qdrant before reingestion
|
||||
qdrant_ids_before = {
|
||||
str(i) for batch in qd.batch_ids(qdrant, modality) for i in batch
|
||||
}
|
||||
assert set(chunk_ids).issubset(qdrant_ids_before)
|
||||
|
||||
# Mock the embedding function to return chunks
|
||||
with patch("memory.common.embedding.embed_source_item") as mock_embed:
|
||||
mock_embed.return_value = [
|
||||
Chunk(
|
||||
id=str(uuid.uuid4()),
|
||||
content="New chunk content 1",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||
),
|
||||
Chunk(
|
||||
id=str(uuid.uuid4()),
|
||||
content="New chunk content 2",
|
||||
embedding_model="test-model",
|
||||
vector=[0.2] * 1024,
|
||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||
),
|
||||
]
|
||||
|
||||
result = reingest_item(str(item.id), item_type)
|
||||
|
||||
assert result["status"] == "processed"
|
||||
assert result[f"{item_type.lower()}_id"] == item.id
|
||||
assert result["chunks_count"] == 2
|
||||
assert result["embed_status"] == "STORED"
|
||||
|
||||
# Verify old chunks were deleted from database
|
||||
db_session.refresh(item)
|
||||
remaining_chunks = db_session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
|
||||
assert len(remaining_chunks) == 0
|
||||
|
||||
# Verify old vectors were deleted from Qdrant
|
||||
qdrant_ids_after = {
|
||||
str(i) for batch in qd.batch_ids(qdrant, modality) for i in batch
|
||||
}
|
||||
assert not set(chunk_ids).intersection(qdrant_ids_after)
|
||||
|
||||
|
||||
def test_reingest_item_not_found(db_session):
|
||||
"""Test reingesting a non-existent item."""
|
||||
non_existent_id = "999"
|
||||
result = reingest_item(non_existent_id, "MailMessage")
|
||||
|
||||
assert result == {"status": "error", "error": f"Item {non_existent_id} not found"}
|
||||
|
||||
|
||||
def test_reingest_item_invalid_type(db_session):
|
||||
"""Test reingesting with an invalid item type."""
|
||||
result = reingest_item("1", "invalid_type")
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "Unsupported item type invalid_type" in result["error"]
|
||||
assert "Available types:" in result["error"]
|
||||
|
||||
|
||||
def test_reingest_item_no_chunks(db_session, qdrant):
|
||||
"""Test reingesting an item that has no chunks."""
|
||||
item = MailMessage(
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content="Test content",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
db_session.add(item)
|
||||
db_session.commit()
|
||||
|
||||
# Mock the embedding function to return a chunk
|
||||
with patch("memory.common.embedding.embed_source_item") as mock_embed:
|
||||
mock_embed.return_value = [
|
||||
Chunk(
|
||||
id=str(uuid.uuid4()),
|
||||
content="New chunk content",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||
),
|
||||
]
|
||||
|
||||
result = reingest_item(str(item.id), "MailMessage")
|
||||
|
||||
assert result["status"] == "processed"
|
||||
assert result["mailmessage_id"] == item.id
|
||||
assert result["chunks_count"] == 1
|
||||
assert result["embed_status"] == "STORED"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
|
||||
def test_reingest_empty_source_items_success(db_session, item_type):
|
||||
"""Test reingesting empty source items."""
|
||||
# Create items with and without chunks
|
||||
if item_type == "MailMessage":
|
||||
empty_items = [
|
||||
MailMessage(
|
||||
sha256=f"empty_hash_{i}".encode() + bytes(32 - len(f"empty_hash_{i}")),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
message_id=f"<empty{i}@example.com>",
|
||||
subject=f"Empty Subject {i}",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content=f"Empty content {i}",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
item_with_chunks = MailMessage(
|
||||
sha256=b"with_chunks_hash" + bytes(16),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="STORED",
|
||||
message_id="<with_chunks@example.com>",
|
||||
subject="With Chunks Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content="Content with chunks",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
else: # blog_post
|
||||
empty_items = [
|
||||
BlogPost(
|
||||
sha256=f"empty_hash_{i}".encode() + bytes(32 - len(f"empty_hash_{i}")),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="text/html",
|
||||
embed_status="RAW",
|
||||
url=f"https://example.com/empty{i}",
|
||||
title=f"Empty Post {i}",
|
||||
author="Author Name",
|
||||
content=f"Empty blog content {i}",
|
||||
modality="blog",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
item_with_chunks = BlogPost(
|
||||
sha256=b"with_chunks_hash" + bytes(16),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="text/html",
|
||||
embed_status="STORED",
|
||||
url="https://example.com/with_chunks",
|
||||
title="With Chunks Post",
|
||||
author="Author Name",
|
||||
content="Blog content with chunks",
|
||||
modality="blog",
|
||||
)
|
||||
|
||||
db_session.add_all(empty_items + [item_with_chunks])
|
||||
db_session.flush()
|
||||
|
||||
# Add a chunk to the item_with_chunks
|
||||
chunk = Chunk(
|
||||
id=str(uuid.uuid4()),
|
||||
source=item_with_chunks,
|
||||
content="Test chunk content",
|
||||
embedding_model="test-model",
|
||||
)
|
||||
db_session.add(chunk)
|
||||
db_session.commit()
|
||||
|
||||
with patch.object(reingest_item, "delay") as mock_reingest:
|
||||
result = reingest_empty_source_items(item_type)
|
||||
|
||||
assert result == {"status": "success", "items": 3}
|
||||
|
||||
# Verify that reingest_item.delay was called for each empty item
|
||||
assert mock_reingest.call_count == 3
|
||||
expected_calls = [call(item.id, item_type) for item in empty_items]
|
||||
mock_reingest.assert_has_calls(expected_calls, any_order=True)
|
||||
|
||||
|
||||
def test_reingest_empty_source_items_no_empty_items(db_session):
|
||||
"""Test when there are no empty source items."""
|
||||
# Create an item with chunks
|
||||
item = MailMessage(
|
||||
sha256=b"with_chunks_hash" + bytes(16),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="STORED",
|
||||
message_id="<with_chunks@example.com>",
|
||||
subject="With Chunks Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content="Content with chunks",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
db_session.add(item)
|
||||
db_session.flush()
|
||||
|
||||
chunk = Chunk(
|
||||
id=str(uuid.uuid4()),
|
||||
source=item,
|
||||
content="Test chunk content",
|
||||
embedding_model="test-model",
|
||||
)
|
||||
db_session.add(chunk)
|
||||
db_session.commit()
|
||||
|
||||
with patch.object(reingest_item, "delay") as mock_reingest:
|
||||
result = reingest_empty_source_items("MailMessage")
|
||||
|
||||
assert result == {"status": "success", "items": 0}
|
||||
mock_reingest.assert_not_called()
|
||||
|
||||
|
||||
def test_reingest_empty_source_items_invalid_type(db_session):
|
||||
"""Test reingesting empty source items with invalid type."""
|
||||
result = reingest_empty_source_items("invalid_type")
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "Unsupported item type invalid_type" in result["error"]
|
||||
assert "Available types:" in result["error"]
|
||||
|
||||
|
||||
def test_reingest_missing_chunks_no_chunks(db_session):
|
||||
assert reingest_missing_chunks() == {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user