unify tasks

This commit is contained in:
Daniel O'Connell 2025-05-25 20:02:47 +02:00
parent ce6f4bf5c5
commit 4aaa45e09c
24 changed files with 1360 additions and 773 deletions

View File

@ -15,7 +15,7 @@ from PIL import Image
from pydantic import BaseModel
from memory.common import embedding, qdrant, extract, settings
from memory.common.embedding import get_modality
from memory.common.collections import get_modality, TEXT_COLLECTIONS
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
@ -189,7 +189,7 @@ async def search(
text_results = query_chunks(
client,
upload_data,
allowed_modalities & embedding.TEXT_COLLECTIONS,
allowed_modalities & TEXT_COLLECTIONS,
embedding.embed_text,
min_score=min_text_score,
limit=limit,

View File

@ -0,0 +1,107 @@
import logging
from typing import Literal, NotRequired, TypedDict
from memory.common import settings
logger = logging.getLogger(__name__)
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
model: str
on_disk: NotRequired[bool]
shards: NotRequired[int]
ALL_COLLECTIONS: dict[str, Collection] = {
"mail": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"chat": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"git": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"book": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"blog": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"text": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
# Multimodal
"photo": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
"comic": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
"doc": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
}
TEXT_COLLECTIONS = {
coll
for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.TEXT_EMBEDDING_MODEL
}
MULTIMODAL_COLLECTIONS = {
coll
for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.MIXED_EMBEDDING_MODEL
}
TYPES = {
"doc": ["application/pdf", "application/docx", "application/msword"],
"text": ["text/*"],
"blog": ["text/markdown", "text/html"],
"photo": ["image/*"],
"book": [
"application/epub+zip",
"application/mobi",
"application/x-mobipocket-ebook",
],
}
def get_modality(mime_type: str) -> str:
for type, mime_types in TYPES.items():
if mime_type in mime_types:
return type
stem = mime_type.split("/")[0]
for type, mime_types in TYPES.items():
if any(mime_type.startswith(stem) for mime_type in mime_types):
return type
return "unknown"
def collection_model(collection: str) -> str | None:
return ALL_COLLECTIONS.get(collection, {}).get("model")

View File

@ -2,11 +2,12 @@
Database models for the knowledge base system.
"""
from dataclasses import dataclass
import pathlib
import re
import textwrap
from datetime import datetime
from typing import Any, ClassVar, cast
from typing import Any, ClassVar, Iterable, Sequence, cast
from PIL import Image
from sqlalchemy import (
@ -32,6 +33,9 @@ from sqlalchemy.orm import Session, relationship
from memory.common import settings
from memory.common.parsers.email import EmailMessage, parse_email_message
import memory.common.extract as extract
import memory.common.collections as collections
import memory.common.chunker as chunker
Base = declarative_base()
@ -137,6 +141,7 @@ class SourceItem(Base):
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
__tablename__ = "source_item"
__allow_unmapped__ = True
id = Column(BigInteger, primary_key=True)
modality = Column(Text, nullable=False)
@ -174,6 +179,14 @@ class SourceItem(Base):
"""Get vector IDs from associated chunks."""
return [chunk.id for chunk in self.chunks]
def data_chunks(self) -> Iterable[extract.DataChunk]:
return [
extract.DataChunk(
data=cast(str, self.content),
collection=cast(str, self.modality),
)
]
def as_payload(self) -> dict:
return {
"source_id": self.id,
@ -306,6 +319,19 @@ class EmailAttachment(SourceItem):
"tags": self.tags,
}
def data_chunks(self) -> Iterable[extract.DataChunk]:
if cast(str | None, self.filename):
contents = pathlib.Path(cast(str, self.filename)).read_bytes()
else:
contents = cast(str, self.content)
return extract.extract_data_chunks(
cast(str, self.mime_type),
contents,
collection=cast(str, self.modality),
embedding_model=collections.collection_model(cast(str, self.modality)),
)
# Add indexes
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
@ -407,6 +433,15 @@ class Comic(SourceItem):
}
return {k: v for k, v in payload.items() if v is not None}
def data_chunks(self) -> Iterable[extract.DataChunk]:
image = Image.open(pathlib.Path(cast(str, self.filename)))
return [
extract.DataChunk(
data=[image, cast(str, self.title), cast(str, self.author)],
collection=cast(str, self.modality),
)
]
class Book(Base):
"""Book-level metadata table"""
@ -503,6 +538,19 @@ class BookSection(SourceItem):
"tags": self.tags,
}
def data_chunks(self) -> Iterable[extract.DataChunk]:
texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)]
texts += [(cast(str, self.content), self.start_page)]
return [
extract.DataChunk(
data=[text],
collection=cast(str, self.modality),
metadata={"page": page_number},
max_size=chunker.EMBEDDING_MAX_TOKENS,
)
for text, page_number in texts
]
class BlogPost(SourceItem):
__tablename__ = "blog_post"
@ -519,6 +567,7 @@ class BlogPost(SourceItem):
description = Column(Text, nullable=True) # Meta description or excerpt
domain = Column(Text, nullable=True) # Domain of the source website
word_count = Column(Integer, nullable=True) # Approximate word count
images = Column(ARRAY(Text), nullable=True) # List of image URLs
# Store original metadata from parser
webpage_metadata = Column(JSONB, nullable=True)
@ -552,6 +601,30 @@ class BlogPost(SourceItem):
}
return {k: v for k, v in payload.items() if v}
def data_chunks(self) -> Iterable[extract.DataChunk]:
images = [Image.open(image) for image in self.images]
data = [cast(str, self.content)] + images
# Always embed the full content as a single chunk (if possible, of course)
chunks = [
extract.DataChunk(
data=data,
collection=cast(str, self.modality),
max_size=chunker.EMBEDDING_MAX_TOKENS,
)
]
# If the content is long enough, also embed it as chunks of the default size.
tokens = chunker.approx_token_count(cast(str, self.content))
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks += [
extract.DataChunk(
data=data,
collection=cast(str, self.modality),
)
]
return chunks
class MiscDoc(SourceItem):
__tablename__ = "misc_doc"

View File

@ -1,130 +1,21 @@
from collections.abc import Sequence
import logging
import pathlib
import uuid
from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast
from typing import Any, Iterable, Literal, cast
import voyageai
from PIL import Image
from memory.common import extract, settings
from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
from memory.common.db.models import Chunk
from memory.common.collections import ALL_COLLECTIONS, Vector
from memory.common.db.models import Chunk, SourceItem
from memory.common.extract import DataChunk
logger = logging.getLogger(__name__)
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
model: str
on_disk: NotRequired[bool]
shards: NotRequired[int]
ALL_COLLECTIONS: dict[str, Collection] = {
"mail": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"chat": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"git": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"book": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"blog": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
"text": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
},
# Multimodal
"photo": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
"comic": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
"doc": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
}
TEXT_COLLECTIONS = {
coll
for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.TEXT_EMBEDDING_MODEL
}
MULTIMODAL_COLLECTIONS = {
coll
for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.MIXED_EMBEDDING_MODEL
}
TYPES = {
"doc": ["application/pdf", "application/docx", "application/msword"],
"text": ["text/*"],
"blog": ["text/markdown", "text/html"],
"photo": ["image/*"],
"book": [
"application/epub+zip",
"application/mobi",
"application/x-mobipocket-ebook",
],
}
def get_mimetype(image: Image.Image) -> str | None:
format_to_mime = {
"JPEG": "image/jpeg",
"PNG": "image/png",
"GIF": "image/gif",
"BMP": "image/bmp",
"TIFF": "image/tiff",
"WEBP": "image/webp",
}
if not image.format:
return None
return format_to_mime.get(image.format.upper(), f"image/{image.format.lower()}")
def get_modality(mime_type: str) -> str:
for type, mime_types in TYPES.items():
if mime_type in mime_types:
return type
stem = mime_type.split("/")[0]
for type, mime_types in TYPES.items():
if any(mime_type.startswith(stem) for mime_type in mime_types):
return type
return "unknown"
def embed_chunks(
chunks: list[str] | list[list[extract.MulitmodalChunk]],
model: str = settings.TEXT_EMBEDDING_MODEL,
@ -164,12 +55,6 @@ def embed_text(
raise
def embed_file(
file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL
) -> list[Vector]:
return embed_text([file_path.read_text()], model)
def embed_mixed(
items: list[extract.MulitmodalChunk],
model: str = settings.MIXED_EMBEDDING_MODEL,
@ -187,22 +72,6 @@ def embed_mixed(
return embed_chunks([chunks], model, input_type)
def embed_page(page: extract.Page) -> list[Vector]:
contents = page["contents"]
chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS)
if all(isinstance(c, str) for c in contents):
return embed_text(
cast(list[str], contents),
model=settings.TEXT_EMBEDDING_MODEL,
chunk_size=chunk_size,
)
return embed_mixed(
cast(list[extract.MulitmodalChunk], contents),
model=settings.MIXED_EMBEDDING_MODEL,
chunk_size=chunk_size,
)
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
if isinstance(item, str):
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt"
@ -219,7 +88,9 @@ def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
def make_chunk(
page: extract.Page, vector: Vector, metadata: dict[str, Any] = {}
contents: Sequence[extract.MulitmodalChunk],
vector: Vector,
metadata: dict[str, Any] = {},
) -> Chunk:
"""Create a Chunk object from a page and a vector.
@ -227,7 +98,6 @@ def make_chunk(
a single image, or a list of strings and images.
"""
chunk_id = str(uuid.uuid4())
contents = page["contents"]
content, filename = None, None
if all(isinstance(c, str) for c in contents):
content = "\n\n".join(cast(list[str], contents))
@ -251,38 +121,42 @@ def make_chunk(
)
def embed(
mime_type: str,
content: bytes | str | pathlib.Path,
def embed_data_chunk(
chunk: DataChunk,
metadata: dict[str, Any] = {},
chunk_size: int | None = None,
) -> tuple[str, list[Chunk]]:
modality = get_modality(mime_type)
pages = extract.extract_content(mime_type, content, chunk_size=chunk_size)
chunks = [
make_chunk(page, vector, metadata)
for page in pages
for vector in embed_page(page)
]
return modality, chunks
) -> list[Chunk]:
chunk_size = chunk.max_size or chunk_size or DEFAULT_CHUNK_TOKENS
model = chunk.embedding_model
if not model and chunk.collection:
model = ALL_COLLECTIONS.get(chunk.collection, {}).get("model")
if not model:
model = settings.TEXT_EMBEDDING_MODEL
def embed_image(
file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None
) -> Chunk:
image = Image.open(file_path)
mime_type = get_mimetype(image)
if mime_type is None:
raise ValueError("Unsupported image format")
vector = embed_mixed(
[image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS
)[0]
return Chunk(
id=str(uuid.uuid4()),
file_path=file_path.absolute().as_posix(),
content=None,
embedding_model=settings.MIXED_EMBEDDING_MODEL,
vector=vector,
if model == settings.TEXT_EMBEDDING_MODEL:
vectors = embed_text(cast(list[str], chunk.data), chunk_size=chunk_size)
elif model == settings.MIXED_EMBEDDING_MODEL:
vectors = embed_mixed(
cast(list[extract.MulitmodalChunk], chunk.data),
chunk_size=chunk_size,
)
else:
raise ValueError(f"Unsupported model: {model}")
metadata = metadata | chunk.metadata
return [make_chunk(chunk.data, vector, metadata) for vector in vectors]
def embed_source_item(
item: SourceItem,
metadata: dict[str, Any] = {},
chunk_size: int | None = None,
) -> list[Chunk]:
return [
chunk
for data_chunk in item.data_chunks()
for chunk in embed_data_chunk(
data_chunk, item.as_payload() | metadata, chunk_size
)
]

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass, field
import io
import logging
import pathlib
@ -21,6 +22,15 @@ class Page(TypedDict):
chunk_size: NotRequired[int]
@dataclass
class DataChunk:
data: Sequence[MulitmodalChunk]
collection: str | None = None
embedding_model: str | None = None
max_size: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@contextmanager
def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]:
if isinstance(content, pathlib.Path):
@ -110,11 +120,13 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
return [{"contents": [cast(str, content)], "metadata": {}}]
def extract_content(
def extract_data_chunks(
mime_type: str,
content: bytes | str | pathlib.Path,
collection: str | None = None,
embedding_model: str | None = None,
chunk_size: int | None = None,
) -> list[Page]:
) -> list[DataChunk]:
pages = []
logger.info(f"Extracting content from {mime_type}")
if mime_type == "application/pdf":
@ -134,4 +146,12 @@ def extract_content(
if chunk_size:
pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
return pages
return [
DataChunk(
data=page["contents"],
collection=collection,
embedding_model=embedding_model,
max_size=chunk_size,
)
for page in pages
]

View File

@ -26,6 +26,7 @@ class EmailMessage(TypedDict):
body: str
attachments: list[Attachment]
hash: bytes
raw_email: str
RawEmailResponse = tuple[str | None, bytes]
@ -171,6 +172,7 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
body = extract_body(msg)
return EmailMessage(
raw_email=raw_email,
message_id=message_id,
subject=subject,
sender=from_,

View File

@ -12,7 +12,7 @@ from bs4 import BeautifulSoup, Tag
from markdownify import markdownify as md
from PIL import Image as PILImage
from memory.common.settings import FILE_STORAGE_DIR, WEBPAGE_STORAGE_DIR
from memory.common import settings
logger = logging.getLogger(__name__)
@ -96,8 +96,8 @@ def extract_date(
datetime_attr = element.get("datetime")
if datetime_attr:
date_str = str(datetime_attr)
if date := parse_date(date_str, date_format):
for format in ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%d", date_format]:
if date := parse_date(str(datetime_attr), format):
return date
for text in element.find_all(string=True):
@ -178,7 +178,7 @@ def process_images(
continue
path = pathlib.Path(image.filename) # type: ignore
img_tag["src"] = str(path.relative_to(FILE_STORAGE_DIR.resolve()))
img_tag["src"] = str(path.relative_to(settings.FILE_STORAGE_DIR.resolve()))
images[img_tag["src"]] = image
except Exception as e:
logger.warning(f"Failed to process image {src}: {e}")
@ -291,7 +291,7 @@ class BaseHTMLParser:
def __init__(self, base_url: str | None = None):
self.base_url = base_url
self.image_dir = WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc)
self.image_dir = settings.WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc)
self.image_dir.mkdir(parents=True, exist_ok=True)
def parse(self, html: str, url: str) -> Article:

View File

@ -5,12 +5,7 @@ import qdrant_client
from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse
from memory.common import settings
from memory.common.embedding import (
Collection,
ALL_COLLECTIONS,
DistanceType,
Vector,
)
from memory.common.collections import ALL_COLLECTIONS, Collection, DistanceType, Vector
logger = logging.getLogger(__name__)

View File

@ -10,17 +10,16 @@ from typing import Callable, Generator, Sequence, cast
from sqlalchemy.orm import Session, scoped_session
from memory.common import embedding, qdrant, settings
from memory.common import embedding, qdrant, settings, collections
from memory.common.db.models import (
EmailAccount,
EmailAttachment,
MailMessage,
SourceItem,
)
from memory.common.parsers.email import (
Attachment,
EmailMessage,
RawEmailResponse,
parse_email_message,
)
logger = logging.getLogger(__name__)
@ -54,7 +53,7 @@ def process_attachment(
return None
return EmailAttachment(
modality=embedding.get_modality(attachment["content_type"]),
modality=collections.get_modality(attachment["content_type"]),
sha256=hashlib.sha256(
real_content if real_content else str(attachment).encode()
).digest(),
@ -94,8 +93,7 @@ def create_mail_message(
db_session: Session | scoped_session,
tags: list[str],
folder: str,
raw_email: str,
message_id: str,
parsed_email: EmailMessage,
) -> MailMessage:
"""
Create a new mail message record and associated attachments.
@ -109,7 +107,7 @@ def create_mail_message(
Returns:
Newly created MailMessage
"""
parsed_email = parse_email_message(raw_email, message_id)
raw_email = parsed_email["raw_email"]
mail_message = MailMessage(
modality="mail",
sha256=parsed_email["hash"],
@ -137,52 +135,6 @@ def create_mail_message(
return mail_message
def does_message_exist(
db_session: Session | scoped_session, message_id: str, message_hash: bytes
) -> bool:
"""
Check if a message already exists in the database.
Args:
db_session: Database session
message_id: Email message ID
message_hash: SHA-256 hash of message
Returns:
True if message exists, False otherwise
"""
# Check by message_id first (faster)
if message_id:
mail_message = (
db_session.query(MailMessage)
.filter(MailMessage.message_id == message_id)
.first()
)
if mail_message is not None:
return True
# Then check by message_hash
source_item = (
db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first()
)
return source_item is not None
def check_message_exists(
db: Session | scoped_session, account_id: int, message_id: str, raw_email: str
) -> bool:
account = db.query(EmailAccount).get(account_id)
if not account:
logger.error(f"Account {account_id} not found")
return False
parsed_email = parse_email_message(raw_email, message_id)
if "szczepalins" in raw_email.lower():
print(parsed_email["message_id"])
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
def extract_email_uid(
msg_data: Sequence[tuple[bytes, bytes]],
) -> tuple[str | None, bytes]:
@ -317,11 +269,7 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
def vectorize_email(email: MailMessage):
qdrant_client = qdrant.get_qdrant_client()
_, chunks = embedding.embed(
"text/plain",
email.body,
metadata=email.as_payload(),
)
chunks = embedding.embed_source_item(email)
email.chunks = chunks
if chunks:
vector_ids = [cast(str, c.id) for c in chunks]
@ -329,7 +277,7 @@ def vectorize_email(email: MailMessage):
metadata = [c.item_metadata for c in chunks]
qdrant.upsert_vectors(
client=qdrant_client,
collection_name="mail",
collection_name=cast(str, email.modality),
ids=vector_ids,
vectors=vectors, # type: ignore
payloads=metadata, # type: ignore
@ -337,18 +285,12 @@ def vectorize_email(email: MailMessage):
embeds = defaultdict(list)
for attachment in email.attachments:
if attachment.filename:
content = pathlib.Path(attachment.filename).read_bytes()
else:
content = attachment.content
collection, chunks = embedding.embed(
attachment.mime_type, content, metadata=attachment.as_payload()
)
chunks = embedding.embed_source_item(attachment)
if not chunks:
continue
attachment.chunks = chunks
embeds[collection].extend(chunks)
embeds[attachment.modality].extend(chunks)
for collection, chunks in embeds.items():
ids = [c.id for c in chunks]

View File

@ -1,99 +1,25 @@
import hashlib
import logging
from typing import Iterable, cast
from typing import Iterable
from memory.common import chunker, embedding, qdrant
from memory.common.db.connection import make_session
from memory.common.db.models import BlogPost
from memory.common.parsers.blogs import parse_webpage
from memory.workers.celery_app import app
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
create_task_result,
process_content_item,
safe_task_execution,
)
logger = logging.getLogger(__name__)
SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage"
def create_blog_post_from_article(article, tags: Iterable[str] = []) -> BlogPost:
"""Create a BlogPost model from parsed article data."""
return BlogPost(
url=article.url,
title=article.title,
published=article.published_date,
content=article.content,
sha256=hashlib.sha256(article.content.encode()).digest(),
modality="blog",
tags=tags,
mime_type="text/markdown",
size=len(article.content.encode("utf-8")),
)
def embed_blog_post(blog_post: BlogPost) -> int:
"""Embed blog post content and return count of successfully embedded chunks."""
try:
# Always embed the full content
_, chunks = embedding.embed(
"text/markdown",
cast(str, blog_post.content),
metadata=blog_post.as_payload(),
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
)
# But also embed the content in chunks (unless it's really short)
if (
chunker.approx_token_count(cast(str, blog_post.content))
> chunker.DEFAULT_CHUNK_TOKENS * 2
):
_, small_chunks = embedding.embed(
"text/markdown",
cast(str, blog_post.content),
metadata=blog_post.as_payload(),
)
chunks += small_chunks
if chunks:
blog_post.chunks = chunks
blog_post.embed_status = "QUEUED" # type: ignore
return len(chunks)
else:
blog_post.embed_status = "FAILED" # type: ignore
logger.warning(f"No chunks generated for blog post: {blog_post.title}")
return 0
except Exception as e:
blog_post.embed_status = "FAILED" # type: ignore
logger.error(f"Failed to embed blog post {blog_post.title}: {e}")
return 0
def push_to_qdrant(blog_post: BlogPost):
"""Push embeddings to Qdrant for successfully embedded blog post."""
if cast(str, blog_post.embed_status) != "QUEUED" or not blog_post.chunks:
return
try:
vector_ids = [str(chunk.id) for chunk in blog_post.chunks]
vectors = [chunk.vector for chunk in blog_post.chunks]
payloads = [chunk.item_metadata for chunk in blog_post.chunks]
qdrant.upsert_vectors(
client=qdrant.get_qdrant_client(),
collection_name="blog",
ids=vector_ids,
vectors=vectors,
payloads=payloads,
)
blog_post.embed_status = "STORED" # type: ignore
logger.info(f"Successfully stored embeddings for: {blog_post.title}")
except Exception as e:
blog_post.embed_status = "FAILED" # type: ignore
logger.error(f"Failed to push embeddings to Qdrant: {e}")
raise
@app.task(name=SYNC_WEBPAGE)
@safe_task_execution
def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
"""
Synchronize a webpage from a URL.
@ -116,61 +42,25 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
"content_length": 0,
}
blog_post = create_blog_post_from_article(article, tags)
blog_post = BlogPost(
url=article.url,
title=article.title,
published=article.published_date,
content=article.content,
sha256=create_content_hash(article.content),
modality="blog",
tags=tags,
mime_type="text/markdown",
size=len(article.content.encode("utf-8")),
images=[image for image in article.images],
)
with make_session() as session:
existing_post = session.query(BlogPost).filter(BlogPost.url == url).first()
existing_post = check_content_exists(
session, BlogPost, url=url, sha256=create_content_hash(article.content)
)
if existing_post:
logger.info(f"Blog post already exists: {existing_post.title}")
return {
"blog_post_id": existing_post.id,
"url": url,
"title": existing_post.title,
"status": "already_exists",
"chunks_count": len(existing_post.chunks),
}
return create_task_result(existing_post, "already_exists", url=url)
existing_post = (
session.query(BlogPost).filter(BlogPost.sha256 == blog_post.sha256).first()
)
if existing_post:
logger.info(
f"Blog post with the same content already exists: {existing_post.title}"
)
return {
"blog_post_id": existing_post.id,
"url": url,
"title": existing_post.title,
"status": "already_exists",
"chunks_count": len(existing_post.chunks),
}
session.add(blog_post)
session.flush()
chunks_count = embed_blog_post(blog_post)
session.flush()
try:
push_to_qdrant(blog_post)
logger.info(
f"Successfully processed webpage: {blog_post.title} "
f"({chunks_count} chunks embedded)"
)
except Exception as e:
logger.error(f"Failed to push embeddings to Qdrant: {e}")
blog_post.embed_status = "FAILED" # type: ignore
session.commit()
return {
"blog_post_id": blog_post.id,
"url": url,
"title": blog_post.title,
"author": article.author,
"published_date": article.published_date,
"status": "processed",
"chunks_count": chunks_count,
"content_length": len(article.content),
"embed_status": blog_post.embed_status,
}
return process_content_item(blog_post, "blog", session, tags)

View File

@ -1,4 +1,3 @@
import hashlib
import logging
from datetime import datetime
from typing import Callable, cast
@ -6,21 +5,25 @@ from typing import Callable, cast
import feedparser
import requests
from memory.common import embedding, qdrant, settings
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import Comic, clean_filename
from memory.common.parsers import comics
from memory.workers.celery_app import app
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
process_content_item,
safe_task_execution,
)
logger = logging.getLogger(__name__)
SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics"
SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc"
SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd"
SYNC_COMIC = "memory.workers.tasks.comic.sync_comic"
BASE_SMBC_URL = "https://www.smbc-comics.com/"
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
@ -36,6 +39,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
return set()
urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries}
urls = {url for url in urls if url}
with make_session() as session:
known = {
@ -61,6 +65,7 @@ def fetch_new_comics(
@app.task(name=SYNC_COMIC)
@safe_task_execution
def sync_comic(
url: str,
image_url: str,
@ -70,20 +75,26 @@ def sync_comic(
):
"""Synchronize a comic from a URL."""
with make_session() as session:
if session.query(Comic).filter(Comic.url == url).first():
return
existing_comic = check_content_exists(session, Comic, url=url)
if existing_comic:
return {"status": "already_exists", "comic_id": existing_comic.id}
response = requests.get(image_url)
if response.status_code != 200:
return {
"status": "failed",
"error": f"Failed to download image: {response.status_code}",
}
file_type = image_url.split(".")[-1]
mime_type = f"image/{file_type}"
filename = (
settings.COMIC_STORAGE_DIR / clean_filename(author) / f"{title}.{file_type}"
)
if response.status_code == 200:
filename.parent.mkdir(parents=True, exist_ok=True)
filename.write_bytes(response.content)
sha256 = hashlib.sha256(f"{image_url}{published_date}".encode()).digest()
comic = Comic(
title=title,
url=url,
@ -92,27 +103,13 @@ def sync_comic(
filename=filename.resolve().as_posix(),
mime_type=mime_type,
size=len(response.content),
sha256=sha256,
sha256=create_content_hash(f"{image_url}{published_date}"),
tags={"comic", author},
modality="comic",
)
chunk = embedding.embed_image(filename, [title, author])
comic.chunks = [chunk]
with make_session() as session:
session.add(comic)
session.add(chunk)
session.flush()
qdrant.upsert_vectors(
client=qdrant.get_qdrant_client(),
collection_name="comic",
ids=[str(chunk.id)],
vectors=[chunk.vector], # type: ignore
payloads=[comic.as_payload()],
)
session.commit()
return process_content_item(comic, "comic", session)
@app.task(name=SYNC_SMBC)

View File

@ -0,0 +1,267 @@
"""
Content processing utilities for memory workers.
This module provides core functionality for processing content items through
the complete workflow: existence checking, content hashing, embedding generation,
vector storage, and result tracking.
"""
import hashlib
import traceback
import logging
from typing import Any, Callable, Iterable, Sequence, cast
from memory.common import embedding, qdrant
from memory.common.db.models import SourceItem
logger = logging.getLogger(__name__)
def check_content_exists(
session,
model_class: type[SourceItem],
**kwargs: Any,
) -> SourceItem | None:
"""
Check if content already exists in the database.
Searches for existing content by any of the provided attributes
(typically URL, file_path, or SHA256 hash).
Args:
session: Database session for querying
model_class: The SourceItem model class to search in
**kwargs: Attribute-value pairs to search for
Returns:
Existing SourceItem if found, None otherwise
"""
for key, value in kwargs.items():
if not hasattr(model_class, key):
continue
existing = (
session.query(model_class)
.filter(getattr(model_class, key) == value)
.first()
)
if existing:
return existing
return None
def create_content_hash(content: str, *additional_data: str) -> bytes:
"""
Create SHA256 hash from content and optional additional data.
Args:
content: Primary content to hash
*additional_data: Additional strings to include in the hash
Returns:
SHA256 hash digest as bytes
"""
hash_input = content + "".join(additional_data)
return hashlib.sha256(hash_input.encode()).digest()
def embed_source_item(source_item: SourceItem) -> int:
"""
Generate embeddings for a source item's content.
Processes the source item through the embedding pipeline, creating
chunks and their corresponding vector embeddings. Updates the item's
embed_status based on success or failure.
Args:
source_item: The SourceItem to embed
Returns:
Number of successfully embedded chunks
Side effects:
- Sets source_item.chunks with generated chunks
- Sets source_item.embed_status to "QUEUED" or "FAILED"
"""
try:
chunks = embedding.embed_source_item(source_item)
if chunks:
source_item.chunks = chunks
source_item.embed_status = "QUEUED" # type: ignore
return len(chunks)
else:
source_item.embed_status = "FAILED" # type: ignore
logger.warning(
f"No chunks generated for {type(source_item).__name__}: {getattr(source_item, 'title', 'unknown')}"
)
return 0
except Exception as e:
source_item.embed_status = "FAILED" # type: ignore
logger.error(f"Failed to embed {type(source_item).__name__}: {e}")
return 0
def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
"""
Push embeddings to Qdrant vector database.
Uploads vector embeddings for all source items that have been successfully
embedded (status "QUEUED") and have chunks available.
Args:
source_items: Sequence of SourceItems to process
collection_name: Name of the Qdrant collection to store vectors in
Raises:
Exception: If the Qdrant upsert operation fails
Side effects:
- Updates embed_status to "STORED" for successful items
- Updates embed_status to "FAILED" for failed items
"""
items_to_process = [
item
for item in source_items
if cast(str, getattr(item, "embed_status", None)) == "QUEUED" and item.chunks
]
if not items_to_process:
return
all_chunks = [chunk for item in items_to_process for chunk in item.chunks]
if not all_chunks:
return
try:
vector_ids = [str(chunk.id) for chunk in all_chunks]
vectors = [chunk.vector for chunk in all_chunks]
payloads = [chunk.item_metadata for chunk in all_chunks]
qdrant.upsert_vectors(
client=qdrant.get_qdrant_client(),
collection_name=collection_name,
ids=vector_ids,
vectors=vectors,
payloads=payloads,
)
for item in items_to_process:
item.embed_status = "STORED" # type: ignore
logger.info(
f"Successfully stored embeddings for: {getattr(item, 'title', 'unknown')}"
)
except Exception as e:
for item in items_to_process:
item.embed_status = "FAILED" # type: ignore
logger.error(f"Failed to push embeddings to Qdrant: {e}")
raise
def create_task_result(
item: SourceItem, status: str, **additional_fields: Any
) -> dict[str, Any]:
"""
Create standardized task result dictionary.
Generates a consistent result format for task execution reporting,
including item metadata and processing status.
Args:
item: The processed SourceItem
status: Processing status string
**additional_fields: Extra fields to include in the result
Returns:
Dictionary with standardized task result format
"""
return {
f"{type(item).__name__.lower()}_id": item.id,
"title": getattr(item, "title", None),
"status": status,
"chunks_count": len(item.chunks),
"embed_status": item.embed_status,
**additional_fields,
}
def process_content_item(
item: SourceItem, collection_name: str, session, tags: Iterable[str] = []
) -> dict[str, Any]:
"""
Execute complete content processing workflow.
Performs the full pipeline for processing a content item:
1. Add to database session and flush to get ID
2. Generate embeddings and chunks
3. Push embeddings to Qdrant vector store
4. Commit transaction and return result
Args:
item: SourceItem to process
collection_name: Qdrant collection name for vector storage
session: Database session for persistence
tags: Optional tags to associate with the item (currently unused)
Returns:
Task result dictionary with processing status and metadata
Side effects:
- Adds item to database session
- Commits database transaction
- Stores vectors in Qdrant
"""
session.add(item)
session.flush()
chunks_count = embed_source_item(item)
session.flush()
try:
push_to_qdrant([item], collection_name)
status = "processed"
logger.info(
f"Successfully processed {type(item).__name__}: {getattr(item, 'title', 'unknown')} ({chunks_count} chunks embedded)"
)
except Exception as e:
logger.error(f"Failed to push embeddings to Qdrant: {e}")
item.embed_status = "FAILED" # type: ignore
status = "failed"
session.commit()
return create_task_result(item, status, content_length=getattr(item, "size", 0))
def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
"""
Decorator for safe task execution with comprehensive error handling.
Wraps task functions to catch and log exceptions, ensuring tasks
always return a result dictionary even when they fail.
Args:
func: Task function to wrap
Returns:
Wrapped function that handles exceptions gracefully
Example:
@safe_task_execution
def my_task(arg1, arg2):
# Task implementation
return {"status": "success"}
"""
def wrapper(*args, **kwargs) -> dict:
try:
return func(*args, **kwargs)
except Exception as e:
logger.error(
f"Task {func.__name__} failed with traceback:\n{traceback.format_exc()}"
)
logger.error(f"Task {func.__name__} failed: {e}")
return {"status": "error", "error": str(e)}
return wrapper

View File

@ -1,17 +1,21 @@
import hashlib
import logging
from pathlib import Path
from typing import Iterable, cast
from memory.common import chunker, embedding, qdrant
from memory.common.db.connection import make_session
from memory.common.db.models import Book, BookSection
from memory.common.parsers.ebook import Ebook, parse_ebook, Section
from memory.common.db.connection import make_session
from memory.workers.celery_app import app
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
embed_source_item,
push_to_qdrant,
safe_task_execution,
)
logger = logging.getLogger(__name__)
SYNC_BOOK = "memory.workers.tasks.book.sync_book"
# Minimum section length to embed (avoid noise from very short sections)
@ -46,10 +50,6 @@ def section_processor(
):
content = "\n\n".join(section.pages).strip()
if len(content) >= MIN_SECTION_LENGTH:
sha256 = hashlib.sha256(
f"{book.id}:{section.title}:{section.start_page}".encode()
).digest()
book_section = BookSection(
book_id=book.id,
section_title=section.title,
@ -59,7 +59,9 @@ def section_processor(
end_page=section.end_page,
parent_section_id=None, # Will be set after flush
content=content,
sha256=sha256,
sha256=create_content_hash(
f"{book.id}:{section.title}:{section.start_page}"
),
modality="book",
tags=book.tags,
pages=section.pages,
@ -127,76 +129,11 @@ def create_book_and_sections(
def embed_sections(all_sections: list[BookSection]) -> int:
"""Embed all sections and return count of successfully embedded sections."""
embedded_count = 0
def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]:
_, chunks = embedding.embed(
"text/plain",
text,
metadata=metadata,
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
)
return chunks
for section in all_sections:
try:
section_chunks = embed_text(
cast(str, section.content), section.as_payload()
)
page_chunks = [
chunk
for i, page in enumerate(section.pages)
for chunk in embed_text(
page, section.as_payload() | {"page_number": i + section.start_page}
)
]
chunks = section_chunks + page_chunks
if chunks:
section.chunks = chunks
section.embed_status = "QUEUED" # type: ignore
embedded_count += 1
else:
section.embed_status = "FAILED" # type: ignore
logger.warning(
f"No chunks generated for section: {section.section_title}"
)
except IOError as e:
section.embed_status = "FAILED" # type: ignore
logger.error(f"Failed to embed section {section.section_title}: {e}")
return embedded_count
def push_to_qdrant(all_sections: list[BookSection]):
"""Push embeddings to Qdrant for all successfully embedded sections."""
vector_ids = []
vectors = []
payloads = []
to_process = [s for s in all_sections if cast(str, s.embed_status) == "QUEUED"]
all_chunks = [chunk for section in to_process for chunk in section.chunks]
if not all_chunks:
return
vector_ids = [str(chunk.id) for chunk in all_chunks]
vectors = [chunk.vector for chunk in all_chunks]
payloads = [chunk.item_metadata for chunk in all_chunks]
qdrant.upsert_vectors(
client=qdrant.get_qdrant_client(),
collection_name="book",
ids=vector_ids,
vectors=vectors,
payloads=payloads,
)
for section in to_process:
section.embed_status = "STORED" # type: ignore
return sum(embed_source_item(section) for section in all_sections)
@app.task(name=SYNC_BOOK)
@safe_task_execution
def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
"""
Synchronize a book from a file path.
@ -211,10 +148,8 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
with make_session() as session:
# Check for existing book
existing_book = (
session.query(Book)
.filter(Book.file_path == ebook.file_path.as_posix())
.first()
existing_book = check_content_exists(
session, Book, file_path=ebook.file_path.as_posix()
)
if existing_book:
logger.info(f"Book already exists: {existing_book.title}")
@ -230,19 +165,10 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
book, all_sections = create_book_and_sections(ebook, session, tags)
# Embed sections
embedded_count = embed_sections(all_sections)
embedded_count = sum(embed_source_item(section) for section in all_sections)
session.flush()
# Push to Qdrant
try:
push_to_qdrant(all_sections)
except Exception as e:
logger.error(f"Failed to push embeddings to Qdrant: {e}")
# Mark sections as failed
for section in all_sections:
if getattr(section, "embed_status") == "STORED":
section.embed_status = "FAILED" # type: ignore
raise
push_to_qdrant(all_sections, "book")
session.commit()

View File

@ -2,16 +2,19 @@ import logging
from datetime import datetime
from typing import cast
from memory.common.db.connection import make_session
from memory.common.db.models import EmailAccount
from memory.common.db.models import EmailAccount, MailMessage
from memory.workers.celery_app import app
from memory.workers.email import (
check_message_exists,
create_mail_message,
imap_connection,
process_folder,
vectorize_email,
)
from memory.common.parsers.email import parse_email_message
from memory.workers.tasks.content_processing import (
check_content_exists,
safe_task_execution,
)
logger = logging.getLogger(__name__)
@ -21,12 +24,13 @@ SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
@app.task(name=PROCESS_EMAIL)
@safe_task_execution
def process_message(
account_id: int,
message_id: str,
folder: str,
raw_email: str,
) -> int | None:
) -> dict:
"""
Process a single email message and store it in the database.
@ -37,29 +41,30 @@ def process_message(
raw_email: Raw email content as string
Returns:
source_id if successful, None otherwise
dict with processing result
"""
logger.info(f"Processing message {message_id} for account {account_id}")
if not raw_email.strip():
logger.warning(f"Empty email message received for account {account_id}")
return None
return {"status": "skipped", "reason": "empty_content"}
with make_session() as db:
if check_message_exists(db, account_id, message_id, raw_email):
logger.debug(f"Message {message_id} already exists in database")
return None
account = db.query(EmailAccount).get(account_id)
if not account:
logger.error(f"Account {account_id} not found")
return None
return {"status": "error", "error": "Account not found"}
mail_message = create_mail_message(
db, account.tags, folder, raw_email, message_id
)
parsed_email = parse_email_message(raw_email, message_id)
if check_content_exists(
db, MailMessage, message_id=message_id, sha256=parsed_email["hash"]
):
return {"status": "already_exists", "message_id": message_id}
mail_message = create_mail_message(db, account.tags, folder, parsed_email)
db.flush()
vectorize_email(mail_message)
db.commit()
logger.info(f"Stored embedding for message {mail_message.message_id}")
@ -71,16 +76,24 @@ def process_message(
for chunk in attachment.chunks:
logger.info(f" - {chunk.id}")
return cast(int, mail_message.id)
return {
"status": "processed",
"mail_message_id": cast(int, mail_message.id),
"message_id": message_id,
"chunks_count": len(mail_message.chunks),
"attachments_count": len(mail_message.attachments),
}
@app.task(name=SYNC_ACCOUNT)
@safe_task_execution
def sync_account(account_id: int, since_date: str | None = None) -> dict:
"""
Synchronize emails from a specific account.
Args:
account_id: ID of the EmailAccount to sync
since_date: ISO format date string to sync since
Returns:
dict with stats about the sync operation
@ -91,7 +104,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
if not account or not cast(bool, account.active):
logger.warning(f"Account {account_id} not found or inactive")
return {"error": "Account not found or inactive"}
return {"status": "error", "error": "Account not found or inactive"}
folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"]
if since_date:
@ -108,7 +121,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
def process_message_wrapper(
account_id: int, message_id: str, folder: str, raw_email: str
) -> int | None:
if check_message_exists(db, account_id, message_id, raw_email): # type: ignore
parsed_email = parse_email_message(raw_email, message_id)
if check_content_exists(
db, MailMessage, message_id=message_id, sha256=parsed_email["hash"]
):
return None
return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore
@ -127,9 +143,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
db.commit()
except Exception as e:
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
return {"error": str(e)}
return {"status": "error", "error": str(e)}
return {
"status": "completed",
"account": account.email_address,
"since_date": cutoff_date.isoformat(),
"folders_processed": len(folders_to_process),

View File

@ -6,7 +6,7 @@ from typing import Sequence
from sqlalchemy import select
from sqlalchemy.orm import contains_eager
from memory.common import embedding, qdrant, settings
from memory.common import collections, embedding, qdrant, settings
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
from memory.workers.celery_app import app
@ -60,11 +60,11 @@ def reingest_chunk(chunk_id: str, collection: str):
logger.error(f"Chunk {chunk_id} not found")
return
if collection not in embedding.ALL_COLLECTIONS:
if collection not in collections.ALL_COLLECTIONS:
raise ValueError(f"Unsupported collection {collection}")
data = chunk.data
if collection in embedding.MULTIMODAL_COLLECTIONS:
if collection in collections.MULTIMODAL_COLLECTIONS:
vector = embedding.embed_mixed(data)[0]
elif len(data) == 1 and isinstance(data[0], str):
vector = embedding.embed_text([data[0]])[0]

View File

@ -205,8 +205,13 @@ def email_provider():
def mock_file_storage(tmp_path: Path):
chunk_storage_dir = tmp_path / "chunks"
chunk_storage_dir.mkdir(parents=True, exist_ok=True)
with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path):
with patch("memory.common.settings.CHUNK_STORAGE_DIR", chunk_storage_dir):
image_storage_dir = tmp_path / "images"
image_storage_dir.mkdir(parents=True, exist_ok=True)
with (
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir),
patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir),
):
yield

View File

@ -249,6 +249,7 @@ def test_parse_simple_email():
"body": "Test body content\n",
"attachments": [],
"sent_at": ANY,
"raw_email": msg.as_string(),
"hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87"
+ b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1",
}

View File

@ -12,6 +12,7 @@ import requests
from bs4 import BeautifulSoup, Tag
from PIL import Image as PILImage
from memory.common import settings
from memory.common.parsers.html import (
Article,
BaseHTMLParser,
@ -164,27 +165,28 @@ def test_parse_date(text, date_format, expected):
assert parse_date(text, date_format) == expected
def test_extract_date():
@pytest.mark.parametrize(
"selector, date_format, expected",
[
("time", "%B %d, %Y", datetime.fromisoformat("2023-01-15T10:30:00")),
(".named-date", "%B %d, %Y", datetime.fromisoformat("2023-01-15")),
(".date", "%Y-%m-%d", datetime.fromisoformat("2023-02-20")),
(".nonexistent", None, None),
],
)
def test_extract_date(selector, date_format, expected):
html = """
<div>
<time datetime="2023-01-15T10:30:00">January 15, 2023</time>
<span class="named-date">January 15, 2023</span>
<span class="date">2023-02-20</span>
<div class="published">March 10, 2023</div>
</div>
"""
soup = BeautifulSoup(html, "html.parser")
# Should extract datetime attribute from time tag
result = extract_date(soup, "time", "%Y-%m-%d")
assert result == "2023-01-15T10:30:00"
# Should extract from text content
result = extract_date(soup, ".date", "%Y-%m-%d")
assert result == "2023-02-20T00:00:00"
# No matching element
result = extract_date(soup, ".nonexistent", "%Y-%m-%d")
assert result is None
result = extract_date(soup, selector, date_format)
assert result == expected
def test_extract_content_element():
@ -393,8 +395,7 @@ def test_process_image_cached(mock_pil_open, mock_requests_get):
@patch("memory.common.parsers.html.process_image")
@patch("memory.common.parsers.html.FILE_STORAGE_DIR")
def test_process_images_basic(mock_file_storage_dir, mock_process_image):
def test_process_images_basic(mock_process_image):
html = """
<div>
<p>Text content</p>
@ -409,40 +410,37 @@ def test_process_images_basic(mock_file_storage_dir, mock_process_image):
content = cast(Tag, soup.find("div"))
base_url = "https://example.com"
with tempfile.TemporaryDirectory() as temp_dir:
image_dir = pathlib.Path(temp_dir)
mock_file_storage_dir.resolve.return_value = pathlib.Path(temp_dir)
# Mock successful image processing with proper filenames
mock_images = []
for i in range(3):
mock_img = MagicMock(spec=PILImage.Image)
mock_img.filename = str(pathlib.Path(temp_dir) / f"image{i + 1}.jpg")
mock_img.filename = str(settings.WEBPAGE_STORAGE_DIR / f"image{i + 1}.jpg")
mock_images.append(mock_img)
mock_process_image.side_effect = mock_images
updated_content, images = process_images(content, base_url, image_dir)
updated_content, images = process_images(
content, base_url, settings.WEBPAGE_STORAGE_DIR
)
# Should have processed 3 images (skipping the one without src)
assert len(images) == 3
assert mock_process_image.call_count == 3
# Check that img src attributes were updated to relative paths
img_tags = [
tag
for tag in (updated_content.find_all("img") if updated_content else [])
if isinstance(tag, Tag)
]
src_values = []
for img in img_tags:
src = img.get("src")
if src and isinstance(src, str):
src_values.append(src)
# Should have relative paths to the processed images
for src in src_values[:3]: # First 3 have src
assert not src.startswith("http") # Should be relative paths
expected = BeautifulSoup(
"""<div>
<p>Text content</p>
<img alt="Image 1" src="images/image1.jpg"/>
<img alt="Image 2" src="images/image2.jpg"/>
<img alt="Image 3" src="images/image3.jpg"/>
<img alt="No src"/>
<p>More text</p>
</div>
""",
"html.parser",
)
assert updated_content.prettify() == expected.prettify() # type: ignore
assert images == {
"images/image1.jpg": mock_images[0],
"images/image2.jpg": mock_images[1],
"images/image3.jpg": mock_images[2],
}
def test_process_images_empty():
@ -454,8 +452,7 @@ def test_process_images_empty():
@patch("memory.common.parsers.html.process_image")
@patch("memory.common.parsers.html.FILE_STORAGE_DIR")
def test_process_images_with_failures(mock_file_storage_dir, mock_process_image):
def test_process_images_with_failures(mock_process_image):
html = """
<div>
<img src="good.jpg" alt="Good image">
@ -465,22 +462,20 @@ def test_process_images_with_failures(mock_file_storage_dir, mock_process_image)
soup = BeautifulSoup(html, "html.parser")
content = cast(Tag, soup.find("div"))
with tempfile.TemporaryDirectory() as temp_dir:
image_dir = pathlib.Path(temp_dir)
mock_file_storage_dir.resolve.return_value = pathlib.Path(temp_dir)
# First image succeeds, second fails
mock_good_image = MagicMock(spec=PILImage.Image)
mock_good_image.filename = str(pathlib.Path(temp_dir) / "good.jpg")
mock_good_image.filename = settings.WEBPAGE_STORAGE_DIR / "good.jpg"
mock_process_image.side_effect = [mock_good_image, None]
updated_content, images = process_images(
content, "https://example.com", image_dir
content, "https://example.com", settings.WEBPAGE_STORAGE_DIR
)
# Should only return successful image
assert len(images) == 1
assert images[0] == mock_good_image
expected = BeautifulSoup(
html.replace("good.jpg", "images/good.jpg"), "html.parser"
).prettify()
assert updated_content.prettify() == expected # type: ignore
assert images == {"images/good.jpg": mock_good_image}
@patch("memory.common.parsers.html.process_image")
@ -494,15 +489,12 @@ def test_process_images_no_filename(mock_process_image):
mock_image.filename = None
mock_process_image.return_value = mock_image
with tempfile.TemporaryDirectory() as temp_dir:
image_dir = pathlib.Path(temp_dir)
updated_content, images = process_images(
content, "https://example.com", image_dir
content, "https://example.com", settings.WEBPAGE_STORAGE_DIR
)
# Should skip image without filename
assert len(images) == 0
assert not images
class TestBaseHTMLParser:
@ -541,7 +533,7 @@ class TestBaseHTMLParser:
assert article.title == "Article Title"
assert article.author == "John Smith" # Should prefer content over meta
assert article.published_date == "2023-01-15T00:00:00"
assert article.published_date == datetime(2023, 1, 15)
assert article.url == "https://example.com/article"
assert "This is the main content" in article.content
assert article.metadata["author"] == "Jane Doe"

View File

@ -5,14 +5,10 @@ from typing import cast
import pytest
from PIL import Image
from memory.common import settings
from memory.common import settings, collections
from memory.common.embedding import (
embed,
embed_file,
embed_mixed,
embed_page,
embed_text,
get_modality,
make_chunk,
write_to_file,
)
@ -50,7 +46,7 @@ def mock_embed(mock_voyage_client):
],
)
def test_get_modality(mime_type, expected_modality):
assert get_modality(mime_type) == expected_modality
assert collections.get_modality(mime_type) == expected_modality
def test_embed_text(mock_embed):
@ -58,59 +54,11 @@ def test_embed_text(mock_embed):
assert embed_text(texts) == [[0], [1]]
def test_embed_file(mock_embed, tmp_path):
mock_file = tmp_path / "test.txt"
mock_file.write_text("file content")
assert embed_file(mock_file) == [[0]]
def test_embed_mixed(mock_embed):
items = ["text", {"type": "image", "data": "base64"}]
assert embed_mixed(items) == [[0]]
def test_embed_page_text_only(mock_embed):
page = {"contents": ["text1", "text2"]}
assert embed_page(page) == [[0], [1]] # type: ignore
def test_embed_page_mixed_content(mock_embed):
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
assert embed_page(page) == [[0]] # type: ignore
def test_embed(mock_embed):
mime_type = "text/plain"
content = "sample content"
metadata = {"source": "test"}
with patch.object(uuid, "uuid4", return_value="id1"):
modality, chunks = embed(mime_type, content, metadata)
assert modality == "text"
assert [
{
"id": c.id, # type: ignore
"file_path": c.file_path, # type: ignore
"content": c.content, # type: ignore
"embedding_model": c.embedding_model, # type: ignore
"vector": c.vector, # type: ignore
"item_metadata": c.item_metadata, # type: ignore
}
for c in chunks
] == [
{
"content": "sample content",
"embedding_model": "voyage-3-large",
"file_path": None,
"id": "id1",
"item_metadata": {"source": "test"},
"vector": [0],
},
]
def test_write_to_file_text(mock_file_storage):
"""Test writing a string to a file."""
chunk_id = "test-chunk-id"
@ -160,17 +108,14 @@ def test_write_to_file_unsupported_type(mock_file_storage):
def test_make_chunk_text_only(mock_file_storage, db_session):
"""Test creating a chunk from string content."""
page = {
"contents": ["text content 1", "text content 2"],
"metadata": {"source": "test"},
}
contents = ["text content 1", "text content 2"]
vector = [0.1, 0.2, 0.3]
metadata = {"doc_type": "test", "source": "unit-test"}
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
):
chunk = make_chunk(page, vector, metadata) # type: ignore
chunk = make_chunk(contents, vector, metadata) # type: ignore
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
@ -183,14 +128,14 @@ def test_make_chunk_text_only(mock_file_storage, db_session):
def test_make_chunk_single_image(mock_file_storage, db_session):
"""Test creating a chunk from a single image."""
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
page = {"contents": [img], "metadata": {"source": "test"}}
contents = [img]
vector = [0.1, 0.2, 0.3]
metadata = {"doc_type": "test", "source": "unit-test"}
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
):
chunk = make_chunk(page, vector, metadata) # type: ignore
chunk = make_chunk(contents, vector, metadata) # type: ignore
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
assert chunk.content is None
@ -208,14 +153,14 @@ def test_make_chunk_single_image(mock_file_storage, db_session):
def test_make_chunk_mixed_content(mock_file_storage, db_session):
"""Test creating a chunk from mixed content (string and image)."""
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
page = {"contents": ["text content", img], "metadata": {"source": "test"}}
contents = ["text content", img]
vector = [0.1, 0.2, 0.3]
metadata = {"doc_type": "test", "source": "unit-test"}
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
):
chunk = make_chunk(page, vector, metadata) # type: ignore
chunk = make_chunk(contents, vector, metadata) # type: ignore
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
assert chunk.content is None
@ -233,3 +178,181 @@ def test_make_chunk_mixed_content(mock_file_storage, db_session):
assert (
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png"
).exists()
@pytest.mark.parametrize(
"data,embedding_model,collection,expected_model,expected_count,expected_has_content",
[
# Text-only with default model
(
["text content 1", "text content 2"],
None,
None,
settings.TEXT_EMBEDDING_MODEL,
2,
True,
),
# Text with explicit mixed model - but make_chunk still uses TEXT_EMBEDDING_MODEL for text-only content
(
["text content"],
settings.MIXED_EMBEDDING_MODEL,
None,
settings.TEXT_EMBEDDING_MODEL,
1,
True,
),
# Text collection model selection - make_chunk uses TEXT_EMBEDDING_MODEL for text-only content
(["text content"], None, "mail", settings.TEXT_EMBEDDING_MODEL, 1, True),
(["text content"], None, "photo", settings.TEXT_EMBEDDING_MODEL, 1, True),
(["text content"], None, "doc", settings.TEXT_EMBEDDING_MODEL, 1, True),
# Unknown collection falls back to default
(["text content"], None, "unknown", settings.TEXT_EMBEDDING_MODEL, 1, True),
# Explicit model takes precedence over collection
(
["text content"],
settings.TEXT_EMBEDDING_MODEL,
"photo",
settings.TEXT_EMBEDDING_MODEL,
1,
True,
),
],
)
def test_embed_data_chunk_scenarios(
data,
embedding_model,
collection,
expected_model,
expected_count,
expected_has_content,
mock_embed,
mock_file_storage,
):
"""Test various embedding scenarios for data chunks."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(
data=data,
embedding_model=embedding_model,
collection=collection,
metadata={"source": "test"},
)
result = embed_data_chunk(chunk, {"doc_type": "test"})
assert len(result) == expected_count
assert all(cast(str, c.embedding_model) == expected_model for c in result)
if expected_has_content:
assert all(c.content is not None for c in result)
assert all(c.file_path is None for c in result)
else:
assert all(c.content is None for c in result)
assert all(c.file_path is not None for c in result)
assert all(
c.item_metadata == {"source": "test", "doc_type": "test"} for c in result
)
def test_embed_data_chunk_mixed_content(mock_embed, mock_file_storage):
"""Test embedding mixed content (text and images)."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
chunk = DataChunk(
data=["text content", img],
embedding_model=settings.MIXED_EMBEDDING_MODEL,
metadata={"source": "test"},
)
result = embed_data_chunk(chunk)
assert len(result) == 1 # Mixed content returns single vector
assert result[0].content is None # Mixed content stored in files
assert result[0].file_path is not None
assert cast(str, result[0].embedding_model) == settings.MIXED_EMBEDDING_MODEL
@pytest.mark.parametrize(
"chunk_max_size,chunk_size_param,expected_chunk_size",
[
(512, 1024, 512), # chunk.max_size takes precedence
(None, 2048, 2048), # chunk_size parameter used when max_size is None
(256, None, 256), # chunk.max_size used when parameter is None
],
)
def test_embed_data_chunk_chunk_size_handling(
chunk_max_size, chunk_size_param, expected_chunk_size, mock_embed, mock_file_storage
):
"""Test chunk size parameter handling."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(
data=["text content"], max_size=chunk_max_size, metadata={"source": "test"}
)
with patch("memory.common.embedding.embed_text") as mock_embed_text:
mock_embed_text.return_value = [[0.1, 0.2, 0.3]]
result = embed_data_chunk(chunk, chunk_size=chunk_size_param)
mock_embed_text.assert_called_once()
args, kwargs = mock_embed_text.call_args
assert kwargs["chunk_size"] == expected_chunk_size
def test_embed_data_chunk_metadata_merging(mock_embed, mock_file_storage):
"""Test that chunk metadata and parameter metadata are properly merged."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(
data=["text content"], metadata={"source": "test", "type": "chunk"}
)
metadata = {
"doc_type": "test",
"source": "override",
} # chunk.metadata takes precedence over parameter metadata
result = embed_data_chunk(chunk, metadata)
assert len(result) == 1
expected_metadata = {
"source": "test",
"type": "chunk",
"doc_type": "test",
} # chunk source wins
assert result[0].item_metadata == expected_metadata
def test_embed_data_chunk_unsupported_model(mock_embed, mock_file_storage):
"""Test error handling for unsupported embedding model."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(
data=["text content"],
embedding_model="unsupported-model",
metadata={"source": "test"},
)
with pytest.raises(ValueError, match="Unsupported model: unsupported-model"):
embed_data_chunk(chunk)
def test_embed_data_chunk_empty_data(mock_embed, mock_file_storage):
"""Test handling of empty data."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(data=[], metadata={"source": "test"})
# Should handle empty data gracefully
with patch("memory.common.embedding.embed_text") as mock_embed_text:
mock_embed_text.return_value = []
result = embed_data_chunk(chunk)
assert result == []

View File

@ -7,7 +7,6 @@ import shutil
from memory.common.extract import (
as_file,
extract_text,
extract_content,
doc_to_images,
extract_image,
docx_to_pdf,
@ -107,52 +106,6 @@ def test_extract_image_with_str():
extract_image("test")
@pytest.mark.parametrize(
"mime_type,content",
[
("text/plain", "Text content"),
("text/html", "<html>content</html>"),
("text/markdown", "# Heading"),
("text/csv", "a,b,c"),
],
)
def test_extract_content_different_text_types(mime_type, content):
assert extract_content(mime_type, content) == [
{"contents": [content], "metadata": {}}
]
def test_extract_content_pdf():
result = extract_content("application/pdf", REGULAMIN)
assert len(result) == 2
assert all(
isinstance(page["contents"], list)
and all(isinstance(c, Image.Image) for c in page["contents"])
for page in result
)
assert all("page" in page["metadata"] for page in result)
assert all("width" in page["metadata"] for page in result)
assert all("height" in page["metadata"] for page in result)
def test_extract_content_image(tmp_path):
# Create a test image
img = Image.new("RGB", (100, 100), color="red")
img_path = tmp_path / "test_img.png"
img.save(img_path)
(result,) = extract_content("image/png", img_path)
assert isinstance(result["contents"][0], Image.Image)
assert result["contents"][0].size == (100, 100)
assert result["metadata"] == {}
def test_extract_content_unsupported_type():
assert extract_content("unsupported/type", "content") == []
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
def test_docx_to_pdf(tmp_path):
output_path = tmp_path / "output.pdf"

View File

@ -0,0 +1,431 @@
import pytest
from datetime import datetime
from unittest.mock import Mock, patch
from memory.common.db.models import Comic
from memory.workers.tasks import comic
import requests
@pytest.fixture
def mock_comic_info():
"""Mock comic info data for testing."""
return {
"title": "Test Comic",
"image_url": "https://example.com/comic.png",
"url": "https://example.com/comic/1",
"published_date": "2024-01-01T12:00:00Z",
}
@pytest.fixture
def mock_feed_data():
"""Mock RSS feed data."""
return {
"entries": [
{"link": "https://example.com/comic/1", "id": None},
{"link": "https://example.com/comic/2", "id": None},
{"link": None, "id": "https://example.com/comic/3"},
]
}
@pytest.fixture
def mock_image_response():
"""Mock HTTP response for comic image."""
# 1x1 PNG image (smallest valid PNG)
png_data = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01"
b"\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13"
b"\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```"
b"\x00\x00\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
)
response = Mock()
response.status_code = 200
response.content = png_data
with patch.object(requests, "get", return_value=response):
yield response
@patch("memory.workers.tasks.comic.feedparser.parse")
def test_find_new_urls_success(mock_parse, mock_feed_data, db_session):
"""Test successful URL discovery from RSS feed."""
mock_parse.return_value = Mock(entries=mock_feed_data["entries"])
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
assert result == {
"https://example.com/comic/1",
"https://example.com/comic/2",
"https://example.com/comic/3",
}
mock_parse.assert_called_once_with("https://example.com/rss")
@patch("memory.workers.tasks.comic.feedparser.parse")
def test_find_new_urls_with_existing_comics(mock_parse, mock_feed_data, db_session):
"""Test URL discovery when some comics already exist."""
mock_parse.return_value = Mock(entries=mock_feed_data["entries"])
# Add existing comic to database
existing_comic = Comic(
title="Existing Comic",
url="https://example.com/comic/1",
author="https://example.com",
filename="/test/path",
sha256=b"test_hash",
modality="comic",
tags=["comic"],
)
db_session.add(existing_comic)
db_session.commit()
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
# Should only return URLs not in database
assert result == {
"https://example.com/comic/2",
"https://example.com/comic/3",
}
@patch("memory.workers.tasks.comic.feedparser.parse")
def test_find_new_urls_parse_error(mock_parse):
"""Test handling of RSS feed parsing errors."""
mock_parse.side_effect = Exception("Parse error")
assert (
comic.find_new_urls("https://example.com", "https://example.com/rss") == set()
)
@patch("memory.workers.tasks.comic.feedparser.parse")
def test_find_new_urls_empty_feed(mock_parse):
"""Test handling of empty RSS feed."""
mock_parse.return_value = Mock(entries=[])
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
assert result == set()
@patch("memory.workers.tasks.comic.feedparser.parse")
def test_find_new_urls_malformed_entries(mock_parse):
"""Test handling of malformed RSS entries."""
mock_parse.return_value = Mock(
entries=[
{"link": None, "id": None}, # Both None
{}, # Missing keys
]
)
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
assert result == set()
@patch("memory.workers.tasks.comic.sync_comic.delay")
@patch("memory.workers.tasks.comic.find_new_urls")
def test_fetch_new_comics_success(mock_find_urls, mock_sync_delay, mock_comic_info):
"""Test successful comic fetching."""
mock_find_urls.return_value = {"https://example.com/comic/1"}
mock_parser = Mock(return_value=mock_comic_info)
result = comic.fetch_new_comics(
"https://example.com", "https://example.com/rss", mock_parser
)
assert result == {"https://example.com/comic/1"}
mock_parser.assert_called_once_with("https://example.com/comic/1")
expected_call_args = {
**mock_comic_info,
"author": "https://example.com",
"url": "https://example.com/comic/1",
}
mock_sync_delay.assert_called_once_with(**expected_call_args)
@patch("memory.workers.tasks.comic.sync_comic.delay")
@patch("memory.workers.tasks.comic.find_new_urls")
def test_fetch_new_comics_no_new_urls(mock_find_urls, mock_sync_delay):
"""Test when no new URLs are found."""
mock_find_urls.return_value = set()
mock_parser = Mock()
result = comic.fetch_new_comics(
"https://example.com", "https://example.com/rss", mock_parser
)
assert result == set()
mock_parser.assert_not_called()
mock_sync_delay.assert_not_called()
@patch("memory.workers.tasks.comic.sync_comic.delay")
@patch("memory.workers.tasks.comic.find_new_urls")
def test_fetch_new_comics_multiple_urls(
mock_find_urls, mock_sync_delay, mock_comic_info
):
"""Test fetching multiple new comics."""
urls = {"https://example.com/comic/1", "https://example.com/comic/2"}
mock_find_urls.return_value = urls
mock_parser = Mock(return_value=mock_comic_info)
result = comic.fetch_new_comics(
"https://example.com", "https://example.com/rss", mock_parser
)
assert result == urls
assert mock_parser.call_count == 2
assert mock_sync_delay.call_count == 2
@patch("memory.workers.tasks.comic.requests.get")
def test_sync_comic_success(mock_get, mock_image_response, db_session, qdrant):
"""Test successful comic synchronization."""
mock_get.return_value = mock_image_response
comic.sync_comic(
url="https://example.com/comic/1",
image_url="https://example.com/image.png",
title="Test Comic",
author="https://example.com",
published_date=datetime(2024, 1, 1, 12, 0, 0),
)
# Verify comic was created in database
saved_comic = (
db_session.query(Comic)
.filter(Comic.url == "https://example.com/comic/1")
.first()
)
assert saved_comic is not None
assert saved_comic.title == "Test Comic"
assert saved_comic.author == "https://example.com"
assert saved_comic.mime_type == "image/png"
assert saved_comic.size == len(mock_image_response.content)
assert "comic" in saved_comic.tags
assert "https://example.com" in saved_comic.tags
# Verify vectors were added to Qdrant
vectors, _ = qdrant.scroll(collection_name="comic")
expected_vectors = [
(
{
"author": "https://example.com",
"published": "2024-01-01T12:00:00",
"tags": ["comic", "https://example.com"],
"title": "Test Comic",
"url": "https://example.com/comic/1",
"source_id": 1,
},
None,
)
]
assert [
({**v.payload, "tags": sorted(v.payload["tags"])}, v.vector) for v in vectors
] == expected_vectors
def test_sync_comic_already_exists(db_session):
"""Test that duplicate comics are not processed."""
# Add existing comic
existing_comic = Comic(
title="Existing Comic",
url="https://example.com/comic/1",
author="https://example.com",
filename="/test/path",
sha256=b"test_hash",
modality="comic",
tags=["comic"],
)
db_session.add(existing_comic)
db_session.commit()
with patch("memory.workers.tasks.comic.requests.get") as mock_get:
result = comic.sync_comic(
url="https://example.com/comic/1",
image_url="https://example.com/image.png",
title="Test Comic",
author="https://example.com",
)
# Should return early without making HTTP request
mock_get.assert_not_called()
assert result == {"comic_id": 1, "status": "already_exists"}
@patch("memory.workers.tasks.comic.requests.get")
def test_sync_comic_http_error(mock_get, db_session, qdrant):
"""Test handling of HTTP errors when downloading image."""
mock_response = Mock()
mock_response.status_code = 404
mock_response.content = b""
mock_get.return_value = mock_response
comic.sync_comic(
url="https://example.com/comic/1",
image_url="https://example.com/image.png",
title="Test Comic",
author="https://example.com",
)
assert not (
db_session.query(Comic)
.filter(Comic.url == "https://example.com/comic/1")
.first()
)
@patch("memory.workers.tasks.comic.requests.get")
def test_sync_comic_no_published_date(
mock_get, mock_image_response, db_session, qdrant
):
"""Test comic sync without published date."""
mock_get.return_value = mock_image_response
comic.sync_comic(
url="https://example.com/comic/1",
image_url="https://example.com/image.png",
title="Test Comic",
author="https://example.com",
published_date=None,
)
saved_comic = (
db_session.query(Comic)
.filter(Comic.url == "https://example.com/comic/1")
.first()
)
assert saved_comic is not None
assert saved_comic.published is None
@patch("memory.workers.tasks.comic.requests.get")
def test_sync_comic_special_characters_in_title(
mock_get, mock_image_response, db_session, qdrant
):
"""Test comic sync with special characters in title."""
mock_get.return_value = mock_image_response
comic.sync_comic(
url="https://example.com/comic/1",
image_url="https://example.com/image.png",
title="Test/Comic: With*Special?Characters",
author="https://example.com",
)
# Verify comic was created with cleaned title
saved_comic = (
db_session.query(Comic)
.filter(Comic.url == "https://example.com/comic/1")
.first()
)
assert saved_comic is not None
assert saved_comic.title == "Test/Comic: With*Special?Characters"
@patch(
"memory.common.embedding.embed_source_item",
side_effect=Exception("Embedding failed"),
)
def test_sync_comic_embedding_failure(
mock_embed_source_item, mock_image_response, db_session, qdrant
):
"""Test handling of embedding failures."""
result = comic.sync_comic(
url="https://example.com/comic/1",
image_url="https://example.com/image.png",
title="Test Comic",
author="https://example.com",
)
assert result == {
"comic_id": 1,
"title": "Test Comic",
"status": "processed",
"chunks_count": 0,
"embed_status": "FAILED",
"content_length": 90,
}
@patch("memory.workers.tasks.comic.sync_xkcd.delay")
@patch("memory.workers.tasks.comic.sync_smbc.delay")
def test_sync_all_comics(mock_smbc_delay, mock_xkcd_delay):
"""Test synchronization of all comics."""
comic.sync_all_comics()
mock_smbc_delay.assert_called_once()
mock_xkcd_delay.assert_called_once()
@patch("memory.workers.tasks.comic.sync_comic.delay")
@patch("memory.workers.tasks.comic.comics.extract_xkcd")
@patch("memory.workers.tasks.comic.comics.extract_smbc")
@patch("requests.get")
def test_trigger_comic_sync_smbc_navigation(
mock_get, mock_extract_smbc, mock_extract_xkcd, mock_sync_delay, mock_comic_info
):
"""Test full SMBC comic sync with navigation."""
# Mock HTML responses for navigation
mock_responses = [
Mock(text='<a class="cc-prev" href="https://smbc.com/comic/2"></a>'),
Mock(text='<a class="cc-prev" href="https://smbc.com/comic/1"></a>'),
Mock(text="<div>No prev link</div>"), # End of navigation
]
mock_get.side_effect = mock_responses
mock_extract_smbc.return_value = mock_comic_info
mock_extract_xkcd.return_value = mock_comic_info
comic.trigger_comic_sync()
# Should have called extract_smbc for each discovered URL
assert mock_extract_smbc.call_count == 2
mock_extract_smbc.assert_any_call("https://smbc.com/comic/2")
mock_extract_smbc.assert_any_call("https://smbc.com/comic/1")
# Should have called extract_xkcd for range 1-307
assert mock_extract_xkcd.call_count == 307
# Should have queued sync tasks
assert mock_sync_delay.call_count == 2 + 307 # SMBC + XKCD
@patch("memory.workers.tasks.comic.sync_comic.delay")
@patch("memory.workers.tasks.comic.comics.extract_smbc")
@patch("requests.get")
def test_trigger_comic_sync_smbc_extraction_error(
mock_get, mock_extract_smbc, mock_sync_delay
):
"""Test handling of extraction errors during full sync."""
# Mock responses: first one has a prev link, second one doesn't
mock_responses = [
Mock(text='<a class="cc-prev" href="https://smbc.com/comic/1"></a>'),
Mock(text="<div>No prev link</div>"),
]
mock_get.side_effect = mock_responses
mock_extract_smbc.side_effect = Exception("Extraction failed")
# Should not raise exception, just log error
comic.trigger_comic_sync()
mock_extract_smbc.assert_called_once_with("https://smbc.com/comic/1")
mock_sync_delay.assert_not_called()
@patch("memory.workers.tasks.comic.sync_comic.delay")
@patch("memory.workers.tasks.comic.comics.extract_xkcd")
@patch("requests.get")
def test_trigger_comic_sync_xkcd_extraction_error(
mock_get, mock_extract_xkcd, mock_sync_delay
):
"""Test handling of XKCD extraction errors during full sync."""
mock_get.return_value = Mock(text="<div>No prev link</div>")
mock_extract_xkcd.side_effect = Exception("XKCD extraction failed")
# Should not raise exception, just log errors
comic.trigger_comic_sync()
# Should attempt all 307 XKCD comics
assert mock_extract_xkcd.call_count == 307
mock_sync_delay.assert_not_called()

View File

@ -51,24 +51,6 @@ def mock_ebook():
)
@pytest.fixture
def mock_embedding():
"""Mock the embedding function to return dummy vectors."""
with patch("memory.workers.tasks.ebook.embedding.embed") as mock:
mock.return_value = (
"book",
[
Chunk(
vector=[0.1] * 1024,
item_metadata={"test": "data"},
content="Test content",
embedding_model="model",
)
],
)
yield mock
@pytest.fixture
def mock_qdrant():
"""Mock Qdrant operations."""
@ -151,7 +133,7 @@ def test_create_book_and_sections(mock_ebook, db_session):
assert getattr(chapter1, "parent_section_id") is None
def test_embed_sections(db_session, mock_embedding):
def test_embed_sections(db_session):
"""Test basic embedding sections workflow."""
# Create a test book first
book = Book(
@ -187,31 +169,6 @@ def test_embed_sections(db_session, mock_embedding):
assert hasattr(sections[0], "embed_status")
def test_push_to_qdrant(qdrant):
"""Test pushing embeddings to Qdrant."""
# Create test sections with chunks
mock_chunk = Mock(
id="00000000-0000-0000-0000-000000000000",
vector=[0.1] * 1024,
item_metadata={"test": "data"},
)
mock_section = Mock(spec=BookSection)
mock_section.embed_status = "QUEUED"
mock_section.chunks = [mock_chunk]
sections = [mock_section]
ebook.push_to_qdrant(sections) # type: ignore
assert {r.id: r.payload for r in qdrant.scroll(collection_name="book")[0]} == {
"00000000-0000-0000-0000-000000000000": {
"test": "data",
}
}
assert mock_section.embed_status == "STORED"
@patch("memory.workers.tasks.ebook.parse_ebook")
def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path):
"""Test successful book synchronization."""
@ -229,7 +186,7 @@ def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path):
"author": "Test Author",
"status": "processed",
"total_sections": 4,
"sections_embedded": 4,
"sections_embedded": 8,
}
book = db_session.query(Book).filter(Book.title == "Test Book").first()
@ -270,8 +227,9 @@ def test_sync_book_already_exists(mock_parse, mock_ebook, db_session, tmp_path):
@patch("memory.workers.tasks.ebook.parse_ebook")
@patch("memory.common.embedding.embed_source_item")
def test_sync_book_embedding_failure(
mock_parse, mock_ebook, db_session, tmp_path, mock_embedding
mock_embedding, mock_parse, mock_ebook, db_session, tmp_path
):
"""Test handling of embedding failures."""
book_file = tmp_path / "test.epub"
@ -307,14 +265,18 @@ def test_sync_book_qdrant_failure(mock_parse, mock_ebook, db_session, tmp_path):
# Since embedding is already failing, this test will complete without hitting Qdrant
# So let's just verify that the function completes without raising an exception
with patch.object(ebook, "push_to_qdrant", side_effect=Exception("Qdrant failed")):
with pytest.raises(Exception, match="Qdrant failed"):
ebook.sync_book(str(book_file))
assert ebook.sync_book(str(book_file)) == {
"status": "error",
"error": "Qdrant failed",
}
def test_sync_book_file_not_found():
"""Test handling of missing files."""
with pytest.raises(FileNotFoundError):
ebook.sync_book("/nonexistent/file.epub")
assert ebook.sync_book("/nonexistent/file.epub") == {
"status": "error",
"error": "Book file not found: /nonexistent/file.epub",
}
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
@ -363,7 +325,7 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
calls = mock_voyage_client.embed.call_args_list
texts = [c[0][0] for c in calls]
assert texts == [
[large_section_content.strip()],
[large_page_1.strip()],
[large_page_2.strip()],
[large_section_content.strip()],
]

View File

@ -69,13 +69,21 @@ def test_email_account(db_session):
def test_process_simple_email(db_session, test_email_account, qdrant):
"""Test processing a simple email message."""
mail_message_id = process_message(
res = process_message(
account_id=test_email_account.id,
message_id="101",
folder="INBOX",
raw_email=SIMPLE_EMAIL_RAW,
)
mail_message_id = res["mail_message_id"]
assert res == {
"status": "processed",
"mail_message_id": mail_message_id,
"message_id": "101",
"chunks_count": 1,
"attachments_count": 0,
}
mail_message = (
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
)
@ -98,7 +106,7 @@ def test_process_email_with_attachment(db_session, test_email_account, qdrant):
message_id="302",
folder="Archive",
raw_email=EMAIL_WITH_ATTACHMENT_RAW,
)
)["mail_message_id"]
# Check mail message specifics
mail_message = (
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
@ -125,25 +133,25 @@ def test_process_email_with_attachment(db_session, test_email_account, qdrant):
def test_process_empty_message(db_session, test_email_account, qdrant):
"""Test processing an empty/invalid message."""
source_id = process_message(
res = process_message(
account_id=test_email_account.id,
message_id="999",
folder="Archive",
raw_email="",
)
assert source_id is None
assert res == {"reason": "empty_content", "status": "skipped"}
def test_process_duplicate_message(db_session, test_email_account, qdrant):
"""Test that duplicate messages are detected and not stored again."""
# First call should succeed and create records
source_id_1 = process_message(
res = process_message(
account_id=test_email_account.id,
message_id="101",
folder="INBOX",
raw_email=SIMPLE_EMAIL_RAW,
)
source_id_1 = res.get("mail_message_id")
assert source_id_1 is not None, "First call should return a source_id"
@ -157,7 +165,7 @@ def test_process_duplicate_message(db_session, test_email_account, qdrant):
message_id="101",
folder="INBOX",
raw_email=SIMPLE_EMAIL_RAW,
)
).get("mail_message_id")
assert source_id_2 is None, "Second call should return None for duplicate message"

View File

@ -12,7 +12,7 @@ from memory.common.db.models import (
EmailAttachment,
MailMessage,
)
from memory.common.parsers.email import Attachment
from memory.common.parsers.email import Attachment, parse_email_message
from memory.workers.email import (
create_mail_message,
extract_email_uid,
@ -226,14 +226,14 @@ def test_create_mail_message(db_session):
"--boundary--"
)
folder = "INBOX"
parsed_email = parse_email_message(raw_email, "321")
# Call function
mail_message = create_mail_message(
db_session=db_session,
raw_email=raw_email,
folder=folder,
tags=["test"],
message_id="123",
parsed_email=parsed_email,
)
db_session.commit()
@ -344,6 +344,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
recipients=["recipient@example.com"],
content="This is a test email for vectorization",
folder="INBOX",
modality="mail",
)
db_session.add(mail_message)
db_session.flush()
@ -373,6 +374,7 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
recipients=["recipient@example.com"],
content="This is a test email with attachments",
folder="INBOX",
modality="mail",
)
db_session.add(mail_message)
db_session.flush()