From 4aaa45e09cc1cb44846b54aafaaac98aab159427 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 25 May 2025 20:02:47 +0200 Subject: [PATCH] unify tasks --- src/memory/api/app.py | 4 +- src/memory/common/collections.py | 107 +++++ src/memory/common/db/models.py | 75 ++- src/memory/common/embedding.py | 214 ++------- src/memory/common/extract.py | 26 +- src/memory/common/parsers/email.py | 2 + src/memory/common/parsers/html.py | 12 +- src/memory/common/qdrant.py | 7 +- src/memory/workers/email.py | 76 +-- src/memory/workers/tasks/blogs.py | 162 ++----- src/memory/workers/tasks/comic.py | 49 +- .../workers/tasks/content_processing.py | 267 +++++++++++ src/memory/workers/tasks/ebook.py | 108 +---- src/memory/workers/tasks/email.py | 53 ++- src/memory/workers/tasks/maintenance.py | 6 +- tests/conftest.py | 11 +- .../common/parsers/test_email_parsers.py | 1 + tests/memory/common/parsers/test_html.py | 132 +++--- tests/memory/common/test_embedding.py | 249 +++++++--- tests/memory/common/test_extract.py | 47 -- .../memory/workers/tasks/test_comic_tasks.py | 431 ++++++++++++++++++ .../memory/workers/tasks/test_ebook_tasks.py | 64 +-- .../memory/workers/tasks/test_email_tasks.py | 22 +- tests/memory/workers/test_email.py | 8 +- 24 files changed, 1360 insertions(+), 773 deletions(-) create mode 100644 src/memory/common/collections.py create mode 100644 src/memory/workers/tasks/content_processing.py create mode 100644 tests/memory/workers/tasks/test_comic_tasks.py diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 41f2b79..30dbeb3 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -15,7 +15,7 @@ from PIL import Image from pydantic import BaseModel from memory.common import embedding, qdrant, extract, settings -from memory.common.embedding import get_modality +from memory.common.collections import get_modality, TEXT_COLLECTIONS from memory.common.db.connection import make_session from memory.common.db.models import Chunk, SourceItem @@ -189,7 +189,7 @@ async def search( text_results = query_chunks( client, upload_data, - allowed_modalities & embedding.TEXT_COLLECTIONS, + allowed_modalities & TEXT_COLLECTIONS, embedding.embed_text, min_score=min_text_score, limit=limit, diff --git a/src/memory/common/collections.py b/src/memory/common/collections.py new file mode 100644 index 0000000..8a2c074 --- /dev/null +++ b/src/memory/common/collections.py @@ -0,0 +1,107 @@ +import logging +from typing import Literal, NotRequired, TypedDict + + +from memory.common import settings + +logger = logging.getLogger(__name__) + + +DistanceType = Literal["Cosine", "Dot", "Euclidean"] +Vector = list[float] + + +class Collection(TypedDict): + dimension: int + distance: DistanceType + model: str + on_disk: NotRequired[bool] + shards: NotRequired[int] + + +ALL_COLLECTIONS: dict[str, Collection] = { + "mail": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "chat": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "git": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "book": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "blog": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "text": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + # Multimodal + "photo": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, + }, + "comic": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, + }, + "doc": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, + }, +} +TEXT_COLLECTIONS = { + coll + for coll, params in ALL_COLLECTIONS.items() + if params["model"] == settings.TEXT_EMBEDDING_MODEL +} +MULTIMODAL_COLLECTIONS = { + coll + for coll, params in ALL_COLLECTIONS.items() + if params["model"] == settings.MIXED_EMBEDDING_MODEL +} + +TYPES = { + "doc": ["application/pdf", "application/docx", "application/msword"], + "text": ["text/*"], + "blog": ["text/markdown", "text/html"], + "photo": ["image/*"], + "book": [ + "application/epub+zip", + "application/mobi", + "application/x-mobipocket-ebook", + ], +} + + +def get_modality(mime_type: str) -> str: + for type, mime_types in TYPES.items(): + if mime_type in mime_types: + return type + stem = mime_type.split("/")[0] + + for type, mime_types in TYPES.items(): + if any(mime_type.startswith(stem) for mime_type in mime_types): + return type + return "unknown" + + +def collection_model(collection: str) -> str | None: + return ALL_COLLECTIONS.get(collection, {}).get("model") diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index c329856..b7cca39 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -2,11 +2,12 @@ Database models for the knowledge base system. """ +from dataclasses import dataclass import pathlib import re import textwrap from datetime import datetime -from typing import Any, ClassVar, cast +from typing import Any, ClassVar, Iterable, Sequence, cast from PIL import Image from sqlalchemy import ( @@ -32,6 +33,9 @@ from sqlalchemy.orm import Session, relationship from memory.common import settings from memory.common.parsers.email import EmailMessage, parse_email_message +import memory.common.extract as extract +import memory.common.collections as collections +import memory.common.chunker as chunker Base = declarative_base() @@ -137,6 +141,7 @@ class SourceItem(Base): """Base class for all content in the system using SQLAlchemy's joined table inheritance.""" __tablename__ = "source_item" + __allow_unmapped__ = True id = Column(BigInteger, primary_key=True) modality = Column(Text, nullable=False) @@ -174,6 +179,14 @@ class SourceItem(Base): """Get vector IDs from associated chunks.""" return [chunk.id for chunk in self.chunks] + def data_chunks(self) -> Iterable[extract.DataChunk]: + return [ + extract.DataChunk( + data=cast(str, self.content), + collection=cast(str, self.modality), + ) + ] + def as_payload(self) -> dict: return { "source_id": self.id, @@ -306,6 +319,19 @@ class EmailAttachment(SourceItem): "tags": self.tags, } + def data_chunks(self) -> Iterable[extract.DataChunk]: + if cast(str | None, self.filename): + contents = pathlib.Path(cast(str, self.filename)).read_bytes() + else: + contents = cast(str, self.content) + + return extract.extract_data_chunks( + cast(str, self.mime_type), + contents, + collection=cast(str, self.modality), + embedding_model=collections.collection_model(cast(str, self.modality)), + ) + # Add indexes __table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),) @@ -407,6 +433,15 @@ class Comic(SourceItem): } return {k: v for k, v in payload.items() if v is not None} + def data_chunks(self) -> Iterable[extract.DataChunk]: + image = Image.open(pathlib.Path(cast(str, self.filename))) + return [ + extract.DataChunk( + data=[image, cast(str, self.title), cast(str, self.author)], + collection=cast(str, self.modality), + ) + ] + class Book(Base): """Book-level metadata table""" @@ -503,6 +538,19 @@ class BookSection(SourceItem): "tags": self.tags, } + def data_chunks(self) -> Iterable[extract.DataChunk]: + texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)] + texts += [(cast(str, self.content), self.start_page)] + return [ + extract.DataChunk( + data=[text], + collection=cast(str, self.modality), + metadata={"page": page_number}, + max_size=chunker.EMBEDDING_MAX_TOKENS, + ) + for text, page_number in texts + ] + class BlogPost(SourceItem): __tablename__ = "blog_post" @@ -519,6 +567,7 @@ class BlogPost(SourceItem): description = Column(Text, nullable=True) # Meta description or excerpt domain = Column(Text, nullable=True) # Domain of the source website word_count = Column(Integer, nullable=True) # Approximate word count + images = Column(ARRAY(Text), nullable=True) # List of image URLs # Store original metadata from parser webpage_metadata = Column(JSONB, nullable=True) @@ -552,6 +601,30 @@ class BlogPost(SourceItem): } return {k: v for k, v in payload.items() if v} + def data_chunks(self) -> Iterable[extract.DataChunk]: + images = [Image.open(image) for image in self.images] + data = [cast(str, self.content)] + images + + # Always embed the full content as a single chunk (if possible, of course) + chunks = [ + extract.DataChunk( + data=data, + collection=cast(str, self.modality), + max_size=chunker.EMBEDDING_MAX_TOKENS, + ) + ] + + # If the content is long enough, also embed it as chunks of the default size. + tokens = chunker.approx_token_count(cast(str, self.content)) + if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2: + chunks += [ + extract.DataChunk( + data=data, + collection=cast(str, self.modality), + ) + ] + return chunks + class MiscDoc(SourceItem): __tablename__ = "misc_doc" diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index be88ac1..6eb26d0 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -1,130 +1,21 @@ +from collections.abc import Sequence import logging import pathlib import uuid -from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast +from typing import Any, Iterable, Literal, cast import voyageai from PIL import Image from memory.common import extract, settings from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS -from memory.common.db.models import Chunk +from memory.common.collections import ALL_COLLECTIONS, Vector +from memory.common.db.models import Chunk, SourceItem +from memory.common.extract import DataChunk logger = logging.getLogger(__name__) -DistanceType = Literal["Cosine", "Dot", "Euclidean"] -Vector = list[float] - - -class Collection(TypedDict): - dimension: int - distance: DistanceType - model: str - on_disk: NotRequired[bool] - shards: NotRequired[int] - - -ALL_COLLECTIONS: dict[str, Collection] = { - "mail": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "chat": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "git": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "book": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "blog": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "text": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - # Multimodal - "photo": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, - }, - "comic": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, - }, - "doc": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, - }, -} -TEXT_COLLECTIONS = { - coll - for coll, params in ALL_COLLECTIONS.items() - if params["model"] == settings.TEXT_EMBEDDING_MODEL -} -MULTIMODAL_COLLECTIONS = { - coll - for coll, params in ALL_COLLECTIONS.items() - if params["model"] == settings.MIXED_EMBEDDING_MODEL -} - -TYPES = { - "doc": ["application/pdf", "application/docx", "application/msword"], - "text": ["text/*"], - "blog": ["text/markdown", "text/html"], - "photo": ["image/*"], - "book": [ - "application/epub+zip", - "application/mobi", - "application/x-mobipocket-ebook", - ], -} - - -def get_mimetype(image: Image.Image) -> str | None: - format_to_mime = { - "JPEG": "image/jpeg", - "PNG": "image/png", - "GIF": "image/gif", - "BMP": "image/bmp", - "TIFF": "image/tiff", - "WEBP": "image/webp", - } - - if not image.format: - return None - - return format_to_mime.get(image.format.upper(), f"image/{image.format.lower()}") - - -def get_modality(mime_type: str) -> str: - for type, mime_types in TYPES.items(): - if mime_type in mime_types: - return type - stem = mime_type.split("/")[0] - - for type, mime_types in TYPES.items(): - if any(mime_type.startswith(stem) for mime_type in mime_types): - return type - return "unknown" - - def embed_chunks( chunks: list[str] | list[list[extract.MulitmodalChunk]], model: str = settings.TEXT_EMBEDDING_MODEL, @@ -164,12 +55,6 @@ def embed_text( raise -def embed_file( - file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL -) -> list[Vector]: - return embed_text([file_path.read_text()], model) - - def embed_mixed( items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL, @@ -187,22 +72,6 @@ def embed_mixed( return embed_chunks([chunks], model, input_type) -def embed_page(page: extract.Page) -> list[Vector]: - contents = page["contents"] - chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS) - if all(isinstance(c, str) for c in contents): - return embed_text( - cast(list[str], contents), - model=settings.TEXT_EMBEDDING_MODEL, - chunk_size=chunk_size, - ) - return embed_mixed( - cast(list[extract.MulitmodalChunk], contents), - model=settings.MIXED_EMBEDDING_MODEL, - chunk_size=chunk_size, - ) - - def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path: if isinstance(item, str): filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt" @@ -219,7 +88,9 @@ def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path: def make_chunk( - page: extract.Page, vector: Vector, metadata: dict[str, Any] = {} + contents: Sequence[extract.MulitmodalChunk], + vector: Vector, + metadata: dict[str, Any] = {}, ) -> Chunk: """Create a Chunk object from a page and a vector. @@ -227,7 +98,6 @@ def make_chunk( a single image, or a list of strings and images. """ chunk_id = str(uuid.uuid4()) - contents = page["contents"] content, filename = None, None if all(isinstance(c, str) for c in contents): content = "\n\n".join(cast(list[str], contents)) @@ -251,38 +121,42 @@ def make_chunk( ) -def embed( - mime_type: str, - content: bytes | str | pathlib.Path, +def embed_data_chunk( + chunk: DataChunk, metadata: dict[str, Any] = {}, chunk_size: int | None = None, -) -> tuple[str, list[Chunk]]: - modality = get_modality(mime_type) - pages = extract.extract_content(mime_type, content, chunk_size=chunk_size) - chunks = [ - make_chunk(page, vector, metadata) - for page in pages - for vector in embed_page(page) +) -> list[Chunk]: + chunk_size = chunk.max_size or chunk_size or DEFAULT_CHUNK_TOKENS + + model = chunk.embedding_model + if not model and chunk.collection: + model = ALL_COLLECTIONS.get(chunk.collection, {}).get("model") + if not model: + model = settings.TEXT_EMBEDDING_MODEL + + if model == settings.TEXT_EMBEDDING_MODEL: + vectors = embed_text(cast(list[str], chunk.data), chunk_size=chunk_size) + elif model == settings.MIXED_EMBEDDING_MODEL: + vectors = embed_mixed( + cast(list[extract.MulitmodalChunk], chunk.data), + chunk_size=chunk_size, + ) + else: + raise ValueError(f"Unsupported model: {model}") + + metadata = metadata | chunk.metadata + return [make_chunk(chunk.data, vector, metadata) for vector in vectors] + + +def embed_source_item( + item: SourceItem, + metadata: dict[str, Any] = {}, + chunk_size: int | None = None, +) -> list[Chunk]: + return [ + chunk + for data_chunk in item.data_chunks() + for chunk in embed_data_chunk( + data_chunk, item.as_payload() | metadata, chunk_size + ) ] - return modality, chunks - - -def embed_image( - file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None -) -> Chunk: - image = Image.open(file_path) - mime_type = get_mimetype(image) - if mime_type is None: - raise ValueError("Unsupported image format") - - vector = embed_mixed( - [image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS - )[0] - - return Chunk( - id=str(uuid.uuid4()), - file_path=file_path.absolute().as_posix(), - content=None, - embedding_model=settings.MIXED_EMBEDDING_MODEL, - vector=vector, - ) diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index 83bfae0..7b7f4d2 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import io import logging import pathlib @@ -21,6 +22,15 @@ class Page(TypedDict): chunk_size: NotRequired[int] +@dataclass +class DataChunk: + data: Sequence[MulitmodalChunk] + collection: str | None = None + embedding_model: str | None = None + max_size: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + @contextmanager def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]: if isinstance(content, pathlib.Path): @@ -110,11 +120,13 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]: return [{"contents": [cast(str, content)], "metadata": {}}] -def extract_content( +def extract_data_chunks( mime_type: str, content: bytes | str | pathlib.Path, + collection: str | None = None, + embedding_model: str | None = None, chunk_size: int | None = None, -) -> list[Page]: +) -> list[DataChunk]: pages = [] logger.info(f"Extracting content from {mime_type}") if mime_type == "application/pdf": @@ -134,4 +146,12 @@ def extract_content( if chunk_size: pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages] - return pages + return [ + DataChunk( + data=page["contents"], + collection=collection, + embedding_model=embedding_model, + max_size=chunk_size, + ) + for page in pages + ] diff --git a/src/memory/common/parsers/email.py b/src/memory/common/parsers/email.py index ac0501b..89c7a0a 100644 --- a/src/memory/common/parsers/email.py +++ b/src/memory/common/parsers/email.py @@ -26,6 +26,7 @@ class EmailMessage(TypedDict): body: str attachments: list[Attachment] hash: bytes + raw_email: str RawEmailResponse = tuple[str | None, bytes] @@ -171,6 +172,7 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage: body = extract_body(msg) return EmailMessage( + raw_email=raw_email, message_id=message_id, subject=subject, sender=from_, diff --git a/src/memory/common/parsers/html.py b/src/memory/common/parsers/html.py index 286e12a..ab64d56 100644 --- a/src/memory/common/parsers/html.py +++ b/src/memory/common/parsers/html.py @@ -12,7 +12,7 @@ from bs4 import BeautifulSoup, Tag from markdownify import markdownify as md from PIL import Image as PILImage -from memory.common.settings import FILE_STORAGE_DIR, WEBPAGE_STORAGE_DIR +from memory.common import settings logger = logging.getLogger(__name__) @@ -96,9 +96,9 @@ def extract_date( datetime_attr = element.get("datetime") if datetime_attr: - date_str = str(datetime_attr) - if date := parse_date(date_str, date_format): - return date + for format in ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%d", date_format]: + if date := parse_date(str(datetime_attr), format): + return date for text in element.find_all(string=True): if text and (date := parse_date(str(text).strip(), date_format)): @@ -178,7 +178,7 @@ def process_images( continue path = pathlib.Path(image.filename) # type: ignore - img_tag["src"] = str(path.relative_to(FILE_STORAGE_DIR.resolve())) + img_tag["src"] = str(path.relative_to(settings.FILE_STORAGE_DIR.resolve())) images[img_tag["src"]] = image except Exception as e: logger.warning(f"Failed to process image {src}: {e}") @@ -291,7 +291,7 @@ class BaseHTMLParser: def __init__(self, base_url: str | None = None): self.base_url = base_url - self.image_dir = WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc) + self.image_dir = settings.WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc) self.image_dir.mkdir(parents=True, exist_ok=True) def parse(self, html: str, url: str) -> Article: diff --git a/src/memory/common/qdrant.py b/src/memory/common/qdrant.py index 9bfde5e..92b489b 100644 --- a/src/memory/common/qdrant.py +++ b/src/memory/common/qdrant.py @@ -5,12 +5,7 @@ import qdrant_client from qdrant_client.http import models as qdrant_models from qdrant_client.http.exceptions import UnexpectedResponse from memory.common import settings -from memory.common.embedding import ( - Collection, - ALL_COLLECTIONS, - DistanceType, - Vector, -) +from memory.common.collections import ALL_COLLECTIONS, Collection, DistanceType, Vector logger = logging.getLogger(__name__) diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index c9f3fbb..493637f 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -10,17 +10,16 @@ from typing import Callable, Generator, Sequence, cast from sqlalchemy.orm import Session, scoped_session -from memory.common import embedding, qdrant, settings +from memory.common import embedding, qdrant, settings, collections from memory.common.db.models import ( EmailAccount, EmailAttachment, MailMessage, - SourceItem, ) from memory.common.parsers.email import ( Attachment, + EmailMessage, RawEmailResponse, - parse_email_message, ) logger = logging.getLogger(__name__) @@ -54,7 +53,7 @@ def process_attachment( return None return EmailAttachment( - modality=embedding.get_modality(attachment["content_type"]), + modality=collections.get_modality(attachment["content_type"]), sha256=hashlib.sha256( real_content if real_content else str(attachment).encode() ).digest(), @@ -94,8 +93,7 @@ def create_mail_message( db_session: Session | scoped_session, tags: list[str], folder: str, - raw_email: str, - message_id: str, + parsed_email: EmailMessage, ) -> MailMessage: """ Create a new mail message record and associated attachments. @@ -109,7 +107,7 @@ def create_mail_message( Returns: Newly created MailMessage """ - parsed_email = parse_email_message(raw_email, message_id) + raw_email = parsed_email["raw_email"] mail_message = MailMessage( modality="mail", sha256=parsed_email["hash"], @@ -137,52 +135,6 @@ def create_mail_message( return mail_message -def does_message_exist( - db_session: Session | scoped_session, message_id: str, message_hash: bytes -) -> bool: - """ - Check if a message already exists in the database. - - Args: - db_session: Database session - message_id: Email message ID - message_hash: SHA-256 hash of message - - Returns: - True if message exists, False otherwise - """ - # Check by message_id first (faster) - if message_id: - mail_message = ( - db_session.query(MailMessage) - .filter(MailMessage.message_id == message_id) - .first() - ) - if mail_message is not None: - return True - - # Then check by message_hash - source_item = ( - db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first() - ) - return source_item is not None - - -def check_message_exists( - db: Session | scoped_session, account_id: int, message_id: str, raw_email: str -) -> bool: - account = db.query(EmailAccount).get(account_id) - if not account: - logger.error(f"Account {account_id} not found") - return False - - parsed_email = parse_email_message(raw_email, message_id) - if "szczepalins" in raw_email.lower(): - print(parsed_email["message_id"]) - - return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"]) - - def extract_email_uid( msg_data: Sequence[tuple[bytes, bytes]], ) -> tuple[str | None, bytes]: @@ -317,11 +269,7 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, def vectorize_email(email: MailMessage): qdrant_client = qdrant.get_qdrant_client() - _, chunks = embedding.embed( - "text/plain", - email.body, - metadata=email.as_payload(), - ) + chunks = embedding.embed_source_item(email) email.chunks = chunks if chunks: vector_ids = [cast(str, c.id) for c in chunks] @@ -329,7 +277,7 @@ def vectorize_email(email: MailMessage): metadata = [c.item_metadata for c in chunks] qdrant.upsert_vectors( client=qdrant_client, - collection_name="mail", + collection_name=cast(str, email.modality), ids=vector_ids, vectors=vectors, # type: ignore payloads=metadata, # type: ignore @@ -337,18 +285,12 @@ def vectorize_email(email: MailMessage): embeds = defaultdict(list) for attachment in email.attachments: - if attachment.filename: - content = pathlib.Path(attachment.filename).read_bytes() - else: - content = attachment.content - collection, chunks = embedding.embed( - attachment.mime_type, content, metadata=attachment.as_payload() - ) + chunks = embedding.embed_source_item(attachment) if not chunks: continue attachment.chunks = chunks - embeds[collection].extend(chunks) + embeds[attachment.modality].extend(chunks) for collection, chunks in embeds.items(): ids = [c.id for c in chunks] diff --git a/src/memory/workers/tasks/blogs.py b/src/memory/workers/tasks/blogs.py index 14609e1..6506972 100644 --- a/src/memory/workers/tasks/blogs.py +++ b/src/memory/workers/tasks/blogs.py @@ -1,99 +1,25 @@ -import hashlib import logging -from typing import Iterable, cast +from typing import Iterable -from memory.common import chunker, embedding, qdrant from memory.common.db.connection import make_session from memory.common.db.models import BlogPost from memory.common.parsers.blogs import parse_webpage from memory.workers.celery_app import app +from memory.workers.tasks.content_processing import ( + check_content_exists, + create_content_hash, + create_task_result, + process_content_item, + safe_task_execution, +) logger = logging.getLogger(__name__) - SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage" -def create_blog_post_from_article(article, tags: Iterable[str] = []) -> BlogPost: - """Create a BlogPost model from parsed article data.""" - return BlogPost( - url=article.url, - title=article.title, - published=article.published_date, - content=article.content, - sha256=hashlib.sha256(article.content.encode()).digest(), - modality="blog", - tags=tags, - mime_type="text/markdown", - size=len(article.content.encode("utf-8")), - ) - - -def embed_blog_post(blog_post: BlogPost) -> int: - """Embed blog post content and return count of successfully embedded chunks.""" - try: - # Always embed the full content - _, chunks = embedding.embed( - "text/markdown", - cast(str, blog_post.content), - metadata=blog_post.as_payload(), - chunk_size=chunker.EMBEDDING_MAX_TOKENS, - ) - # But also embed the content in chunks (unless it's really short) - if ( - chunker.approx_token_count(cast(str, blog_post.content)) - > chunker.DEFAULT_CHUNK_TOKENS * 2 - ): - _, small_chunks = embedding.embed( - "text/markdown", - cast(str, blog_post.content), - metadata=blog_post.as_payload(), - ) - chunks += small_chunks - - if chunks: - blog_post.chunks = chunks - blog_post.embed_status = "QUEUED" # type: ignore - return len(chunks) - else: - blog_post.embed_status = "FAILED" # type: ignore - logger.warning(f"No chunks generated for blog post: {blog_post.title}") - return 0 - - except Exception as e: - blog_post.embed_status = "FAILED" # type: ignore - logger.error(f"Failed to embed blog post {blog_post.title}: {e}") - return 0 - - -def push_to_qdrant(blog_post: BlogPost): - """Push embeddings to Qdrant for successfully embedded blog post.""" - if cast(str, blog_post.embed_status) != "QUEUED" or not blog_post.chunks: - return - - try: - vector_ids = [str(chunk.id) for chunk in blog_post.chunks] - vectors = [chunk.vector for chunk in blog_post.chunks] - payloads = [chunk.item_metadata for chunk in blog_post.chunks] - - qdrant.upsert_vectors( - client=qdrant.get_qdrant_client(), - collection_name="blog", - ids=vector_ids, - vectors=vectors, - payloads=payloads, - ) - - blog_post.embed_status = "STORED" # type: ignore - logger.info(f"Successfully stored embeddings for: {blog_post.title}") - - except Exception as e: - blog_post.embed_status = "FAILED" # type: ignore - logger.error(f"Failed to push embeddings to Qdrant: {e}") - raise - - @app.task(name=SYNC_WEBPAGE) +@safe_task_execution def sync_webpage(url: str, tags: Iterable[str] = []) -> dict: """ Synchronize a webpage from a URL. @@ -116,61 +42,25 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict: "content_length": 0, } - blog_post = create_blog_post_from_article(article, tags) + blog_post = BlogPost( + url=article.url, + title=article.title, + published=article.published_date, + content=article.content, + sha256=create_content_hash(article.content), + modality="blog", + tags=tags, + mime_type="text/markdown", + size=len(article.content.encode("utf-8")), + images=[image for image in article.images], + ) with make_session() as session: - existing_post = session.query(BlogPost).filter(BlogPost.url == url).first() - if existing_post: - logger.info(f"Blog post already exists: {existing_post.title}") - return { - "blog_post_id": existing_post.id, - "url": url, - "title": existing_post.title, - "status": "already_exists", - "chunks_count": len(existing_post.chunks), - } - - existing_post = ( - session.query(BlogPost).filter(BlogPost.sha256 == blog_post.sha256).first() + existing_post = check_content_exists( + session, BlogPost, url=url, sha256=create_content_hash(article.content) ) if existing_post: - logger.info( - f"Blog post with the same content already exists: {existing_post.title}" - ) - return { - "blog_post_id": existing_post.id, - "url": url, - "title": existing_post.title, - "status": "already_exists", - "chunks_count": len(existing_post.chunks), - } + logger.info(f"Blog post already exists: {existing_post.title}") + return create_task_result(existing_post, "already_exists", url=url) - session.add(blog_post) - session.flush() - - chunks_count = embed_blog_post(blog_post) - session.flush() - - try: - push_to_qdrant(blog_post) - logger.info( - f"Successfully processed webpage: {blog_post.title} " - f"({chunks_count} chunks embedded)" - ) - except Exception as e: - logger.error(f"Failed to push embeddings to Qdrant: {e}") - blog_post.embed_status = "FAILED" # type: ignore - - session.commit() - - return { - "blog_post_id": blog_post.id, - "url": url, - "title": blog_post.title, - "author": article.author, - "published_date": article.published_date, - "status": "processed", - "chunks_count": chunks_count, - "content_length": len(article.content), - "embed_status": blog_post.embed_status, - } + return process_content_item(blog_post, "blog", session, tags) diff --git a/src/memory/workers/tasks/comic.py b/src/memory/workers/tasks/comic.py index 0016403..ea650d4 100644 --- a/src/memory/workers/tasks/comic.py +++ b/src/memory/workers/tasks/comic.py @@ -1,4 +1,3 @@ -import hashlib import logging from datetime import datetime from typing import Callable, cast @@ -6,21 +5,25 @@ from typing import Callable, cast import feedparser import requests -from memory.common import embedding, qdrant, settings +from memory.common import settings from memory.common.db.connection import make_session from memory.common.db.models import Comic, clean_filename from memory.common.parsers import comics from memory.workers.celery_app import app +from memory.workers.tasks.content_processing import ( + check_content_exists, + create_content_hash, + process_content_item, + safe_task_execution, +) logger = logging.getLogger(__name__) - SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics" SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc" SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd" SYNC_COMIC = "memory.workers.tasks.comic.sync_comic" - BASE_SMBC_URL = "https://www.smbc-comics.com/" SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss" @@ -36,6 +39,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]: return set() urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries} + urls = {url for url in urls if url} with make_session() as session: known = { @@ -61,6 +65,7 @@ def fetch_new_comics( @app.task(name=SYNC_COMIC) +@safe_task_execution def sync_comic( url: str, image_url: str, @@ -70,20 +75,26 @@ def sync_comic( ): """Synchronize a comic from a URL.""" with make_session() as session: - if session.query(Comic).filter(Comic.url == url).first(): - return + existing_comic = check_content_exists(session, Comic, url=url) + if existing_comic: + return {"status": "already_exists", "comic_id": existing_comic.id} response = requests.get(image_url) + if response.status_code != 200: + return { + "status": "failed", + "error": f"Failed to download image: {response.status_code}", + } + file_type = image_url.split(".")[-1] mime_type = f"image/{file_type}" filename = ( settings.COMIC_STORAGE_DIR / clean_filename(author) / f"{title}.{file_type}" ) - if response.status_code == 200: - filename.parent.mkdir(parents=True, exist_ok=True) - filename.write_bytes(response.content) - sha256 = hashlib.sha256(f"{image_url}{published_date}".encode()).digest() + filename.parent.mkdir(parents=True, exist_ok=True) + filename.write_bytes(response.content) + comic = Comic( title=title, url=url, @@ -92,27 +103,13 @@ def sync_comic( filename=filename.resolve().as_posix(), mime_type=mime_type, size=len(response.content), - sha256=sha256, + sha256=create_content_hash(f"{image_url}{published_date}"), tags={"comic", author}, modality="comic", ) - chunk = embedding.embed_image(filename, [title, author]) - comic.chunks = [chunk] with make_session() as session: - session.add(comic) - session.add(chunk) - session.flush() - - qdrant.upsert_vectors( - client=qdrant.get_qdrant_client(), - collection_name="comic", - ids=[str(chunk.id)], - vectors=[chunk.vector], # type: ignore - payloads=[comic.as_payload()], - ) - - session.commit() + return process_content_item(comic, "comic", session) @app.task(name=SYNC_SMBC) diff --git a/src/memory/workers/tasks/content_processing.py b/src/memory/workers/tasks/content_processing.py new file mode 100644 index 0000000..3c3fb2c --- /dev/null +++ b/src/memory/workers/tasks/content_processing.py @@ -0,0 +1,267 @@ +""" +Content processing utilities for memory workers. + +This module provides core functionality for processing content items through +the complete workflow: existence checking, content hashing, embedding generation, +vector storage, and result tracking. +""" + +import hashlib +import traceback +import logging +from typing import Any, Callable, Iterable, Sequence, cast + +from memory.common import embedding, qdrant +from memory.common.db.models import SourceItem + +logger = logging.getLogger(__name__) + + +def check_content_exists( + session, + model_class: type[SourceItem], + **kwargs: Any, +) -> SourceItem | None: + """ + Check if content already exists in the database. + + Searches for existing content by any of the provided attributes + (typically URL, file_path, or SHA256 hash). + + Args: + session: Database session for querying + model_class: The SourceItem model class to search in + **kwargs: Attribute-value pairs to search for + + Returns: + Existing SourceItem if found, None otherwise + """ + for key, value in kwargs.items(): + if not hasattr(model_class, key): + continue + + existing = ( + session.query(model_class) + .filter(getattr(model_class, key) == value) + .first() + ) + if existing: + return existing + + return None + + +def create_content_hash(content: str, *additional_data: str) -> bytes: + """ + Create SHA256 hash from content and optional additional data. + + Args: + content: Primary content to hash + *additional_data: Additional strings to include in the hash + + Returns: + SHA256 hash digest as bytes + """ + hash_input = content + "".join(additional_data) + return hashlib.sha256(hash_input.encode()).digest() + + +def embed_source_item(source_item: SourceItem) -> int: + """ + Generate embeddings for a source item's content. + + Processes the source item through the embedding pipeline, creating + chunks and their corresponding vector embeddings. Updates the item's + embed_status based on success or failure. + + Args: + source_item: The SourceItem to embed + + Returns: + Number of successfully embedded chunks + + Side effects: + - Sets source_item.chunks with generated chunks + - Sets source_item.embed_status to "QUEUED" or "FAILED" + """ + try: + chunks = embedding.embed_source_item(source_item) + if chunks: + source_item.chunks = chunks + source_item.embed_status = "QUEUED" # type: ignore + return len(chunks) + else: + source_item.embed_status = "FAILED" # type: ignore + logger.warning( + f"No chunks generated for {type(source_item).__name__}: {getattr(source_item, 'title', 'unknown')}" + ) + return 0 + except Exception as e: + source_item.embed_status = "FAILED" # type: ignore + logger.error(f"Failed to embed {type(source_item).__name__}: {e}") + return 0 + + +def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str): + """ + Push embeddings to Qdrant vector database. + + Uploads vector embeddings for all source items that have been successfully + embedded (status "QUEUED") and have chunks available. + + Args: + source_items: Sequence of SourceItems to process + collection_name: Name of the Qdrant collection to store vectors in + + Raises: + Exception: If the Qdrant upsert operation fails + + Side effects: + - Updates embed_status to "STORED" for successful items + - Updates embed_status to "FAILED" for failed items + """ + items_to_process = [ + item + for item in source_items + if cast(str, getattr(item, "embed_status", None)) == "QUEUED" and item.chunks + ] + + if not items_to_process: + return + + all_chunks = [chunk for item in items_to_process for chunk in item.chunks] + if not all_chunks: + return + + try: + vector_ids = [str(chunk.id) for chunk in all_chunks] + vectors = [chunk.vector for chunk in all_chunks] + payloads = [chunk.item_metadata for chunk in all_chunks] + + qdrant.upsert_vectors( + client=qdrant.get_qdrant_client(), + collection_name=collection_name, + ids=vector_ids, + vectors=vectors, + payloads=payloads, + ) + + for item in items_to_process: + item.embed_status = "STORED" # type: ignore + logger.info( + f"Successfully stored embeddings for: {getattr(item, 'title', 'unknown')}" + ) + + except Exception as e: + for item in items_to_process: + item.embed_status = "FAILED" # type: ignore + logger.error(f"Failed to push embeddings to Qdrant: {e}") + raise + + +def create_task_result( + item: SourceItem, status: str, **additional_fields: Any +) -> dict[str, Any]: + """ + Create standardized task result dictionary. + + Generates a consistent result format for task execution reporting, + including item metadata and processing status. + + Args: + item: The processed SourceItem + status: Processing status string + **additional_fields: Extra fields to include in the result + + Returns: + Dictionary with standardized task result format + """ + return { + f"{type(item).__name__.lower()}_id": item.id, + "title": getattr(item, "title", None), + "status": status, + "chunks_count": len(item.chunks), + "embed_status": item.embed_status, + **additional_fields, + } + + +def process_content_item( + item: SourceItem, collection_name: str, session, tags: Iterable[str] = [] +) -> dict[str, Any]: + """ + Execute complete content processing workflow. + + Performs the full pipeline for processing a content item: + 1. Add to database session and flush to get ID + 2. Generate embeddings and chunks + 3. Push embeddings to Qdrant vector store + 4. Commit transaction and return result + + Args: + item: SourceItem to process + collection_name: Qdrant collection name for vector storage + session: Database session for persistence + tags: Optional tags to associate with the item (currently unused) + + Returns: + Task result dictionary with processing status and metadata + + Side effects: + - Adds item to database session + - Commits database transaction + - Stores vectors in Qdrant + """ + session.add(item) + session.flush() + + chunks_count = embed_source_item(item) + session.flush() + + try: + push_to_qdrant([item], collection_name) + status = "processed" + logger.info( + f"Successfully processed {type(item).__name__}: {getattr(item, 'title', 'unknown')} ({chunks_count} chunks embedded)" + ) + except Exception as e: + logger.error(f"Failed to push embeddings to Qdrant: {e}") + item.embed_status = "FAILED" # type: ignore + status = "failed" + + session.commit() + + return create_task_result(item, status, content_length=getattr(item, "size", 0)) + + +def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]: + """ + Decorator for safe task execution with comprehensive error handling. + + Wraps task functions to catch and log exceptions, ensuring tasks + always return a result dictionary even when they fail. + + Args: + func: Task function to wrap + + Returns: + Wrapped function that handles exceptions gracefully + + Example: + @safe_task_execution + def my_task(arg1, arg2): + # Task implementation + return {"status": "success"} + """ + + def wrapper(*args, **kwargs) -> dict: + try: + return func(*args, **kwargs) + except Exception as e: + logger.error( + f"Task {func.__name__} failed with traceback:\n{traceback.format_exc()}" + ) + logger.error(f"Task {func.__name__} failed: {e}") + return {"status": "error", "error": str(e)} + + return wrapper diff --git a/src/memory/workers/tasks/ebook.py b/src/memory/workers/tasks/ebook.py index 919f7f8..bb2a8d9 100644 --- a/src/memory/workers/tasks/ebook.py +++ b/src/memory/workers/tasks/ebook.py @@ -1,17 +1,21 @@ -import hashlib import logging from pathlib import Path from typing import Iterable, cast -from memory.common import chunker, embedding, qdrant -from memory.common.db.connection import make_session from memory.common.db.models import Book, BookSection from memory.common.parsers.ebook import Ebook, parse_ebook, Section +from memory.common.db.connection import make_session from memory.workers.celery_app import app +from memory.workers.tasks.content_processing import ( + check_content_exists, + create_content_hash, + embed_source_item, + push_to_qdrant, + safe_task_execution, +) logger = logging.getLogger(__name__) - SYNC_BOOK = "memory.workers.tasks.book.sync_book" # Minimum section length to embed (avoid noise from very short sections) @@ -46,10 +50,6 @@ def section_processor( ): content = "\n\n".join(section.pages).strip() if len(content) >= MIN_SECTION_LENGTH: - sha256 = hashlib.sha256( - f"{book.id}:{section.title}:{section.start_page}".encode() - ).digest() - book_section = BookSection( book_id=book.id, section_title=section.title, @@ -59,7 +59,9 @@ def section_processor( end_page=section.end_page, parent_section_id=None, # Will be set after flush content=content, - sha256=sha256, + sha256=create_content_hash( + f"{book.id}:{section.title}:{section.start_page}" + ), modality="book", tags=book.tags, pages=section.pages, @@ -127,76 +129,11 @@ def create_book_and_sections( def embed_sections(all_sections: list[BookSection]) -> int: """Embed all sections and return count of successfully embedded sections.""" - embedded_count = 0 - - def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]: - _, chunks = embedding.embed( - "text/plain", - text, - metadata=metadata, - chunk_size=chunker.EMBEDDING_MAX_TOKENS, - ) - return chunks - - for section in all_sections: - try: - section_chunks = embed_text( - cast(str, section.content), section.as_payload() - ) - page_chunks = [ - chunk - for i, page in enumerate(section.pages) - for chunk in embed_text( - page, section.as_payload() | {"page_number": i + section.start_page} - ) - ] - chunks = section_chunks + page_chunks - - if chunks: - section.chunks = chunks - section.embed_status = "QUEUED" # type: ignore - embedded_count += 1 - else: - section.embed_status = "FAILED" # type: ignore - logger.warning( - f"No chunks generated for section: {section.section_title}" - ) - - except IOError as e: - section.embed_status = "FAILED" # type: ignore - logger.error(f"Failed to embed section {section.section_title}: {e}") - - return embedded_count - - -def push_to_qdrant(all_sections: list[BookSection]): - """Push embeddings to Qdrant for all successfully embedded sections.""" - vector_ids = [] - vectors = [] - payloads = [] - - to_process = [s for s in all_sections if cast(str, s.embed_status) == "QUEUED"] - all_chunks = [chunk for section in to_process for chunk in section.chunks] - if not all_chunks: - return - - vector_ids = [str(chunk.id) for chunk in all_chunks] - vectors = [chunk.vector for chunk in all_chunks] - payloads = [chunk.item_metadata for chunk in all_chunks] - - qdrant.upsert_vectors( - client=qdrant.get_qdrant_client(), - collection_name="book", - ids=vector_ids, - vectors=vectors, - payloads=payloads, - ) - - for section in to_process: - section.embed_status = "STORED" # type: ignore + return sum(embed_source_item(section) for section in all_sections) @app.task(name=SYNC_BOOK) +@safe_task_execution def sync_book(file_path: str, tags: Iterable[str] = []) -> dict: """ Synchronize a book from a file path. @@ -211,10 +148,8 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict: with make_session() as session: # Check for existing book - existing_book = ( - session.query(Book) - .filter(Book.file_path == ebook.file_path.as_posix()) - .first() + existing_book = check_content_exists( + session, Book, file_path=ebook.file_path.as_posix() ) if existing_book: logger.info(f"Book already exists: {existing_book.title}") @@ -230,19 +165,10 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict: book, all_sections = create_book_and_sections(ebook, session, tags) # Embed sections - embedded_count = embed_sections(all_sections) + embedded_count = sum(embed_source_item(section) for section in all_sections) session.flush() - # Push to Qdrant - try: - push_to_qdrant(all_sections) - except Exception as e: - logger.error(f"Failed to push embeddings to Qdrant: {e}") - # Mark sections as failed - for section in all_sections: - if getattr(section, "embed_status") == "STORED": - section.embed_status = "FAILED" # type: ignore - raise + push_to_qdrant(all_sections, "book") session.commit() diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py index eb5e74a..97efd72 100644 --- a/src/memory/workers/tasks/email.py +++ b/src/memory/workers/tasks/email.py @@ -2,16 +2,19 @@ import logging from datetime import datetime from typing import cast from memory.common.db.connection import make_session -from memory.common.db.models import EmailAccount +from memory.common.db.models import EmailAccount, MailMessage from memory.workers.celery_app import app from memory.workers.email import ( - check_message_exists, create_mail_message, imap_connection, process_folder, vectorize_email, ) - +from memory.common.parsers.email import parse_email_message +from memory.workers.tasks.content_processing import ( + check_content_exists, + safe_task_execution, +) logger = logging.getLogger(__name__) @@ -21,12 +24,13 @@ SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts" @app.task(name=PROCESS_EMAIL) +@safe_task_execution def process_message( account_id: int, message_id: str, folder: str, raw_email: str, -) -> int | None: +) -> dict: """ Process a single email message and store it in the database. @@ -37,29 +41,30 @@ def process_message( raw_email: Raw email content as string Returns: - source_id if successful, None otherwise + dict with processing result """ logger.info(f"Processing message {message_id} for account {account_id}") if not raw_email.strip(): logger.warning(f"Empty email message received for account {account_id}") - return None + return {"status": "skipped", "reason": "empty_content"} with make_session() as db: - if check_message_exists(db, account_id, message_id, raw_email): - logger.debug(f"Message {message_id} already exists in database") - return None - account = db.query(EmailAccount).get(account_id) if not account: logger.error(f"Account {account_id} not found") - return None + return {"status": "error", "error": "Account not found"} - mail_message = create_mail_message( - db, account.tags, folder, raw_email, message_id - ) + parsed_email = parse_email_message(raw_email, message_id) + if check_content_exists( + db, MailMessage, message_id=message_id, sha256=parsed_email["hash"] + ): + return {"status": "already_exists", "message_id": message_id} + + mail_message = create_mail_message(db, account.tags, folder, parsed_email) db.flush() vectorize_email(mail_message) + db.commit() logger.info(f"Stored embedding for message {mail_message.message_id}") @@ -71,16 +76,24 @@ def process_message( for chunk in attachment.chunks: logger.info(f" - {chunk.id}") - return cast(int, mail_message.id) + return { + "status": "processed", + "mail_message_id": cast(int, mail_message.id), + "message_id": message_id, + "chunks_count": len(mail_message.chunks), + "attachments_count": len(mail_message.attachments), + } @app.task(name=SYNC_ACCOUNT) +@safe_task_execution def sync_account(account_id: int, since_date: str | None = None) -> dict: """ Synchronize emails from a specific account. Args: account_id: ID of the EmailAccount to sync + since_date: ISO format date string to sync since Returns: dict with stats about the sync operation @@ -91,7 +104,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict: account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first() if not account or not cast(bool, account.active): logger.warning(f"Account {account_id} not found or inactive") - return {"error": "Account not found or inactive"} + return {"status": "error", "error": "Account not found or inactive"} folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"] if since_date: @@ -108,7 +121,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict: def process_message_wrapper( account_id: int, message_id: str, folder: str, raw_email: str ) -> int | None: - if check_message_exists(db, account_id, message_id, raw_email): # type: ignore + parsed_email = parse_email_message(raw_email, message_id) + if check_content_exists( + db, MailMessage, message_id=message_id, sha256=parsed_email["hash"] + ): return None return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore @@ -127,9 +143,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict: db.commit() except Exception as e: logger.error(f"Error connecting to server {account.imap_server}: {str(e)}") - return {"error": str(e)} + return {"status": "error", "error": str(e)} return { + "status": "completed", "account": account.email_address, "since_date": cutoff_date.isoformat(), "folders_processed": len(folders_to_process), diff --git a/src/memory/workers/tasks/maintenance.py b/src/memory/workers/tasks/maintenance.py index cdd8085..f8c978d 100644 --- a/src/memory/workers/tasks/maintenance.py +++ b/src/memory/workers/tasks/maintenance.py @@ -6,7 +6,7 @@ from typing import Sequence from sqlalchemy import select from sqlalchemy.orm import contains_eager -from memory.common import embedding, qdrant, settings +from memory.common import collections, embedding, qdrant, settings from memory.common.db.connection import make_session from memory.common.db.models import Chunk, SourceItem from memory.workers.celery_app import app @@ -60,11 +60,11 @@ def reingest_chunk(chunk_id: str, collection: str): logger.error(f"Chunk {chunk_id} not found") return - if collection not in embedding.ALL_COLLECTIONS: + if collection not in collections.ALL_COLLECTIONS: raise ValueError(f"Unsupported collection {collection}") data = chunk.data - if collection in embedding.MULTIMODAL_COLLECTIONS: + if collection in collections.MULTIMODAL_COLLECTIONS: vector = embedding.embed_mixed(data)[0] elif len(data) == 1 and isinstance(data[0], str): vector = embedding.embed_text([data[0]])[0] diff --git a/tests/conftest.py b/tests/conftest.py index a3edd06..1ac921a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -205,9 +205,14 @@ def email_provider(): def mock_file_storage(tmp_path: Path): chunk_storage_dir = tmp_path / "chunks" chunk_storage_dir.mkdir(parents=True, exist_ok=True) - with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path): - with patch("memory.common.settings.CHUNK_STORAGE_DIR", chunk_storage_dir): - yield + image_storage_dir = tmp_path / "images" + image_storage_dir.mkdir(parents=True, exist_ok=True) + with ( + patch.object(settings, "FILE_STORAGE_DIR", tmp_path), + patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir), + patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir), + ): + yield @pytest.fixture diff --git a/tests/memory/common/parsers/test_email_parsers.py b/tests/memory/common/parsers/test_email_parsers.py index 449ea20..ccaaf63 100644 --- a/tests/memory/common/parsers/test_email_parsers.py +++ b/tests/memory/common/parsers/test_email_parsers.py @@ -249,6 +249,7 @@ def test_parse_simple_email(): "body": "Test body content\n", "attachments": [], "sent_at": ANY, + "raw_email": msg.as_string(), "hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87" + b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1", } diff --git a/tests/memory/common/parsers/test_html.py b/tests/memory/common/parsers/test_html.py index c0f5d76..0ad7e1f 100644 --- a/tests/memory/common/parsers/test_html.py +++ b/tests/memory/common/parsers/test_html.py @@ -12,6 +12,7 @@ import requests from bs4 import BeautifulSoup, Tag from PIL import Image as PILImage +from memory.common import settings from memory.common.parsers.html import ( Article, BaseHTMLParser, @@ -164,27 +165,28 @@ def test_parse_date(text, date_format, expected): assert parse_date(text, date_format) == expected -def test_extract_date(): +@pytest.mark.parametrize( + "selector, date_format, expected", + [ + ("time", "%B %d, %Y", datetime.fromisoformat("2023-01-15T10:30:00")), + (".named-date", "%B %d, %Y", datetime.fromisoformat("2023-01-15")), + (".date", "%Y-%m-%d", datetime.fromisoformat("2023-02-20")), + (".nonexistent", None, None), + ], +) +def test_extract_date(selector, date_format, expected): html = """
+ January 15, 2023 2023-02-20
March 10, 2023
""" soup = BeautifulSoup(html, "html.parser") - # Should extract datetime attribute from time tag - result = extract_date(soup, "time", "%Y-%m-%d") - assert result == "2023-01-15T10:30:00" - - # Should extract from text content - result = extract_date(soup, ".date", "%Y-%m-%d") - assert result == "2023-02-20T00:00:00" - - # No matching element - result = extract_date(soup, ".nonexistent", "%Y-%m-%d") - assert result is None + result = extract_date(soup, selector, date_format) + assert result == expected def test_extract_content_element(): @@ -393,8 +395,7 @@ def test_process_image_cached(mock_pil_open, mock_requests_get): @patch("memory.common.parsers.html.process_image") -@patch("memory.common.parsers.html.FILE_STORAGE_DIR") -def test_process_images_basic(mock_file_storage_dir, mock_process_image): +def test_process_images_basic(mock_process_image): html = """

Text content

@@ -409,40 +410,37 @@ def test_process_images_basic(mock_file_storage_dir, mock_process_image): content = cast(Tag, soup.find("div")) base_url = "https://example.com" - with tempfile.TemporaryDirectory() as temp_dir: - image_dir = pathlib.Path(temp_dir) - mock_file_storage_dir.resolve.return_value = pathlib.Path(temp_dir) + # Mock successful image processing with proper filenames + mock_images = [] + for i in range(3): + mock_img = MagicMock(spec=PILImage.Image) + mock_img.filename = str(settings.WEBPAGE_STORAGE_DIR / f"image{i + 1}.jpg") + mock_images.append(mock_img) - # Mock successful image processing with proper filenames - mock_images = [] - for i in range(3): - mock_img = MagicMock(spec=PILImage.Image) - mock_img.filename = str(pathlib.Path(temp_dir) / f"image{i + 1}.jpg") - mock_images.append(mock_img) + mock_process_image.side_effect = mock_images - mock_process_image.side_effect = mock_images + updated_content, images = process_images( + content, base_url, settings.WEBPAGE_STORAGE_DIR + ) - updated_content, images = process_images(content, base_url, image_dir) - - # Should have processed 3 images (skipping the one without src) - assert len(images) == 3 - assert mock_process_image.call_count == 3 - - # Check that img src attributes were updated to relative paths - img_tags = [ - tag - for tag in (updated_content.find_all("img") if updated_content else []) - if isinstance(tag, Tag) - ] - src_values = [] - for img in img_tags: - src = img.get("src") - if src and isinstance(src, str): - src_values.append(src) - - # Should have relative paths to the processed images - for src in src_values[:3]: # First 3 have src - assert not src.startswith("http") # Should be relative paths + expected = BeautifulSoup( + """
+

Text content

+ Image 1 + Image 2 + Image 3 + No src +

More text

+
+ """, + "html.parser", + ) + assert updated_content.prettify() == expected.prettify() # type: ignore + assert images == { + "images/image1.jpg": mock_images[0], + "images/image2.jpg": mock_images[1], + "images/image3.jpg": mock_images[2], + } def test_process_images_empty(): @@ -454,8 +452,7 @@ def test_process_images_empty(): @patch("memory.common.parsers.html.process_image") -@patch("memory.common.parsers.html.FILE_STORAGE_DIR") -def test_process_images_with_failures(mock_file_storage_dir, mock_process_image): +def test_process_images_with_failures(mock_process_image): html = """
Good image @@ -465,22 +462,20 @@ def test_process_images_with_failures(mock_file_storage_dir, mock_process_image) soup = BeautifulSoup(html, "html.parser") content = cast(Tag, soup.find("div")) - with tempfile.TemporaryDirectory() as temp_dir: - image_dir = pathlib.Path(temp_dir) - mock_file_storage_dir.resolve.return_value = pathlib.Path(temp_dir) + # First image succeeds, second fails + mock_good_image = MagicMock(spec=PILImage.Image) + mock_good_image.filename = settings.WEBPAGE_STORAGE_DIR / "good.jpg" + mock_process_image.side_effect = [mock_good_image, None] - # First image succeeds, second fails - mock_good_image = MagicMock(spec=PILImage.Image) - mock_good_image.filename = str(pathlib.Path(temp_dir) / "good.jpg") - mock_process_image.side_effect = [mock_good_image, None] + updated_content, images = process_images( + content, "https://example.com", settings.WEBPAGE_STORAGE_DIR + ) - updated_content, images = process_images( - content, "https://example.com", image_dir - ) - - # Should only return successful image - assert len(images) == 1 - assert images[0] == mock_good_image + expected = BeautifulSoup( + html.replace("good.jpg", "images/good.jpg"), "html.parser" + ).prettify() + assert updated_content.prettify() == expected # type: ignore + assert images == {"images/good.jpg": mock_good_image} @patch("memory.common.parsers.html.process_image") @@ -494,15 +489,12 @@ def test_process_images_no_filename(mock_process_image): mock_image.filename = None mock_process_image.return_value = mock_image - with tempfile.TemporaryDirectory() as temp_dir: - image_dir = pathlib.Path(temp_dir) + updated_content, images = process_images( + content, "https://example.com", settings.WEBPAGE_STORAGE_DIR + ) - updated_content, images = process_images( - content, "https://example.com", image_dir - ) - - # Should skip image without filename - assert len(images) == 0 + # Should skip image without filename + assert not images class TestBaseHTMLParser: @@ -541,7 +533,7 @@ class TestBaseHTMLParser: assert article.title == "Article Title" assert article.author == "John Smith" # Should prefer content over meta - assert article.published_date == "2023-01-15T00:00:00" + assert article.published_date == datetime(2023, 1, 15) assert article.url == "https://example.com/article" assert "This is the main content" in article.content assert article.metadata["author"] == "Jane Doe" diff --git a/tests/memory/common/test_embedding.py b/tests/memory/common/test_embedding.py index adf3602..f921802 100644 --- a/tests/memory/common/test_embedding.py +++ b/tests/memory/common/test_embedding.py @@ -5,14 +5,10 @@ from typing import cast import pytest from PIL import Image -from memory.common import settings +from memory.common import settings, collections from memory.common.embedding import ( - embed, - embed_file, embed_mixed, - embed_page, embed_text, - get_modality, make_chunk, write_to_file, ) @@ -50,7 +46,7 @@ def mock_embed(mock_voyage_client): ], ) def test_get_modality(mime_type, expected_modality): - assert get_modality(mime_type) == expected_modality + assert collections.get_modality(mime_type) == expected_modality def test_embed_text(mock_embed): @@ -58,59 +54,11 @@ def test_embed_text(mock_embed): assert embed_text(texts) == [[0], [1]] -def test_embed_file(mock_embed, tmp_path): - mock_file = tmp_path / "test.txt" - mock_file.write_text("file content") - - assert embed_file(mock_file) == [[0]] - - def test_embed_mixed(mock_embed): items = ["text", {"type": "image", "data": "base64"}] assert embed_mixed(items) == [[0]] -def test_embed_page_text_only(mock_embed): - page = {"contents": ["text1", "text2"]} - assert embed_page(page) == [[0], [1]] # type: ignore - - -def test_embed_page_mixed_content(mock_embed): - page = {"contents": ["text", {"type": "image", "data": "base64"}]} - assert embed_page(page) == [[0]] # type: ignore - - -def test_embed(mock_embed): - mime_type = "text/plain" - content = "sample content" - metadata = {"source": "test"} - - with patch.object(uuid, "uuid4", return_value="id1"): - modality, chunks = embed(mime_type, content, metadata) - - assert modality == "text" - assert [ - { - "id": c.id, # type: ignore - "file_path": c.file_path, # type: ignore - "content": c.content, # type: ignore - "embedding_model": c.embedding_model, # type: ignore - "vector": c.vector, # type: ignore - "item_metadata": c.item_metadata, # type: ignore - } - for c in chunks - ] == [ - { - "content": "sample content", - "embedding_model": "voyage-3-large", - "file_path": None, - "id": "id1", - "item_metadata": {"source": "test"}, - "vector": [0], - }, - ] - - def test_write_to_file_text(mock_file_storage): """Test writing a string to a file.""" chunk_id = "test-chunk-id" @@ -160,17 +108,14 @@ def test_write_to_file_unsupported_type(mock_file_storage): def test_make_chunk_text_only(mock_file_storage, db_session): """Test creating a chunk from string content.""" - page = { - "contents": ["text content 1", "text content 2"], - "metadata": {"source": "test"}, - } + contents = ["text content 1", "text content 2"] vector = [0.1, 0.2, 0.3] metadata = {"doc_type": "test", "source": "unit-test"} with patch.object( uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001") ): - chunk = make_chunk(page, vector, metadata) # type: ignore + chunk = make_chunk(contents, vector, metadata) # type: ignore assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001" assert cast(str, chunk.content) == "text content 1\n\ntext content 2" @@ -183,14 +128,14 @@ def test_make_chunk_text_only(mock_file_storage, db_session): def test_make_chunk_single_image(mock_file_storage, db_session): """Test creating a chunk from a single image.""" img = Image.new("RGB", (100, 100), color=(73, 109, 137)) - page = {"contents": [img], "metadata": {"source": "test"}} + contents = [img] vector = [0.1, 0.2, 0.3] metadata = {"doc_type": "test", "source": "unit-test"} with patch.object( uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002") ): - chunk = make_chunk(page, vector, metadata) # type: ignore + chunk = make_chunk(contents, vector, metadata) # type: ignore assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002" assert chunk.content is None @@ -208,14 +153,14 @@ def test_make_chunk_single_image(mock_file_storage, db_session): def test_make_chunk_mixed_content(mock_file_storage, db_session): """Test creating a chunk from mixed content (string and image).""" img = Image.new("RGB", (100, 100), color=(73, 109, 137)) - page = {"contents": ["text content", img], "metadata": {"source": "test"}} + contents = ["text content", img] vector = [0.1, 0.2, 0.3] metadata = {"doc_type": "test", "source": "unit-test"} with patch.object( uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003") ): - chunk = make_chunk(page, vector, metadata) # type: ignore + chunk = make_chunk(contents, vector, metadata) # type: ignore assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003" assert chunk.content is None @@ -233,3 +178,181 @@ def test_make_chunk_mixed_content(mock_file_storage, db_session): assert ( settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png" ).exists() + + +@pytest.mark.parametrize( + "data,embedding_model,collection,expected_model,expected_count,expected_has_content", + [ + # Text-only with default model + ( + ["text content 1", "text content 2"], + None, + None, + settings.TEXT_EMBEDDING_MODEL, + 2, + True, + ), + # Text with explicit mixed model - but make_chunk still uses TEXT_EMBEDDING_MODEL for text-only content + ( + ["text content"], + settings.MIXED_EMBEDDING_MODEL, + None, + settings.TEXT_EMBEDDING_MODEL, + 1, + True, + ), + # Text collection model selection - make_chunk uses TEXT_EMBEDDING_MODEL for text-only content + (["text content"], None, "mail", settings.TEXT_EMBEDDING_MODEL, 1, True), + (["text content"], None, "photo", settings.TEXT_EMBEDDING_MODEL, 1, True), + (["text content"], None, "doc", settings.TEXT_EMBEDDING_MODEL, 1, True), + # Unknown collection falls back to default + (["text content"], None, "unknown", settings.TEXT_EMBEDDING_MODEL, 1, True), + # Explicit model takes precedence over collection + ( + ["text content"], + settings.TEXT_EMBEDDING_MODEL, + "photo", + settings.TEXT_EMBEDDING_MODEL, + 1, + True, + ), + ], +) +def test_embed_data_chunk_scenarios( + data, + embedding_model, + collection, + expected_model, + expected_count, + expected_has_content, + mock_embed, + mock_file_storage, +): + """Test various embedding scenarios for data chunks.""" + from memory.common.extract import DataChunk + from memory.common.embedding import embed_data_chunk + + chunk = DataChunk( + data=data, + embedding_model=embedding_model, + collection=collection, + metadata={"source": "test"}, + ) + + result = embed_data_chunk(chunk, {"doc_type": "test"}) + + assert len(result) == expected_count + assert all(cast(str, c.embedding_model) == expected_model for c in result) + if expected_has_content: + assert all(c.content is not None for c in result) + assert all(c.file_path is None for c in result) + else: + assert all(c.content is None for c in result) + assert all(c.file_path is not None for c in result) + assert all( + c.item_metadata == {"source": "test", "doc_type": "test"} for c in result + ) + + +def test_embed_data_chunk_mixed_content(mock_embed, mock_file_storage): + """Test embedding mixed content (text and images).""" + from memory.common.extract import DataChunk + from memory.common.embedding import embed_data_chunk + + img = Image.new("RGB", (100, 100), color=(73, 109, 137)) + chunk = DataChunk( + data=["text content", img], + embedding_model=settings.MIXED_EMBEDDING_MODEL, + metadata={"source": "test"}, + ) + + result = embed_data_chunk(chunk) + + assert len(result) == 1 # Mixed content returns single vector + assert result[0].content is None # Mixed content stored in files + assert result[0].file_path is not None + assert cast(str, result[0].embedding_model) == settings.MIXED_EMBEDDING_MODEL + + +@pytest.mark.parametrize( + "chunk_max_size,chunk_size_param,expected_chunk_size", + [ + (512, 1024, 512), # chunk.max_size takes precedence + (None, 2048, 2048), # chunk_size parameter used when max_size is None + (256, None, 256), # chunk.max_size used when parameter is None + ], +) +def test_embed_data_chunk_chunk_size_handling( + chunk_max_size, chunk_size_param, expected_chunk_size, mock_embed, mock_file_storage +): + """Test chunk size parameter handling.""" + from memory.common.extract import DataChunk + from memory.common.embedding import embed_data_chunk + + chunk = DataChunk( + data=["text content"], max_size=chunk_max_size, metadata={"source": "test"} + ) + + with patch("memory.common.embedding.embed_text") as mock_embed_text: + mock_embed_text.return_value = [[0.1, 0.2, 0.3]] + + result = embed_data_chunk(chunk, chunk_size=chunk_size_param) + + mock_embed_text.assert_called_once() + args, kwargs = mock_embed_text.call_args + assert kwargs["chunk_size"] == expected_chunk_size + + +def test_embed_data_chunk_metadata_merging(mock_embed, mock_file_storage): + """Test that chunk metadata and parameter metadata are properly merged.""" + from memory.common.extract import DataChunk + from memory.common.embedding import embed_data_chunk + + chunk = DataChunk( + data=["text content"], metadata={"source": "test", "type": "chunk"} + ) + metadata = { + "doc_type": "test", + "source": "override", + } # chunk.metadata takes precedence over parameter metadata + + result = embed_data_chunk(chunk, metadata) + + assert len(result) == 1 + expected_metadata = { + "source": "test", + "type": "chunk", + "doc_type": "test", + } # chunk source wins + assert result[0].item_metadata == expected_metadata + + +def test_embed_data_chunk_unsupported_model(mock_embed, mock_file_storage): + """Test error handling for unsupported embedding model.""" + from memory.common.extract import DataChunk + from memory.common.embedding import embed_data_chunk + + chunk = DataChunk( + data=["text content"], + embedding_model="unsupported-model", + metadata={"source": "test"}, + ) + + with pytest.raises(ValueError, match="Unsupported model: unsupported-model"): + embed_data_chunk(chunk) + + +def test_embed_data_chunk_empty_data(mock_embed, mock_file_storage): + """Test handling of empty data.""" + from memory.common.extract import DataChunk + from memory.common.embedding import embed_data_chunk + + chunk = DataChunk(data=[], metadata={"source": "test"}) + + # Should handle empty data gracefully + with patch("memory.common.embedding.embed_text") as mock_embed_text: + mock_embed_text.return_value = [] + + result = embed_data_chunk(chunk) + + assert result == [] diff --git a/tests/memory/common/test_extract.py b/tests/memory/common/test_extract.py index 8be3aef..6a40bdc 100644 --- a/tests/memory/common/test_extract.py +++ b/tests/memory/common/test_extract.py @@ -7,7 +7,6 @@ import shutil from memory.common.extract import ( as_file, extract_text, - extract_content, doc_to_images, extract_image, docx_to_pdf, @@ -107,52 +106,6 @@ def test_extract_image_with_str(): extract_image("test") -@pytest.mark.parametrize( - "mime_type,content", - [ - ("text/plain", "Text content"), - ("text/html", "content"), - ("text/markdown", "# Heading"), - ("text/csv", "a,b,c"), - ], -) -def test_extract_content_different_text_types(mime_type, content): - assert extract_content(mime_type, content) == [ - {"contents": [content], "metadata": {}} - ] - - -def test_extract_content_pdf(): - result = extract_content("application/pdf", REGULAMIN) - - assert len(result) == 2 - assert all( - isinstance(page["contents"], list) - and all(isinstance(c, Image.Image) for c in page["contents"]) - for page in result - ) - assert all("page" in page["metadata"] for page in result) - assert all("width" in page["metadata"] for page in result) - assert all("height" in page["metadata"] for page in result) - - -def test_extract_content_image(tmp_path): - # Create a test image - img = Image.new("RGB", (100, 100), color="red") - img_path = tmp_path / "test_img.png" - img.save(img_path) - - (result,) = extract_content("image/png", img_path) - - assert isinstance(result["contents"][0], Image.Image) - assert result["contents"][0].size == (100, 100) - assert result["metadata"] == {} - - -def test_extract_content_unsupported_type(): - assert extract_content("unsupported/type", "content") == [] - - @pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed") def test_docx_to_pdf(tmp_path): output_path = tmp_path / "output.pdf" diff --git a/tests/memory/workers/tasks/test_comic_tasks.py b/tests/memory/workers/tasks/test_comic_tasks.py new file mode 100644 index 0000000..6b9f5db --- /dev/null +++ b/tests/memory/workers/tasks/test_comic_tasks.py @@ -0,0 +1,431 @@ +import pytest +from datetime import datetime +from unittest.mock import Mock, patch + +from memory.common.db.models import Comic +from memory.workers.tasks import comic +import requests + + +@pytest.fixture +def mock_comic_info(): + """Mock comic info data for testing.""" + return { + "title": "Test Comic", + "image_url": "https://example.com/comic.png", + "url": "https://example.com/comic/1", + "published_date": "2024-01-01T12:00:00Z", + } + + +@pytest.fixture +def mock_feed_data(): + """Mock RSS feed data.""" + return { + "entries": [ + {"link": "https://example.com/comic/1", "id": None}, + {"link": "https://example.com/comic/2", "id": None}, + {"link": None, "id": "https://example.com/comic/3"}, + ] + } + + +@pytest.fixture +def mock_image_response(): + """Mock HTTP response for comic image.""" + # 1x1 PNG image (smallest valid PNG) + png_data = ( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01" + b"\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13" + b"\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```" + b"\x00\x00\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82" + ) + response = Mock() + response.status_code = 200 + response.content = png_data + with patch.object(requests, "get", return_value=response): + yield response + + +@patch("memory.workers.tasks.comic.feedparser.parse") +def test_find_new_urls_success(mock_parse, mock_feed_data, db_session): + """Test successful URL discovery from RSS feed.""" + mock_parse.return_value = Mock(entries=mock_feed_data["entries"]) + + result = comic.find_new_urls("https://example.com", "https://example.com/rss") + + assert result == { + "https://example.com/comic/1", + "https://example.com/comic/2", + "https://example.com/comic/3", + } + mock_parse.assert_called_once_with("https://example.com/rss") + + +@patch("memory.workers.tasks.comic.feedparser.parse") +def test_find_new_urls_with_existing_comics(mock_parse, mock_feed_data, db_session): + """Test URL discovery when some comics already exist.""" + mock_parse.return_value = Mock(entries=mock_feed_data["entries"]) + + # Add existing comic to database + existing_comic = Comic( + title="Existing Comic", + url="https://example.com/comic/1", + author="https://example.com", + filename="/test/path", + sha256=b"test_hash", + modality="comic", + tags=["comic"], + ) + db_session.add(existing_comic) + db_session.commit() + + result = comic.find_new_urls("https://example.com", "https://example.com/rss") + + # Should only return URLs not in database + assert result == { + "https://example.com/comic/2", + "https://example.com/comic/3", + } + + +@patch("memory.workers.tasks.comic.feedparser.parse") +def test_find_new_urls_parse_error(mock_parse): + """Test handling of RSS feed parsing errors.""" + mock_parse.side_effect = Exception("Parse error") + + assert ( + comic.find_new_urls("https://example.com", "https://example.com/rss") == set() + ) + + +@patch("memory.workers.tasks.comic.feedparser.parse") +def test_find_new_urls_empty_feed(mock_parse): + """Test handling of empty RSS feed.""" + mock_parse.return_value = Mock(entries=[]) + + result = comic.find_new_urls("https://example.com", "https://example.com/rss") + + assert result == set() + + +@patch("memory.workers.tasks.comic.feedparser.parse") +def test_find_new_urls_malformed_entries(mock_parse): + """Test handling of malformed RSS entries.""" + mock_parse.return_value = Mock( + entries=[ + {"link": None, "id": None}, # Both None + {}, # Missing keys + ] + ) + + result = comic.find_new_urls("https://example.com", "https://example.com/rss") + + assert result == set() + + +@patch("memory.workers.tasks.comic.sync_comic.delay") +@patch("memory.workers.tasks.comic.find_new_urls") +def test_fetch_new_comics_success(mock_find_urls, mock_sync_delay, mock_comic_info): + """Test successful comic fetching.""" + mock_find_urls.return_value = {"https://example.com/comic/1"} + mock_parser = Mock(return_value=mock_comic_info) + + result = comic.fetch_new_comics( + "https://example.com", "https://example.com/rss", mock_parser + ) + + assert result == {"https://example.com/comic/1"} + mock_parser.assert_called_once_with("https://example.com/comic/1") + expected_call_args = { + **mock_comic_info, + "author": "https://example.com", + "url": "https://example.com/comic/1", + } + mock_sync_delay.assert_called_once_with(**expected_call_args) + + +@patch("memory.workers.tasks.comic.sync_comic.delay") +@patch("memory.workers.tasks.comic.find_new_urls") +def test_fetch_new_comics_no_new_urls(mock_find_urls, mock_sync_delay): + """Test when no new URLs are found.""" + mock_find_urls.return_value = set() + mock_parser = Mock() + + result = comic.fetch_new_comics( + "https://example.com", "https://example.com/rss", mock_parser + ) + + assert result == set() + mock_parser.assert_not_called() + mock_sync_delay.assert_not_called() + + +@patch("memory.workers.tasks.comic.sync_comic.delay") +@patch("memory.workers.tasks.comic.find_new_urls") +def test_fetch_new_comics_multiple_urls( + mock_find_urls, mock_sync_delay, mock_comic_info +): + """Test fetching multiple new comics.""" + urls = {"https://example.com/comic/1", "https://example.com/comic/2"} + mock_find_urls.return_value = urls + mock_parser = Mock(return_value=mock_comic_info) + + result = comic.fetch_new_comics( + "https://example.com", "https://example.com/rss", mock_parser + ) + + assert result == urls + assert mock_parser.call_count == 2 + assert mock_sync_delay.call_count == 2 + + +@patch("memory.workers.tasks.comic.requests.get") +def test_sync_comic_success(mock_get, mock_image_response, db_session, qdrant): + """Test successful comic synchronization.""" + mock_get.return_value = mock_image_response + + comic.sync_comic( + url="https://example.com/comic/1", + image_url="https://example.com/image.png", + title="Test Comic", + author="https://example.com", + published_date=datetime(2024, 1, 1, 12, 0, 0), + ) + + # Verify comic was created in database + saved_comic = ( + db_session.query(Comic) + .filter(Comic.url == "https://example.com/comic/1") + .first() + ) + assert saved_comic is not None + assert saved_comic.title == "Test Comic" + assert saved_comic.author == "https://example.com" + assert saved_comic.mime_type == "image/png" + assert saved_comic.size == len(mock_image_response.content) + assert "comic" in saved_comic.tags + assert "https://example.com" in saved_comic.tags + + # Verify vectors were added to Qdrant + vectors, _ = qdrant.scroll(collection_name="comic") + expected_vectors = [ + ( + { + "author": "https://example.com", + "published": "2024-01-01T12:00:00", + "tags": ["comic", "https://example.com"], + "title": "Test Comic", + "url": "https://example.com/comic/1", + "source_id": 1, + }, + None, + ) + ] + assert [ + ({**v.payload, "tags": sorted(v.payload["tags"])}, v.vector) for v in vectors + ] == expected_vectors + + +def test_sync_comic_already_exists(db_session): + """Test that duplicate comics are not processed.""" + # Add existing comic + existing_comic = Comic( + title="Existing Comic", + url="https://example.com/comic/1", + author="https://example.com", + filename="/test/path", + sha256=b"test_hash", + modality="comic", + tags=["comic"], + ) + db_session.add(existing_comic) + db_session.commit() + + with patch("memory.workers.tasks.comic.requests.get") as mock_get: + result = comic.sync_comic( + url="https://example.com/comic/1", + image_url="https://example.com/image.png", + title="Test Comic", + author="https://example.com", + ) + + # Should return early without making HTTP request + mock_get.assert_not_called() + assert result == {"comic_id": 1, "status": "already_exists"} + + +@patch("memory.workers.tasks.comic.requests.get") +def test_sync_comic_http_error(mock_get, db_session, qdrant): + """Test handling of HTTP errors when downloading image.""" + mock_response = Mock() + mock_response.status_code = 404 + mock_response.content = b"" + mock_get.return_value = mock_response + + comic.sync_comic( + url="https://example.com/comic/1", + image_url="https://example.com/image.png", + title="Test Comic", + author="https://example.com", + ) + + assert not ( + db_session.query(Comic) + .filter(Comic.url == "https://example.com/comic/1") + .first() + ) + + +@patch("memory.workers.tasks.comic.requests.get") +def test_sync_comic_no_published_date( + mock_get, mock_image_response, db_session, qdrant +): + """Test comic sync without published date.""" + mock_get.return_value = mock_image_response + + comic.sync_comic( + url="https://example.com/comic/1", + image_url="https://example.com/image.png", + title="Test Comic", + author="https://example.com", + published_date=None, + ) + + saved_comic = ( + db_session.query(Comic) + .filter(Comic.url == "https://example.com/comic/1") + .first() + ) + assert saved_comic is not None + assert saved_comic.published is None + + +@patch("memory.workers.tasks.comic.requests.get") +def test_sync_comic_special_characters_in_title( + mock_get, mock_image_response, db_session, qdrant +): + """Test comic sync with special characters in title.""" + mock_get.return_value = mock_image_response + + comic.sync_comic( + url="https://example.com/comic/1", + image_url="https://example.com/image.png", + title="Test/Comic: With*Special?Characters", + author="https://example.com", + ) + + # Verify comic was created with cleaned title + saved_comic = ( + db_session.query(Comic) + .filter(Comic.url == "https://example.com/comic/1") + .first() + ) + assert saved_comic is not None + assert saved_comic.title == "Test/Comic: With*Special?Characters" + + +@patch( + "memory.common.embedding.embed_source_item", + side_effect=Exception("Embedding failed"), +) +def test_sync_comic_embedding_failure( + mock_embed_source_item, mock_image_response, db_session, qdrant +): + """Test handling of embedding failures.""" + result = comic.sync_comic( + url="https://example.com/comic/1", + image_url="https://example.com/image.png", + title="Test Comic", + author="https://example.com", + ) + assert result == { + "comic_id": 1, + "title": "Test Comic", + "status": "processed", + "chunks_count": 0, + "embed_status": "FAILED", + "content_length": 90, + } + + +@patch("memory.workers.tasks.comic.sync_xkcd.delay") +@patch("memory.workers.tasks.comic.sync_smbc.delay") +def test_sync_all_comics(mock_smbc_delay, mock_xkcd_delay): + """Test synchronization of all comics.""" + comic.sync_all_comics() + + mock_smbc_delay.assert_called_once() + mock_xkcd_delay.assert_called_once() + + +@patch("memory.workers.tasks.comic.sync_comic.delay") +@patch("memory.workers.tasks.comic.comics.extract_xkcd") +@patch("memory.workers.tasks.comic.comics.extract_smbc") +@patch("requests.get") +def test_trigger_comic_sync_smbc_navigation( + mock_get, mock_extract_smbc, mock_extract_xkcd, mock_sync_delay, mock_comic_info +): + """Test full SMBC comic sync with navigation.""" + # Mock HTML responses for navigation + mock_responses = [ + Mock(text=''), + Mock(text=''), + Mock(text="
No prev link
"), # End of navigation + ] + mock_get.side_effect = mock_responses + mock_extract_smbc.return_value = mock_comic_info + mock_extract_xkcd.return_value = mock_comic_info + + comic.trigger_comic_sync() + + # Should have called extract_smbc for each discovered URL + assert mock_extract_smbc.call_count == 2 + mock_extract_smbc.assert_any_call("https://smbc.com/comic/2") + mock_extract_smbc.assert_any_call("https://smbc.com/comic/1") + + # Should have called extract_xkcd for range 1-307 + assert mock_extract_xkcd.call_count == 307 + + # Should have queued sync tasks + assert mock_sync_delay.call_count == 2 + 307 # SMBC + XKCD + + +@patch("memory.workers.tasks.comic.sync_comic.delay") +@patch("memory.workers.tasks.comic.comics.extract_smbc") +@patch("requests.get") +def test_trigger_comic_sync_smbc_extraction_error( + mock_get, mock_extract_smbc, mock_sync_delay +): + """Test handling of extraction errors during full sync.""" + # Mock responses: first one has a prev link, second one doesn't + mock_responses = [ + Mock(text=''), + Mock(text="
No prev link
"), + ] + mock_get.side_effect = mock_responses + mock_extract_smbc.side_effect = Exception("Extraction failed") + + # Should not raise exception, just log error + comic.trigger_comic_sync() + + mock_extract_smbc.assert_called_once_with("https://smbc.com/comic/1") + mock_sync_delay.assert_not_called() + + +@patch("memory.workers.tasks.comic.sync_comic.delay") +@patch("memory.workers.tasks.comic.comics.extract_xkcd") +@patch("requests.get") +def test_trigger_comic_sync_xkcd_extraction_error( + mock_get, mock_extract_xkcd, mock_sync_delay +): + """Test handling of XKCD extraction errors during full sync.""" + mock_get.return_value = Mock(text="
No prev link
") + mock_extract_xkcd.side_effect = Exception("XKCD extraction failed") + + # Should not raise exception, just log errors + comic.trigger_comic_sync() + + # Should attempt all 307 XKCD comics + assert mock_extract_xkcd.call_count == 307 + mock_sync_delay.assert_not_called() diff --git a/tests/memory/workers/tasks/test_ebook_tasks.py b/tests/memory/workers/tasks/test_ebook_tasks.py index 6ace439..9dd4de4 100644 --- a/tests/memory/workers/tasks/test_ebook_tasks.py +++ b/tests/memory/workers/tasks/test_ebook_tasks.py @@ -51,24 +51,6 @@ def mock_ebook(): ) -@pytest.fixture -def mock_embedding(): - """Mock the embedding function to return dummy vectors.""" - with patch("memory.workers.tasks.ebook.embedding.embed") as mock: - mock.return_value = ( - "book", - [ - Chunk( - vector=[0.1] * 1024, - item_metadata={"test": "data"}, - content="Test content", - embedding_model="model", - ) - ], - ) - yield mock - - @pytest.fixture def mock_qdrant(): """Mock Qdrant operations.""" @@ -151,7 +133,7 @@ def test_create_book_and_sections(mock_ebook, db_session): assert getattr(chapter1, "parent_section_id") is None -def test_embed_sections(db_session, mock_embedding): +def test_embed_sections(db_session): """Test basic embedding sections workflow.""" # Create a test book first book = Book( @@ -187,31 +169,6 @@ def test_embed_sections(db_session, mock_embedding): assert hasattr(sections[0], "embed_status") -def test_push_to_qdrant(qdrant): - """Test pushing embeddings to Qdrant.""" - # Create test sections with chunks - mock_chunk = Mock( - id="00000000-0000-0000-0000-000000000000", - vector=[0.1] * 1024, - item_metadata={"test": "data"}, - ) - - mock_section = Mock(spec=BookSection) - mock_section.embed_status = "QUEUED" - mock_section.chunks = [mock_chunk] - - sections = [mock_section] - - ebook.push_to_qdrant(sections) # type: ignore - - assert {r.id: r.payload for r in qdrant.scroll(collection_name="book")[0]} == { - "00000000-0000-0000-0000-000000000000": { - "test": "data", - } - } - assert mock_section.embed_status == "STORED" - - @patch("memory.workers.tasks.ebook.parse_ebook") def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path): """Test successful book synchronization.""" @@ -229,7 +186,7 @@ def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path): "author": "Test Author", "status": "processed", "total_sections": 4, - "sections_embedded": 4, + "sections_embedded": 8, } book = db_session.query(Book).filter(Book.title == "Test Book").first() @@ -270,8 +227,9 @@ def test_sync_book_already_exists(mock_parse, mock_ebook, db_session, tmp_path): @patch("memory.workers.tasks.ebook.parse_ebook") +@patch("memory.common.embedding.embed_source_item") def test_sync_book_embedding_failure( - mock_parse, mock_ebook, db_session, tmp_path, mock_embedding + mock_embedding, mock_parse, mock_ebook, db_session, tmp_path ): """Test handling of embedding failures.""" book_file = tmp_path / "test.epub" @@ -307,14 +265,18 @@ def test_sync_book_qdrant_failure(mock_parse, mock_ebook, db_session, tmp_path): # Since embedding is already failing, this test will complete without hitting Qdrant # So let's just verify that the function completes without raising an exception with patch.object(ebook, "push_to_qdrant", side_effect=Exception("Qdrant failed")): - with pytest.raises(Exception, match="Qdrant failed"): - ebook.sync_book(str(book_file)) + assert ebook.sync_book(str(book_file)) == { + "status": "error", + "error": "Qdrant failed", + } def test_sync_book_file_not_found(): """Test handling of missing files.""" - with pytest.raises(FileNotFoundError): - ebook.sync_book("/nonexistent/file.epub") + assert ebook.sync_book("/nonexistent/file.epub") == { + "status": "error", + "error": "Book file not found: /nonexistent/file.epub", + } def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client): @@ -363,7 +325,7 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client): calls = mock_voyage_client.embed.call_args_list texts = [c[0][0] for c in calls] assert texts == [ - [large_section_content.strip()], [large_page_1.strip()], [large_page_2.strip()], + [large_section_content.strip()], ] diff --git a/tests/memory/workers/tasks/test_email_tasks.py b/tests/memory/workers/tasks/test_email_tasks.py index 8ae3db3..ca8f921 100644 --- a/tests/memory/workers/tasks/test_email_tasks.py +++ b/tests/memory/workers/tasks/test_email_tasks.py @@ -69,13 +69,21 @@ def test_email_account(db_session): def test_process_simple_email(db_session, test_email_account, qdrant): """Test processing a simple email message.""" - mail_message_id = process_message( + res = process_message( account_id=test_email_account.id, message_id="101", folder="INBOX", raw_email=SIMPLE_EMAIL_RAW, ) + mail_message_id = res["mail_message_id"] + assert res == { + "status": "processed", + "mail_message_id": mail_message_id, + "message_id": "101", + "chunks_count": 1, + "attachments_count": 0, + } mail_message = ( db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one() ) @@ -98,7 +106,7 @@ def test_process_email_with_attachment(db_session, test_email_account, qdrant): message_id="302", folder="Archive", raw_email=EMAIL_WITH_ATTACHMENT_RAW, - ) + )["mail_message_id"] # Check mail message specifics mail_message = ( db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one() @@ -125,25 +133,25 @@ def test_process_email_with_attachment(db_session, test_email_account, qdrant): def test_process_empty_message(db_session, test_email_account, qdrant): """Test processing an empty/invalid message.""" - source_id = process_message( + res = process_message( account_id=test_email_account.id, message_id="999", folder="Archive", raw_email="", ) - - assert source_id is None + assert res == {"reason": "empty_content", "status": "skipped"} def test_process_duplicate_message(db_session, test_email_account, qdrant): """Test that duplicate messages are detected and not stored again.""" # First call should succeed and create records - source_id_1 = process_message( + res = process_message( account_id=test_email_account.id, message_id="101", folder="INBOX", raw_email=SIMPLE_EMAIL_RAW, ) + source_id_1 = res.get("mail_message_id") assert source_id_1 is not None, "First call should return a source_id" @@ -157,7 +165,7 @@ def test_process_duplicate_message(db_session, test_email_account, qdrant): message_id="101", folder="INBOX", raw_email=SIMPLE_EMAIL_RAW, - ) + ).get("mail_message_id") assert source_id_2 is None, "Second call should return None for duplicate message" diff --git a/tests/memory/workers/test_email.py b/tests/memory/workers/test_email.py index 154aca0..e17898b 100644 --- a/tests/memory/workers/test_email.py +++ b/tests/memory/workers/test_email.py @@ -12,7 +12,7 @@ from memory.common.db.models import ( EmailAttachment, MailMessage, ) -from memory.common.parsers.email import Attachment +from memory.common.parsers.email import Attachment, parse_email_message from memory.workers.email import ( create_mail_message, extract_email_uid, @@ -226,14 +226,14 @@ def test_create_mail_message(db_session): "--boundary--" ) folder = "INBOX" + parsed_email = parse_email_message(raw_email, "321") # Call function mail_message = create_mail_message( db_session=db_session, - raw_email=raw_email, folder=folder, tags=["test"], - message_id="123", + parsed_email=parsed_email, ) db_session.commit() @@ -344,6 +344,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4): recipients=["recipient@example.com"], content="This is a test email for vectorization", folder="INBOX", + modality="mail", ) db_session.add(mail_message) db_session.flush() @@ -373,6 +374,7 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4): recipients=["recipient@example.com"], content="This is a test email with attachments", folder="INBOX", + modality="mail", ) db_session.add(mail_message) db_session.flush()