diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 41f2b79..30dbeb3 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -15,7 +15,7 @@ from PIL import Image from pydantic import BaseModel from memory.common import embedding, qdrant, extract, settings -from memory.common.embedding import get_modality +from memory.common.collections import get_modality, TEXT_COLLECTIONS from memory.common.db.connection import make_session from memory.common.db.models import Chunk, SourceItem @@ -189,7 +189,7 @@ async def search( text_results = query_chunks( client, upload_data, - allowed_modalities & embedding.TEXT_COLLECTIONS, + allowed_modalities & TEXT_COLLECTIONS, embedding.embed_text, min_score=min_text_score, limit=limit, diff --git a/src/memory/common/collections.py b/src/memory/common/collections.py new file mode 100644 index 0000000..8a2c074 --- /dev/null +++ b/src/memory/common/collections.py @@ -0,0 +1,107 @@ +import logging +from typing import Literal, NotRequired, TypedDict + + +from memory.common import settings + +logger = logging.getLogger(__name__) + + +DistanceType = Literal["Cosine", "Dot", "Euclidean"] +Vector = list[float] + + +class Collection(TypedDict): + dimension: int + distance: DistanceType + model: str + on_disk: NotRequired[bool] + shards: NotRequired[int] + + +ALL_COLLECTIONS: dict[str, Collection] = { + "mail": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "chat": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "git": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "book": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "blog": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + "text": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.TEXT_EMBEDDING_MODEL, + }, + # Multimodal + "photo": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, + }, + "comic": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, + }, + "doc": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, + }, +} +TEXT_COLLECTIONS = { + coll + for coll, params in ALL_COLLECTIONS.items() + if params["model"] == settings.TEXT_EMBEDDING_MODEL +} +MULTIMODAL_COLLECTIONS = { + coll + for coll, params in ALL_COLLECTIONS.items() + if params["model"] == settings.MIXED_EMBEDDING_MODEL +} + +TYPES = { + "doc": ["application/pdf", "application/docx", "application/msword"], + "text": ["text/*"], + "blog": ["text/markdown", "text/html"], + "photo": ["image/*"], + "book": [ + "application/epub+zip", + "application/mobi", + "application/x-mobipocket-ebook", + ], +} + + +def get_modality(mime_type: str) -> str: + for type, mime_types in TYPES.items(): + if mime_type in mime_types: + return type + stem = mime_type.split("/")[0] + + for type, mime_types in TYPES.items(): + if any(mime_type.startswith(stem) for mime_type in mime_types): + return type + return "unknown" + + +def collection_model(collection: str) -> str | None: + return ALL_COLLECTIONS.get(collection, {}).get("model") diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index c329856..b7cca39 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -2,11 +2,12 @@ Database models for the knowledge base system. """ +from dataclasses import dataclass import pathlib import re import textwrap from datetime import datetime -from typing import Any, ClassVar, cast +from typing import Any, ClassVar, Iterable, Sequence, cast from PIL import Image from sqlalchemy import ( @@ -32,6 +33,9 @@ from sqlalchemy.orm import Session, relationship from memory.common import settings from memory.common.parsers.email import EmailMessage, parse_email_message +import memory.common.extract as extract +import memory.common.collections as collections +import memory.common.chunker as chunker Base = declarative_base() @@ -137,6 +141,7 @@ class SourceItem(Base): """Base class for all content in the system using SQLAlchemy's joined table inheritance.""" __tablename__ = "source_item" + __allow_unmapped__ = True id = Column(BigInteger, primary_key=True) modality = Column(Text, nullable=False) @@ -174,6 +179,14 @@ class SourceItem(Base): """Get vector IDs from associated chunks.""" return [chunk.id for chunk in self.chunks] + def data_chunks(self) -> Iterable[extract.DataChunk]: + return [ + extract.DataChunk( + data=cast(str, self.content), + collection=cast(str, self.modality), + ) + ] + def as_payload(self) -> dict: return { "source_id": self.id, @@ -306,6 +319,19 @@ class EmailAttachment(SourceItem): "tags": self.tags, } + def data_chunks(self) -> Iterable[extract.DataChunk]: + if cast(str | None, self.filename): + contents = pathlib.Path(cast(str, self.filename)).read_bytes() + else: + contents = cast(str, self.content) + + return extract.extract_data_chunks( + cast(str, self.mime_type), + contents, + collection=cast(str, self.modality), + embedding_model=collections.collection_model(cast(str, self.modality)), + ) + # Add indexes __table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),) @@ -407,6 +433,15 @@ class Comic(SourceItem): } return {k: v for k, v in payload.items() if v is not None} + def data_chunks(self) -> Iterable[extract.DataChunk]: + image = Image.open(pathlib.Path(cast(str, self.filename))) + return [ + extract.DataChunk( + data=[image, cast(str, self.title), cast(str, self.author)], + collection=cast(str, self.modality), + ) + ] + class Book(Base): """Book-level metadata table""" @@ -503,6 +538,19 @@ class BookSection(SourceItem): "tags": self.tags, } + def data_chunks(self) -> Iterable[extract.DataChunk]: + texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)] + texts += [(cast(str, self.content), self.start_page)] + return [ + extract.DataChunk( + data=[text], + collection=cast(str, self.modality), + metadata={"page": page_number}, + max_size=chunker.EMBEDDING_MAX_TOKENS, + ) + for text, page_number in texts + ] + class BlogPost(SourceItem): __tablename__ = "blog_post" @@ -519,6 +567,7 @@ class BlogPost(SourceItem): description = Column(Text, nullable=True) # Meta description or excerpt domain = Column(Text, nullable=True) # Domain of the source website word_count = Column(Integer, nullable=True) # Approximate word count + images = Column(ARRAY(Text), nullable=True) # List of image URLs # Store original metadata from parser webpage_metadata = Column(JSONB, nullable=True) @@ -552,6 +601,30 @@ class BlogPost(SourceItem): } return {k: v for k, v in payload.items() if v} + def data_chunks(self) -> Iterable[extract.DataChunk]: + images = [Image.open(image) for image in self.images] + data = [cast(str, self.content)] + images + + # Always embed the full content as a single chunk (if possible, of course) + chunks = [ + extract.DataChunk( + data=data, + collection=cast(str, self.modality), + max_size=chunker.EMBEDDING_MAX_TOKENS, + ) + ] + + # If the content is long enough, also embed it as chunks of the default size. + tokens = chunker.approx_token_count(cast(str, self.content)) + if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2: + chunks += [ + extract.DataChunk( + data=data, + collection=cast(str, self.modality), + ) + ] + return chunks + class MiscDoc(SourceItem): __tablename__ = "misc_doc" diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index be88ac1..6eb26d0 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -1,130 +1,21 @@ +from collections.abc import Sequence import logging import pathlib import uuid -from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast +from typing import Any, Iterable, Literal, cast import voyageai from PIL import Image from memory.common import extract, settings from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS -from memory.common.db.models import Chunk +from memory.common.collections import ALL_COLLECTIONS, Vector +from memory.common.db.models import Chunk, SourceItem +from memory.common.extract import DataChunk logger = logging.getLogger(__name__) -DistanceType = Literal["Cosine", "Dot", "Euclidean"] -Vector = list[float] - - -class Collection(TypedDict): - dimension: int - distance: DistanceType - model: str - on_disk: NotRequired[bool] - shards: NotRequired[int] - - -ALL_COLLECTIONS: dict[str, Collection] = { - "mail": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "chat": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "git": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "book": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "blog": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - "text": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, - }, - # Multimodal - "photo": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, - }, - "comic": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, - }, - "doc": { - "dimension": 1024, - "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, - }, -} -TEXT_COLLECTIONS = { - coll - for coll, params in ALL_COLLECTIONS.items() - if params["model"] == settings.TEXT_EMBEDDING_MODEL -} -MULTIMODAL_COLLECTIONS = { - coll - for coll, params in ALL_COLLECTIONS.items() - if params["model"] == settings.MIXED_EMBEDDING_MODEL -} - -TYPES = { - "doc": ["application/pdf", "application/docx", "application/msword"], - "text": ["text/*"], - "blog": ["text/markdown", "text/html"], - "photo": ["image/*"], - "book": [ - "application/epub+zip", - "application/mobi", - "application/x-mobipocket-ebook", - ], -} - - -def get_mimetype(image: Image.Image) -> str | None: - format_to_mime = { - "JPEG": "image/jpeg", - "PNG": "image/png", - "GIF": "image/gif", - "BMP": "image/bmp", - "TIFF": "image/tiff", - "WEBP": "image/webp", - } - - if not image.format: - return None - - return format_to_mime.get(image.format.upper(), f"image/{image.format.lower()}") - - -def get_modality(mime_type: str) -> str: - for type, mime_types in TYPES.items(): - if mime_type in mime_types: - return type - stem = mime_type.split("/")[0] - - for type, mime_types in TYPES.items(): - if any(mime_type.startswith(stem) for mime_type in mime_types): - return type - return "unknown" - - def embed_chunks( chunks: list[str] | list[list[extract.MulitmodalChunk]], model: str = settings.TEXT_EMBEDDING_MODEL, @@ -164,12 +55,6 @@ def embed_text( raise -def embed_file( - file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL -) -> list[Vector]: - return embed_text([file_path.read_text()], model) - - def embed_mixed( items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL, @@ -187,22 +72,6 @@ def embed_mixed( return embed_chunks([chunks], model, input_type) -def embed_page(page: extract.Page) -> list[Vector]: - contents = page["contents"] - chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS) - if all(isinstance(c, str) for c in contents): - return embed_text( - cast(list[str], contents), - model=settings.TEXT_EMBEDDING_MODEL, - chunk_size=chunk_size, - ) - return embed_mixed( - cast(list[extract.MulitmodalChunk], contents), - model=settings.MIXED_EMBEDDING_MODEL, - chunk_size=chunk_size, - ) - - def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path: if isinstance(item, str): filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt" @@ -219,7 +88,9 @@ def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path: def make_chunk( - page: extract.Page, vector: Vector, metadata: dict[str, Any] = {} + contents: Sequence[extract.MulitmodalChunk], + vector: Vector, + metadata: dict[str, Any] = {}, ) -> Chunk: """Create a Chunk object from a page and a vector. @@ -227,7 +98,6 @@ def make_chunk( a single image, or a list of strings and images. """ chunk_id = str(uuid.uuid4()) - contents = page["contents"] content, filename = None, None if all(isinstance(c, str) for c in contents): content = "\n\n".join(cast(list[str], contents)) @@ -251,38 +121,42 @@ def make_chunk( ) -def embed( - mime_type: str, - content: bytes | str | pathlib.Path, +def embed_data_chunk( + chunk: DataChunk, metadata: dict[str, Any] = {}, chunk_size: int | None = None, -) -> tuple[str, list[Chunk]]: - modality = get_modality(mime_type) - pages = extract.extract_content(mime_type, content, chunk_size=chunk_size) - chunks = [ - make_chunk(page, vector, metadata) - for page in pages - for vector in embed_page(page) +) -> list[Chunk]: + chunk_size = chunk.max_size or chunk_size or DEFAULT_CHUNK_TOKENS + + model = chunk.embedding_model + if not model and chunk.collection: + model = ALL_COLLECTIONS.get(chunk.collection, {}).get("model") + if not model: + model = settings.TEXT_EMBEDDING_MODEL + + if model == settings.TEXT_EMBEDDING_MODEL: + vectors = embed_text(cast(list[str], chunk.data), chunk_size=chunk_size) + elif model == settings.MIXED_EMBEDDING_MODEL: + vectors = embed_mixed( + cast(list[extract.MulitmodalChunk], chunk.data), + chunk_size=chunk_size, + ) + else: + raise ValueError(f"Unsupported model: {model}") + + metadata = metadata | chunk.metadata + return [make_chunk(chunk.data, vector, metadata) for vector in vectors] + + +def embed_source_item( + item: SourceItem, + metadata: dict[str, Any] = {}, + chunk_size: int | None = None, +) -> list[Chunk]: + return [ + chunk + for data_chunk in item.data_chunks() + for chunk in embed_data_chunk( + data_chunk, item.as_payload() | metadata, chunk_size + ) ] - return modality, chunks - - -def embed_image( - file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None -) -> Chunk: - image = Image.open(file_path) - mime_type = get_mimetype(image) - if mime_type is None: - raise ValueError("Unsupported image format") - - vector = embed_mixed( - [image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS - )[0] - - return Chunk( - id=str(uuid.uuid4()), - file_path=file_path.absolute().as_posix(), - content=None, - embedding_model=settings.MIXED_EMBEDDING_MODEL, - vector=vector, - ) diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index 83bfae0..7b7f4d2 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import io import logging import pathlib @@ -21,6 +22,15 @@ class Page(TypedDict): chunk_size: NotRequired[int] +@dataclass +class DataChunk: + data: Sequence[MulitmodalChunk] + collection: str | None = None + embedding_model: str | None = None + max_size: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + @contextmanager def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]: if isinstance(content, pathlib.Path): @@ -110,11 +120,13 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]: return [{"contents": [cast(str, content)], "metadata": {}}] -def extract_content( +def extract_data_chunks( mime_type: str, content: bytes | str | pathlib.Path, + collection: str | None = None, + embedding_model: str | None = None, chunk_size: int | None = None, -) -> list[Page]: +) -> list[DataChunk]: pages = [] logger.info(f"Extracting content from {mime_type}") if mime_type == "application/pdf": @@ -134,4 +146,12 @@ def extract_content( if chunk_size: pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages] - return pages + return [ + DataChunk( + data=page["contents"], + collection=collection, + embedding_model=embedding_model, + max_size=chunk_size, + ) + for page in pages + ] diff --git a/src/memory/common/parsers/email.py b/src/memory/common/parsers/email.py index ac0501b..89c7a0a 100644 --- a/src/memory/common/parsers/email.py +++ b/src/memory/common/parsers/email.py @@ -26,6 +26,7 @@ class EmailMessage(TypedDict): body: str attachments: list[Attachment] hash: bytes + raw_email: str RawEmailResponse = tuple[str | None, bytes] @@ -171,6 +172,7 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage: body = extract_body(msg) return EmailMessage( + raw_email=raw_email, message_id=message_id, subject=subject, sender=from_, diff --git a/src/memory/common/parsers/html.py b/src/memory/common/parsers/html.py index 286e12a..ab64d56 100644 --- a/src/memory/common/parsers/html.py +++ b/src/memory/common/parsers/html.py @@ -12,7 +12,7 @@ from bs4 import BeautifulSoup, Tag from markdownify import markdownify as md from PIL import Image as PILImage -from memory.common.settings import FILE_STORAGE_DIR, WEBPAGE_STORAGE_DIR +from memory.common import settings logger = logging.getLogger(__name__) @@ -96,9 +96,9 @@ def extract_date( datetime_attr = element.get("datetime") if datetime_attr: - date_str = str(datetime_attr) - if date := parse_date(date_str, date_format): - return date + for format in ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%d", date_format]: + if date := parse_date(str(datetime_attr), format): + return date for text in element.find_all(string=True): if text and (date := parse_date(str(text).strip(), date_format)): @@ -178,7 +178,7 @@ def process_images( continue path = pathlib.Path(image.filename) # type: ignore - img_tag["src"] = str(path.relative_to(FILE_STORAGE_DIR.resolve())) + img_tag["src"] = str(path.relative_to(settings.FILE_STORAGE_DIR.resolve())) images[img_tag["src"]] = image except Exception as e: logger.warning(f"Failed to process image {src}: {e}") @@ -291,7 +291,7 @@ class BaseHTMLParser: def __init__(self, base_url: str | None = None): self.base_url = base_url - self.image_dir = WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc) + self.image_dir = settings.WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc) self.image_dir.mkdir(parents=True, exist_ok=True) def parse(self, html: str, url: str) -> Article: diff --git a/src/memory/common/qdrant.py b/src/memory/common/qdrant.py index 9bfde5e..92b489b 100644 --- a/src/memory/common/qdrant.py +++ b/src/memory/common/qdrant.py @@ -5,12 +5,7 @@ import qdrant_client from qdrant_client.http import models as qdrant_models from qdrant_client.http.exceptions import UnexpectedResponse from memory.common import settings -from memory.common.embedding import ( - Collection, - ALL_COLLECTIONS, - DistanceType, - Vector, -) +from memory.common.collections import ALL_COLLECTIONS, Collection, DistanceType, Vector logger = logging.getLogger(__name__) diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index c9f3fbb..493637f 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -10,17 +10,16 @@ from typing import Callable, Generator, Sequence, cast from sqlalchemy.orm import Session, scoped_session -from memory.common import embedding, qdrant, settings +from memory.common import embedding, qdrant, settings, collections from memory.common.db.models import ( EmailAccount, EmailAttachment, MailMessage, - SourceItem, ) from memory.common.parsers.email import ( Attachment, + EmailMessage, RawEmailResponse, - parse_email_message, ) logger = logging.getLogger(__name__) @@ -54,7 +53,7 @@ def process_attachment( return None return EmailAttachment( - modality=embedding.get_modality(attachment["content_type"]), + modality=collections.get_modality(attachment["content_type"]), sha256=hashlib.sha256( real_content if real_content else str(attachment).encode() ).digest(), @@ -94,8 +93,7 @@ def create_mail_message( db_session: Session | scoped_session, tags: list[str], folder: str, - raw_email: str, - message_id: str, + parsed_email: EmailMessage, ) -> MailMessage: """ Create a new mail message record and associated attachments. @@ -109,7 +107,7 @@ def create_mail_message( Returns: Newly created MailMessage """ - parsed_email = parse_email_message(raw_email, message_id) + raw_email = parsed_email["raw_email"] mail_message = MailMessage( modality="mail", sha256=parsed_email["hash"], @@ -137,52 +135,6 @@ def create_mail_message( return mail_message -def does_message_exist( - db_session: Session | scoped_session, message_id: str, message_hash: bytes -) -> bool: - """ - Check if a message already exists in the database. - - Args: - db_session: Database session - message_id: Email message ID - message_hash: SHA-256 hash of message - - Returns: - True if message exists, False otherwise - """ - # Check by message_id first (faster) - if message_id: - mail_message = ( - db_session.query(MailMessage) - .filter(MailMessage.message_id == message_id) - .first() - ) - if mail_message is not None: - return True - - # Then check by message_hash - source_item = ( - db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first() - ) - return source_item is not None - - -def check_message_exists( - db: Session | scoped_session, account_id: int, message_id: str, raw_email: str -) -> bool: - account = db.query(EmailAccount).get(account_id) - if not account: - logger.error(f"Account {account_id} not found") - return False - - parsed_email = parse_email_message(raw_email, message_id) - if "szczepalins" in raw_email.lower(): - print(parsed_email["message_id"]) - - return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"]) - - def extract_email_uid( msg_data: Sequence[tuple[bytes, bytes]], ) -> tuple[str | None, bytes]: @@ -317,11 +269,7 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, def vectorize_email(email: MailMessage): qdrant_client = qdrant.get_qdrant_client() - _, chunks = embedding.embed( - "text/plain", - email.body, - metadata=email.as_payload(), - ) + chunks = embedding.embed_source_item(email) email.chunks = chunks if chunks: vector_ids = [cast(str, c.id) for c in chunks] @@ -329,7 +277,7 @@ def vectorize_email(email: MailMessage): metadata = [c.item_metadata for c in chunks] qdrant.upsert_vectors( client=qdrant_client, - collection_name="mail", + collection_name=cast(str, email.modality), ids=vector_ids, vectors=vectors, # type: ignore payloads=metadata, # type: ignore @@ -337,18 +285,12 @@ def vectorize_email(email: MailMessage): embeds = defaultdict(list) for attachment in email.attachments: - if attachment.filename: - content = pathlib.Path(attachment.filename).read_bytes() - else: - content = attachment.content - collection, chunks = embedding.embed( - attachment.mime_type, content, metadata=attachment.as_payload() - ) + chunks = embedding.embed_source_item(attachment) if not chunks: continue attachment.chunks = chunks - embeds[collection].extend(chunks) + embeds[attachment.modality].extend(chunks) for collection, chunks in embeds.items(): ids = [c.id for c in chunks] diff --git a/src/memory/workers/tasks/blogs.py b/src/memory/workers/tasks/blogs.py index 14609e1..6506972 100644 --- a/src/memory/workers/tasks/blogs.py +++ b/src/memory/workers/tasks/blogs.py @@ -1,99 +1,25 @@ -import hashlib import logging -from typing import Iterable, cast +from typing import Iterable -from memory.common import chunker, embedding, qdrant from memory.common.db.connection import make_session from memory.common.db.models import BlogPost from memory.common.parsers.blogs import parse_webpage from memory.workers.celery_app import app +from memory.workers.tasks.content_processing import ( + check_content_exists, + create_content_hash, + create_task_result, + process_content_item, + safe_task_execution, +) logger = logging.getLogger(__name__) - SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage" -def create_blog_post_from_article(article, tags: Iterable[str] = []) -> BlogPost: - """Create a BlogPost model from parsed article data.""" - return BlogPost( - url=article.url, - title=article.title, - published=article.published_date, - content=article.content, - sha256=hashlib.sha256(article.content.encode()).digest(), - modality="blog", - tags=tags, - mime_type="text/markdown", - size=len(article.content.encode("utf-8")), - ) - - -def embed_blog_post(blog_post: BlogPost) -> int: - """Embed blog post content and return count of successfully embedded chunks.""" - try: - # Always embed the full content - _, chunks = embedding.embed( - "text/markdown", - cast(str, blog_post.content), - metadata=blog_post.as_payload(), - chunk_size=chunker.EMBEDDING_MAX_TOKENS, - ) - # But also embed the content in chunks (unless it's really short) - if ( - chunker.approx_token_count(cast(str, blog_post.content)) - > chunker.DEFAULT_CHUNK_TOKENS * 2 - ): - _, small_chunks = embedding.embed( - "text/markdown", - cast(str, blog_post.content), - metadata=blog_post.as_payload(), - ) - chunks += small_chunks - - if chunks: - blog_post.chunks = chunks - blog_post.embed_status = "QUEUED" # type: ignore - return len(chunks) - else: - blog_post.embed_status = "FAILED" # type: ignore - logger.warning(f"No chunks generated for blog post: {blog_post.title}") - return 0 - - except Exception as e: - blog_post.embed_status = "FAILED" # type: ignore - logger.error(f"Failed to embed blog post {blog_post.title}: {e}") - return 0 - - -def push_to_qdrant(blog_post: BlogPost): - """Push embeddings to Qdrant for successfully embedded blog post.""" - if cast(str, blog_post.embed_status) != "QUEUED" or not blog_post.chunks: - return - - try: - vector_ids = [str(chunk.id) for chunk in blog_post.chunks] - vectors = [chunk.vector for chunk in blog_post.chunks] - payloads = [chunk.item_metadata for chunk in blog_post.chunks] - - qdrant.upsert_vectors( - client=qdrant.get_qdrant_client(), - collection_name="blog", - ids=vector_ids, - vectors=vectors, - payloads=payloads, - ) - - blog_post.embed_status = "STORED" # type: ignore - logger.info(f"Successfully stored embeddings for: {blog_post.title}") - - except Exception as e: - blog_post.embed_status = "FAILED" # type: ignore - logger.error(f"Failed to push embeddings to Qdrant: {e}") - raise - - @app.task(name=SYNC_WEBPAGE) +@safe_task_execution def sync_webpage(url: str, tags: Iterable[str] = []) -> dict: """ Synchronize a webpage from a URL. @@ -116,61 +42,25 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict: "content_length": 0, } - blog_post = create_blog_post_from_article(article, tags) + blog_post = BlogPost( + url=article.url, + title=article.title, + published=article.published_date, + content=article.content, + sha256=create_content_hash(article.content), + modality="blog", + tags=tags, + mime_type="text/markdown", + size=len(article.content.encode("utf-8")), + images=[image for image in article.images], + ) with make_session() as session: - existing_post = session.query(BlogPost).filter(BlogPost.url == url).first() - if existing_post: - logger.info(f"Blog post already exists: {existing_post.title}") - return { - "blog_post_id": existing_post.id, - "url": url, - "title": existing_post.title, - "status": "already_exists", - "chunks_count": len(existing_post.chunks), - } - - existing_post = ( - session.query(BlogPost).filter(BlogPost.sha256 == blog_post.sha256).first() + existing_post = check_content_exists( + session, BlogPost, url=url, sha256=create_content_hash(article.content) ) if existing_post: - logger.info( - f"Blog post with the same content already exists: {existing_post.title}" - ) - return { - "blog_post_id": existing_post.id, - "url": url, - "title": existing_post.title, - "status": "already_exists", - "chunks_count": len(existing_post.chunks), - } + logger.info(f"Blog post already exists: {existing_post.title}") + return create_task_result(existing_post, "already_exists", url=url) - session.add(blog_post) - session.flush() - - chunks_count = embed_blog_post(blog_post) - session.flush() - - try: - push_to_qdrant(blog_post) - logger.info( - f"Successfully processed webpage: {blog_post.title} " - f"({chunks_count} chunks embedded)" - ) - except Exception as e: - logger.error(f"Failed to push embeddings to Qdrant: {e}") - blog_post.embed_status = "FAILED" # type: ignore - - session.commit() - - return { - "blog_post_id": blog_post.id, - "url": url, - "title": blog_post.title, - "author": article.author, - "published_date": article.published_date, - "status": "processed", - "chunks_count": chunks_count, - "content_length": len(article.content), - "embed_status": blog_post.embed_status, - } + return process_content_item(blog_post, "blog", session, tags) diff --git a/src/memory/workers/tasks/comic.py b/src/memory/workers/tasks/comic.py index 0016403..ea650d4 100644 --- a/src/memory/workers/tasks/comic.py +++ b/src/memory/workers/tasks/comic.py @@ -1,4 +1,3 @@ -import hashlib import logging from datetime import datetime from typing import Callable, cast @@ -6,21 +5,25 @@ from typing import Callable, cast import feedparser import requests -from memory.common import embedding, qdrant, settings +from memory.common import settings from memory.common.db.connection import make_session from memory.common.db.models import Comic, clean_filename from memory.common.parsers import comics from memory.workers.celery_app import app +from memory.workers.tasks.content_processing import ( + check_content_exists, + create_content_hash, + process_content_item, + safe_task_execution, +) logger = logging.getLogger(__name__) - SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics" SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc" SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd" SYNC_COMIC = "memory.workers.tasks.comic.sync_comic" - BASE_SMBC_URL = "https://www.smbc-comics.com/" SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss" @@ -36,6 +39,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]: return set() urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries} + urls = {url for url in urls if url} with make_session() as session: known = { @@ -61,6 +65,7 @@ def fetch_new_comics( @app.task(name=SYNC_COMIC) +@safe_task_execution def sync_comic( url: str, image_url: str, @@ -70,20 +75,26 @@ def sync_comic( ): """Synchronize a comic from a URL.""" with make_session() as session: - if session.query(Comic).filter(Comic.url == url).first(): - return + existing_comic = check_content_exists(session, Comic, url=url) + if existing_comic: + return {"status": "already_exists", "comic_id": existing_comic.id} response = requests.get(image_url) + if response.status_code != 200: + return { + "status": "failed", + "error": f"Failed to download image: {response.status_code}", + } + file_type = image_url.split(".")[-1] mime_type = f"image/{file_type}" filename = ( settings.COMIC_STORAGE_DIR / clean_filename(author) / f"{title}.{file_type}" ) - if response.status_code == 200: - filename.parent.mkdir(parents=True, exist_ok=True) - filename.write_bytes(response.content) - sha256 = hashlib.sha256(f"{image_url}{published_date}".encode()).digest() + filename.parent.mkdir(parents=True, exist_ok=True) + filename.write_bytes(response.content) + comic = Comic( title=title, url=url, @@ -92,27 +103,13 @@ def sync_comic( filename=filename.resolve().as_posix(), mime_type=mime_type, size=len(response.content), - sha256=sha256, + sha256=create_content_hash(f"{image_url}{published_date}"), tags={"comic", author}, modality="comic", ) - chunk = embedding.embed_image(filename, [title, author]) - comic.chunks = [chunk] with make_session() as session: - session.add(comic) - session.add(chunk) - session.flush() - - qdrant.upsert_vectors( - client=qdrant.get_qdrant_client(), - collection_name="comic", - ids=[str(chunk.id)], - vectors=[chunk.vector], # type: ignore - payloads=[comic.as_payload()], - ) - - session.commit() + return process_content_item(comic, "comic", session) @app.task(name=SYNC_SMBC) diff --git a/src/memory/workers/tasks/content_processing.py b/src/memory/workers/tasks/content_processing.py new file mode 100644 index 0000000..3c3fb2c --- /dev/null +++ b/src/memory/workers/tasks/content_processing.py @@ -0,0 +1,267 @@ +""" +Content processing utilities for memory workers. + +This module provides core functionality for processing content items through +the complete workflow: existence checking, content hashing, embedding generation, +vector storage, and result tracking. +""" + +import hashlib +import traceback +import logging +from typing import Any, Callable, Iterable, Sequence, cast + +from memory.common import embedding, qdrant +from memory.common.db.models import SourceItem + +logger = logging.getLogger(__name__) + + +def check_content_exists( + session, + model_class: type[SourceItem], + **kwargs: Any, +) -> SourceItem | None: + """ + Check if content already exists in the database. + + Searches for existing content by any of the provided attributes + (typically URL, file_path, or SHA256 hash). + + Args: + session: Database session for querying + model_class: The SourceItem model class to search in + **kwargs: Attribute-value pairs to search for + + Returns: + Existing SourceItem if found, None otherwise + """ + for key, value in kwargs.items(): + if not hasattr(model_class, key): + continue + + existing = ( + session.query(model_class) + .filter(getattr(model_class, key) == value) + .first() + ) + if existing: + return existing + + return None + + +def create_content_hash(content: str, *additional_data: str) -> bytes: + """ + Create SHA256 hash from content and optional additional data. + + Args: + content: Primary content to hash + *additional_data: Additional strings to include in the hash + + Returns: + SHA256 hash digest as bytes + """ + hash_input = content + "".join(additional_data) + return hashlib.sha256(hash_input.encode()).digest() + + +def embed_source_item(source_item: SourceItem) -> int: + """ + Generate embeddings for a source item's content. + + Processes the source item through the embedding pipeline, creating + chunks and their corresponding vector embeddings. Updates the item's + embed_status based on success or failure. + + Args: + source_item: The SourceItem to embed + + Returns: + Number of successfully embedded chunks + + Side effects: + - Sets source_item.chunks with generated chunks + - Sets source_item.embed_status to "QUEUED" or "FAILED" + """ + try: + chunks = embedding.embed_source_item(source_item) + if chunks: + source_item.chunks = chunks + source_item.embed_status = "QUEUED" # type: ignore + return len(chunks) + else: + source_item.embed_status = "FAILED" # type: ignore + logger.warning( + f"No chunks generated for {type(source_item).__name__}: {getattr(source_item, 'title', 'unknown')}" + ) + return 0 + except Exception as e: + source_item.embed_status = "FAILED" # type: ignore + logger.error(f"Failed to embed {type(source_item).__name__}: {e}") + return 0 + + +def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str): + """ + Push embeddings to Qdrant vector database. + + Uploads vector embeddings for all source items that have been successfully + embedded (status "QUEUED") and have chunks available. + + Args: + source_items: Sequence of SourceItems to process + collection_name: Name of the Qdrant collection to store vectors in + + Raises: + Exception: If the Qdrant upsert operation fails + + Side effects: + - Updates embed_status to "STORED" for successful items + - Updates embed_status to "FAILED" for failed items + """ + items_to_process = [ + item + for item in source_items + if cast(str, getattr(item, "embed_status", None)) == "QUEUED" and item.chunks + ] + + if not items_to_process: + return + + all_chunks = [chunk for item in items_to_process for chunk in item.chunks] + if not all_chunks: + return + + try: + vector_ids = [str(chunk.id) for chunk in all_chunks] + vectors = [chunk.vector for chunk in all_chunks] + payloads = [chunk.item_metadata for chunk in all_chunks] + + qdrant.upsert_vectors( + client=qdrant.get_qdrant_client(), + collection_name=collection_name, + ids=vector_ids, + vectors=vectors, + payloads=payloads, + ) + + for item in items_to_process: + item.embed_status = "STORED" # type: ignore + logger.info( + f"Successfully stored embeddings for: {getattr(item, 'title', 'unknown')}" + ) + + except Exception as e: + for item in items_to_process: + item.embed_status = "FAILED" # type: ignore + logger.error(f"Failed to push embeddings to Qdrant: {e}") + raise + + +def create_task_result( + item: SourceItem, status: str, **additional_fields: Any +) -> dict[str, Any]: + """ + Create standardized task result dictionary. + + Generates a consistent result format for task execution reporting, + including item metadata and processing status. + + Args: + item: The processed SourceItem + status: Processing status string + **additional_fields: Extra fields to include in the result + + Returns: + Dictionary with standardized task result format + """ + return { + f"{type(item).__name__.lower()}_id": item.id, + "title": getattr(item, "title", None), + "status": status, + "chunks_count": len(item.chunks), + "embed_status": item.embed_status, + **additional_fields, + } + + +def process_content_item( + item: SourceItem, collection_name: str, session, tags: Iterable[str] = [] +) -> dict[str, Any]: + """ + Execute complete content processing workflow. + + Performs the full pipeline for processing a content item: + 1. Add to database session and flush to get ID + 2. Generate embeddings and chunks + 3. Push embeddings to Qdrant vector store + 4. Commit transaction and return result + + Args: + item: SourceItem to process + collection_name: Qdrant collection name for vector storage + session: Database session for persistence + tags: Optional tags to associate with the item (currently unused) + + Returns: + Task result dictionary with processing status and metadata + + Side effects: + - Adds item to database session + - Commits database transaction + - Stores vectors in Qdrant + """ + session.add(item) + session.flush() + + chunks_count = embed_source_item(item) + session.flush() + + try: + push_to_qdrant([item], collection_name) + status = "processed" + logger.info( + f"Successfully processed {type(item).__name__}: {getattr(item, 'title', 'unknown')} ({chunks_count} chunks embedded)" + ) + except Exception as e: + logger.error(f"Failed to push embeddings to Qdrant: {e}") + item.embed_status = "FAILED" # type: ignore + status = "failed" + + session.commit() + + return create_task_result(item, status, content_length=getattr(item, "size", 0)) + + +def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]: + """ + Decorator for safe task execution with comprehensive error handling. + + Wraps task functions to catch and log exceptions, ensuring tasks + always return a result dictionary even when they fail. + + Args: + func: Task function to wrap + + Returns: + Wrapped function that handles exceptions gracefully + + Example: + @safe_task_execution + def my_task(arg1, arg2): + # Task implementation + return {"status": "success"} + """ + + def wrapper(*args, **kwargs) -> dict: + try: + return func(*args, **kwargs) + except Exception as e: + logger.error( + f"Task {func.__name__} failed with traceback:\n{traceback.format_exc()}" + ) + logger.error(f"Task {func.__name__} failed: {e}") + return {"status": "error", "error": str(e)} + + return wrapper diff --git a/src/memory/workers/tasks/ebook.py b/src/memory/workers/tasks/ebook.py index 919f7f8..bb2a8d9 100644 --- a/src/memory/workers/tasks/ebook.py +++ b/src/memory/workers/tasks/ebook.py @@ -1,17 +1,21 @@ -import hashlib import logging from pathlib import Path from typing import Iterable, cast -from memory.common import chunker, embedding, qdrant -from memory.common.db.connection import make_session from memory.common.db.models import Book, BookSection from memory.common.parsers.ebook import Ebook, parse_ebook, Section +from memory.common.db.connection import make_session from memory.workers.celery_app import app +from memory.workers.tasks.content_processing import ( + check_content_exists, + create_content_hash, + embed_source_item, + push_to_qdrant, + safe_task_execution, +) logger = logging.getLogger(__name__) - SYNC_BOOK = "memory.workers.tasks.book.sync_book" # Minimum section length to embed (avoid noise from very short sections) @@ -46,10 +50,6 @@ def section_processor( ): content = "\n\n".join(section.pages).strip() if len(content) >= MIN_SECTION_LENGTH: - sha256 = hashlib.sha256( - f"{book.id}:{section.title}:{section.start_page}".encode() - ).digest() - book_section = BookSection( book_id=book.id, section_title=section.title, @@ -59,7 +59,9 @@ def section_processor( end_page=section.end_page, parent_section_id=None, # Will be set after flush content=content, - sha256=sha256, + sha256=create_content_hash( + f"{book.id}:{section.title}:{section.start_page}" + ), modality="book", tags=book.tags, pages=section.pages, @@ -127,76 +129,11 @@ def create_book_and_sections( def embed_sections(all_sections: list[BookSection]) -> int: """Embed all sections and return count of successfully embedded sections.""" - embedded_count = 0 - - def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]: - _, chunks = embedding.embed( - "text/plain", - text, - metadata=metadata, - chunk_size=chunker.EMBEDDING_MAX_TOKENS, - ) - return chunks - - for section in all_sections: - try: - section_chunks = embed_text( - cast(str, section.content), section.as_payload() - ) - page_chunks = [ - chunk - for i, page in enumerate(section.pages) - for chunk in embed_text( - page, section.as_payload() | {"page_number": i + section.start_page} - ) - ] - chunks = section_chunks + page_chunks - - if chunks: - section.chunks = chunks - section.embed_status = "QUEUED" # type: ignore - embedded_count += 1 - else: - section.embed_status = "FAILED" # type: ignore - logger.warning( - f"No chunks generated for section: {section.section_title}" - ) - - except IOError as e: - section.embed_status = "FAILED" # type: ignore - logger.error(f"Failed to embed section {section.section_title}: {e}") - - return embedded_count - - -def push_to_qdrant(all_sections: list[BookSection]): - """Push embeddings to Qdrant for all successfully embedded sections.""" - vector_ids = [] - vectors = [] - payloads = [] - - to_process = [s for s in all_sections if cast(str, s.embed_status) == "QUEUED"] - all_chunks = [chunk for section in to_process for chunk in section.chunks] - if not all_chunks: - return - - vector_ids = [str(chunk.id) for chunk in all_chunks] - vectors = [chunk.vector for chunk in all_chunks] - payloads = [chunk.item_metadata for chunk in all_chunks] - - qdrant.upsert_vectors( - client=qdrant.get_qdrant_client(), - collection_name="book", - ids=vector_ids, - vectors=vectors, - payloads=payloads, - ) - - for section in to_process: - section.embed_status = "STORED" # type: ignore + return sum(embed_source_item(section) for section in all_sections) @app.task(name=SYNC_BOOK) +@safe_task_execution def sync_book(file_path: str, tags: Iterable[str] = []) -> dict: """ Synchronize a book from a file path. @@ -211,10 +148,8 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict: with make_session() as session: # Check for existing book - existing_book = ( - session.query(Book) - .filter(Book.file_path == ebook.file_path.as_posix()) - .first() + existing_book = check_content_exists( + session, Book, file_path=ebook.file_path.as_posix() ) if existing_book: logger.info(f"Book already exists: {existing_book.title}") @@ -230,19 +165,10 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict: book, all_sections = create_book_and_sections(ebook, session, tags) # Embed sections - embedded_count = embed_sections(all_sections) + embedded_count = sum(embed_source_item(section) for section in all_sections) session.flush() - # Push to Qdrant - try: - push_to_qdrant(all_sections) - except Exception as e: - logger.error(f"Failed to push embeddings to Qdrant: {e}") - # Mark sections as failed - for section in all_sections: - if getattr(section, "embed_status") == "STORED": - section.embed_status = "FAILED" # type: ignore - raise + push_to_qdrant(all_sections, "book") session.commit() diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py index eb5e74a..97efd72 100644 --- a/src/memory/workers/tasks/email.py +++ b/src/memory/workers/tasks/email.py @@ -2,16 +2,19 @@ import logging from datetime import datetime from typing import cast from memory.common.db.connection import make_session -from memory.common.db.models import EmailAccount +from memory.common.db.models import EmailAccount, MailMessage from memory.workers.celery_app import app from memory.workers.email import ( - check_message_exists, create_mail_message, imap_connection, process_folder, vectorize_email, ) - +from memory.common.parsers.email import parse_email_message +from memory.workers.tasks.content_processing import ( + check_content_exists, + safe_task_execution, +) logger = logging.getLogger(__name__) @@ -21,12 +24,13 @@ SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts" @app.task(name=PROCESS_EMAIL) +@safe_task_execution def process_message( account_id: int, message_id: str, folder: str, raw_email: str, -) -> int | None: +) -> dict: """ Process a single email message and store it in the database. @@ -37,29 +41,30 @@ def process_message( raw_email: Raw email content as string Returns: - source_id if successful, None otherwise + dict with processing result """ logger.info(f"Processing message {message_id} for account {account_id}") if not raw_email.strip(): logger.warning(f"Empty email message received for account {account_id}") - return None + return {"status": "skipped", "reason": "empty_content"} with make_session() as db: - if check_message_exists(db, account_id, message_id, raw_email): - logger.debug(f"Message {message_id} already exists in database") - return None - account = db.query(EmailAccount).get(account_id) if not account: logger.error(f"Account {account_id} not found") - return None + return {"status": "error", "error": "Account not found"} - mail_message = create_mail_message( - db, account.tags, folder, raw_email, message_id - ) + parsed_email = parse_email_message(raw_email, message_id) + if check_content_exists( + db, MailMessage, message_id=message_id, sha256=parsed_email["hash"] + ): + return {"status": "already_exists", "message_id": message_id} + + mail_message = create_mail_message(db, account.tags, folder, parsed_email) db.flush() vectorize_email(mail_message) + db.commit() logger.info(f"Stored embedding for message {mail_message.message_id}") @@ -71,16 +76,24 @@ def process_message( for chunk in attachment.chunks: logger.info(f" - {chunk.id}") - return cast(int, mail_message.id) + return { + "status": "processed", + "mail_message_id": cast(int, mail_message.id), + "message_id": message_id, + "chunks_count": len(mail_message.chunks), + "attachments_count": len(mail_message.attachments), + } @app.task(name=SYNC_ACCOUNT) +@safe_task_execution def sync_account(account_id: int, since_date: str | None = None) -> dict: """ Synchronize emails from a specific account. Args: account_id: ID of the EmailAccount to sync + since_date: ISO format date string to sync since Returns: dict with stats about the sync operation @@ -91,7 +104,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict: account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first() if not account or not cast(bool, account.active): logger.warning(f"Account {account_id} not found or inactive") - return {"error": "Account not found or inactive"} + return {"status": "error", "error": "Account not found or inactive"} folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"] if since_date: @@ -108,7 +121,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict: def process_message_wrapper( account_id: int, message_id: str, folder: str, raw_email: str ) -> int | None: - if check_message_exists(db, account_id, message_id, raw_email): # type: ignore + parsed_email = parse_email_message(raw_email, message_id) + if check_content_exists( + db, MailMessage, message_id=message_id, sha256=parsed_email["hash"] + ): return None return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore @@ -127,9 +143,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict: db.commit() except Exception as e: logger.error(f"Error connecting to server {account.imap_server}: {str(e)}") - return {"error": str(e)} + return {"status": "error", "error": str(e)} return { + "status": "completed", "account": account.email_address, "since_date": cutoff_date.isoformat(), "folders_processed": len(folders_to_process), diff --git a/src/memory/workers/tasks/maintenance.py b/src/memory/workers/tasks/maintenance.py index cdd8085..f8c978d 100644 --- a/src/memory/workers/tasks/maintenance.py +++ b/src/memory/workers/tasks/maintenance.py @@ -6,7 +6,7 @@ from typing import Sequence from sqlalchemy import select from sqlalchemy.orm import contains_eager -from memory.common import embedding, qdrant, settings +from memory.common import collections, embedding, qdrant, settings from memory.common.db.connection import make_session from memory.common.db.models import Chunk, SourceItem from memory.workers.celery_app import app @@ -60,11 +60,11 @@ def reingest_chunk(chunk_id: str, collection: str): logger.error(f"Chunk {chunk_id} not found") return - if collection not in embedding.ALL_COLLECTIONS: + if collection not in collections.ALL_COLLECTIONS: raise ValueError(f"Unsupported collection {collection}") data = chunk.data - if collection in embedding.MULTIMODAL_COLLECTIONS: + if collection in collections.MULTIMODAL_COLLECTIONS: vector = embedding.embed_mixed(data)[0] elif len(data) == 1 and isinstance(data[0], str): vector = embedding.embed_text([data[0]])[0] diff --git a/tests/conftest.py b/tests/conftest.py index a3edd06..1ac921a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -205,9 +205,14 @@ def email_provider(): def mock_file_storage(tmp_path: Path): chunk_storage_dir = tmp_path / "chunks" chunk_storage_dir.mkdir(parents=True, exist_ok=True) - with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path): - with patch("memory.common.settings.CHUNK_STORAGE_DIR", chunk_storage_dir): - yield + image_storage_dir = tmp_path / "images" + image_storage_dir.mkdir(parents=True, exist_ok=True) + with ( + patch.object(settings, "FILE_STORAGE_DIR", tmp_path), + patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir), + patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir), + ): + yield @pytest.fixture diff --git a/tests/memory/common/parsers/test_email_parsers.py b/tests/memory/common/parsers/test_email_parsers.py index 449ea20..ccaaf63 100644 --- a/tests/memory/common/parsers/test_email_parsers.py +++ b/tests/memory/common/parsers/test_email_parsers.py @@ -249,6 +249,7 @@ def test_parse_simple_email(): "body": "Test body content\n", "attachments": [], "sent_at": ANY, + "raw_email": msg.as_string(), "hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87" + b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1", } diff --git a/tests/memory/common/parsers/test_html.py b/tests/memory/common/parsers/test_html.py index c0f5d76..0ad7e1f 100644 --- a/tests/memory/common/parsers/test_html.py +++ b/tests/memory/common/parsers/test_html.py @@ -12,6 +12,7 @@ import requests from bs4 import BeautifulSoup, Tag from PIL import Image as PILImage +from memory.common import settings from memory.common.parsers.html import ( Article, BaseHTMLParser, @@ -164,27 +165,28 @@ def test_parse_date(text, date_format, expected): assert parse_date(text, date_format) == expected -def test_extract_date(): +@pytest.mark.parametrize( + "selector, date_format, expected", + [ + ("time", "%B %d, %Y", datetime.fromisoformat("2023-01-15T10:30:00")), + (".named-date", "%B %d, %Y", datetime.fromisoformat("2023-01-15")), + (".date", "%Y-%m-%d", datetime.fromisoformat("2023-02-20")), + (".nonexistent", None, None), + ], +) +def test_extract_date(selector, date_format, expected): html = """
Text content
@@ -409,40 +410,37 @@ def test_process_images_basic(mock_file_storage_dir, mock_process_image): content = cast(Tag, soup.find("div")) base_url = "https://example.com" - with tempfile.TemporaryDirectory() as temp_dir: - image_dir = pathlib.Path(temp_dir) - mock_file_storage_dir.resolve.return_value = pathlib.Path(temp_dir) + # Mock successful image processing with proper filenames + mock_images = [] + for i in range(3): + mock_img = MagicMock(spec=PILImage.Image) + mock_img.filename = str(settings.WEBPAGE_STORAGE_DIR / f"image{i + 1}.jpg") + mock_images.append(mock_img) - # Mock successful image processing with proper filenames - mock_images = [] - for i in range(3): - mock_img = MagicMock(spec=PILImage.Image) - mock_img.filename = str(pathlib.Path(temp_dir) / f"image{i + 1}.jpg") - mock_images.append(mock_img) + mock_process_image.side_effect = mock_images - mock_process_image.side_effect = mock_images + updated_content, images = process_images( + content, base_url, settings.WEBPAGE_STORAGE_DIR + ) - updated_content, images = process_images(content, base_url, image_dir) - - # Should have processed 3 images (skipping the one without src) - assert len(images) == 3 - assert mock_process_image.call_count == 3 - - # Check that img src attributes were updated to relative paths - img_tags = [ - tag - for tag in (updated_content.find_all("img") if updated_content else []) - if isinstance(tag, Tag) - ] - src_values = [] - for img in img_tags: - src = img.get("src") - if src and isinstance(src, str): - src_values.append(src) - - # Should have relative paths to the processed images - for src in src_values[:3]: # First 3 have src - assert not src.startswith("http") # Should be relative paths + expected = BeautifulSoup( + """Text content
+More text
+