mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
unify tasks
This commit is contained in:
parent
ce6f4bf5c5
commit
4aaa45e09c
@ -15,7 +15,7 @@ from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common import embedding, qdrant, extract, settings
|
||||
from memory.common.embedding import get_modality
|
||||
from memory.common.collections import get_modality, TEXT_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
|
||||
@ -189,7 +189,7 @@ async def search(
|
||||
text_results = query_chunks(
|
||||
client,
|
||||
upload_data,
|
||||
allowed_modalities & embedding.TEXT_COLLECTIONS,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
embedding.embed_text,
|
||||
min_score=min_text_score,
|
||||
limit=limit,
|
||||
|
107
src/memory/common/collections.py
Normal file
107
src/memory/common/collections.py
Normal file
@ -0,0 +1,107 @@
|
||||
import logging
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||
Vector = list[float]
|
||||
|
||||
|
||||
class Collection(TypedDict):
|
||||
dimension: int
|
||||
distance: DistanceType
|
||||
model: str
|
||||
on_disk: NotRequired[bool]
|
||||
shards: NotRequired[int]
|
||||
|
||||
|
||||
ALL_COLLECTIONS: dict[str, Collection] = {
|
||||
"mail": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"chat": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"git": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"book": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"blog": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"text": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
# Multimodal
|
||||
"photo": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
"comic": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
"doc": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
}
|
||||
TEXT_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in ALL_COLLECTIONS.items()
|
||||
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
||||
}
|
||||
MULTIMODAL_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in ALL_COLLECTIONS.items()
|
||||
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
||||
}
|
||||
|
||||
TYPES = {
|
||||
"doc": ["application/pdf", "application/docx", "application/msword"],
|
||||
"text": ["text/*"],
|
||||
"blog": ["text/markdown", "text/html"],
|
||||
"photo": ["image/*"],
|
||||
"book": [
|
||||
"application/epub+zip",
|
||||
"application/mobi",
|
||||
"application/x-mobipocket-ebook",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_modality(mime_type: str) -> str:
|
||||
for type, mime_types in TYPES.items():
|
||||
if mime_type in mime_types:
|
||||
return type
|
||||
stem = mime_type.split("/")[0]
|
||||
|
||||
for type, mime_types in TYPES.items():
|
||||
if any(mime_type.startswith(stem) for mime_type in mime_types):
|
||||
return type
|
||||
return "unknown"
|
||||
|
||||
|
||||
def collection_model(collection: str) -> str | None:
|
||||
return ALL_COLLECTIONS.get(collection, {}).get("model")
|
@ -2,11 +2,12 @@
|
||||
Database models for the knowledge base system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import pathlib
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, ClassVar, cast
|
||||
from typing import Any, ClassVar, Iterable, Sequence, cast
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
@ -32,6 +33,9 @@ from sqlalchemy.orm import Session, relationship
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.parsers.email import EmailMessage, parse_email_message
|
||||
import memory.common.extract as extract
|
||||
import memory.common.collections as collections
|
||||
import memory.common.chunker as chunker
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@ -137,6 +141,7 @@ class SourceItem(Base):
|
||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||
|
||||
__tablename__ = "source_item"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
modality = Column(Text, nullable=False)
|
||||
@ -174,6 +179,14 @@ class SourceItem(Base):
|
||||
"""Get vector IDs from associated chunks."""
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
return [
|
||||
extract.DataChunk(
|
||||
data=cast(str, self.content),
|
||||
collection=cast(str, self.modality),
|
||||
)
|
||||
]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
@ -306,6 +319,19 @@ class EmailAttachment(SourceItem):
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
if cast(str | None, self.filename):
|
||||
contents = pathlib.Path(cast(str, self.filename)).read_bytes()
|
||||
else:
|
||||
contents = cast(str, self.content)
|
||||
|
||||
return extract.extract_data_chunks(
|
||||
cast(str, self.mime_type),
|
||||
contents,
|
||||
collection=cast(str, self.modality),
|
||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
||||
|
||||
@ -407,6 +433,15 @@ class Comic(SourceItem):
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
image = Image.open(pathlib.Path(cast(str, self.filename)))
|
||||
return [
|
||||
extract.DataChunk(
|
||||
data=[image, cast(str, self.title), cast(str, self.author)],
|
||||
collection=cast(str, self.modality),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Book(Base):
|
||||
"""Book-level metadata table"""
|
||||
@ -503,6 +538,19 @@ class BookSection(SourceItem):
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)]
|
||||
texts += [(cast(str, self.content), self.start_page)]
|
||||
return [
|
||||
extract.DataChunk(
|
||||
data=[text],
|
||||
collection=cast(str, self.modality),
|
||||
metadata={"page": page_number},
|
||||
max_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||
)
|
||||
for text, page_number in texts
|
||||
]
|
||||
|
||||
|
||||
class BlogPost(SourceItem):
|
||||
__tablename__ = "blog_post"
|
||||
@ -519,6 +567,7 @@ class BlogPost(SourceItem):
|
||||
description = Column(Text, nullable=True) # Meta description or excerpt
|
||||
domain = Column(Text, nullable=True) # Domain of the source website
|
||||
word_count = Column(Integer, nullable=True) # Approximate word count
|
||||
images = Column(ARRAY(Text), nullable=True) # List of image URLs
|
||||
|
||||
# Store original metadata from parser
|
||||
webpage_metadata = Column(JSONB, nullable=True)
|
||||
@ -552,6 +601,30 @@ class BlogPost(SourceItem):
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v}
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
images = [Image.open(image) for image in self.images]
|
||||
data = [cast(str, self.content)] + images
|
||||
|
||||
# Always embed the full content as a single chunk (if possible, of course)
|
||||
chunks = [
|
||||
extract.DataChunk(
|
||||
data=data,
|
||||
collection=cast(str, self.modality),
|
||||
max_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||
)
|
||||
]
|
||||
|
||||
# If the content is long enough, also embed it as chunks of the default size.
|
||||
tokens = chunker.approx_token_count(cast(str, self.content))
|
||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
chunks += [
|
||||
extract.DataChunk(
|
||||
data=data,
|
||||
collection=cast(str, self.modality),
|
||||
)
|
||||
]
|
||||
return chunks
|
||||
|
||||
|
||||
class MiscDoc(SourceItem):
|
||||
__tablename__ = "misc_doc"
|
||||
|
@ -1,130 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import pathlib
|
||||
import uuid
|
||||
from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast
|
||||
from typing import Any, Iterable, Literal, cast
|
||||
|
||||
import voyageai
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import extract, settings
|
||||
from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
|
||||
from memory.common.db.models import Chunk
|
||||
from memory.common.collections import ALL_COLLECTIONS, Vector
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common.extract import DataChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||
Vector = list[float]
|
||||
|
||||
|
||||
class Collection(TypedDict):
|
||||
dimension: int
|
||||
distance: DistanceType
|
||||
model: str
|
||||
on_disk: NotRequired[bool]
|
||||
shards: NotRequired[int]
|
||||
|
||||
|
||||
ALL_COLLECTIONS: dict[str, Collection] = {
|
||||
"mail": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"chat": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"git": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"book": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"blog": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
"text": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
},
|
||||
# Multimodal
|
||||
"photo": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
"comic": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
"doc": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
},
|
||||
}
|
||||
TEXT_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in ALL_COLLECTIONS.items()
|
||||
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
||||
}
|
||||
MULTIMODAL_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in ALL_COLLECTIONS.items()
|
||||
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
||||
}
|
||||
|
||||
TYPES = {
|
||||
"doc": ["application/pdf", "application/docx", "application/msword"],
|
||||
"text": ["text/*"],
|
||||
"blog": ["text/markdown", "text/html"],
|
||||
"photo": ["image/*"],
|
||||
"book": [
|
||||
"application/epub+zip",
|
||||
"application/mobi",
|
||||
"application/x-mobipocket-ebook",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_mimetype(image: Image.Image) -> str | None:
|
||||
format_to_mime = {
|
||||
"JPEG": "image/jpeg",
|
||||
"PNG": "image/png",
|
||||
"GIF": "image/gif",
|
||||
"BMP": "image/bmp",
|
||||
"TIFF": "image/tiff",
|
||||
"WEBP": "image/webp",
|
||||
}
|
||||
|
||||
if not image.format:
|
||||
return None
|
||||
|
||||
return format_to_mime.get(image.format.upper(), f"image/{image.format.lower()}")
|
||||
|
||||
|
||||
def get_modality(mime_type: str) -> str:
|
||||
for type, mime_types in TYPES.items():
|
||||
if mime_type in mime_types:
|
||||
return type
|
||||
stem = mime_type.split("/")[0]
|
||||
|
||||
for type, mime_types in TYPES.items():
|
||||
if any(mime_type.startswith(stem) for mime_type in mime_types):
|
||||
return type
|
||||
return "unknown"
|
||||
|
||||
|
||||
def embed_chunks(
|
||||
chunks: list[str] | list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
@ -164,12 +55,6 @@ def embed_text(
|
||||
raise
|
||||
|
||||
|
||||
def embed_file(
|
||||
file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL
|
||||
) -> list[Vector]:
|
||||
return embed_text([file_path.read_text()], model)
|
||||
|
||||
|
||||
def embed_mixed(
|
||||
items: list[extract.MulitmodalChunk],
|
||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||
@ -187,22 +72,6 @@ def embed_mixed(
|
||||
return embed_chunks([chunks], model, input_type)
|
||||
|
||||
|
||||
def embed_page(page: extract.Page) -> list[Vector]:
|
||||
contents = page["contents"]
|
||||
chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS)
|
||||
if all(isinstance(c, str) for c in contents):
|
||||
return embed_text(
|
||||
cast(list[str], contents),
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
return embed_mixed(
|
||||
cast(list[extract.MulitmodalChunk], contents),
|
||||
model=settings.MIXED_EMBEDDING_MODEL,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
|
||||
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
||||
if isinstance(item, str):
|
||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt"
|
||||
@ -219,7 +88,9 @@ def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
||||
|
||||
|
||||
def make_chunk(
|
||||
page: extract.Page, vector: Vector, metadata: dict[str, Any] = {}
|
||||
contents: Sequence[extract.MulitmodalChunk],
|
||||
vector: Vector,
|
||||
metadata: dict[str, Any] = {},
|
||||
) -> Chunk:
|
||||
"""Create a Chunk object from a page and a vector.
|
||||
|
||||
@ -227,7 +98,6 @@ def make_chunk(
|
||||
a single image, or a list of strings and images.
|
||||
"""
|
||||
chunk_id = str(uuid.uuid4())
|
||||
contents = page["contents"]
|
||||
content, filename = None, None
|
||||
if all(isinstance(c, str) for c in contents):
|
||||
content = "\n\n".join(cast(list[str], contents))
|
||||
@ -251,38 +121,42 @@ def make_chunk(
|
||||
)
|
||||
|
||||
|
||||
def embed(
|
||||
mime_type: str,
|
||||
content: bytes | str | pathlib.Path,
|
||||
def embed_data_chunk(
|
||||
chunk: DataChunk,
|
||||
metadata: dict[str, Any] = {},
|
||||
chunk_size: int | None = None,
|
||||
) -> tuple[str, list[Chunk]]:
|
||||
modality = get_modality(mime_type)
|
||||
pages = extract.extract_content(mime_type, content, chunk_size=chunk_size)
|
||||
chunks = [
|
||||
make_chunk(page, vector, metadata)
|
||||
for page in pages
|
||||
for vector in embed_page(page)
|
||||
) -> list[Chunk]:
|
||||
chunk_size = chunk.max_size or chunk_size or DEFAULT_CHUNK_TOKENS
|
||||
|
||||
model = chunk.embedding_model
|
||||
if not model and chunk.collection:
|
||||
model = ALL_COLLECTIONS.get(chunk.collection, {}).get("model")
|
||||
if not model:
|
||||
model = settings.TEXT_EMBEDDING_MODEL
|
||||
|
||||
if model == settings.TEXT_EMBEDDING_MODEL:
|
||||
vectors = embed_text(cast(list[str], chunk.data), chunk_size=chunk_size)
|
||||
elif model == settings.MIXED_EMBEDDING_MODEL:
|
||||
vectors = embed_mixed(
|
||||
cast(list[extract.MulitmodalChunk], chunk.data),
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
metadata = metadata | chunk.metadata
|
||||
return [make_chunk(chunk.data, vector, metadata) for vector in vectors]
|
||||
|
||||
|
||||
def embed_source_item(
|
||||
item: SourceItem,
|
||||
metadata: dict[str, Any] = {},
|
||||
chunk_size: int | None = None,
|
||||
) -> list[Chunk]:
|
||||
return [
|
||||
chunk
|
||||
for data_chunk in item.data_chunks()
|
||||
for chunk in embed_data_chunk(
|
||||
data_chunk, item.as_payload() | metadata, chunk_size
|
||||
)
|
||||
]
|
||||
return modality, chunks
|
||||
|
||||
|
||||
def embed_image(
|
||||
file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None
|
||||
) -> Chunk:
|
||||
image = Image.open(file_path)
|
||||
mime_type = get_mimetype(image)
|
||||
if mime_type is None:
|
||||
raise ValueError("Unsupported image format")
|
||||
|
||||
vector = embed_mixed(
|
||||
[image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS
|
||||
)[0]
|
||||
|
||||
return Chunk(
|
||||
id=str(uuid.uuid4()),
|
||||
file_path=file_path.absolute().as_posix(),
|
||||
content=None,
|
||||
embedding_model=settings.MIXED_EMBEDDING_MODEL,
|
||||
vector=vector,
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
import io
|
||||
import logging
|
||||
import pathlib
|
||||
@ -21,6 +22,15 @@ class Page(TypedDict):
|
||||
chunk_size: NotRequired[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataChunk:
|
||||
data: Sequence[MulitmodalChunk]
|
||||
collection: str | None = None
|
||||
embedding_model: str | None = None
|
||||
max_size: int | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
@ -110,11 +120,13 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
return [{"contents": [cast(str, content)], "metadata": {}}]
|
||||
|
||||
|
||||
def extract_content(
|
||||
def extract_data_chunks(
|
||||
mime_type: str,
|
||||
content: bytes | str | pathlib.Path,
|
||||
collection: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
chunk_size: int | None = None,
|
||||
) -> list[Page]:
|
||||
) -> list[DataChunk]:
|
||||
pages = []
|
||||
logger.info(f"Extracting content from {mime_type}")
|
||||
if mime_type == "application/pdf":
|
||||
@ -134,4 +146,12 @@ def extract_content(
|
||||
if chunk_size:
|
||||
pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
|
||||
|
||||
return pages
|
||||
return [
|
||||
DataChunk(
|
||||
data=page["contents"],
|
||||
collection=collection,
|
||||
embedding_model=embedding_model,
|
||||
max_size=chunk_size,
|
||||
)
|
||||
for page in pages
|
||||
]
|
||||
|
@ -26,6 +26,7 @@ class EmailMessage(TypedDict):
|
||||
body: str
|
||||
attachments: list[Attachment]
|
||||
hash: bytes
|
||||
raw_email: str
|
||||
|
||||
|
||||
RawEmailResponse = tuple[str | None, bytes]
|
||||
@ -171,6 +172,7 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
|
||||
body = extract_body(msg)
|
||||
|
||||
return EmailMessage(
|
||||
raw_email=raw_email,
|
||||
message_id=message_id,
|
||||
subject=subject,
|
||||
sender=from_,
|
||||
|
@ -12,7 +12,7 @@ from bs4 import BeautifulSoup, Tag
|
||||
from markdownify import markdownify as md
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from memory.common.settings import FILE_STORAGE_DIR, WEBPAGE_STORAGE_DIR
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -96,9 +96,9 @@ def extract_date(
|
||||
|
||||
datetime_attr = element.get("datetime")
|
||||
if datetime_attr:
|
||||
date_str = str(datetime_attr)
|
||||
if date := parse_date(date_str, date_format):
|
||||
return date
|
||||
for format in ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%d", date_format]:
|
||||
if date := parse_date(str(datetime_attr), format):
|
||||
return date
|
||||
|
||||
for text in element.find_all(string=True):
|
||||
if text and (date := parse_date(str(text).strip(), date_format)):
|
||||
@ -178,7 +178,7 @@ def process_images(
|
||||
continue
|
||||
|
||||
path = pathlib.Path(image.filename) # type: ignore
|
||||
img_tag["src"] = str(path.relative_to(FILE_STORAGE_DIR.resolve()))
|
||||
img_tag["src"] = str(path.relative_to(settings.FILE_STORAGE_DIR.resolve()))
|
||||
images[img_tag["src"]] = image
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process image {src}: {e}")
|
||||
@ -291,7 +291,7 @@ class BaseHTMLParser:
|
||||
|
||||
def __init__(self, base_url: str | None = None):
|
||||
self.base_url = base_url
|
||||
self.image_dir = WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc)
|
||||
self.image_dir = settings.WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc)
|
||||
self.image_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def parse(self, html: str, url: str) -> Article:
|
||||
|
@ -5,12 +5,7 @@ import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from memory.common import settings
|
||||
from memory.common.embedding import (
|
||||
Collection,
|
||||
ALL_COLLECTIONS,
|
||||
DistanceType,
|
||||
Vector,
|
||||
)
|
||||
from memory.common.collections import ALL_COLLECTIONS, Collection, DistanceType, Vector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -10,17 +10,16 @@ from typing import Callable, Generator, Sequence, cast
|
||||
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from memory.common import embedding, qdrant, settings
|
||||
from memory.common import embedding, qdrant, settings, collections
|
||||
from memory.common.db.models import (
|
||||
EmailAccount,
|
||||
EmailAttachment,
|
||||
MailMessage,
|
||||
SourceItem,
|
||||
)
|
||||
from memory.common.parsers.email import (
|
||||
Attachment,
|
||||
EmailMessage,
|
||||
RawEmailResponse,
|
||||
parse_email_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -54,7 +53,7 @@ def process_attachment(
|
||||
return None
|
||||
|
||||
return EmailAttachment(
|
||||
modality=embedding.get_modality(attachment["content_type"]),
|
||||
modality=collections.get_modality(attachment["content_type"]),
|
||||
sha256=hashlib.sha256(
|
||||
real_content if real_content else str(attachment).encode()
|
||||
).digest(),
|
||||
@ -94,8 +93,7 @@ def create_mail_message(
|
||||
db_session: Session | scoped_session,
|
||||
tags: list[str],
|
||||
folder: str,
|
||||
raw_email: str,
|
||||
message_id: str,
|
||||
parsed_email: EmailMessage,
|
||||
) -> MailMessage:
|
||||
"""
|
||||
Create a new mail message record and associated attachments.
|
||||
@ -109,7 +107,7 @@ def create_mail_message(
|
||||
Returns:
|
||||
Newly created MailMessage
|
||||
"""
|
||||
parsed_email = parse_email_message(raw_email, message_id)
|
||||
raw_email = parsed_email["raw_email"]
|
||||
mail_message = MailMessage(
|
||||
modality="mail",
|
||||
sha256=parsed_email["hash"],
|
||||
@ -137,52 +135,6 @@ def create_mail_message(
|
||||
return mail_message
|
||||
|
||||
|
||||
def does_message_exist(
|
||||
db_session: Session | scoped_session, message_id: str, message_hash: bytes
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a message already exists in the database.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
message_id: Email message ID
|
||||
message_hash: SHA-256 hash of message
|
||||
|
||||
Returns:
|
||||
True if message exists, False otherwise
|
||||
"""
|
||||
# Check by message_id first (faster)
|
||||
if message_id:
|
||||
mail_message = (
|
||||
db_session.query(MailMessage)
|
||||
.filter(MailMessage.message_id == message_id)
|
||||
.first()
|
||||
)
|
||||
if mail_message is not None:
|
||||
return True
|
||||
|
||||
# Then check by message_hash
|
||||
source_item = (
|
||||
db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first()
|
||||
)
|
||||
return source_item is not None
|
||||
|
||||
|
||||
def check_message_exists(
|
||||
db: Session | scoped_session, account_id: int, message_id: str, raw_email: str
|
||||
) -> bool:
|
||||
account = db.query(EmailAccount).get(account_id)
|
||||
if not account:
|
||||
logger.error(f"Account {account_id} not found")
|
||||
return False
|
||||
|
||||
parsed_email = parse_email_message(raw_email, message_id)
|
||||
if "szczepalins" in raw_email.lower():
|
||||
print(parsed_email["message_id"])
|
||||
|
||||
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
|
||||
|
||||
|
||||
def extract_email_uid(
|
||||
msg_data: Sequence[tuple[bytes, bytes]],
|
||||
) -> tuple[str | None, bytes]:
|
||||
@ -317,11 +269,7 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
||||
def vectorize_email(email: MailMessage):
|
||||
qdrant_client = qdrant.get_qdrant_client()
|
||||
|
||||
_, chunks = embedding.embed(
|
||||
"text/plain",
|
||||
email.body,
|
||||
metadata=email.as_payload(),
|
||||
)
|
||||
chunks = embedding.embed_source_item(email)
|
||||
email.chunks = chunks
|
||||
if chunks:
|
||||
vector_ids = [cast(str, c.id) for c in chunks]
|
||||
@ -329,7 +277,7 @@ def vectorize_email(email: MailMessage):
|
||||
metadata = [c.item_metadata for c in chunks]
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant_client,
|
||||
collection_name="mail",
|
||||
collection_name=cast(str, email.modality),
|
||||
ids=vector_ids,
|
||||
vectors=vectors, # type: ignore
|
||||
payloads=metadata, # type: ignore
|
||||
@ -337,18 +285,12 @@ def vectorize_email(email: MailMessage):
|
||||
|
||||
embeds = defaultdict(list)
|
||||
for attachment in email.attachments:
|
||||
if attachment.filename:
|
||||
content = pathlib.Path(attachment.filename).read_bytes()
|
||||
else:
|
||||
content = attachment.content
|
||||
collection, chunks = embedding.embed(
|
||||
attachment.mime_type, content, metadata=attachment.as_payload()
|
||||
)
|
||||
chunks = embedding.embed_source_item(attachment)
|
||||
if not chunks:
|
||||
continue
|
||||
|
||||
attachment.chunks = chunks
|
||||
embeds[collection].extend(chunks)
|
||||
embeds[attachment.modality].extend(chunks)
|
||||
|
||||
for collection, chunks in embeds.items():
|
||||
ids = [c.id for c in chunks]
|
||||
|
@ -1,99 +1,25 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Iterable, cast
|
||||
from typing import Iterable
|
||||
|
||||
from memory.common import chunker, embedding, qdrant
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import BlogPost
|
||||
from memory.common.parsers.blogs import parse_webpage
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
create_content_hash,
|
||||
create_task_result,
|
||||
process_content_item,
|
||||
safe_task_execution,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage"
|
||||
|
||||
|
||||
def create_blog_post_from_article(article, tags: Iterable[str] = []) -> BlogPost:
|
||||
"""Create a BlogPost model from parsed article data."""
|
||||
return BlogPost(
|
||||
url=article.url,
|
||||
title=article.title,
|
||||
published=article.published_date,
|
||||
content=article.content,
|
||||
sha256=hashlib.sha256(article.content.encode()).digest(),
|
||||
modality="blog",
|
||||
tags=tags,
|
||||
mime_type="text/markdown",
|
||||
size=len(article.content.encode("utf-8")),
|
||||
)
|
||||
|
||||
|
||||
def embed_blog_post(blog_post: BlogPost) -> int:
|
||||
"""Embed blog post content and return count of successfully embedded chunks."""
|
||||
try:
|
||||
# Always embed the full content
|
||||
_, chunks = embedding.embed(
|
||||
"text/markdown",
|
||||
cast(str, blog_post.content),
|
||||
metadata=blog_post.as_payload(),
|
||||
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||
)
|
||||
# But also embed the content in chunks (unless it's really short)
|
||||
if (
|
||||
chunker.approx_token_count(cast(str, blog_post.content))
|
||||
> chunker.DEFAULT_CHUNK_TOKENS * 2
|
||||
):
|
||||
_, small_chunks = embedding.embed(
|
||||
"text/markdown",
|
||||
cast(str, blog_post.content),
|
||||
metadata=blog_post.as_payload(),
|
||||
)
|
||||
chunks += small_chunks
|
||||
|
||||
if chunks:
|
||||
blog_post.chunks = chunks
|
||||
blog_post.embed_status = "QUEUED" # type: ignore
|
||||
return len(chunks)
|
||||
else:
|
||||
blog_post.embed_status = "FAILED" # type: ignore
|
||||
logger.warning(f"No chunks generated for blog post: {blog_post.title}")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
blog_post.embed_status = "FAILED" # type: ignore
|
||||
logger.error(f"Failed to embed blog post {blog_post.title}: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def push_to_qdrant(blog_post: BlogPost):
|
||||
"""Push embeddings to Qdrant for successfully embedded blog post."""
|
||||
if cast(str, blog_post.embed_status) != "QUEUED" or not blog_post.chunks:
|
||||
return
|
||||
|
||||
try:
|
||||
vector_ids = [str(chunk.id) for chunk in blog_post.chunks]
|
||||
vectors = [chunk.vector for chunk in blog_post.chunks]
|
||||
payloads = [chunk.item_metadata for chunk in blog_post.chunks]
|
||||
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant.get_qdrant_client(),
|
||||
collection_name="blog",
|
||||
ids=vector_ids,
|
||||
vectors=vectors,
|
||||
payloads=payloads,
|
||||
)
|
||||
|
||||
blog_post.embed_status = "STORED" # type: ignore
|
||||
logger.info(f"Successfully stored embeddings for: {blog_post.title}")
|
||||
|
||||
except Exception as e:
|
||||
blog_post.embed_status = "FAILED" # type: ignore
|
||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@app.task(name=SYNC_WEBPAGE)
|
||||
@safe_task_execution
|
||||
def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
|
||||
"""
|
||||
Synchronize a webpage from a URL.
|
||||
@ -116,61 +42,25 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
|
||||
"content_length": 0,
|
||||
}
|
||||
|
||||
blog_post = create_blog_post_from_article(article, tags)
|
||||
blog_post = BlogPost(
|
||||
url=article.url,
|
||||
title=article.title,
|
||||
published=article.published_date,
|
||||
content=article.content,
|
||||
sha256=create_content_hash(article.content),
|
||||
modality="blog",
|
||||
tags=tags,
|
||||
mime_type="text/markdown",
|
||||
size=len(article.content.encode("utf-8")),
|
||||
images=[image for image in article.images],
|
||||
)
|
||||
|
||||
with make_session() as session:
|
||||
existing_post = session.query(BlogPost).filter(BlogPost.url == url).first()
|
||||
if existing_post:
|
||||
logger.info(f"Blog post already exists: {existing_post.title}")
|
||||
return {
|
||||
"blog_post_id": existing_post.id,
|
||||
"url": url,
|
||||
"title": existing_post.title,
|
||||
"status": "already_exists",
|
||||
"chunks_count": len(existing_post.chunks),
|
||||
}
|
||||
|
||||
existing_post = (
|
||||
session.query(BlogPost).filter(BlogPost.sha256 == blog_post.sha256).first()
|
||||
existing_post = check_content_exists(
|
||||
session, BlogPost, url=url, sha256=create_content_hash(article.content)
|
||||
)
|
||||
if existing_post:
|
||||
logger.info(
|
||||
f"Blog post with the same content already exists: {existing_post.title}"
|
||||
)
|
||||
return {
|
||||
"blog_post_id": existing_post.id,
|
||||
"url": url,
|
||||
"title": existing_post.title,
|
||||
"status": "already_exists",
|
||||
"chunks_count": len(existing_post.chunks),
|
||||
}
|
||||
logger.info(f"Blog post already exists: {existing_post.title}")
|
||||
return create_task_result(existing_post, "already_exists", url=url)
|
||||
|
||||
session.add(blog_post)
|
||||
session.flush()
|
||||
|
||||
chunks_count = embed_blog_post(blog_post)
|
||||
session.flush()
|
||||
|
||||
try:
|
||||
push_to_qdrant(blog_post)
|
||||
logger.info(
|
||||
f"Successfully processed webpage: {blog_post.title} "
|
||||
f"({chunks_count} chunks embedded)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||
blog_post.embed_status = "FAILED" # type: ignore
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"blog_post_id": blog_post.id,
|
||||
"url": url,
|
||||
"title": blog_post.title,
|
||||
"author": article.author,
|
||||
"published_date": article.published_date,
|
||||
"status": "processed",
|
||||
"chunks_count": chunks_count,
|
||||
"content_length": len(article.content),
|
||||
"embed_status": blog_post.embed_status,
|
||||
}
|
||||
return process_content_item(blog_post, "blog", session, tags)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Callable, cast
|
||||
@ -6,21 +5,25 @@ from typing import Callable, cast
|
||||
import feedparser
|
||||
import requests
|
||||
|
||||
from memory.common import embedding, qdrant, settings
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Comic, clean_filename
|
||||
from memory.common.parsers import comics
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
create_content_hash,
|
||||
process_content_item,
|
||||
safe_task_execution,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics"
|
||||
SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc"
|
||||
SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd"
|
||||
SYNC_COMIC = "memory.workers.tasks.comic.sync_comic"
|
||||
|
||||
|
||||
BASE_SMBC_URL = "https://www.smbc-comics.com/"
|
||||
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
|
||||
|
||||
@ -36,6 +39,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
|
||||
return set()
|
||||
|
||||
urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries}
|
||||
urls = {url for url in urls if url}
|
||||
|
||||
with make_session() as session:
|
||||
known = {
|
||||
@ -61,6 +65,7 @@ def fetch_new_comics(
|
||||
|
||||
|
||||
@app.task(name=SYNC_COMIC)
|
||||
@safe_task_execution
|
||||
def sync_comic(
|
||||
url: str,
|
||||
image_url: str,
|
||||
@ -70,20 +75,26 @@ def sync_comic(
|
||||
):
|
||||
"""Synchronize a comic from a URL."""
|
||||
with make_session() as session:
|
||||
if session.query(Comic).filter(Comic.url == url).first():
|
||||
return
|
||||
existing_comic = check_content_exists(session, Comic, url=url)
|
||||
if existing_comic:
|
||||
return {"status": "already_exists", "comic_id": existing_comic.id}
|
||||
|
||||
response = requests.get(image_url)
|
||||
if response.status_code != 200:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"Failed to download image: {response.status_code}",
|
||||
}
|
||||
|
||||
file_type = image_url.split(".")[-1]
|
||||
mime_type = f"image/{file_type}"
|
||||
filename = (
|
||||
settings.COMIC_STORAGE_DIR / clean_filename(author) / f"{title}.{file_type}"
|
||||
)
|
||||
if response.status_code == 200:
|
||||
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
filename.write_bytes(response.content)
|
||||
|
||||
sha256 = hashlib.sha256(f"{image_url}{published_date}".encode()).digest()
|
||||
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
filename.write_bytes(response.content)
|
||||
|
||||
comic = Comic(
|
||||
title=title,
|
||||
url=url,
|
||||
@ -92,27 +103,13 @@ def sync_comic(
|
||||
filename=filename.resolve().as_posix(),
|
||||
mime_type=mime_type,
|
||||
size=len(response.content),
|
||||
sha256=sha256,
|
||||
sha256=create_content_hash(f"{image_url}{published_date}"),
|
||||
tags={"comic", author},
|
||||
modality="comic",
|
||||
)
|
||||
chunk = embedding.embed_image(filename, [title, author])
|
||||
comic.chunks = [chunk]
|
||||
|
||||
with make_session() as session:
|
||||
session.add(comic)
|
||||
session.add(chunk)
|
||||
session.flush()
|
||||
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant.get_qdrant_client(),
|
||||
collection_name="comic",
|
||||
ids=[str(chunk.id)],
|
||||
vectors=[chunk.vector], # type: ignore
|
||||
payloads=[comic.as_payload()],
|
||||
)
|
||||
|
||||
session.commit()
|
||||
return process_content_item(comic, "comic", session)
|
||||
|
||||
|
||||
@app.task(name=SYNC_SMBC)
|
||||
|
267
src/memory/workers/tasks/content_processing.py
Normal file
267
src/memory/workers/tasks/content_processing.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
Content processing utilities for memory workers.
|
||||
|
||||
This module provides core functionality for processing content items through
|
||||
the complete workflow: existence checking, content hashing, embedding generation,
|
||||
vector storage, and result tracking.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import traceback
|
||||
import logging
|
||||
from typing import Any, Callable, Iterable, Sequence, cast
|
||||
|
||||
from memory.common import embedding, qdrant
|
||||
from memory.common.db.models import SourceItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_content_exists(
|
||||
session,
|
||||
model_class: type[SourceItem],
|
||||
**kwargs: Any,
|
||||
) -> SourceItem | None:
|
||||
"""
|
||||
Check if content already exists in the database.
|
||||
|
||||
Searches for existing content by any of the provided attributes
|
||||
(typically URL, file_path, or SHA256 hash).
|
||||
|
||||
Args:
|
||||
session: Database session for querying
|
||||
model_class: The SourceItem model class to search in
|
||||
**kwargs: Attribute-value pairs to search for
|
||||
|
||||
Returns:
|
||||
Existing SourceItem if found, None otherwise
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(model_class, key):
|
||||
continue
|
||||
|
||||
existing = (
|
||||
session.query(model_class)
|
||||
.filter(getattr(model_class, key) == value)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_content_hash(content: str, *additional_data: str) -> bytes:
|
||||
"""
|
||||
Create SHA256 hash from content and optional additional data.
|
||||
|
||||
Args:
|
||||
content: Primary content to hash
|
||||
*additional_data: Additional strings to include in the hash
|
||||
|
||||
Returns:
|
||||
SHA256 hash digest as bytes
|
||||
"""
|
||||
hash_input = content + "".join(additional_data)
|
||||
return hashlib.sha256(hash_input.encode()).digest()
|
||||
|
||||
|
||||
def embed_source_item(source_item: SourceItem) -> int:
|
||||
"""
|
||||
Generate embeddings for a source item's content.
|
||||
|
||||
Processes the source item through the embedding pipeline, creating
|
||||
chunks and their corresponding vector embeddings. Updates the item's
|
||||
embed_status based on success or failure.
|
||||
|
||||
Args:
|
||||
source_item: The SourceItem to embed
|
||||
|
||||
Returns:
|
||||
Number of successfully embedded chunks
|
||||
|
||||
Side effects:
|
||||
- Sets source_item.chunks with generated chunks
|
||||
- Sets source_item.embed_status to "QUEUED" or "FAILED"
|
||||
"""
|
||||
try:
|
||||
chunks = embedding.embed_source_item(source_item)
|
||||
if chunks:
|
||||
source_item.chunks = chunks
|
||||
source_item.embed_status = "QUEUED" # type: ignore
|
||||
return len(chunks)
|
||||
else:
|
||||
source_item.embed_status = "FAILED" # type: ignore
|
||||
logger.warning(
|
||||
f"No chunks generated for {type(source_item).__name__}: {getattr(source_item, 'title', 'unknown')}"
|
||||
)
|
||||
return 0
|
||||
except Exception as e:
|
||||
source_item.embed_status = "FAILED" # type: ignore
|
||||
logger.error(f"Failed to embed {type(source_item).__name__}: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
||||
"""
|
||||
Push embeddings to Qdrant vector database.
|
||||
|
||||
Uploads vector embeddings for all source items that have been successfully
|
||||
embedded (status "QUEUED") and have chunks available.
|
||||
|
||||
Args:
|
||||
source_items: Sequence of SourceItems to process
|
||||
collection_name: Name of the Qdrant collection to store vectors in
|
||||
|
||||
Raises:
|
||||
Exception: If the Qdrant upsert operation fails
|
||||
|
||||
Side effects:
|
||||
- Updates embed_status to "STORED" for successful items
|
||||
- Updates embed_status to "FAILED" for failed items
|
||||
"""
|
||||
items_to_process = [
|
||||
item
|
||||
for item in source_items
|
||||
if cast(str, getattr(item, "embed_status", None)) == "QUEUED" and item.chunks
|
||||
]
|
||||
|
||||
if not items_to_process:
|
||||
return
|
||||
|
||||
all_chunks = [chunk for item in items_to_process for chunk in item.chunks]
|
||||
if not all_chunks:
|
||||
return
|
||||
|
||||
try:
|
||||
vector_ids = [str(chunk.id) for chunk in all_chunks]
|
||||
vectors = [chunk.vector for chunk in all_chunks]
|
||||
payloads = [chunk.item_metadata for chunk in all_chunks]
|
||||
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant.get_qdrant_client(),
|
||||
collection_name=collection_name,
|
||||
ids=vector_ids,
|
||||
vectors=vectors,
|
||||
payloads=payloads,
|
||||
)
|
||||
|
||||
for item in items_to_process:
|
||||
item.embed_status = "STORED" # type: ignore
|
||||
logger.info(
|
||||
f"Successfully stored embeddings for: {getattr(item, 'title', 'unknown')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
for item in items_to_process:
|
||||
item.embed_status = "FAILED" # type: ignore
|
||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_task_result(
|
||||
item: SourceItem, status: str, **additional_fields: Any
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create standardized task result dictionary.
|
||||
|
||||
Generates a consistent result format for task execution reporting,
|
||||
including item metadata and processing status.
|
||||
|
||||
Args:
|
||||
item: The processed SourceItem
|
||||
status: Processing status string
|
||||
**additional_fields: Extra fields to include in the result
|
||||
|
||||
Returns:
|
||||
Dictionary with standardized task result format
|
||||
"""
|
||||
return {
|
||||
f"{type(item).__name__.lower()}_id": item.id,
|
||||
"title": getattr(item, "title", None),
|
||||
"status": status,
|
||||
"chunks_count": len(item.chunks),
|
||||
"embed_status": item.embed_status,
|
||||
**additional_fields,
|
||||
}
|
||||
|
||||
|
||||
def process_content_item(
|
||||
item: SourceItem, collection_name: str, session, tags: Iterable[str] = []
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute complete content processing workflow.
|
||||
|
||||
Performs the full pipeline for processing a content item:
|
||||
1. Add to database session and flush to get ID
|
||||
2. Generate embeddings and chunks
|
||||
3. Push embeddings to Qdrant vector store
|
||||
4. Commit transaction and return result
|
||||
|
||||
Args:
|
||||
item: SourceItem to process
|
||||
collection_name: Qdrant collection name for vector storage
|
||||
session: Database session for persistence
|
||||
tags: Optional tags to associate with the item (currently unused)
|
||||
|
||||
Returns:
|
||||
Task result dictionary with processing status and metadata
|
||||
|
||||
Side effects:
|
||||
- Adds item to database session
|
||||
- Commits database transaction
|
||||
- Stores vectors in Qdrant
|
||||
"""
|
||||
session.add(item)
|
||||
session.flush()
|
||||
|
||||
chunks_count = embed_source_item(item)
|
||||
session.flush()
|
||||
|
||||
try:
|
||||
push_to_qdrant([item], collection_name)
|
||||
status = "processed"
|
||||
logger.info(
|
||||
f"Successfully processed {type(item).__name__}: {getattr(item, 'title', 'unknown')} ({chunks_count} chunks embedded)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||
item.embed_status = "FAILED" # type: ignore
|
||||
status = "failed"
|
||||
|
||||
session.commit()
|
||||
|
||||
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
||||
|
||||
|
||||
def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
||||
"""
|
||||
Decorator for safe task execution with comprehensive error handling.
|
||||
|
||||
Wraps task functions to catch and log exceptions, ensuring tasks
|
||||
always return a result dictionary even when they fail.
|
||||
|
||||
Args:
|
||||
func: Task function to wrap
|
||||
|
||||
Returns:
|
||||
Wrapped function that handles exceptions gracefully
|
||||
|
||||
Example:
|
||||
@safe_task_execution
|
||||
def my_task(arg1, arg2):
|
||||
# Task implementation
|
||||
return {"status": "success"}
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs) -> dict:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Task {func.__name__} failed with traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.error(f"Task {func.__name__} failed: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
return wrapper
|
@ -1,17 +1,21 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable, cast
|
||||
|
||||
from memory.common import chunker, embedding, qdrant
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Book, BookSection
|
||||
from memory.common.parsers.ebook import Ebook, parse_ebook, Section
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
create_content_hash,
|
||||
embed_source_item,
|
||||
push_to_qdrant,
|
||||
safe_task_execution,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SYNC_BOOK = "memory.workers.tasks.book.sync_book"
|
||||
|
||||
# Minimum section length to embed (avoid noise from very short sections)
|
||||
@ -46,10 +50,6 @@ def section_processor(
|
||||
):
|
||||
content = "\n\n".join(section.pages).strip()
|
||||
if len(content) >= MIN_SECTION_LENGTH:
|
||||
sha256 = hashlib.sha256(
|
||||
f"{book.id}:{section.title}:{section.start_page}".encode()
|
||||
).digest()
|
||||
|
||||
book_section = BookSection(
|
||||
book_id=book.id,
|
||||
section_title=section.title,
|
||||
@ -59,7 +59,9 @@ def section_processor(
|
||||
end_page=section.end_page,
|
||||
parent_section_id=None, # Will be set after flush
|
||||
content=content,
|
||||
sha256=sha256,
|
||||
sha256=create_content_hash(
|
||||
f"{book.id}:{section.title}:{section.start_page}"
|
||||
),
|
||||
modality="book",
|
||||
tags=book.tags,
|
||||
pages=section.pages,
|
||||
@ -127,76 +129,11 @@ def create_book_and_sections(
|
||||
|
||||
def embed_sections(all_sections: list[BookSection]) -> int:
|
||||
"""Embed all sections and return count of successfully embedded sections."""
|
||||
embedded_count = 0
|
||||
|
||||
def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]:
|
||||
_, chunks = embedding.embed(
|
||||
"text/plain",
|
||||
text,
|
||||
metadata=metadata,
|
||||
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||
)
|
||||
return chunks
|
||||
|
||||
for section in all_sections:
|
||||
try:
|
||||
section_chunks = embed_text(
|
||||
cast(str, section.content), section.as_payload()
|
||||
)
|
||||
page_chunks = [
|
||||
chunk
|
||||
for i, page in enumerate(section.pages)
|
||||
for chunk in embed_text(
|
||||
page, section.as_payload() | {"page_number": i + section.start_page}
|
||||
)
|
||||
]
|
||||
chunks = section_chunks + page_chunks
|
||||
|
||||
if chunks:
|
||||
section.chunks = chunks
|
||||
section.embed_status = "QUEUED" # type: ignore
|
||||
embedded_count += 1
|
||||
else:
|
||||
section.embed_status = "FAILED" # type: ignore
|
||||
logger.warning(
|
||||
f"No chunks generated for section: {section.section_title}"
|
||||
)
|
||||
|
||||
except IOError as e:
|
||||
section.embed_status = "FAILED" # type: ignore
|
||||
logger.error(f"Failed to embed section {section.section_title}: {e}")
|
||||
|
||||
return embedded_count
|
||||
|
||||
|
||||
def push_to_qdrant(all_sections: list[BookSection]):
|
||||
"""Push embeddings to Qdrant for all successfully embedded sections."""
|
||||
vector_ids = []
|
||||
vectors = []
|
||||
payloads = []
|
||||
|
||||
to_process = [s for s in all_sections if cast(str, s.embed_status) == "QUEUED"]
|
||||
all_chunks = [chunk for section in to_process for chunk in section.chunks]
|
||||
if not all_chunks:
|
||||
return
|
||||
|
||||
vector_ids = [str(chunk.id) for chunk in all_chunks]
|
||||
vectors = [chunk.vector for chunk in all_chunks]
|
||||
payloads = [chunk.item_metadata for chunk in all_chunks]
|
||||
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant.get_qdrant_client(),
|
||||
collection_name="book",
|
||||
ids=vector_ids,
|
||||
vectors=vectors,
|
||||
payloads=payloads,
|
||||
)
|
||||
|
||||
for section in to_process:
|
||||
section.embed_status = "STORED" # type: ignore
|
||||
return sum(embed_source_item(section) for section in all_sections)
|
||||
|
||||
|
||||
@app.task(name=SYNC_BOOK)
|
||||
@safe_task_execution
|
||||
def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
||||
"""
|
||||
Synchronize a book from a file path.
|
||||
@ -211,10 +148,8 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
||||
|
||||
with make_session() as session:
|
||||
# Check for existing book
|
||||
existing_book = (
|
||||
session.query(Book)
|
||||
.filter(Book.file_path == ebook.file_path.as_posix())
|
||||
.first()
|
||||
existing_book = check_content_exists(
|
||||
session, Book, file_path=ebook.file_path.as_posix()
|
||||
)
|
||||
if existing_book:
|
||||
logger.info(f"Book already exists: {existing_book.title}")
|
||||
@ -230,19 +165,10 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
||||
book, all_sections = create_book_and_sections(ebook, session, tags)
|
||||
|
||||
# Embed sections
|
||||
embedded_count = embed_sections(all_sections)
|
||||
embedded_count = sum(embed_source_item(section) for section in all_sections)
|
||||
session.flush()
|
||||
|
||||
# Push to Qdrant
|
||||
try:
|
||||
push_to_qdrant(all_sections)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||
# Mark sections as failed
|
||||
for section in all_sections:
|
||||
if getattr(section, "embed_status") == "STORED":
|
||||
section.embed_status = "FAILED" # type: ignore
|
||||
raise
|
||||
push_to_qdrant(all_sections, "book")
|
||||
|
||||
session.commit()
|
||||
|
||||
|
@ -2,16 +2,19 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import EmailAccount
|
||||
from memory.common.db.models import EmailAccount, MailMessage
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.email import (
|
||||
check_message_exists,
|
||||
create_mail_message,
|
||||
imap_connection,
|
||||
process_folder,
|
||||
vectorize_email,
|
||||
)
|
||||
|
||||
from memory.common.parsers.email import parse_email_message
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
safe_task_execution,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -21,12 +24,13 @@ SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
|
||||
|
||||
|
||||
@app.task(name=PROCESS_EMAIL)
|
||||
@safe_task_execution
|
||||
def process_message(
|
||||
account_id: int,
|
||||
message_id: str,
|
||||
folder: str,
|
||||
raw_email: str,
|
||||
) -> int | None:
|
||||
) -> dict:
|
||||
"""
|
||||
Process a single email message and store it in the database.
|
||||
|
||||
@ -37,29 +41,30 @@ def process_message(
|
||||
raw_email: Raw email content as string
|
||||
|
||||
Returns:
|
||||
source_id if successful, None otherwise
|
||||
dict with processing result
|
||||
"""
|
||||
logger.info(f"Processing message {message_id} for account {account_id}")
|
||||
if not raw_email.strip():
|
||||
logger.warning(f"Empty email message received for account {account_id}")
|
||||
return None
|
||||
return {"status": "skipped", "reason": "empty_content"}
|
||||
|
||||
with make_session() as db:
|
||||
if check_message_exists(db, account_id, message_id, raw_email):
|
||||
logger.debug(f"Message {message_id} already exists in database")
|
||||
return None
|
||||
|
||||
account = db.query(EmailAccount).get(account_id)
|
||||
if not account:
|
||||
logger.error(f"Account {account_id} not found")
|
||||
return None
|
||||
return {"status": "error", "error": "Account not found"}
|
||||
|
||||
mail_message = create_mail_message(
|
||||
db, account.tags, folder, raw_email, message_id
|
||||
)
|
||||
parsed_email = parse_email_message(raw_email, message_id)
|
||||
if check_content_exists(
|
||||
db, MailMessage, message_id=message_id, sha256=parsed_email["hash"]
|
||||
):
|
||||
return {"status": "already_exists", "message_id": message_id}
|
||||
|
||||
mail_message = create_mail_message(db, account.tags, folder, parsed_email)
|
||||
|
||||
db.flush()
|
||||
vectorize_email(mail_message)
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Stored embedding for message {mail_message.message_id}")
|
||||
@ -71,16 +76,24 @@ def process_message(
|
||||
for chunk in attachment.chunks:
|
||||
logger.info(f" - {chunk.id}")
|
||||
|
||||
return cast(int, mail_message.id)
|
||||
return {
|
||||
"status": "processed",
|
||||
"mail_message_id": cast(int, mail_message.id),
|
||||
"message_id": message_id,
|
||||
"chunks_count": len(mail_message.chunks),
|
||||
"attachments_count": len(mail_message.attachments),
|
||||
}
|
||||
|
||||
|
||||
@app.task(name=SYNC_ACCOUNT)
|
||||
@safe_task_execution
|
||||
def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||
"""
|
||||
Synchronize emails from a specific account.
|
||||
|
||||
Args:
|
||||
account_id: ID of the EmailAccount to sync
|
||||
since_date: ISO format date string to sync since
|
||||
|
||||
Returns:
|
||||
dict with stats about the sync operation
|
||||
@ -91,7 +104,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||
if not account or not cast(bool, account.active):
|
||||
logger.warning(f"Account {account_id} not found or inactive")
|
||||
return {"error": "Account not found or inactive"}
|
||||
return {"status": "error", "error": "Account not found or inactive"}
|
||||
|
||||
folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"]
|
||||
if since_date:
|
||||
@ -108,7 +121,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||
def process_message_wrapper(
|
||||
account_id: int, message_id: str, folder: str, raw_email: str
|
||||
) -> int | None:
|
||||
if check_message_exists(db, account_id, message_id, raw_email): # type: ignore
|
||||
parsed_email = parse_email_message(raw_email, message_id)
|
||||
if check_content_exists(
|
||||
db, MailMessage, message_id=message_id, sha256=parsed_email["hash"]
|
||||
):
|
||||
return None
|
||||
return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore
|
||||
|
||||
@ -127,9 +143,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"account": account.email_address,
|
||||
"since_date": cutoff_date.isoformat(),
|
||||
"folders_processed": len(folders_to_process),
|
||||
|
@ -6,7 +6,7 @@ from typing import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import contains_eager
|
||||
|
||||
from memory.common import embedding, qdrant, settings
|
||||
from memory.common import collections, embedding, qdrant, settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.workers.celery_app import app
|
||||
@ -60,11 +60,11 @@ def reingest_chunk(chunk_id: str, collection: str):
|
||||
logger.error(f"Chunk {chunk_id} not found")
|
||||
return
|
||||
|
||||
if collection not in embedding.ALL_COLLECTIONS:
|
||||
if collection not in collections.ALL_COLLECTIONS:
|
||||
raise ValueError(f"Unsupported collection {collection}")
|
||||
|
||||
data = chunk.data
|
||||
if collection in embedding.MULTIMODAL_COLLECTIONS:
|
||||
if collection in collections.MULTIMODAL_COLLECTIONS:
|
||||
vector = embedding.embed_mixed(data)[0]
|
||||
elif len(data) == 1 and isinstance(data[0], str):
|
||||
vector = embedding.embed_text([data[0]])[0]
|
||||
|
@ -205,9 +205,14 @@ def email_provider():
|
||||
def mock_file_storage(tmp_path: Path):
|
||||
chunk_storage_dir = tmp_path / "chunks"
|
||||
chunk_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path):
|
||||
with patch("memory.common.settings.CHUNK_STORAGE_DIR", chunk_storage_dir):
|
||||
yield
|
||||
image_storage_dir = tmp_path / "images"
|
||||
image_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
with (
|
||||
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
|
||||
patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir),
|
||||
patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -249,6 +249,7 @@ def test_parse_simple_email():
|
||||
"body": "Test body content\n",
|
||||
"attachments": [],
|
||||
"sent_at": ANY,
|
||||
"raw_email": msg.as_string(),
|
||||
"hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87"
|
||||
+ b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1",
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import requests
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.parsers.html import (
|
||||
Article,
|
||||
BaseHTMLParser,
|
||||
@ -164,27 +165,28 @@ def test_parse_date(text, date_format, expected):
|
||||
assert parse_date(text, date_format) == expected
|
||||
|
||||
|
||||
def test_extract_date():
|
||||
@pytest.mark.parametrize(
|
||||
"selector, date_format, expected",
|
||||
[
|
||||
("time", "%B %d, %Y", datetime.fromisoformat("2023-01-15T10:30:00")),
|
||||
(".named-date", "%B %d, %Y", datetime.fromisoformat("2023-01-15")),
|
||||
(".date", "%Y-%m-%d", datetime.fromisoformat("2023-02-20")),
|
||||
(".nonexistent", None, None),
|
||||
],
|
||||
)
|
||||
def test_extract_date(selector, date_format, expected):
|
||||
html = """
|
||||
<div>
|
||||
<time datetime="2023-01-15T10:30:00">January 15, 2023</time>
|
||||
<span class="named-date">January 15, 2023</span>
|
||||
<span class="date">2023-02-20</span>
|
||||
<div class="published">March 10, 2023</div>
|
||||
</div>
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# Should extract datetime attribute from time tag
|
||||
result = extract_date(soup, "time", "%Y-%m-%d")
|
||||
assert result == "2023-01-15T10:30:00"
|
||||
|
||||
# Should extract from text content
|
||||
result = extract_date(soup, ".date", "%Y-%m-%d")
|
||||
assert result == "2023-02-20T00:00:00"
|
||||
|
||||
# No matching element
|
||||
result = extract_date(soup, ".nonexistent", "%Y-%m-%d")
|
||||
assert result is None
|
||||
result = extract_date(soup, selector, date_format)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_extract_content_element():
|
||||
@ -393,8 +395,7 @@ def test_process_image_cached(mock_pil_open, mock_requests_get):
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.process_image")
|
||||
@patch("memory.common.parsers.html.FILE_STORAGE_DIR")
|
||||
def test_process_images_basic(mock_file_storage_dir, mock_process_image):
|
||||
def test_process_images_basic(mock_process_image):
|
||||
html = """
|
||||
<div>
|
||||
<p>Text content</p>
|
||||
@ -409,40 +410,37 @@ def test_process_images_basic(mock_file_storage_dir, mock_process_image):
|
||||
content = cast(Tag, soup.find("div"))
|
||||
base_url = "https://example.com"
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
image_dir = pathlib.Path(temp_dir)
|
||||
mock_file_storage_dir.resolve.return_value = pathlib.Path(temp_dir)
|
||||
# Mock successful image processing with proper filenames
|
||||
mock_images = []
|
||||
for i in range(3):
|
||||
mock_img = MagicMock(spec=PILImage.Image)
|
||||
mock_img.filename = str(settings.WEBPAGE_STORAGE_DIR / f"image{i + 1}.jpg")
|
||||
mock_images.append(mock_img)
|
||||
|
||||
# Mock successful image processing with proper filenames
|
||||
mock_images = []
|
||||
for i in range(3):
|
||||
mock_img = MagicMock(spec=PILImage.Image)
|
||||
mock_img.filename = str(pathlib.Path(temp_dir) / f"image{i + 1}.jpg")
|
||||
mock_images.append(mock_img)
|
||||
mock_process_image.side_effect = mock_images
|
||||
|
||||
mock_process_image.side_effect = mock_images
|
||||
updated_content, images = process_images(
|
||||
content, base_url, settings.WEBPAGE_STORAGE_DIR
|
||||
)
|
||||
|
||||
updated_content, images = process_images(content, base_url, image_dir)
|
||||
|
||||
# Should have processed 3 images (skipping the one without src)
|
||||
assert len(images) == 3
|
||||
assert mock_process_image.call_count == 3
|
||||
|
||||
# Check that img src attributes were updated to relative paths
|
||||
img_tags = [
|
||||
tag
|
||||
for tag in (updated_content.find_all("img") if updated_content else [])
|
||||
if isinstance(tag, Tag)
|
||||
]
|
||||
src_values = []
|
||||
for img in img_tags:
|
||||
src = img.get("src")
|
||||
if src and isinstance(src, str):
|
||||
src_values.append(src)
|
||||
|
||||
# Should have relative paths to the processed images
|
||||
for src in src_values[:3]: # First 3 have src
|
||||
assert not src.startswith("http") # Should be relative paths
|
||||
expected = BeautifulSoup(
|
||||
"""<div>
|
||||
<p>Text content</p>
|
||||
<img alt="Image 1" src="images/image1.jpg"/>
|
||||
<img alt="Image 2" src="images/image2.jpg"/>
|
||||
<img alt="Image 3" src="images/image3.jpg"/>
|
||||
<img alt="No src"/>
|
||||
<p>More text</p>
|
||||
</div>
|
||||
""",
|
||||
"html.parser",
|
||||
)
|
||||
assert updated_content.prettify() == expected.prettify() # type: ignore
|
||||
assert images == {
|
||||
"images/image1.jpg": mock_images[0],
|
||||
"images/image2.jpg": mock_images[1],
|
||||
"images/image3.jpg": mock_images[2],
|
||||
}
|
||||
|
||||
|
||||
def test_process_images_empty():
|
||||
@ -454,8 +452,7 @@ def test_process_images_empty():
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.process_image")
|
||||
@patch("memory.common.parsers.html.FILE_STORAGE_DIR")
|
||||
def test_process_images_with_failures(mock_file_storage_dir, mock_process_image):
|
||||
def test_process_images_with_failures(mock_process_image):
|
||||
html = """
|
||||
<div>
|
||||
<img src="good.jpg" alt="Good image">
|
||||
@ -465,22 +462,20 @@ def test_process_images_with_failures(mock_file_storage_dir, mock_process_image)
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
content = cast(Tag, soup.find("div"))
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
image_dir = pathlib.Path(temp_dir)
|
||||
mock_file_storage_dir.resolve.return_value = pathlib.Path(temp_dir)
|
||||
# First image succeeds, second fails
|
||||
mock_good_image = MagicMock(spec=PILImage.Image)
|
||||
mock_good_image.filename = settings.WEBPAGE_STORAGE_DIR / "good.jpg"
|
||||
mock_process_image.side_effect = [mock_good_image, None]
|
||||
|
||||
# First image succeeds, second fails
|
||||
mock_good_image = MagicMock(spec=PILImage.Image)
|
||||
mock_good_image.filename = str(pathlib.Path(temp_dir) / "good.jpg")
|
||||
mock_process_image.side_effect = [mock_good_image, None]
|
||||
updated_content, images = process_images(
|
||||
content, "https://example.com", settings.WEBPAGE_STORAGE_DIR
|
||||
)
|
||||
|
||||
updated_content, images = process_images(
|
||||
content, "https://example.com", image_dir
|
||||
)
|
||||
|
||||
# Should only return successful image
|
||||
assert len(images) == 1
|
||||
assert images[0] == mock_good_image
|
||||
expected = BeautifulSoup(
|
||||
html.replace("good.jpg", "images/good.jpg"), "html.parser"
|
||||
).prettify()
|
||||
assert updated_content.prettify() == expected # type: ignore
|
||||
assert images == {"images/good.jpg": mock_good_image}
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.process_image")
|
||||
@ -494,15 +489,12 @@ def test_process_images_no_filename(mock_process_image):
|
||||
mock_image.filename = None
|
||||
mock_process_image.return_value = mock_image
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
image_dir = pathlib.Path(temp_dir)
|
||||
updated_content, images = process_images(
|
||||
content, "https://example.com", settings.WEBPAGE_STORAGE_DIR
|
||||
)
|
||||
|
||||
updated_content, images = process_images(
|
||||
content, "https://example.com", image_dir
|
||||
)
|
||||
|
||||
# Should skip image without filename
|
||||
assert len(images) == 0
|
||||
# Should skip image without filename
|
||||
assert not images
|
||||
|
||||
|
||||
class TestBaseHTMLParser:
|
||||
@ -541,7 +533,7 @@ class TestBaseHTMLParser:
|
||||
|
||||
assert article.title == "Article Title"
|
||||
assert article.author == "John Smith" # Should prefer content over meta
|
||||
assert article.published_date == "2023-01-15T00:00:00"
|
||||
assert article.published_date == datetime(2023, 1, 15)
|
||||
assert article.url == "https://example.com/article"
|
||||
assert "This is the main content" in article.content
|
||||
assert article.metadata["author"] == "Jane Doe"
|
||||
|
@ -5,14 +5,10 @@ from typing import cast
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common import settings, collections
|
||||
from memory.common.embedding import (
|
||||
embed,
|
||||
embed_file,
|
||||
embed_mixed,
|
||||
embed_page,
|
||||
embed_text,
|
||||
get_modality,
|
||||
make_chunk,
|
||||
write_to_file,
|
||||
)
|
||||
@ -50,7 +46,7 @@ def mock_embed(mock_voyage_client):
|
||||
],
|
||||
)
|
||||
def test_get_modality(mime_type, expected_modality):
|
||||
assert get_modality(mime_type) == expected_modality
|
||||
assert collections.get_modality(mime_type) == expected_modality
|
||||
|
||||
|
||||
def test_embed_text(mock_embed):
|
||||
@ -58,59 +54,11 @@ def test_embed_text(mock_embed):
|
||||
assert embed_text(texts) == [[0], [1]]
|
||||
|
||||
|
||||
def test_embed_file(mock_embed, tmp_path):
|
||||
mock_file = tmp_path / "test.txt"
|
||||
mock_file.write_text("file content")
|
||||
|
||||
assert embed_file(mock_file) == [[0]]
|
||||
|
||||
|
||||
def test_embed_mixed(mock_embed):
|
||||
items = ["text", {"type": "image", "data": "base64"}]
|
||||
assert embed_mixed(items) == [[0]]
|
||||
|
||||
|
||||
def test_embed_page_text_only(mock_embed):
|
||||
page = {"contents": ["text1", "text2"]}
|
||||
assert embed_page(page) == [[0], [1]] # type: ignore
|
||||
|
||||
|
||||
def test_embed_page_mixed_content(mock_embed):
|
||||
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
|
||||
assert embed_page(page) == [[0]] # type: ignore
|
||||
|
||||
|
||||
def test_embed(mock_embed):
|
||||
mime_type = "text/plain"
|
||||
content = "sample content"
|
||||
metadata = {"source": "test"}
|
||||
|
||||
with patch.object(uuid, "uuid4", return_value="id1"):
|
||||
modality, chunks = embed(mime_type, content, metadata)
|
||||
|
||||
assert modality == "text"
|
||||
assert [
|
||||
{
|
||||
"id": c.id, # type: ignore
|
||||
"file_path": c.file_path, # type: ignore
|
||||
"content": c.content, # type: ignore
|
||||
"embedding_model": c.embedding_model, # type: ignore
|
||||
"vector": c.vector, # type: ignore
|
||||
"item_metadata": c.item_metadata, # type: ignore
|
||||
}
|
||||
for c in chunks
|
||||
] == [
|
||||
{
|
||||
"content": "sample content",
|
||||
"embedding_model": "voyage-3-large",
|
||||
"file_path": None,
|
||||
"id": "id1",
|
||||
"item_metadata": {"source": "test"},
|
||||
"vector": [0],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_write_to_file_text(mock_file_storage):
|
||||
"""Test writing a string to a file."""
|
||||
chunk_id = "test-chunk-id"
|
||||
@ -160,17 +108,14 @@ def test_write_to_file_unsupported_type(mock_file_storage):
|
||||
|
||||
def test_make_chunk_text_only(mock_file_storage, db_session):
|
||||
"""Test creating a chunk from string content."""
|
||||
page = {
|
||||
"contents": ["text content 1", "text content 2"],
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
contents = ["text content 1", "text content 2"]
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
):
|
||||
chunk = make_chunk(page, vector, metadata) # type: ignore
|
||||
chunk = make_chunk(contents, vector, metadata) # type: ignore
|
||||
|
||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
|
||||
assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
|
||||
@ -183,14 +128,14 @@ def test_make_chunk_text_only(mock_file_storage, db_session):
|
||||
def test_make_chunk_single_image(mock_file_storage, db_session):
|
||||
"""Test creating a chunk from a single image."""
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
page = {"contents": [img], "metadata": {"source": "test"}}
|
||||
contents = [img]
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
|
||||
):
|
||||
chunk = make_chunk(page, vector, metadata) # type: ignore
|
||||
chunk = make_chunk(contents, vector, metadata) # type: ignore
|
||||
|
||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
|
||||
assert chunk.content is None
|
||||
@ -208,14 +153,14 @@ def test_make_chunk_single_image(mock_file_storage, db_session):
|
||||
def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
||||
"""Test creating a chunk from mixed content (string and image)."""
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
page = {"contents": ["text content", img], "metadata": {"source": "test"}}
|
||||
contents = ["text content", img]
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
|
||||
):
|
||||
chunk = make_chunk(page, vector, metadata) # type: ignore
|
||||
chunk = make_chunk(contents, vector, metadata) # type: ignore
|
||||
|
||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
|
||||
assert chunk.content is None
|
||||
@ -233,3 +178,181 @@ def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
||||
assert (
|
||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png"
|
||||
).exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data,embedding_model,collection,expected_model,expected_count,expected_has_content",
|
||||
[
|
||||
# Text-only with default model
|
||||
(
|
||||
["text content 1", "text content 2"],
|
||||
None,
|
||||
None,
|
||||
settings.TEXT_EMBEDDING_MODEL,
|
||||
2,
|
||||
True,
|
||||
),
|
||||
# Text with explicit mixed model - but make_chunk still uses TEXT_EMBEDDING_MODEL for text-only content
|
||||
(
|
||||
["text content"],
|
||||
settings.MIXED_EMBEDDING_MODEL,
|
||||
None,
|
||||
settings.TEXT_EMBEDDING_MODEL,
|
||||
1,
|
||||
True,
|
||||
),
|
||||
# Text collection model selection - make_chunk uses TEXT_EMBEDDING_MODEL for text-only content
|
||||
(["text content"], None, "mail", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||
(["text content"], None, "photo", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||
(["text content"], None, "doc", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||
# Unknown collection falls back to default
|
||||
(["text content"], None, "unknown", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||
# Explicit model takes precedence over collection
|
||||
(
|
||||
["text content"],
|
||||
settings.TEXT_EMBEDDING_MODEL,
|
||||
"photo",
|
||||
settings.TEXT_EMBEDDING_MODEL,
|
||||
1,
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_embed_data_chunk_scenarios(
|
||||
data,
|
||||
embedding_model,
|
||||
collection,
|
||||
expected_model,
|
||||
expected_count,
|
||||
expected_has_content,
|
||||
mock_embed,
|
||||
mock_file_storage,
|
||||
):
|
||||
"""Test various embedding scenarios for data chunks."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(
|
||||
data=data,
|
||||
embedding_model=embedding_model,
|
||||
collection=collection,
|
||||
metadata={"source": "test"},
|
||||
)
|
||||
|
||||
result = embed_data_chunk(chunk, {"doc_type": "test"})
|
||||
|
||||
assert len(result) == expected_count
|
||||
assert all(cast(str, c.embedding_model) == expected_model for c in result)
|
||||
if expected_has_content:
|
||||
assert all(c.content is not None for c in result)
|
||||
assert all(c.file_path is None for c in result)
|
||||
else:
|
||||
assert all(c.content is None for c in result)
|
||||
assert all(c.file_path is not None for c in result)
|
||||
assert all(
|
||||
c.item_metadata == {"source": "test", "doc_type": "test"} for c in result
|
||||
)
|
||||
|
||||
|
||||
def test_embed_data_chunk_mixed_content(mock_embed, mock_file_storage):
|
||||
"""Test embedding mixed content (text and images)."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
chunk = DataChunk(
|
||||
data=["text content", img],
|
||||
embedding_model=settings.MIXED_EMBEDDING_MODEL,
|
||||
metadata={"source": "test"},
|
||||
)
|
||||
|
||||
result = embed_data_chunk(chunk)
|
||||
|
||||
assert len(result) == 1 # Mixed content returns single vector
|
||||
assert result[0].content is None # Mixed content stored in files
|
||||
assert result[0].file_path is not None
|
||||
assert cast(str, result[0].embedding_model) == settings.MIXED_EMBEDDING_MODEL
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_max_size,chunk_size_param,expected_chunk_size",
|
||||
[
|
||||
(512, 1024, 512), # chunk.max_size takes precedence
|
||||
(None, 2048, 2048), # chunk_size parameter used when max_size is None
|
||||
(256, None, 256), # chunk.max_size used when parameter is None
|
||||
],
|
||||
)
|
||||
def test_embed_data_chunk_chunk_size_handling(
|
||||
chunk_max_size, chunk_size_param, expected_chunk_size, mock_embed, mock_file_storage
|
||||
):
|
||||
"""Test chunk size parameter handling."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(
|
||||
data=["text content"], max_size=chunk_max_size, metadata={"source": "test"}
|
||||
)
|
||||
|
||||
with patch("memory.common.embedding.embed_text") as mock_embed_text:
|
||||
mock_embed_text.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
result = embed_data_chunk(chunk, chunk_size=chunk_size_param)
|
||||
|
||||
mock_embed_text.assert_called_once()
|
||||
args, kwargs = mock_embed_text.call_args
|
||||
assert kwargs["chunk_size"] == expected_chunk_size
|
||||
|
||||
|
||||
def test_embed_data_chunk_metadata_merging(mock_embed, mock_file_storage):
|
||||
"""Test that chunk metadata and parameter metadata are properly merged."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(
|
||||
data=["text content"], metadata={"source": "test", "type": "chunk"}
|
||||
)
|
||||
metadata = {
|
||||
"doc_type": "test",
|
||||
"source": "override",
|
||||
} # chunk.metadata takes precedence over parameter metadata
|
||||
|
||||
result = embed_data_chunk(chunk, metadata)
|
||||
|
||||
assert len(result) == 1
|
||||
expected_metadata = {
|
||||
"source": "test",
|
||||
"type": "chunk",
|
||||
"doc_type": "test",
|
||||
} # chunk source wins
|
||||
assert result[0].item_metadata == expected_metadata
|
||||
|
||||
|
||||
def test_embed_data_chunk_unsupported_model(mock_embed, mock_file_storage):
|
||||
"""Test error handling for unsupported embedding model."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(
|
||||
data=["text content"],
|
||||
embedding_model="unsupported-model",
|
||||
metadata={"source": "test"},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported model: unsupported-model"):
|
||||
embed_data_chunk(chunk)
|
||||
|
||||
|
||||
def test_embed_data_chunk_empty_data(mock_embed, mock_file_storage):
|
||||
"""Test handling of empty data."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(data=[], metadata={"source": "test"})
|
||||
|
||||
# Should handle empty data gracefully
|
||||
with patch("memory.common.embedding.embed_text") as mock_embed_text:
|
||||
mock_embed_text.return_value = []
|
||||
|
||||
result = embed_data_chunk(chunk)
|
||||
|
||||
assert result == []
|
||||
|
@ -7,7 +7,6 @@ import shutil
|
||||
from memory.common.extract import (
|
||||
as_file,
|
||||
extract_text,
|
||||
extract_content,
|
||||
doc_to_images,
|
||||
extract_image,
|
||||
docx_to_pdf,
|
||||
@ -107,52 +106,6 @@ def test_extract_image_with_str():
|
||||
extract_image("test")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mime_type,content",
|
||||
[
|
||||
("text/plain", "Text content"),
|
||||
("text/html", "<html>content</html>"),
|
||||
("text/markdown", "# Heading"),
|
||||
("text/csv", "a,b,c"),
|
||||
],
|
||||
)
|
||||
def test_extract_content_different_text_types(mime_type, content):
|
||||
assert extract_content(mime_type, content) == [
|
||||
{"contents": [content], "metadata": {}}
|
||||
]
|
||||
|
||||
|
||||
def test_extract_content_pdf():
|
||||
result = extract_content("application/pdf", REGULAMIN)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(
|
||||
isinstance(page["contents"], list)
|
||||
and all(isinstance(c, Image.Image) for c in page["contents"])
|
||||
for page in result
|
||||
)
|
||||
assert all("page" in page["metadata"] for page in result)
|
||||
assert all("width" in page["metadata"] for page in result)
|
||||
assert all("height" in page["metadata"] for page in result)
|
||||
|
||||
|
||||
def test_extract_content_image(tmp_path):
|
||||
# Create a test image
|
||||
img = Image.new("RGB", (100, 100), color="red")
|
||||
img_path = tmp_path / "test_img.png"
|
||||
img.save(img_path)
|
||||
|
||||
(result,) = extract_content("image/png", img_path)
|
||||
|
||||
assert isinstance(result["contents"][0], Image.Image)
|
||||
assert result["contents"][0].size == (100, 100)
|
||||
assert result["metadata"] == {}
|
||||
|
||||
|
||||
def test_extract_content_unsupported_type():
|
||||
assert extract_content("unsupported/type", "content") == []
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
|
||||
def test_docx_to_pdf(tmp_path):
|
||||
output_path = tmp_path / "output.pdf"
|
||||
|
431
tests/memory/workers/tasks/test_comic_tasks.py
Normal file
431
tests/memory/workers/tasks/test_comic_tasks.py
Normal file
@ -0,0 +1,431 @@
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from memory.common.db.models import Comic
|
||||
from memory.workers.tasks import comic
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_comic_info():
|
||||
"""Mock comic info data for testing."""
|
||||
return {
|
||||
"title": "Test Comic",
|
||||
"image_url": "https://example.com/comic.png",
|
||||
"url": "https://example.com/comic/1",
|
||||
"published_date": "2024-01-01T12:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feed_data():
|
||||
"""Mock RSS feed data."""
|
||||
return {
|
||||
"entries": [
|
||||
{"link": "https://example.com/comic/1", "id": None},
|
||||
{"link": "https://example.com/comic/2", "id": None},
|
||||
{"link": None, "id": "https://example.com/comic/3"},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_response():
|
||||
"""Mock HTTP response for comic image."""
|
||||
# 1x1 PNG image (smallest valid PNG)
|
||||
png_data = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01"
|
||||
b"\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13"
|
||||
b"\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```"
|
||||
b"\x00\x00\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
response = Mock()
|
||||
response.status_code = 200
|
||||
response.content = png_data
|
||||
with patch.object(requests, "get", return_value=response):
|
||||
yield response
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||
def test_find_new_urls_success(mock_parse, mock_feed_data, db_session):
|
||||
"""Test successful URL discovery from RSS feed."""
|
||||
mock_parse.return_value = Mock(entries=mock_feed_data["entries"])
|
||||
|
||||
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
|
||||
|
||||
assert result == {
|
||||
"https://example.com/comic/1",
|
||||
"https://example.com/comic/2",
|
||||
"https://example.com/comic/3",
|
||||
}
|
||||
mock_parse.assert_called_once_with("https://example.com/rss")
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||
def test_find_new_urls_with_existing_comics(mock_parse, mock_feed_data, db_session):
|
||||
"""Test URL discovery when some comics already exist."""
|
||||
mock_parse.return_value = Mock(entries=mock_feed_data["entries"])
|
||||
|
||||
# Add existing comic to database
|
||||
existing_comic = Comic(
|
||||
title="Existing Comic",
|
||||
url="https://example.com/comic/1",
|
||||
author="https://example.com",
|
||||
filename="/test/path",
|
||||
sha256=b"test_hash",
|
||||
modality="comic",
|
||||
tags=["comic"],
|
||||
)
|
||||
db_session.add(existing_comic)
|
||||
db_session.commit()
|
||||
|
||||
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
|
||||
|
||||
# Should only return URLs not in database
|
||||
assert result == {
|
||||
"https://example.com/comic/2",
|
||||
"https://example.com/comic/3",
|
||||
}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||
def test_find_new_urls_parse_error(mock_parse):
|
||||
"""Test handling of RSS feed parsing errors."""
|
||||
mock_parse.side_effect = Exception("Parse error")
|
||||
|
||||
assert (
|
||||
comic.find_new_urls("https://example.com", "https://example.com/rss") == set()
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||
def test_find_new_urls_empty_feed(mock_parse):
|
||||
"""Test handling of empty RSS feed."""
|
||||
mock_parse.return_value = Mock(entries=[])
|
||||
|
||||
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
|
||||
|
||||
assert result == set()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||
def test_find_new_urls_malformed_entries(mock_parse):
|
||||
"""Test handling of malformed RSS entries."""
|
||||
mock_parse.return_value = Mock(
|
||||
entries=[
|
||||
{"link": None, "id": None}, # Both None
|
||||
{}, # Missing keys
|
||||
]
|
||||
)
|
||||
|
||||
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
|
||||
|
||||
assert result == set()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||
@patch("memory.workers.tasks.comic.find_new_urls")
|
||||
def test_fetch_new_comics_success(mock_find_urls, mock_sync_delay, mock_comic_info):
|
||||
"""Test successful comic fetching."""
|
||||
mock_find_urls.return_value = {"https://example.com/comic/1"}
|
||||
mock_parser = Mock(return_value=mock_comic_info)
|
||||
|
||||
result = comic.fetch_new_comics(
|
||||
"https://example.com", "https://example.com/rss", mock_parser
|
||||
)
|
||||
|
||||
assert result == {"https://example.com/comic/1"}
|
||||
mock_parser.assert_called_once_with("https://example.com/comic/1")
|
||||
expected_call_args = {
|
||||
**mock_comic_info,
|
||||
"author": "https://example.com",
|
||||
"url": "https://example.com/comic/1",
|
||||
}
|
||||
mock_sync_delay.assert_called_once_with(**expected_call_args)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||
@patch("memory.workers.tasks.comic.find_new_urls")
|
||||
def test_fetch_new_comics_no_new_urls(mock_find_urls, mock_sync_delay):
|
||||
"""Test when no new URLs are found."""
|
||||
mock_find_urls.return_value = set()
|
||||
mock_parser = Mock()
|
||||
|
||||
result = comic.fetch_new_comics(
|
||||
"https://example.com", "https://example.com/rss", mock_parser
|
||||
)
|
||||
|
||||
assert result == set()
|
||||
mock_parser.assert_not_called()
|
||||
mock_sync_delay.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||
@patch("memory.workers.tasks.comic.find_new_urls")
|
||||
def test_fetch_new_comics_multiple_urls(
|
||||
mock_find_urls, mock_sync_delay, mock_comic_info
|
||||
):
|
||||
"""Test fetching multiple new comics."""
|
||||
urls = {"https://example.com/comic/1", "https://example.com/comic/2"}
|
||||
mock_find_urls.return_value = urls
|
||||
mock_parser = Mock(return_value=mock_comic_info)
|
||||
|
||||
result = comic.fetch_new_comics(
|
||||
"https://example.com", "https://example.com/rss", mock_parser
|
||||
)
|
||||
|
||||
assert result == urls
|
||||
assert mock_parser.call_count == 2
|
||||
assert mock_sync_delay.call_count == 2
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.requests.get")
|
||||
def test_sync_comic_success(mock_get, mock_image_response, db_session, qdrant):
|
||||
"""Test successful comic synchronization."""
|
||||
mock_get.return_value = mock_image_response
|
||||
|
||||
comic.sync_comic(
|
||||
url="https://example.com/comic/1",
|
||||
image_url="https://example.com/image.png",
|
||||
title="Test Comic",
|
||||
author="https://example.com",
|
||||
published_date=datetime(2024, 1, 1, 12, 0, 0),
|
||||
)
|
||||
|
||||
# Verify comic was created in database
|
||||
saved_comic = (
|
||||
db_session.query(Comic)
|
||||
.filter(Comic.url == "https://example.com/comic/1")
|
||||
.first()
|
||||
)
|
||||
assert saved_comic is not None
|
||||
assert saved_comic.title == "Test Comic"
|
||||
assert saved_comic.author == "https://example.com"
|
||||
assert saved_comic.mime_type == "image/png"
|
||||
assert saved_comic.size == len(mock_image_response.content)
|
||||
assert "comic" in saved_comic.tags
|
||||
assert "https://example.com" in saved_comic.tags
|
||||
|
||||
# Verify vectors were added to Qdrant
|
||||
vectors, _ = qdrant.scroll(collection_name="comic")
|
||||
expected_vectors = [
|
||||
(
|
||||
{
|
||||
"author": "https://example.com",
|
||||
"published": "2024-01-01T12:00:00",
|
||||
"tags": ["comic", "https://example.com"],
|
||||
"title": "Test Comic",
|
||||
"url": "https://example.com/comic/1",
|
||||
"source_id": 1,
|
||||
},
|
||||
None,
|
||||
)
|
||||
]
|
||||
assert [
|
||||
({**v.payload, "tags": sorted(v.payload["tags"])}, v.vector) for v in vectors
|
||||
] == expected_vectors
|
||||
|
||||
|
||||
def test_sync_comic_already_exists(db_session):
|
||||
"""Test that duplicate comics are not processed."""
|
||||
# Add existing comic
|
||||
existing_comic = Comic(
|
||||
title="Existing Comic",
|
||||
url="https://example.com/comic/1",
|
||||
author="https://example.com",
|
||||
filename="/test/path",
|
||||
sha256=b"test_hash",
|
||||
modality="comic",
|
||||
tags=["comic"],
|
||||
)
|
||||
db_session.add(existing_comic)
|
||||
db_session.commit()
|
||||
|
||||
with patch("memory.workers.tasks.comic.requests.get") as mock_get:
|
||||
result = comic.sync_comic(
|
||||
url="https://example.com/comic/1",
|
||||
image_url="https://example.com/image.png",
|
||||
title="Test Comic",
|
||||
author="https://example.com",
|
||||
)
|
||||
|
||||
# Should return early without making HTTP request
|
||||
mock_get.assert_not_called()
|
||||
assert result == {"comic_id": 1, "status": "already_exists"}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.requests.get")
|
||||
def test_sync_comic_http_error(mock_get, db_session, qdrant):
|
||||
"""Test handling of HTTP errors when downloading image."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.content = b""
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
comic.sync_comic(
|
||||
url="https://example.com/comic/1",
|
||||
image_url="https://example.com/image.png",
|
||||
title="Test Comic",
|
||||
author="https://example.com",
|
||||
)
|
||||
|
||||
assert not (
|
||||
db_session.query(Comic)
|
||||
.filter(Comic.url == "https://example.com/comic/1")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.requests.get")
|
||||
def test_sync_comic_no_published_date(
|
||||
mock_get, mock_image_response, db_session, qdrant
|
||||
):
|
||||
"""Test comic sync without published date."""
|
||||
mock_get.return_value = mock_image_response
|
||||
|
||||
comic.sync_comic(
|
||||
url="https://example.com/comic/1",
|
||||
image_url="https://example.com/image.png",
|
||||
title="Test Comic",
|
||||
author="https://example.com",
|
||||
published_date=None,
|
||||
)
|
||||
|
||||
saved_comic = (
|
||||
db_session.query(Comic)
|
||||
.filter(Comic.url == "https://example.com/comic/1")
|
||||
.first()
|
||||
)
|
||||
assert saved_comic is not None
|
||||
assert saved_comic.published is None
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.requests.get")
|
||||
def test_sync_comic_special_characters_in_title(
|
||||
mock_get, mock_image_response, db_session, qdrant
|
||||
):
|
||||
"""Test comic sync with special characters in title."""
|
||||
mock_get.return_value = mock_image_response
|
||||
|
||||
comic.sync_comic(
|
||||
url="https://example.com/comic/1",
|
||||
image_url="https://example.com/image.png",
|
||||
title="Test/Comic: With*Special?Characters",
|
||||
author="https://example.com",
|
||||
)
|
||||
|
||||
# Verify comic was created with cleaned title
|
||||
saved_comic = (
|
||||
db_session.query(Comic)
|
||||
.filter(Comic.url == "https://example.com/comic/1")
|
||||
.first()
|
||||
)
|
||||
assert saved_comic is not None
|
||||
assert saved_comic.title == "Test/Comic: With*Special?Characters"
|
||||
|
||||
|
||||
@patch(
|
||||
"memory.common.embedding.embed_source_item",
|
||||
side_effect=Exception("Embedding failed"),
|
||||
)
|
||||
def test_sync_comic_embedding_failure(
|
||||
mock_embed_source_item, mock_image_response, db_session, qdrant
|
||||
):
|
||||
"""Test handling of embedding failures."""
|
||||
result = comic.sync_comic(
|
||||
url="https://example.com/comic/1",
|
||||
image_url="https://example.com/image.png",
|
||||
title="Test Comic",
|
||||
author="https://example.com",
|
||||
)
|
||||
assert result == {
|
||||
"comic_id": 1,
|
||||
"title": "Test Comic",
|
||||
"status": "processed",
|
||||
"chunks_count": 0,
|
||||
"embed_status": "FAILED",
|
||||
"content_length": 90,
|
||||
}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.sync_xkcd.delay")
|
||||
@patch("memory.workers.tasks.comic.sync_smbc.delay")
|
||||
def test_sync_all_comics(mock_smbc_delay, mock_xkcd_delay):
|
||||
"""Test synchronization of all comics."""
|
||||
comic.sync_all_comics()
|
||||
|
||||
mock_smbc_delay.assert_called_once()
|
||||
mock_xkcd_delay.assert_called_once()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||
@patch("memory.workers.tasks.comic.comics.extract_xkcd")
|
||||
@patch("memory.workers.tasks.comic.comics.extract_smbc")
|
||||
@patch("requests.get")
|
||||
def test_trigger_comic_sync_smbc_navigation(
|
||||
mock_get, mock_extract_smbc, mock_extract_xkcd, mock_sync_delay, mock_comic_info
|
||||
):
|
||||
"""Test full SMBC comic sync with navigation."""
|
||||
# Mock HTML responses for navigation
|
||||
mock_responses = [
|
||||
Mock(text='<a class="cc-prev" href="https://smbc.com/comic/2"></a>'),
|
||||
Mock(text='<a class="cc-prev" href="https://smbc.com/comic/1"></a>'),
|
||||
Mock(text="<div>No prev link</div>"), # End of navigation
|
||||
]
|
||||
mock_get.side_effect = mock_responses
|
||||
mock_extract_smbc.return_value = mock_comic_info
|
||||
mock_extract_xkcd.return_value = mock_comic_info
|
||||
|
||||
comic.trigger_comic_sync()
|
||||
|
||||
# Should have called extract_smbc for each discovered URL
|
||||
assert mock_extract_smbc.call_count == 2
|
||||
mock_extract_smbc.assert_any_call("https://smbc.com/comic/2")
|
||||
mock_extract_smbc.assert_any_call("https://smbc.com/comic/1")
|
||||
|
||||
# Should have called extract_xkcd for range 1-307
|
||||
assert mock_extract_xkcd.call_count == 307
|
||||
|
||||
# Should have queued sync tasks
|
||||
assert mock_sync_delay.call_count == 2 + 307 # SMBC + XKCD
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||
@patch("memory.workers.tasks.comic.comics.extract_smbc")
|
||||
@patch("requests.get")
|
||||
def test_trigger_comic_sync_smbc_extraction_error(
|
||||
mock_get, mock_extract_smbc, mock_sync_delay
|
||||
):
|
||||
"""Test handling of extraction errors during full sync."""
|
||||
# Mock responses: first one has a prev link, second one doesn't
|
||||
mock_responses = [
|
||||
Mock(text='<a class="cc-prev" href="https://smbc.com/comic/1"></a>'),
|
||||
Mock(text="<div>No prev link</div>"),
|
||||
]
|
||||
mock_get.side_effect = mock_responses
|
||||
mock_extract_smbc.side_effect = Exception("Extraction failed")
|
||||
|
||||
# Should not raise exception, just log error
|
||||
comic.trigger_comic_sync()
|
||||
|
||||
mock_extract_smbc.assert_called_once_with("https://smbc.com/comic/1")
|
||||
mock_sync_delay.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||
@patch("memory.workers.tasks.comic.comics.extract_xkcd")
|
||||
@patch("requests.get")
|
||||
def test_trigger_comic_sync_xkcd_extraction_error(
|
||||
mock_get, mock_extract_xkcd, mock_sync_delay
|
||||
):
|
||||
"""Test handling of XKCD extraction errors during full sync."""
|
||||
mock_get.return_value = Mock(text="<div>No prev link</div>")
|
||||
mock_extract_xkcd.side_effect = Exception("XKCD extraction failed")
|
||||
|
||||
# Should not raise exception, just log errors
|
||||
comic.trigger_comic_sync()
|
||||
|
||||
# Should attempt all 307 XKCD comics
|
||||
assert mock_extract_xkcd.call_count == 307
|
||||
mock_sync_delay.assert_not_called()
|
@ -51,24 +51,6 @@ def mock_ebook():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding():
|
||||
"""Mock the embedding function to return dummy vectors."""
|
||||
with patch("memory.workers.tasks.ebook.embedding.embed") as mock:
|
||||
mock.return_value = (
|
||||
"book",
|
||||
[
|
||||
Chunk(
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"test": "data"},
|
||||
content="Test content",
|
||||
embedding_model="model",
|
||||
)
|
||||
],
|
||||
)
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant():
|
||||
"""Mock Qdrant operations."""
|
||||
@ -151,7 +133,7 @@ def test_create_book_and_sections(mock_ebook, db_session):
|
||||
assert getattr(chapter1, "parent_section_id") is None
|
||||
|
||||
|
||||
def test_embed_sections(db_session, mock_embedding):
|
||||
def test_embed_sections(db_session):
|
||||
"""Test basic embedding sections workflow."""
|
||||
# Create a test book first
|
||||
book = Book(
|
||||
@ -187,31 +169,6 @@ def test_embed_sections(db_session, mock_embedding):
|
||||
assert hasattr(sections[0], "embed_status")
|
||||
|
||||
|
||||
def test_push_to_qdrant(qdrant):
|
||||
"""Test pushing embeddings to Qdrant."""
|
||||
# Create test sections with chunks
|
||||
mock_chunk = Mock(
|
||||
id="00000000-0000-0000-0000-000000000000",
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"test": "data"},
|
||||
)
|
||||
|
||||
mock_section = Mock(spec=BookSection)
|
||||
mock_section.embed_status = "QUEUED"
|
||||
mock_section.chunks = [mock_chunk]
|
||||
|
||||
sections = [mock_section]
|
||||
|
||||
ebook.push_to_qdrant(sections) # type: ignore
|
||||
|
||||
assert {r.id: r.payload for r in qdrant.scroll(collection_name="book")[0]} == {
|
||||
"00000000-0000-0000-0000-000000000000": {
|
||||
"test": "data",
|
||||
}
|
||||
}
|
||||
assert mock_section.embed_status == "STORED"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.ebook.parse_ebook")
|
||||
def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path):
|
||||
"""Test successful book synchronization."""
|
||||
@ -229,7 +186,7 @@ def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path):
|
||||
"author": "Test Author",
|
||||
"status": "processed",
|
||||
"total_sections": 4,
|
||||
"sections_embedded": 4,
|
||||
"sections_embedded": 8,
|
||||
}
|
||||
|
||||
book = db_session.query(Book).filter(Book.title == "Test Book").first()
|
||||
@ -270,8 +227,9 @@ def test_sync_book_already_exists(mock_parse, mock_ebook, db_session, tmp_path):
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.ebook.parse_ebook")
|
||||
@patch("memory.common.embedding.embed_source_item")
|
||||
def test_sync_book_embedding_failure(
|
||||
mock_parse, mock_ebook, db_session, tmp_path, mock_embedding
|
||||
mock_embedding, mock_parse, mock_ebook, db_session, tmp_path
|
||||
):
|
||||
"""Test handling of embedding failures."""
|
||||
book_file = tmp_path / "test.epub"
|
||||
@ -307,14 +265,18 @@ def test_sync_book_qdrant_failure(mock_parse, mock_ebook, db_session, tmp_path):
|
||||
# Since embedding is already failing, this test will complete without hitting Qdrant
|
||||
# So let's just verify that the function completes without raising an exception
|
||||
with patch.object(ebook, "push_to_qdrant", side_effect=Exception("Qdrant failed")):
|
||||
with pytest.raises(Exception, match="Qdrant failed"):
|
||||
ebook.sync_book(str(book_file))
|
||||
assert ebook.sync_book(str(book_file)) == {
|
||||
"status": "error",
|
||||
"error": "Qdrant failed",
|
||||
}
|
||||
|
||||
|
||||
def test_sync_book_file_not_found():
|
||||
"""Test handling of missing files."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
ebook.sync_book("/nonexistent/file.epub")
|
||||
assert ebook.sync_book("/nonexistent/file.epub") == {
|
||||
"status": "error",
|
||||
"error": "Book file not found: /nonexistent/file.epub",
|
||||
}
|
||||
|
||||
|
||||
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
||||
@ -363,7 +325,7 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
||||
calls = mock_voyage_client.embed.call_args_list
|
||||
texts = [c[0][0] for c in calls]
|
||||
assert texts == [
|
||||
[large_section_content.strip()],
|
||||
[large_page_1.strip()],
|
||||
[large_page_2.strip()],
|
||||
[large_section_content.strip()],
|
||||
]
|
||||
|
@ -69,13 +69,21 @@ def test_email_account(db_session):
|
||||
|
||||
def test_process_simple_email(db_session, test_email_account, qdrant):
|
||||
"""Test processing a simple email message."""
|
||||
mail_message_id = process_message(
|
||||
res = process_message(
|
||||
account_id=test_email_account.id,
|
||||
message_id="101",
|
||||
folder="INBOX",
|
||||
raw_email=SIMPLE_EMAIL_RAW,
|
||||
)
|
||||
|
||||
mail_message_id = res["mail_message_id"]
|
||||
assert res == {
|
||||
"status": "processed",
|
||||
"mail_message_id": mail_message_id,
|
||||
"message_id": "101",
|
||||
"chunks_count": 1,
|
||||
"attachments_count": 0,
|
||||
}
|
||||
mail_message = (
|
||||
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
|
||||
)
|
||||
@ -98,7 +106,7 @@ def test_process_email_with_attachment(db_session, test_email_account, qdrant):
|
||||
message_id="302",
|
||||
folder="Archive",
|
||||
raw_email=EMAIL_WITH_ATTACHMENT_RAW,
|
||||
)
|
||||
)["mail_message_id"]
|
||||
# Check mail message specifics
|
||||
mail_message = (
|
||||
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
|
||||
@ -125,25 +133,25 @@ def test_process_email_with_attachment(db_session, test_email_account, qdrant):
|
||||
|
||||
def test_process_empty_message(db_session, test_email_account, qdrant):
|
||||
"""Test processing an empty/invalid message."""
|
||||
source_id = process_message(
|
||||
res = process_message(
|
||||
account_id=test_email_account.id,
|
||||
message_id="999",
|
||||
folder="Archive",
|
||||
raw_email="",
|
||||
)
|
||||
|
||||
assert source_id is None
|
||||
assert res == {"reason": "empty_content", "status": "skipped"}
|
||||
|
||||
|
||||
def test_process_duplicate_message(db_session, test_email_account, qdrant):
|
||||
"""Test that duplicate messages are detected and not stored again."""
|
||||
# First call should succeed and create records
|
||||
source_id_1 = process_message(
|
||||
res = process_message(
|
||||
account_id=test_email_account.id,
|
||||
message_id="101",
|
||||
folder="INBOX",
|
||||
raw_email=SIMPLE_EMAIL_RAW,
|
||||
)
|
||||
source_id_1 = res.get("mail_message_id")
|
||||
|
||||
assert source_id_1 is not None, "First call should return a source_id"
|
||||
|
||||
@ -157,7 +165,7 @@ def test_process_duplicate_message(db_session, test_email_account, qdrant):
|
||||
message_id="101",
|
||||
folder="INBOX",
|
||||
raw_email=SIMPLE_EMAIL_RAW,
|
||||
)
|
||||
).get("mail_message_id")
|
||||
|
||||
assert source_id_2 is None, "Second call should return None for duplicate message"
|
||||
|
||||
|
@ -12,7 +12,7 @@ from memory.common.db.models import (
|
||||
EmailAttachment,
|
||||
MailMessage,
|
||||
)
|
||||
from memory.common.parsers.email import Attachment
|
||||
from memory.common.parsers.email import Attachment, parse_email_message
|
||||
from memory.workers.email import (
|
||||
create_mail_message,
|
||||
extract_email_uid,
|
||||
@ -226,14 +226,14 @@ def test_create_mail_message(db_session):
|
||||
"--boundary--"
|
||||
)
|
||||
folder = "INBOX"
|
||||
parsed_email = parse_email_message(raw_email, "321")
|
||||
|
||||
# Call function
|
||||
mail_message = create_mail_message(
|
||||
db_session=db_session,
|
||||
raw_email=raw_email,
|
||||
folder=folder,
|
||||
tags=["test"],
|
||||
message_id="123",
|
||||
parsed_email=parsed_email,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@ -344,6 +344,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
||||
recipients=["recipient@example.com"],
|
||||
content="This is a test email for vectorization",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.flush()
|
||||
@ -373,6 +374,7 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||
recipients=["recipient@example.com"],
|
||||
content="This is a test email with attachments",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.flush()
|
||||
|
Loading…
x
Reference in New Issue
Block a user