mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 21:34:42 +02:00
unify tasks
This commit is contained in:
parent
ce6f4bf5c5
commit
4aaa45e09c
@ -15,7 +15,7 @@ from PIL import Image
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, extract, settings
|
from memory.common import embedding, qdrant, extract, settings
|
||||||
from memory.common.embedding import get_modality
|
from memory.common.collections import get_modality, TEXT_COLLECTIONS
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
|
|
||||||
@ -189,7 +189,7 @@ async def search(
|
|||||||
text_results = query_chunks(
|
text_results = query_chunks(
|
||||||
client,
|
client,
|
||||||
upload_data,
|
upload_data,
|
||||||
allowed_modalities & embedding.TEXT_COLLECTIONS,
|
allowed_modalities & TEXT_COLLECTIONS,
|
||||||
embedding.embed_text,
|
embedding.embed_text,
|
||||||
min_score=min_text_score,
|
min_score=min_text_score,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
107
src/memory/common/collections.py
Normal file
107
src/memory/common/collections.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||||
|
Vector = list[float]
|
||||||
|
|
||||||
|
|
||||||
|
class Collection(TypedDict):
|
||||||
|
dimension: int
|
||||||
|
distance: DistanceType
|
||||||
|
model: str
|
||||||
|
on_disk: NotRequired[bool]
|
||||||
|
shards: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
|
ALL_COLLECTIONS: dict[str, Collection] = {
|
||||||
|
"mail": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"chat": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"git": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"book": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"blog": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
# Multimodal
|
||||||
|
"photo": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"comic": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
"doc": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
TEXT_COLLECTIONS = {
|
||||||
|
coll
|
||||||
|
for coll, params in ALL_COLLECTIONS.items()
|
||||||
|
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
||||||
|
}
|
||||||
|
MULTIMODAL_COLLECTIONS = {
|
||||||
|
coll
|
||||||
|
for coll, params in ALL_COLLECTIONS.items()
|
||||||
|
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPES = {
|
||||||
|
"doc": ["application/pdf", "application/docx", "application/msword"],
|
||||||
|
"text": ["text/*"],
|
||||||
|
"blog": ["text/markdown", "text/html"],
|
||||||
|
"photo": ["image/*"],
|
||||||
|
"book": [
|
||||||
|
"application/epub+zip",
|
||||||
|
"application/mobi",
|
||||||
|
"application/x-mobipocket-ebook",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_modality(mime_type: str) -> str:
|
||||||
|
for type, mime_types in TYPES.items():
|
||||||
|
if mime_type in mime_types:
|
||||||
|
return type
|
||||||
|
stem = mime_type.split("/")[0]
|
||||||
|
|
||||||
|
for type, mime_types in TYPES.items():
|
||||||
|
if any(mime_type.startswith(stem) for mime_type in mime_types):
|
||||||
|
return type
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def collection_model(collection: str) -> str | None:
|
||||||
|
return ALL_COLLECTIONS.get(collection, {}).get("model")
|
@ -2,11 +2,12 @@
|
|||||||
Database models for the knowledge base system.
|
Database models for the knowledge base system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, ClassVar, cast
|
from typing import Any, ClassVar, Iterable, Sequence, cast
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
@ -32,6 +33,9 @@ from sqlalchemy.orm import Session, relationship
|
|||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.parsers.email import EmailMessage, parse_email_message
|
from memory.common.parsers.email import EmailMessage, parse_email_message
|
||||||
|
import memory.common.extract as extract
|
||||||
|
import memory.common.collections as collections
|
||||||
|
import memory.common.chunker as chunker
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
@ -137,6 +141,7 @@ class SourceItem(Base):
|
|||||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||||
|
|
||||||
__tablename__ = "source_item"
|
__tablename__ = "source_item"
|
||||||
|
__allow_unmapped__ = True
|
||||||
|
|
||||||
id = Column(BigInteger, primary_key=True)
|
id = Column(BigInteger, primary_key=True)
|
||||||
modality = Column(Text, nullable=False)
|
modality = Column(Text, nullable=False)
|
||||||
@ -174,6 +179,14 @@ class SourceItem(Base):
|
|||||||
"""Get vector IDs from associated chunks."""
|
"""Get vector IDs from associated chunks."""
|
||||||
return [chunk.id for chunk in self.chunks]
|
return [chunk.id for chunk in self.chunks]
|
||||||
|
|
||||||
|
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||||
|
return [
|
||||||
|
extract.DataChunk(
|
||||||
|
data=cast(str, self.content),
|
||||||
|
collection=cast(str, self.modality),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"source_id": self.id,
|
"source_id": self.id,
|
||||||
@ -306,6 +319,19 @@ class EmailAttachment(SourceItem):
|
|||||||
"tags": self.tags,
|
"tags": self.tags,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||||
|
if cast(str | None, self.filename):
|
||||||
|
contents = pathlib.Path(cast(str, self.filename)).read_bytes()
|
||||||
|
else:
|
||||||
|
contents = cast(str, self.content)
|
||||||
|
|
||||||
|
return extract.extract_data_chunks(
|
||||||
|
cast(str, self.mime_type),
|
||||||
|
contents,
|
||||||
|
collection=cast(str, self.modality),
|
||||||
|
embedding_model=collections.collection_model(cast(str, self.modality)),
|
||||||
|
)
|
||||||
|
|
||||||
# Add indexes
|
# Add indexes
|
||||||
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
||||||
|
|
||||||
@ -407,6 +433,15 @@ class Comic(SourceItem):
|
|||||||
}
|
}
|
||||||
return {k: v for k, v in payload.items() if v is not None}
|
return {k: v for k, v in payload.items() if v is not None}
|
||||||
|
|
||||||
|
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||||
|
image = Image.open(pathlib.Path(cast(str, self.filename)))
|
||||||
|
return [
|
||||||
|
extract.DataChunk(
|
||||||
|
data=[image, cast(str, self.title), cast(str, self.author)],
|
||||||
|
collection=cast(str, self.modality),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Book(Base):
|
class Book(Base):
|
||||||
"""Book-level metadata table"""
|
"""Book-level metadata table"""
|
||||||
@ -503,6 +538,19 @@ class BookSection(SourceItem):
|
|||||||
"tags": self.tags,
|
"tags": self.tags,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||||
|
texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)]
|
||||||
|
texts += [(cast(str, self.content), self.start_page)]
|
||||||
|
return [
|
||||||
|
extract.DataChunk(
|
||||||
|
data=[text],
|
||||||
|
collection=cast(str, self.modality),
|
||||||
|
metadata={"page": page_number},
|
||||||
|
max_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||||
|
)
|
||||||
|
for text, page_number in texts
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class BlogPost(SourceItem):
|
class BlogPost(SourceItem):
|
||||||
__tablename__ = "blog_post"
|
__tablename__ = "blog_post"
|
||||||
@ -519,6 +567,7 @@ class BlogPost(SourceItem):
|
|||||||
description = Column(Text, nullable=True) # Meta description or excerpt
|
description = Column(Text, nullable=True) # Meta description or excerpt
|
||||||
domain = Column(Text, nullable=True) # Domain of the source website
|
domain = Column(Text, nullable=True) # Domain of the source website
|
||||||
word_count = Column(Integer, nullable=True) # Approximate word count
|
word_count = Column(Integer, nullable=True) # Approximate word count
|
||||||
|
images = Column(ARRAY(Text), nullable=True) # List of image URLs
|
||||||
|
|
||||||
# Store original metadata from parser
|
# Store original metadata from parser
|
||||||
webpage_metadata = Column(JSONB, nullable=True)
|
webpage_metadata = Column(JSONB, nullable=True)
|
||||||
@ -552,6 +601,30 @@ class BlogPost(SourceItem):
|
|||||||
}
|
}
|
||||||
return {k: v for k, v in payload.items() if v}
|
return {k: v for k, v in payload.items() if v}
|
||||||
|
|
||||||
|
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||||
|
images = [Image.open(image) for image in self.images]
|
||||||
|
data = [cast(str, self.content)] + images
|
||||||
|
|
||||||
|
# Always embed the full content as a single chunk (if possible, of course)
|
||||||
|
chunks = [
|
||||||
|
extract.DataChunk(
|
||||||
|
data=data,
|
||||||
|
collection=cast(str, self.modality),
|
||||||
|
max_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# If the content is long enough, also embed it as chunks of the default size.
|
||||||
|
tokens = chunker.approx_token_count(cast(str, self.content))
|
||||||
|
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||||
|
chunks += [
|
||||||
|
extract.DataChunk(
|
||||||
|
data=data,
|
||||||
|
collection=cast(str, self.modality),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
class MiscDoc(SourceItem):
|
class MiscDoc(SourceItem):
|
||||||
__tablename__ = "misc_doc"
|
__tablename__ = "misc_doc"
|
||||||
|
@ -1,130 +1,21 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast
|
from typing import Any, Iterable, Literal, cast
|
||||||
|
|
||||||
import voyageai
|
import voyageai
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
|
from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
|
||||||
from memory.common.db.models import Chunk
|
from memory.common.collections import ALL_COLLECTIONS, Vector
|
||||||
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
|
from memory.common.extract import DataChunk
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
|
||||||
Vector = list[float]
|
|
||||||
|
|
||||||
|
|
||||||
class Collection(TypedDict):
|
|
||||||
dimension: int
|
|
||||||
distance: DistanceType
|
|
||||||
model: str
|
|
||||||
on_disk: NotRequired[bool]
|
|
||||||
shards: NotRequired[int]
|
|
||||||
|
|
||||||
|
|
||||||
ALL_COLLECTIONS: dict[str, Collection] = {
|
|
||||||
"mail": {
|
|
||||||
"dimension": 1024,
|
|
||||||
"distance": "Cosine",
|
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
"chat": {
|
|
||||||
"dimension": 1024,
|
|
||||||
"distance": "Cosine",
|
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
"git": {
|
|
||||||
"dimension": 1024,
|
|
||||||
"distance": "Cosine",
|
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
"book": {
|
|
||||||
"dimension": 1024,
|
|
||||||
"distance": "Cosine",
|
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
"blog": {
|
|
||||||
"dimension": 1024,
|
|
||||||
"distance": "Cosine",
|
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
"text": {
|
|
||||||
"dimension": 1024,
|
|
||||||
"distance": "Cosine",
|
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
# Multimodal
|
|
||||||
"photo": {
|
|
||||||
"dimension": 1024,
|
|
||||||
"distance": "Cosine",
|
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
"comic": {
|
|
||||||
"dimension": 1024,
|
|
||||||
"distance": "Cosine",
|
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
"doc": {
|
|
||||||
"dimension": 1024,
|
|
||||||
"distance": "Cosine",
|
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
TEXT_COLLECTIONS = {
|
|
||||||
coll
|
|
||||||
for coll, params in ALL_COLLECTIONS.items()
|
|
||||||
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
|
||||||
}
|
|
||||||
MULTIMODAL_COLLECTIONS = {
|
|
||||||
coll
|
|
||||||
for coll, params in ALL_COLLECTIONS.items()
|
|
||||||
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPES = {
|
|
||||||
"doc": ["application/pdf", "application/docx", "application/msword"],
|
|
||||||
"text": ["text/*"],
|
|
||||||
"blog": ["text/markdown", "text/html"],
|
|
||||||
"photo": ["image/*"],
|
|
||||||
"book": [
|
|
||||||
"application/epub+zip",
|
|
||||||
"application/mobi",
|
|
||||||
"application/x-mobipocket-ebook",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_mimetype(image: Image.Image) -> str | None:
|
|
||||||
format_to_mime = {
|
|
||||||
"JPEG": "image/jpeg",
|
|
||||||
"PNG": "image/png",
|
|
||||||
"GIF": "image/gif",
|
|
||||||
"BMP": "image/bmp",
|
|
||||||
"TIFF": "image/tiff",
|
|
||||||
"WEBP": "image/webp",
|
|
||||||
}
|
|
||||||
|
|
||||||
if not image.format:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return format_to_mime.get(image.format.upper(), f"image/{image.format.lower()}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_modality(mime_type: str) -> str:
|
|
||||||
for type, mime_types in TYPES.items():
|
|
||||||
if mime_type in mime_types:
|
|
||||||
return type
|
|
||||||
stem = mime_type.split("/")[0]
|
|
||||||
|
|
||||||
for type, mime_types in TYPES.items():
|
|
||||||
if any(mime_type.startswith(stem) for mime_type in mime_types):
|
|
||||||
return type
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def embed_chunks(
|
def embed_chunks(
|
||||||
chunks: list[str] | list[list[extract.MulitmodalChunk]],
|
chunks: list[str] | list[list[extract.MulitmodalChunk]],
|
||||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
@ -164,12 +55,6 @@ def embed_text(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def embed_file(
|
|
||||||
file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL
|
|
||||||
) -> list[Vector]:
|
|
||||||
return embed_text([file_path.read_text()], model)
|
|
||||||
|
|
||||||
|
|
||||||
def embed_mixed(
|
def embed_mixed(
|
||||||
items: list[extract.MulitmodalChunk],
|
items: list[extract.MulitmodalChunk],
|
||||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||||
@ -187,22 +72,6 @@ def embed_mixed(
|
|||||||
return embed_chunks([chunks], model, input_type)
|
return embed_chunks([chunks], model, input_type)
|
||||||
|
|
||||||
|
|
||||||
def embed_page(page: extract.Page) -> list[Vector]:
|
|
||||||
contents = page["contents"]
|
|
||||||
chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS)
|
|
||||||
if all(isinstance(c, str) for c in contents):
|
|
||||||
return embed_text(
|
|
||||||
cast(list[str], contents),
|
|
||||||
model=settings.TEXT_EMBEDDING_MODEL,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
)
|
|
||||||
return embed_mixed(
|
|
||||||
cast(list[extract.MulitmodalChunk], contents),
|
|
||||||
model=settings.MIXED_EMBEDDING_MODEL,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt"
|
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt"
|
||||||
@ -219,7 +88,9 @@ def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
|||||||
|
|
||||||
|
|
||||||
def make_chunk(
|
def make_chunk(
|
||||||
page: extract.Page, vector: Vector, metadata: dict[str, Any] = {}
|
contents: Sequence[extract.MulitmodalChunk],
|
||||||
|
vector: Vector,
|
||||||
|
metadata: dict[str, Any] = {},
|
||||||
) -> Chunk:
|
) -> Chunk:
|
||||||
"""Create a Chunk object from a page and a vector.
|
"""Create a Chunk object from a page and a vector.
|
||||||
|
|
||||||
@ -227,7 +98,6 @@ def make_chunk(
|
|||||||
a single image, or a list of strings and images.
|
a single image, or a list of strings and images.
|
||||||
"""
|
"""
|
||||||
chunk_id = str(uuid.uuid4())
|
chunk_id = str(uuid.uuid4())
|
||||||
contents = page["contents"]
|
|
||||||
content, filename = None, None
|
content, filename = None, None
|
||||||
if all(isinstance(c, str) for c in contents):
|
if all(isinstance(c, str) for c in contents):
|
||||||
content = "\n\n".join(cast(list[str], contents))
|
content = "\n\n".join(cast(list[str], contents))
|
||||||
@ -251,38 +121,42 @@ def make_chunk(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def embed(
|
def embed_data_chunk(
|
||||||
mime_type: str,
|
chunk: DataChunk,
|
||||||
content: bytes | str | pathlib.Path,
|
|
||||||
metadata: dict[str, Any] = {},
|
metadata: dict[str, Any] = {},
|
||||||
chunk_size: int | None = None,
|
chunk_size: int | None = None,
|
||||||
) -> tuple[str, list[Chunk]]:
|
) -> list[Chunk]:
|
||||||
modality = get_modality(mime_type)
|
chunk_size = chunk.max_size or chunk_size or DEFAULT_CHUNK_TOKENS
|
||||||
pages = extract.extract_content(mime_type, content, chunk_size=chunk_size)
|
|
||||||
chunks = [
|
|
||||||
make_chunk(page, vector, metadata)
|
|
||||||
for page in pages
|
|
||||||
for vector in embed_page(page)
|
|
||||||
]
|
|
||||||
return modality, chunks
|
|
||||||
|
|
||||||
|
model = chunk.embedding_model
|
||||||
|
if not model and chunk.collection:
|
||||||
|
model = ALL_COLLECTIONS.get(chunk.collection, {}).get("model")
|
||||||
|
if not model:
|
||||||
|
model = settings.TEXT_EMBEDDING_MODEL
|
||||||
|
|
||||||
def embed_image(
|
if model == settings.TEXT_EMBEDDING_MODEL:
|
||||||
file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None
|
vectors = embed_text(cast(list[str], chunk.data), chunk_size=chunk_size)
|
||||||
) -> Chunk:
|
elif model == settings.MIXED_EMBEDDING_MODEL:
|
||||||
image = Image.open(file_path)
|
vectors = embed_mixed(
|
||||||
mime_type = get_mimetype(image)
|
cast(list[extract.MulitmodalChunk], chunk.data),
|
||||||
if mime_type is None:
|
chunk_size=chunk_size,
|
||||||
raise ValueError("Unsupported image format")
|
|
||||||
|
|
||||||
vector = embed_mixed(
|
|
||||||
[image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
return Chunk(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
file_path=file_path.absolute().as_posix(),
|
|
||||||
content=None,
|
|
||||||
embedding_model=settings.MIXED_EMBEDDING_MODEL,
|
|
||||||
vector=vector,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model: {model}")
|
||||||
|
|
||||||
|
metadata = metadata | chunk.metadata
|
||||||
|
return [make_chunk(chunk.data, vector, metadata) for vector in vectors]
|
||||||
|
|
||||||
|
|
||||||
|
def embed_source_item(
|
||||||
|
item: SourceItem,
|
||||||
|
metadata: dict[str, Any] = {},
|
||||||
|
chunk_size: int | None = None,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
return [
|
||||||
|
chunk
|
||||||
|
for data_chunk in item.data_chunks()
|
||||||
|
for chunk in embed_data_chunk(
|
||||||
|
data_chunk, item.as_payload() | metadata, chunk_size
|
||||||
|
)
|
||||||
|
]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -21,6 +22,15 @@ class Page(TypedDict):
|
|||||||
chunk_size: NotRequired[int]
|
chunk_size: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataChunk:
|
||||||
|
data: Sequence[MulitmodalChunk]
|
||||||
|
collection: str | None = None
|
||||||
|
embedding_model: str | None = None
|
||||||
|
max_size: int | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]:
|
def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]:
|
||||||
if isinstance(content, pathlib.Path):
|
if isinstance(content, pathlib.Path):
|
||||||
@ -110,11 +120,13 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
|||||||
return [{"contents": [cast(str, content)], "metadata": {}}]
|
return [{"contents": [cast(str, content)], "metadata": {}}]
|
||||||
|
|
||||||
|
|
||||||
def extract_content(
|
def extract_data_chunks(
|
||||||
mime_type: str,
|
mime_type: str,
|
||||||
content: bytes | str | pathlib.Path,
|
content: bytes | str | pathlib.Path,
|
||||||
|
collection: str | None = None,
|
||||||
|
embedding_model: str | None = None,
|
||||||
chunk_size: int | None = None,
|
chunk_size: int | None = None,
|
||||||
) -> list[Page]:
|
) -> list[DataChunk]:
|
||||||
pages = []
|
pages = []
|
||||||
logger.info(f"Extracting content from {mime_type}")
|
logger.info(f"Extracting content from {mime_type}")
|
||||||
if mime_type == "application/pdf":
|
if mime_type == "application/pdf":
|
||||||
@ -134,4 +146,12 @@ def extract_content(
|
|||||||
if chunk_size:
|
if chunk_size:
|
||||||
pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
|
pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
|
||||||
|
|
||||||
return pages
|
return [
|
||||||
|
DataChunk(
|
||||||
|
data=page["contents"],
|
||||||
|
collection=collection,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
max_size=chunk_size,
|
||||||
|
)
|
||||||
|
for page in pages
|
||||||
|
]
|
||||||
|
@ -26,6 +26,7 @@ class EmailMessage(TypedDict):
|
|||||||
body: str
|
body: str
|
||||||
attachments: list[Attachment]
|
attachments: list[Attachment]
|
||||||
hash: bytes
|
hash: bytes
|
||||||
|
raw_email: str
|
||||||
|
|
||||||
|
|
||||||
RawEmailResponse = tuple[str | None, bytes]
|
RawEmailResponse = tuple[str | None, bytes]
|
||||||
@ -171,6 +172,7 @@ def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
|
|||||||
body = extract_body(msg)
|
body = extract_body(msg)
|
||||||
|
|
||||||
return EmailMessage(
|
return EmailMessage(
|
||||||
|
raw_email=raw_email,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
subject=subject,
|
subject=subject,
|
||||||
sender=from_,
|
sender=from_,
|
||||||
|
@ -12,7 +12,7 @@ from bs4 import BeautifulSoup, Tag
|
|||||||
from markdownify import markdownify as md
|
from markdownify import markdownify as md
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from memory.common.settings import FILE_STORAGE_DIR, WEBPAGE_STORAGE_DIR
|
from memory.common import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -96,8 +96,8 @@ def extract_date(
|
|||||||
|
|
||||||
datetime_attr = element.get("datetime")
|
datetime_attr = element.get("datetime")
|
||||||
if datetime_attr:
|
if datetime_attr:
|
||||||
date_str = str(datetime_attr)
|
for format in ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%d", date_format]:
|
||||||
if date := parse_date(date_str, date_format):
|
if date := parse_date(str(datetime_attr), format):
|
||||||
return date
|
return date
|
||||||
|
|
||||||
for text in element.find_all(string=True):
|
for text in element.find_all(string=True):
|
||||||
@ -178,7 +178,7 @@ def process_images(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
path = pathlib.Path(image.filename) # type: ignore
|
path = pathlib.Path(image.filename) # type: ignore
|
||||||
img_tag["src"] = str(path.relative_to(FILE_STORAGE_DIR.resolve()))
|
img_tag["src"] = str(path.relative_to(settings.FILE_STORAGE_DIR.resolve()))
|
||||||
images[img_tag["src"]] = image
|
images[img_tag["src"]] = image
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to process image {src}: {e}")
|
logger.warning(f"Failed to process image {src}: {e}")
|
||||||
@ -291,7 +291,7 @@ class BaseHTMLParser:
|
|||||||
|
|
||||||
def __init__(self, base_url: str | None = None):
|
def __init__(self, base_url: str | None = None):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.image_dir = WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc)
|
self.image_dir = settings.WEBPAGE_STORAGE_DIR / str(urlparse(base_url).netloc)
|
||||||
self.image_dir.mkdir(parents=True, exist_ok=True)
|
self.image_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def parse(self, html: str, url: str) -> Article:
|
def parse(self, html: str, url: str) -> Article:
|
||||||
|
@ -5,12 +5,7 @@ import qdrant_client
|
|||||||
from qdrant_client.http import models as qdrant_models
|
from qdrant_client.http import models as qdrant_models
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.embedding import (
|
from memory.common.collections import ALL_COLLECTIONS, Collection, DistanceType, Vector
|
||||||
Collection,
|
|
||||||
ALL_COLLECTIONS,
|
|
||||||
DistanceType,
|
|
||||||
Vector,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -10,17 +10,16 @@ from typing import Callable, Generator, Sequence, cast
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session, scoped_session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, settings
|
from memory.common import embedding, qdrant, settings, collections
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
MailMessage,
|
MailMessage,
|
||||||
SourceItem,
|
|
||||||
)
|
)
|
||||||
from memory.common.parsers.email import (
|
from memory.common.parsers.email import (
|
||||||
Attachment,
|
Attachment,
|
||||||
|
EmailMessage,
|
||||||
RawEmailResponse,
|
RawEmailResponse,
|
||||||
parse_email_message,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -54,7 +53,7 @@ def process_attachment(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return EmailAttachment(
|
return EmailAttachment(
|
||||||
modality=embedding.get_modality(attachment["content_type"]),
|
modality=collections.get_modality(attachment["content_type"]),
|
||||||
sha256=hashlib.sha256(
|
sha256=hashlib.sha256(
|
||||||
real_content if real_content else str(attachment).encode()
|
real_content if real_content else str(attachment).encode()
|
||||||
).digest(),
|
).digest(),
|
||||||
@ -94,8 +93,7 @@ def create_mail_message(
|
|||||||
db_session: Session | scoped_session,
|
db_session: Session | scoped_session,
|
||||||
tags: list[str],
|
tags: list[str],
|
||||||
folder: str,
|
folder: str,
|
||||||
raw_email: str,
|
parsed_email: EmailMessage,
|
||||||
message_id: str,
|
|
||||||
) -> MailMessage:
|
) -> MailMessage:
|
||||||
"""
|
"""
|
||||||
Create a new mail message record and associated attachments.
|
Create a new mail message record and associated attachments.
|
||||||
@ -109,7 +107,7 @@ def create_mail_message(
|
|||||||
Returns:
|
Returns:
|
||||||
Newly created MailMessage
|
Newly created MailMessage
|
||||||
"""
|
"""
|
||||||
parsed_email = parse_email_message(raw_email, message_id)
|
raw_email = parsed_email["raw_email"]
|
||||||
mail_message = MailMessage(
|
mail_message = MailMessage(
|
||||||
modality="mail",
|
modality="mail",
|
||||||
sha256=parsed_email["hash"],
|
sha256=parsed_email["hash"],
|
||||||
@ -137,52 +135,6 @@ def create_mail_message(
|
|||||||
return mail_message
|
return mail_message
|
||||||
|
|
||||||
|
|
||||||
def does_message_exist(
|
|
||||||
db_session: Session | scoped_session, message_id: str, message_hash: bytes
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a message already exists in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_session: Database session
|
|
||||||
message_id: Email message ID
|
|
||||||
message_hash: SHA-256 hash of message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if message exists, False otherwise
|
|
||||||
"""
|
|
||||||
# Check by message_id first (faster)
|
|
||||||
if message_id:
|
|
||||||
mail_message = (
|
|
||||||
db_session.query(MailMessage)
|
|
||||||
.filter(MailMessage.message_id == message_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if mail_message is not None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Then check by message_hash
|
|
||||||
source_item = (
|
|
||||||
db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first()
|
|
||||||
)
|
|
||||||
return source_item is not None
|
|
||||||
|
|
||||||
|
|
||||||
def check_message_exists(
|
|
||||||
db: Session | scoped_session, account_id: int, message_id: str, raw_email: str
|
|
||||||
) -> bool:
|
|
||||||
account = db.query(EmailAccount).get(account_id)
|
|
||||||
if not account:
|
|
||||||
logger.error(f"Account {account_id} not found")
|
|
||||||
return False
|
|
||||||
|
|
||||||
parsed_email = parse_email_message(raw_email, message_id)
|
|
||||||
if "szczepalins" in raw_email.lower():
|
|
||||||
print(parsed_email["message_id"])
|
|
||||||
|
|
||||||
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
|
|
||||||
|
|
||||||
|
|
||||||
def extract_email_uid(
|
def extract_email_uid(
|
||||||
msg_data: Sequence[tuple[bytes, bytes]],
|
msg_data: Sequence[tuple[bytes, bytes]],
|
||||||
) -> tuple[str | None, bytes]:
|
) -> tuple[str | None, bytes]:
|
||||||
@ -317,11 +269,7 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
|||||||
def vectorize_email(email: MailMessage):
|
def vectorize_email(email: MailMessage):
|
||||||
qdrant_client = qdrant.get_qdrant_client()
|
qdrant_client = qdrant.get_qdrant_client()
|
||||||
|
|
||||||
_, chunks = embedding.embed(
|
chunks = embedding.embed_source_item(email)
|
||||||
"text/plain",
|
|
||||||
email.body,
|
|
||||||
metadata=email.as_payload(),
|
|
||||||
)
|
|
||||||
email.chunks = chunks
|
email.chunks = chunks
|
||||||
if chunks:
|
if chunks:
|
||||||
vector_ids = [cast(str, c.id) for c in chunks]
|
vector_ids = [cast(str, c.id) for c in chunks]
|
||||||
@ -329,7 +277,7 @@ def vectorize_email(email: MailMessage):
|
|||||||
metadata = [c.item_metadata for c in chunks]
|
metadata = [c.item_metadata for c in chunks]
|
||||||
qdrant.upsert_vectors(
|
qdrant.upsert_vectors(
|
||||||
client=qdrant_client,
|
client=qdrant_client,
|
||||||
collection_name="mail",
|
collection_name=cast(str, email.modality),
|
||||||
ids=vector_ids,
|
ids=vector_ids,
|
||||||
vectors=vectors, # type: ignore
|
vectors=vectors, # type: ignore
|
||||||
payloads=metadata, # type: ignore
|
payloads=metadata, # type: ignore
|
||||||
@ -337,18 +285,12 @@ def vectorize_email(email: MailMessage):
|
|||||||
|
|
||||||
embeds = defaultdict(list)
|
embeds = defaultdict(list)
|
||||||
for attachment in email.attachments:
|
for attachment in email.attachments:
|
||||||
if attachment.filename:
|
chunks = embedding.embed_source_item(attachment)
|
||||||
content = pathlib.Path(attachment.filename).read_bytes()
|
|
||||||
else:
|
|
||||||
content = attachment.content
|
|
||||||
collection, chunks = embedding.embed(
|
|
||||||
attachment.mime_type, content, metadata=attachment.as_payload()
|
|
||||||
)
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
attachment.chunks = chunks
|
attachment.chunks = chunks
|
||||||
embeds[collection].extend(chunks)
|
embeds[attachment.modality].extend(chunks)
|
||||||
|
|
||||||
for collection, chunks in embeds.items():
|
for collection, chunks in embeds.items():
|
||||||
ids = [c.id for c in chunks]
|
ids = [c.id for c in chunks]
|
||||||
|
@ -1,99 +1,25 @@
|
|||||||
import hashlib
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable, cast
|
from typing import Iterable
|
||||||
|
|
||||||
from memory.common import chunker, embedding, qdrant
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import BlogPost
|
from memory.common.db.models import BlogPost
|
||||||
from memory.common.parsers.blogs import parse_webpage
|
from memory.common.parsers.blogs import parse_webpage
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app
|
||||||
|
from memory.workers.tasks.content_processing import (
|
||||||
|
check_content_exists,
|
||||||
|
create_content_hash,
|
||||||
|
create_task_result,
|
||||||
|
process_content_item,
|
||||||
|
safe_task_execution,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage"
|
SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage"
|
||||||
|
|
||||||
|
|
||||||
def create_blog_post_from_article(article, tags: Iterable[str] = []) -> BlogPost:
|
|
||||||
"""Create a BlogPost model from parsed article data."""
|
|
||||||
return BlogPost(
|
|
||||||
url=article.url,
|
|
||||||
title=article.title,
|
|
||||||
published=article.published_date,
|
|
||||||
content=article.content,
|
|
||||||
sha256=hashlib.sha256(article.content.encode()).digest(),
|
|
||||||
modality="blog",
|
|
||||||
tags=tags,
|
|
||||||
mime_type="text/markdown",
|
|
||||||
size=len(article.content.encode("utf-8")),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def embed_blog_post(blog_post: BlogPost) -> int:
|
|
||||||
"""Embed blog post content and return count of successfully embedded chunks."""
|
|
||||||
try:
|
|
||||||
# Always embed the full content
|
|
||||||
_, chunks = embedding.embed(
|
|
||||||
"text/markdown",
|
|
||||||
cast(str, blog_post.content),
|
|
||||||
metadata=blog_post.as_payload(),
|
|
||||||
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
|
|
||||||
)
|
|
||||||
# But also embed the content in chunks (unless it's really short)
|
|
||||||
if (
|
|
||||||
chunker.approx_token_count(cast(str, blog_post.content))
|
|
||||||
> chunker.DEFAULT_CHUNK_TOKENS * 2
|
|
||||||
):
|
|
||||||
_, small_chunks = embedding.embed(
|
|
||||||
"text/markdown",
|
|
||||||
cast(str, blog_post.content),
|
|
||||||
metadata=blog_post.as_payload(),
|
|
||||||
)
|
|
||||||
chunks += small_chunks
|
|
||||||
|
|
||||||
if chunks:
|
|
||||||
blog_post.chunks = chunks
|
|
||||||
blog_post.embed_status = "QUEUED" # type: ignore
|
|
||||||
return len(chunks)
|
|
||||||
else:
|
|
||||||
blog_post.embed_status = "FAILED" # type: ignore
|
|
||||||
logger.warning(f"No chunks generated for blog post: {blog_post.title}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
blog_post.embed_status = "FAILED" # type: ignore
|
|
||||||
logger.error(f"Failed to embed blog post {blog_post.title}: {e}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def push_to_qdrant(blog_post: BlogPost):
|
|
||||||
"""Push embeddings to Qdrant for successfully embedded blog post."""
|
|
||||||
if cast(str, blog_post.embed_status) != "QUEUED" or not blog_post.chunks:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
vector_ids = [str(chunk.id) for chunk in blog_post.chunks]
|
|
||||||
vectors = [chunk.vector for chunk in blog_post.chunks]
|
|
||||||
payloads = [chunk.item_metadata for chunk in blog_post.chunks]
|
|
||||||
|
|
||||||
qdrant.upsert_vectors(
|
|
||||||
client=qdrant.get_qdrant_client(),
|
|
||||||
collection_name="blog",
|
|
||||||
ids=vector_ids,
|
|
||||||
vectors=vectors,
|
|
||||||
payloads=payloads,
|
|
||||||
)
|
|
||||||
|
|
||||||
blog_post.embed_status = "STORED" # type: ignore
|
|
||||||
logger.info(f"Successfully stored embeddings for: {blog_post.title}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
blog_post.embed_status = "FAILED" # type: ignore
|
|
||||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_WEBPAGE)
|
@app.task(name=SYNC_WEBPAGE)
|
||||||
|
@safe_task_execution
|
||||||
def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
|
def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
|
||||||
"""
|
"""
|
||||||
Synchronize a webpage from a URL.
|
Synchronize a webpage from a URL.
|
||||||
@ -116,61 +42,25 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
|
|||||||
"content_length": 0,
|
"content_length": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
blog_post = create_blog_post_from_article(article, tags)
|
blog_post = BlogPost(
|
||||||
|
url=article.url,
|
||||||
|
title=article.title,
|
||||||
|
published=article.published_date,
|
||||||
|
content=article.content,
|
||||||
|
sha256=create_content_hash(article.content),
|
||||||
|
modality="blog",
|
||||||
|
tags=tags,
|
||||||
|
mime_type="text/markdown",
|
||||||
|
size=len(article.content.encode("utf-8")),
|
||||||
|
images=[image for image in article.images],
|
||||||
|
)
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
existing_post = session.query(BlogPost).filter(BlogPost.url == url).first()
|
existing_post = check_content_exists(
|
||||||
|
session, BlogPost, url=url, sha256=create_content_hash(article.content)
|
||||||
|
)
|
||||||
if existing_post:
|
if existing_post:
|
||||||
logger.info(f"Blog post already exists: {existing_post.title}")
|
logger.info(f"Blog post already exists: {existing_post.title}")
|
||||||
return {
|
return create_task_result(existing_post, "already_exists", url=url)
|
||||||
"blog_post_id": existing_post.id,
|
|
||||||
"url": url,
|
|
||||||
"title": existing_post.title,
|
|
||||||
"status": "already_exists",
|
|
||||||
"chunks_count": len(existing_post.chunks),
|
|
||||||
}
|
|
||||||
|
|
||||||
existing_post = (
|
return process_content_item(blog_post, "blog", session, tags)
|
||||||
session.query(BlogPost).filter(BlogPost.sha256 == blog_post.sha256).first()
|
|
||||||
)
|
|
||||||
if existing_post:
|
|
||||||
logger.info(
|
|
||||||
f"Blog post with the same content already exists: {existing_post.title}"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"blog_post_id": existing_post.id,
|
|
||||||
"url": url,
|
|
||||||
"title": existing_post.title,
|
|
||||||
"status": "already_exists",
|
|
||||||
"chunks_count": len(existing_post.chunks),
|
|
||||||
}
|
|
||||||
|
|
||||||
session.add(blog_post)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
chunks_count = embed_blog_post(blog_post)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
try:
|
|
||||||
push_to_qdrant(blog_post)
|
|
||||||
logger.info(
|
|
||||||
f"Successfully processed webpage: {blog_post.title} "
|
|
||||||
f"({chunks_count} chunks embedded)"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
|
||||||
blog_post.embed_status = "FAILED" # type: ignore
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"blog_post_id": blog_post.id,
|
|
||||||
"url": url,
|
|
||||||
"title": blog_post.title,
|
|
||||||
"author": article.author,
|
|
||||||
"published_date": article.published_date,
|
|
||||||
"status": "processed",
|
|
||||||
"chunks_count": chunks_count,
|
|
||||||
"content_length": len(article.content),
|
|
||||||
"embed_status": blog_post.embed_status,
|
|
||||||
}
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import hashlib
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Callable, cast
|
from typing import Callable, cast
|
||||||
@ -6,21 +5,25 @@ from typing import Callable, cast
|
|||||||
import feedparser
|
import feedparser
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Comic, clean_filename
|
from memory.common.db.models import Comic, clean_filename
|
||||||
from memory.common.parsers import comics
|
from memory.common.parsers import comics
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app
|
||||||
|
from memory.workers.tasks.content_processing import (
|
||||||
|
check_content_exists,
|
||||||
|
create_content_hash,
|
||||||
|
process_content_item,
|
||||||
|
safe_task_execution,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics"
|
SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics"
|
||||||
SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc"
|
SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc"
|
||||||
SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd"
|
SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd"
|
||||||
SYNC_COMIC = "memory.workers.tasks.comic.sync_comic"
|
SYNC_COMIC = "memory.workers.tasks.comic.sync_comic"
|
||||||
|
|
||||||
|
|
||||||
BASE_SMBC_URL = "https://www.smbc-comics.com/"
|
BASE_SMBC_URL = "https://www.smbc-comics.com/"
|
||||||
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
|
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
|
||||||
|
|
||||||
@ -36,6 +39,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
|
|||||||
return set()
|
return set()
|
||||||
|
|
||||||
urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries}
|
urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries}
|
||||||
|
urls = {url for url in urls if url}
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
known = {
|
known = {
|
||||||
@ -61,6 +65,7 @@ def fetch_new_comics(
|
|||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_COMIC)
|
@app.task(name=SYNC_COMIC)
|
||||||
|
@safe_task_execution
|
||||||
def sync_comic(
|
def sync_comic(
|
||||||
url: str,
|
url: str,
|
||||||
image_url: str,
|
image_url: str,
|
||||||
@ -70,20 +75,26 @@ def sync_comic(
|
|||||||
):
|
):
|
||||||
"""Synchronize a comic from a URL."""
|
"""Synchronize a comic from a URL."""
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
if session.query(Comic).filter(Comic.url == url).first():
|
existing_comic = check_content_exists(session, Comic, url=url)
|
||||||
return
|
if existing_comic:
|
||||||
|
return {"status": "already_exists", "comic_id": existing_comic.id}
|
||||||
|
|
||||||
response = requests.get(image_url)
|
response = requests.get(image_url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": f"Failed to download image: {response.status_code}",
|
||||||
|
}
|
||||||
|
|
||||||
file_type = image_url.split(".")[-1]
|
file_type = image_url.split(".")[-1]
|
||||||
mime_type = f"image/{file_type}"
|
mime_type = f"image/{file_type}"
|
||||||
filename = (
|
filename = (
|
||||||
settings.COMIC_STORAGE_DIR / clean_filename(author) / f"{title}.{file_type}"
|
settings.COMIC_STORAGE_DIR / clean_filename(author) / f"{title}.{file_type}"
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
|
||||||
filename.parent.mkdir(parents=True, exist_ok=True)
|
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
filename.write_bytes(response.content)
|
filename.write_bytes(response.content)
|
||||||
|
|
||||||
sha256 = hashlib.sha256(f"{image_url}{published_date}".encode()).digest()
|
|
||||||
comic = Comic(
|
comic = Comic(
|
||||||
title=title,
|
title=title,
|
||||||
url=url,
|
url=url,
|
||||||
@ -92,27 +103,13 @@ def sync_comic(
|
|||||||
filename=filename.resolve().as_posix(),
|
filename=filename.resolve().as_posix(),
|
||||||
mime_type=mime_type,
|
mime_type=mime_type,
|
||||||
size=len(response.content),
|
size=len(response.content),
|
||||||
sha256=sha256,
|
sha256=create_content_hash(f"{image_url}{published_date}"),
|
||||||
tags={"comic", author},
|
tags={"comic", author},
|
||||||
modality="comic",
|
modality="comic",
|
||||||
)
|
)
|
||||||
chunk = embedding.embed_image(filename, [title, author])
|
|
||||||
comic.chunks = [chunk]
|
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
session.add(comic)
|
return process_content_item(comic, "comic", session)
|
||||||
session.add(chunk)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
qdrant.upsert_vectors(
|
|
||||||
client=qdrant.get_qdrant_client(),
|
|
||||||
collection_name="comic",
|
|
||||||
ids=[str(chunk.id)],
|
|
||||||
vectors=[chunk.vector], # type: ignore
|
|
||||||
payloads=[comic.as_payload()],
|
|
||||||
)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_SMBC)
|
@app.task(name=SYNC_SMBC)
|
||||||
|
267
src/memory/workers/tasks/content_processing.py
Normal file
267
src/memory/workers/tasks/content_processing.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
Content processing utilities for memory workers.
|
||||||
|
|
||||||
|
This module provides core functionality for processing content items through
|
||||||
|
the complete workflow: existence checking, content hashing, embedding generation,
|
||||||
|
vector storage, and result tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import traceback
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Iterable, Sequence, cast
|
||||||
|
|
||||||
|
from memory.common import embedding, qdrant
|
||||||
|
from memory.common.db.models import SourceItem
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_content_exists(
|
||||||
|
session,
|
||||||
|
model_class: type[SourceItem],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> SourceItem | None:
|
||||||
|
"""
|
||||||
|
Check if content already exists in the database.
|
||||||
|
|
||||||
|
Searches for existing content by any of the provided attributes
|
||||||
|
(typically URL, file_path, or SHA256 hash).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session for querying
|
||||||
|
model_class: The SourceItem model class to search in
|
||||||
|
**kwargs: Attribute-value pairs to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Existing SourceItem if found, None otherwise
|
||||||
|
"""
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if not hasattr(model_class, key):
|
||||||
|
continue
|
||||||
|
|
||||||
|
existing = (
|
||||||
|
session.query(model_class)
|
||||||
|
.filter(getattr(model_class, key) == value)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_content_hash(content: str, *additional_data: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Create SHA256 hash from content and optional additional data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Primary content to hash
|
||||||
|
*additional_data: Additional strings to include in the hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SHA256 hash digest as bytes
|
||||||
|
"""
|
||||||
|
hash_input = content + "".join(additional_data)
|
||||||
|
return hashlib.sha256(hash_input.encode()).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def embed_source_item(source_item: SourceItem) -> int:
|
||||||
|
"""
|
||||||
|
Generate embeddings for a source item's content.
|
||||||
|
|
||||||
|
Processes the source item through the embedding pipeline, creating
|
||||||
|
chunks and their corresponding vector embeddings. Updates the item's
|
||||||
|
embed_status based on success or failure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_item: The SourceItem to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of successfully embedded chunks
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
- Sets source_item.chunks with generated chunks
|
||||||
|
- Sets source_item.embed_status to "QUEUED" or "FAILED"
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
chunks = embedding.embed_source_item(source_item)
|
||||||
|
if chunks:
|
||||||
|
source_item.chunks = chunks
|
||||||
|
source_item.embed_status = "QUEUED" # type: ignore
|
||||||
|
return len(chunks)
|
||||||
|
else:
|
||||||
|
source_item.embed_status = "FAILED" # type: ignore
|
||||||
|
logger.warning(
|
||||||
|
f"No chunks generated for {type(source_item).__name__}: {getattr(source_item, 'title', 'unknown')}"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
source_item.embed_status = "FAILED" # type: ignore
|
||||||
|
logger.error(f"Failed to embed {type(source_item).__name__}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
||||||
|
"""
|
||||||
|
Push embeddings to Qdrant vector database.
|
||||||
|
|
||||||
|
Uploads vector embeddings for all source items that have been successfully
|
||||||
|
embedded (status "QUEUED") and have chunks available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_items: Sequence of SourceItems to process
|
||||||
|
collection_name: Name of the Qdrant collection to store vectors in
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the Qdrant upsert operation fails
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
- Updates embed_status to "STORED" for successful items
|
||||||
|
- Updates embed_status to "FAILED" for failed items
|
||||||
|
"""
|
||||||
|
items_to_process = [
|
||||||
|
item
|
||||||
|
for item in source_items
|
||||||
|
if cast(str, getattr(item, "embed_status", None)) == "QUEUED" and item.chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
if not items_to_process:
|
||||||
|
return
|
||||||
|
|
||||||
|
all_chunks = [chunk for item in items_to_process for chunk in item.chunks]
|
||||||
|
if not all_chunks:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
vector_ids = [str(chunk.id) for chunk in all_chunks]
|
||||||
|
vectors = [chunk.vector for chunk in all_chunks]
|
||||||
|
payloads = [chunk.item_metadata for chunk in all_chunks]
|
||||||
|
|
||||||
|
qdrant.upsert_vectors(
|
||||||
|
client=qdrant.get_qdrant_client(),
|
||||||
|
collection_name=collection_name,
|
||||||
|
ids=vector_ids,
|
||||||
|
vectors=vectors,
|
||||||
|
payloads=payloads,
|
||||||
|
)
|
||||||
|
|
||||||
|
for item in items_to_process:
|
||||||
|
item.embed_status = "STORED" # type: ignore
|
||||||
|
logger.info(
|
||||||
|
f"Successfully stored embeddings for: {getattr(item, 'title', 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
for item in items_to_process:
|
||||||
|
item.embed_status = "FAILED" # type: ignore
|
||||||
|
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def create_task_result(
|
||||||
|
item: SourceItem, status: str, **additional_fields: Any
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create standardized task result dictionary.
|
||||||
|
|
||||||
|
Generates a consistent result format for task execution reporting,
|
||||||
|
including item metadata and processing status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: The processed SourceItem
|
||||||
|
status: Processing status string
|
||||||
|
**additional_fields: Extra fields to include in the result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with standardized task result format
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
f"{type(item).__name__.lower()}_id": item.id,
|
||||||
|
"title": getattr(item, "title", None),
|
||||||
|
"status": status,
|
||||||
|
"chunks_count": len(item.chunks),
|
||||||
|
"embed_status": item.embed_status,
|
||||||
|
**additional_fields,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def process_content_item(
|
||||||
|
item: SourceItem, collection_name: str, session, tags: Iterable[str] = []
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute complete content processing workflow.
|
||||||
|
|
||||||
|
Performs the full pipeline for processing a content item:
|
||||||
|
1. Add to database session and flush to get ID
|
||||||
|
2. Generate embeddings and chunks
|
||||||
|
3. Push embeddings to Qdrant vector store
|
||||||
|
4. Commit transaction and return result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: SourceItem to process
|
||||||
|
collection_name: Qdrant collection name for vector storage
|
||||||
|
session: Database session for persistence
|
||||||
|
tags: Optional tags to associate with the item (currently unused)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task result dictionary with processing status and metadata
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
- Adds item to database session
|
||||||
|
- Commits database transaction
|
||||||
|
- Stores vectors in Qdrant
|
||||||
|
"""
|
||||||
|
session.add(item)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
chunks_count = embed_source_item(item)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
try:
|
||||||
|
push_to_qdrant([item], collection_name)
|
||||||
|
status = "processed"
|
||||||
|
logger.info(
|
||||||
|
f"Successfully processed {type(item).__name__}: {getattr(item, 'title', 'unknown')} ({chunks_count} chunks embedded)"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||||
|
item.embed_status = "FAILED" # type: ignore
|
||||||
|
status = "failed"
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
||||||
|
|
||||||
|
|
||||||
|
def safe_task_execution(func: Callable[..., dict]) -> Callable[..., dict]:
|
||||||
|
"""
|
||||||
|
Decorator for safe task execution with comprehensive error handling.
|
||||||
|
|
||||||
|
Wraps task functions to catch and log exceptions, ensuring tasks
|
||||||
|
always return a result dictionary even when they fail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Task function to wrap
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped function that handles exceptions gracefully
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@safe_task_execution
|
||||||
|
def my_task(arg1, arg2):
|
||||||
|
# Task implementation
|
||||||
|
return {"status": "success"}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs) -> dict:
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Task {func.__name__} failed with traceback:\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
logger.error(f"Task {func.__name__} failed: {e}")
|
||||||
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
|
return wrapper
|
@ -1,17 +1,21 @@
|
|||||||
import hashlib
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, cast
|
from typing import Iterable, cast
|
||||||
|
|
||||||
from memory.common import chunker, embedding, qdrant
|
|
||||||
from memory.common.db.connection import make_session
|
|
||||||
from memory.common.db.models import Book, BookSection
|
from memory.common.db.models import Book, BookSection
|
||||||
from memory.common.parsers.ebook import Ebook, parse_ebook, Section
|
from memory.common.parsers.ebook import Ebook, parse_ebook, Section
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app
|
||||||
|
from memory.workers.tasks.content_processing import (
|
||||||
|
check_content_exists,
|
||||||
|
create_content_hash,
|
||||||
|
embed_source_item,
|
||||||
|
push_to_qdrant,
|
||||||
|
safe_task_execution,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SYNC_BOOK = "memory.workers.tasks.book.sync_book"
|
SYNC_BOOK = "memory.workers.tasks.book.sync_book"
|
||||||
|
|
||||||
# Minimum section length to embed (avoid noise from very short sections)
|
# Minimum section length to embed (avoid noise from very short sections)
|
||||||
@ -46,10 +50,6 @@ def section_processor(
|
|||||||
):
|
):
|
||||||
content = "\n\n".join(section.pages).strip()
|
content = "\n\n".join(section.pages).strip()
|
||||||
if len(content) >= MIN_SECTION_LENGTH:
|
if len(content) >= MIN_SECTION_LENGTH:
|
||||||
sha256 = hashlib.sha256(
|
|
||||||
f"{book.id}:{section.title}:{section.start_page}".encode()
|
|
||||||
).digest()
|
|
||||||
|
|
||||||
book_section = BookSection(
|
book_section = BookSection(
|
||||||
book_id=book.id,
|
book_id=book.id,
|
||||||
section_title=section.title,
|
section_title=section.title,
|
||||||
@ -59,7 +59,9 @@ def section_processor(
|
|||||||
end_page=section.end_page,
|
end_page=section.end_page,
|
||||||
parent_section_id=None, # Will be set after flush
|
parent_section_id=None, # Will be set after flush
|
||||||
content=content,
|
content=content,
|
||||||
sha256=sha256,
|
sha256=create_content_hash(
|
||||||
|
f"{book.id}:{section.title}:{section.start_page}"
|
||||||
|
),
|
||||||
modality="book",
|
modality="book",
|
||||||
tags=book.tags,
|
tags=book.tags,
|
||||||
pages=section.pages,
|
pages=section.pages,
|
||||||
@ -127,76 +129,11 @@ def create_book_and_sections(
|
|||||||
|
|
||||||
def embed_sections(all_sections: list[BookSection]) -> int:
|
def embed_sections(all_sections: list[BookSection]) -> int:
|
||||||
"""Embed all sections and return count of successfully embedded sections."""
|
"""Embed all sections and return count of successfully embedded sections."""
|
||||||
embedded_count = 0
|
return sum(embed_source_item(section) for section in all_sections)
|
||||||
|
|
||||||
def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]:
|
|
||||||
_, chunks = embedding.embed(
|
|
||||||
"text/plain",
|
|
||||||
text,
|
|
||||||
metadata=metadata,
|
|
||||||
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
|
|
||||||
)
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
for section in all_sections:
|
|
||||||
try:
|
|
||||||
section_chunks = embed_text(
|
|
||||||
cast(str, section.content), section.as_payload()
|
|
||||||
)
|
|
||||||
page_chunks = [
|
|
||||||
chunk
|
|
||||||
for i, page in enumerate(section.pages)
|
|
||||||
for chunk in embed_text(
|
|
||||||
page, section.as_payload() | {"page_number": i + section.start_page}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
chunks = section_chunks + page_chunks
|
|
||||||
|
|
||||||
if chunks:
|
|
||||||
section.chunks = chunks
|
|
||||||
section.embed_status = "QUEUED" # type: ignore
|
|
||||||
embedded_count += 1
|
|
||||||
else:
|
|
||||||
section.embed_status = "FAILED" # type: ignore
|
|
||||||
logger.warning(
|
|
||||||
f"No chunks generated for section: {section.section_title}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except IOError as e:
|
|
||||||
section.embed_status = "FAILED" # type: ignore
|
|
||||||
logger.error(f"Failed to embed section {section.section_title}: {e}")
|
|
||||||
|
|
||||||
return embedded_count
|
|
||||||
|
|
||||||
|
|
||||||
def push_to_qdrant(all_sections: list[BookSection]):
|
|
||||||
"""Push embeddings to Qdrant for all successfully embedded sections."""
|
|
||||||
vector_ids = []
|
|
||||||
vectors = []
|
|
||||||
payloads = []
|
|
||||||
|
|
||||||
to_process = [s for s in all_sections if cast(str, s.embed_status) == "QUEUED"]
|
|
||||||
all_chunks = [chunk for section in to_process for chunk in section.chunks]
|
|
||||||
if not all_chunks:
|
|
||||||
return
|
|
||||||
|
|
||||||
vector_ids = [str(chunk.id) for chunk in all_chunks]
|
|
||||||
vectors = [chunk.vector for chunk in all_chunks]
|
|
||||||
payloads = [chunk.item_metadata for chunk in all_chunks]
|
|
||||||
|
|
||||||
qdrant.upsert_vectors(
|
|
||||||
client=qdrant.get_qdrant_client(),
|
|
||||||
collection_name="book",
|
|
||||||
ids=vector_ids,
|
|
||||||
vectors=vectors,
|
|
||||||
payloads=payloads,
|
|
||||||
)
|
|
||||||
|
|
||||||
for section in to_process:
|
|
||||||
section.embed_status = "STORED" # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_BOOK)
|
@app.task(name=SYNC_BOOK)
|
||||||
|
@safe_task_execution
|
||||||
def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
||||||
"""
|
"""
|
||||||
Synchronize a book from a file path.
|
Synchronize a book from a file path.
|
||||||
@ -211,10 +148,8 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
|||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
# Check for existing book
|
# Check for existing book
|
||||||
existing_book = (
|
existing_book = check_content_exists(
|
||||||
session.query(Book)
|
session, Book, file_path=ebook.file_path.as_posix()
|
||||||
.filter(Book.file_path == ebook.file_path.as_posix())
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
if existing_book:
|
if existing_book:
|
||||||
logger.info(f"Book already exists: {existing_book.title}")
|
logger.info(f"Book already exists: {existing_book.title}")
|
||||||
@ -230,19 +165,10 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
|||||||
book, all_sections = create_book_and_sections(ebook, session, tags)
|
book, all_sections = create_book_and_sections(ebook, session, tags)
|
||||||
|
|
||||||
# Embed sections
|
# Embed sections
|
||||||
embedded_count = embed_sections(all_sections)
|
embedded_count = sum(embed_source_item(section) for section in all_sections)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
# Push to Qdrant
|
push_to_qdrant(all_sections, "book")
|
||||||
try:
|
|
||||||
push_to_qdrant(all_sections)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
|
||||||
# Mark sections as failed
|
|
||||||
for section in all_sections:
|
|
||||||
if getattr(section, "embed_status") == "STORED":
|
|
||||||
section.embed_status = "FAILED" # type: ignore
|
|
||||||
raise
|
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
@ -2,16 +2,19 @@ import logging
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import EmailAccount
|
from memory.common.db.models import EmailAccount, MailMessage
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app
|
||||||
from memory.workers.email import (
|
from memory.workers.email import (
|
||||||
check_message_exists,
|
|
||||||
create_mail_message,
|
create_mail_message,
|
||||||
imap_connection,
|
imap_connection,
|
||||||
process_folder,
|
process_folder,
|
||||||
vectorize_email,
|
vectorize_email,
|
||||||
)
|
)
|
||||||
|
from memory.common.parsers.email import parse_email_message
|
||||||
|
from memory.workers.tasks.content_processing import (
|
||||||
|
check_content_exists,
|
||||||
|
safe_task_execution,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -21,12 +24,13 @@ SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
|
|||||||
|
|
||||||
|
|
||||||
@app.task(name=PROCESS_EMAIL)
|
@app.task(name=PROCESS_EMAIL)
|
||||||
|
@safe_task_execution
|
||||||
def process_message(
|
def process_message(
|
||||||
account_id: int,
|
account_id: int,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
folder: str,
|
folder: str,
|
||||||
raw_email: str,
|
raw_email: str,
|
||||||
) -> int | None:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Process a single email message and store it in the database.
|
Process a single email message and store it in the database.
|
||||||
|
|
||||||
@ -37,29 +41,30 @@ def process_message(
|
|||||||
raw_email: Raw email content as string
|
raw_email: Raw email content as string
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
source_id if successful, None otherwise
|
dict with processing result
|
||||||
"""
|
"""
|
||||||
logger.info(f"Processing message {message_id} for account {account_id}")
|
logger.info(f"Processing message {message_id} for account {account_id}")
|
||||||
if not raw_email.strip():
|
if not raw_email.strip():
|
||||||
logger.warning(f"Empty email message received for account {account_id}")
|
logger.warning(f"Empty email message received for account {account_id}")
|
||||||
return None
|
return {"status": "skipped", "reason": "empty_content"}
|
||||||
|
|
||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
if check_message_exists(db, account_id, message_id, raw_email):
|
|
||||||
logger.debug(f"Message {message_id} already exists in database")
|
|
||||||
return None
|
|
||||||
|
|
||||||
account = db.query(EmailAccount).get(account_id)
|
account = db.query(EmailAccount).get(account_id)
|
||||||
if not account:
|
if not account:
|
||||||
logger.error(f"Account {account_id} not found")
|
logger.error(f"Account {account_id} not found")
|
||||||
return None
|
return {"status": "error", "error": "Account not found"}
|
||||||
|
|
||||||
mail_message = create_mail_message(
|
parsed_email = parse_email_message(raw_email, message_id)
|
||||||
db, account.tags, folder, raw_email, message_id
|
if check_content_exists(
|
||||||
)
|
db, MailMessage, message_id=message_id, sha256=parsed_email["hash"]
|
||||||
|
):
|
||||||
|
return {"status": "already_exists", "message_id": message_id}
|
||||||
|
|
||||||
|
mail_message = create_mail_message(db, account.tags, folder, parsed_email)
|
||||||
|
|
||||||
db.flush()
|
db.flush()
|
||||||
vectorize_email(mail_message)
|
vectorize_email(mail_message)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
logger.info(f"Stored embedding for message {mail_message.message_id}")
|
logger.info(f"Stored embedding for message {mail_message.message_id}")
|
||||||
@ -71,16 +76,24 @@ def process_message(
|
|||||||
for chunk in attachment.chunks:
|
for chunk in attachment.chunks:
|
||||||
logger.info(f" - {chunk.id}")
|
logger.info(f" - {chunk.id}")
|
||||||
|
|
||||||
return cast(int, mail_message.id)
|
return {
|
||||||
|
"status": "processed",
|
||||||
|
"mail_message_id": cast(int, mail_message.id),
|
||||||
|
"message_id": message_id,
|
||||||
|
"chunks_count": len(mail_message.chunks),
|
||||||
|
"attachments_count": len(mail_message.attachments),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_ACCOUNT)
|
@app.task(name=SYNC_ACCOUNT)
|
||||||
|
@safe_task_execution
|
||||||
def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Synchronize emails from a specific account.
|
Synchronize emails from a specific account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
account_id: ID of the EmailAccount to sync
|
account_id: ID of the EmailAccount to sync
|
||||||
|
since_date: ISO format date string to sync since
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with stats about the sync operation
|
dict with stats about the sync operation
|
||||||
@ -91,7 +104,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
|||||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||||
if not account or not cast(bool, account.active):
|
if not account or not cast(bool, account.active):
|
||||||
logger.warning(f"Account {account_id} not found or inactive")
|
logger.warning(f"Account {account_id} not found or inactive")
|
||||||
return {"error": "Account not found or inactive"}
|
return {"status": "error", "error": "Account not found or inactive"}
|
||||||
|
|
||||||
folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"]
|
folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"]
|
||||||
if since_date:
|
if since_date:
|
||||||
@ -108,7 +121,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
|||||||
def process_message_wrapper(
|
def process_message_wrapper(
|
||||||
account_id: int, message_id: str, folder: str, raw_email: str
|
account_id: int, message_id: str, folder: str, raw_email: str
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
if check_message_exists(db, account_id, message_id, raw_email): # type: ignore
|
parsed_email = parse_email_message(raw_email, message_id)
|
||||||
|
if check_content_exists(
|
||||||
|
db, MailMessage, message_id=message_id, sha256=parsed_email["hash"]
|
||||||
|
):
|
||||||
return None
|
return None
|
||||||
return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore
|
return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore
|
||||||
|
|
||||||
@ -127,9 +143,10 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
|||||||
db.commit()
|
db.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
|
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
|
||||||
return {"error": str(e)}
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"status": "completed",
|
||||||
"account": account.email_address,
|
"account": account.email_address,
|
||||||
"since_date": cutoff_date.isoformat(),
|
"since_date": cutoff_date.isoformat(),
|
||||||
"folders_processed": len(folders_to_process),
|
"folders_processed": len(folders_to_process),
|
||||||
|
@ -6,7 +6,7 @@ from typing import Sequence
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import contains_eager
|
from sqlalchemy.orm import contains_eager
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, settings
|
from memory.common import collections, embedding, qdrant, settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app
|
||||||
@ -60,11 +60,11 @@ def reingest_chunk(chunk_id: str, collection: str):
|
|||||||
logger.error(f"Chunk {chunk_id} not found")
|
logger.error(f"Chunk {chunk_id} not found")
|
||||||
return
|
return
|
||||||
|
|
||||||
if collection not in embedding.ALL_COLLECTIONS:
|
if collection not in collections.ALL_COLLECTIONS:
|
||||||
raise ValueError(f"Unsupported collection {collection}")
|
raise ValueError(f"Unsupported collection {collection}")
|
||||||
|
|
||||||
data = chunk.data
|
data = chunk.data
|
||||||
if collection in embedding.MULTIMODAL_COLLECTIONS:
|
if collection in collections.MULTIMODAL_COLLECTIONS:
|
||||||
vector = embedding.embed_mixed(data)[0]
|
vector = embedding.embed_mixed(data)[0]
|
||||||
elif len(data) == 1 and isinstance(data[0], str):
|
elif len(data) == 1 and isinstance(data[0], str):
|
||||||
vector = embedding.embed_text([data[0]])[0]
|
vector = embedding.embed_text([data[0]])[0]
|
||||||
|
@ -205,8 +205,13 @@ def email_provider():
|
|||||||
def mock_file_storage(tmp_path: Path):
|
def mock_file_storage(tmp_path: Path):
|
||||||
chunk_storage_dir = tmp_path / "chunks"
|
chunk_storage_dir = tmp_path / "chunks"
|
||||||
chunk_storage_dir.mkdir(parents=True, exist_ok=True)
|
chunk_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||||
with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path):
|
image_storage_dir = tmp_path / "images"
|
||||||
with patch("memory.common.settings.CHUNK_STORAGE_DIR", chunk_storage_dir):
|
image_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
with (
|
||||||
|
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
|
||||||
|
patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir),
|
||||||
|
patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir),
|
||||||
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@ -249,6 +249,7 @@ def test_parse_simple_email():
|
|||||||
"body": "Test body content\n",
|
"body": "Test body content\n",
|
||||||
"attachments": [],
|
"attachments": [],
|
||||||
"sent_at": ANY,
|
"sent_at": ANY,
|
||||||
|
"raw_email": msg.as_string(),
|
||||||
"hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87"
|
"hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87"
|
||||||
+ b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1",
|
+ b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1",
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ import requests
|
|||||||
from bs4 import BeautifulSoup, Tag
|
from bs4 import BeautifulSoup, Tag
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
from memory.common.parsers.html import (
|
from memory.common.parsers.html import (
|
||||||
Article,
|
Article,
|
||||||
BaseHTMLParser,
|
BaseHTMLParser,
|
||||||
@ -164,27 +165,28 @@ def test_parse_date(text, date_format, expected):
|
|||||||
assert parse_date(text, date_format) == expected
|
assert parse_date(text, date_format) == expected
|
||||||
|
|
||||||
|
|
||||||
def test_extract_date():
|
@pytest.mark.parametrize(
|
||||||
|
"selector, date_format, expected",
|
||||||
|
[
|
||||||
|
("time", "%B %d, %Y", datetime.fromisoformat("2023-01-15T10:30:00")),
|
||||||
|
(".named-date", "%B %d, %Y", datetime.fromisoformat("2023-01-15")),
|
||||||
|
(".date", "%Y-%m-%d", datetime.fromisoformat("2023-02-20")),
|
||||||
|
(".nonexistent", None, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_date(selector, date_format, expected):
|
||||||
html = """
|
html = """
|
||||||
<div>
|
<div>
|
||||||
<time datetime="2023-01-15T10:30:00">January 15, 2023</time>
|
<time datetime="2023-01-15T10:30:00">January 15, 2023</time>
|
||||||
|
<span class="named-date">January 15, 2023</span>
|
||||||
<span class="date">2023-02-20</span>
|
<span class="date">2023-02-20</span>
|
||||||
<div class="published">March 10, 2023</div>
|
<div class="published">March 10, 2023</div>
|
||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
soup = BeautifulSoup(html, "html.parser")
|
soup = BeautifulSoup(html, "html.parser")
|
||||||
|
|
||||||
# Should extract datetime attribute from time tag
|
result = extract_date(soup, selector, date_format)
|
||||||
result = extract_date(soup, "time", "%Y-%m-%d")
|
assert result == expected
|
||||||
assert result == "2023-01-15T10:30:00"
|
|
||||||
|
|
||||||
# Should extract from text content
|
|
||||||
result = extract_date(soup, ".date", "%Y-%m-%d")
|
|
||||||
assert result == "2023-02-20T00:00:00"
|
|
||||||
|
|
||||||
# No matching element
|
|
||||||
result = extract_date(soup, ".nonexistent", "%Y-%m-%d")
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_content_element():
|
def test_extract_content_element():
|
||||||
@ -393,8 +395,7 @@ def test_process_image_cached(mock_pil_open, mock_requests_get):
|
|||||||
|
|
||||||
|
|
||||||
@patch("memory.common.parsers.html.process_image")
|
@patch("memory.common.parsers.html.process_image")
|
||||||
@patch("memory.common.parsers.html.FILE_STORAGE_DIR")
|
def test_process_images_basic(mock_process_image):
|
||||||
def test_process_images_basic(mock_file_storage_dir, mock_process_image):
|
|
||||||
html = """
|
html = """
|
||||||
<div>
|
<div>
|
||||||
<p>Text content</p>
|
<p>Text content</p>
|
||||||
@ -409,40 +410,37 @@ def test_process_images_basic(mock_file_storage_dir, mock_process_image):
|
|||||||
content = cast(Tag, soup.find("div"))
|
content = cast(Tag, soup.find("div"))
|
||||||
base_url = "https://example.com"
|
base_url = "https://example.com"
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
image_dir = pathlib.Path(temp_dir)
|
|
||||||
mock_file_storage_dir.resolve.return_value = pathlib.Path(temp_dir)
|
|
||||||
|
|
||||||
# Mock successful image processing with proper filenames
|
# Mock successful image processing with proper filenames
|
||||||
mock_images = []
|
mock_images = []
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
mock_img = MagicMock(spec=PILImage.Image)
|
mock_img = MagicMock(spec=PILImage.Image)
|
||||||
mock_img.filename = str(pathlib.Path(temp_dir) / f"image{i + 1}.jpg")
|
mock_img.filename = str(settings.WEBPAGE_STORAGE_DIR / f"image{i + 1}.jpg")
|
||||||
mock_images.append(mock_img)
|
mock_images.append(mock_img)
|
||||||
|
|
||||||
mock_process_image.side_effect = mock_images
|
mock_process_image.side_effect = mock_images
|
||||||
|
|
||||||
updated_content, images = process_images(content, base_url, image_dir)
|
updated_content, images = process_images(
|
||||||
|
content, base_url, settings.WEBPAGE_STORAGE_DIR
|
||||||
|
)
|
||||||
|
|
||||||
# Should have processed 3 images (skipping the one without src)
|
expected = BeautifulSoup(
|
||||||
assert len(images) == 3
|
"""<div>
|
||||||
assert mock_process_image.call_count == 3
|
<p>Text content</p>
|
||||||
|
<img alt="Image 1" src="images/image1.jpg"/>
|
||||||
# Check that img src attributes were updated to relative paths
|
<img alt="Image 2" src="images/image2.jpg"/>
|
||||||
img_tags = [
|
<img alt="Image 3" src="images/image3.jpg"/>
|
||||||
tag
|
<img alt="No src"/>
|
||||||
for tag in (updated_content.find_all("img") if updated_content else [])
|
<p>More text</p>
|
||||||
if isinstance(tag, Tag)
|
</div>
|
||||||
]
|
""",
|
||||||
src_values = []
|
"html.parser",
|
||||||
for img in img_tags:
|
)
|
||||||
src = img.get("src")
|
assert updated_content.prettify() == expected.prettify() # type: ignore
|
||||||
if src and isinstance(src, str):
|
assert images == {
|
||||||
src_values.append(src)
|
"images/image1.jpg": mock_images[0],
|
||||||
|
"images/image2.jpg": mock_images[1],
|
||||||
# Should have relative paths to the processed images
|
"images/image3.jpg": mock_images[2],
|
||||||
for src in src_values[:3]: # First 3 have src
|
}
|
||||||
assert not src.startswith("http") # Should be relative paths
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_images_empty():
|
def test_process_images_empty():
|
||||||
@ -454,8 +452,7 @@ def test_process_images_empty():
|
|||||||
|
|
||||||
|
|
||||||
@patch("memory.common.parsers.html.process_image")
|
@patch("memory.common.parsers.html.process_image")
|
||||||
@patch("memory.common.parsers.html.FILE_STORAGE_DIR")
|
def test_process_images_with_failures(mock_process_image):
|
||||||
def test_process_images_with_failures(mock_file_storage_dir, mock_process_image):
|
|
||||||
html = """
|
html = """
|
||||||
<div>
|
<div>
|
||||||
<img src="good.jpg" alt="Good image">
|
<img src="good.jpg" alt="Good image">
|
||||||
@ -465,22 +462,20 @@ def test_process_images_with_failures(mock_file_storage_dir, mock_process_image)
|
|||||||
soup = BeautifulSoup(html, "html.parser")
|
soup = BeautifulSoup(html, "html.parser")
|
||||||
content = cast(Tag, soup.find("div"))
|
content = cast(Tag, soup.find("div"))
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
image_dir = pathlib.Path(temp_dir)
|
|
||||||
mock_file_storage_dir.resolve.return_value = pathlib.Path(temp_dir)
|
|
||||||
|
|
||||||
# First image succeeds, second fails
|
# First image succeeds, second fails
|
||||||
mock_good_image = MagicMock(spec=PILImage.Image)
|
mock_good_image = MagicMock(spec=PILImage.Image)
|
||||||
mock_good_image.filename = str(pathlib.Path(temp_dir) / "good.jpg")
|
mock_good_image.filename = settings.WEBPAGE_STORAGE_DIR / "good.jpg"
|
||||||
mock_process_image.side_effect = [mock_good_image, None]
|
mock_process_image.side_effect = [mock_good_image, None]
|
||||||
|
|
||||||
updated_content, images = process_images(
|
updated_content, images = process_images(
|
||||||
content, "https://example.com", image_dir
|
content, "https://example.com", settings.WEBPAGE_STORAGE_DIR
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should only return successful image
|
expected = BeautifulSoup(
|
||||||
assert len(images) == 1
|
html.replace("good.jpg", "images/good.jpg"), "html.parser"
|
||||||
assert images[0] == mock_good_image
|
).prettify()
|
||||||
|
assert updated_content.prettify() == expected # type: ignore
|
||||||
|
assert images == {"images/good.jpg": mock_good_image}
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.common.parsers.html.process_image")
|
@patch("memory.common.parsers.html.process_image")
|
||||||
@ -494,15 +489,12 @@ def test_process_images_no_filename(mock_process_image):
|
|||||||
mock_image.filename = None
|
mock_image.filename = None
|
||||||
mock_process_image.return_value = mock_image
|
mock_process_image.return_value = mock_image
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
image_dir = pathlib.Path(temp_dir)
|
|
||||||
|
|
||||||
updated_content, images = process_images(
|
updated_content, images = process_images(
|
||||||
content, "https://example.com", image_dir
|
content, "https://example.com", settings.WEBPAGE_STORAGE_DIR
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should skip image without filename
|
# Should skip image without filename
|
||||||
assert len(images) == 0
|
assert not images
|
||||||
|
|
||||||
|
|
||||||
class TestBaseHTMLParser:
|
class TestBaseHTMLParser:
|
||||||
@ -541,7 +533,7 @@ class TestBaseHTMLParser:
|
|||||||
|
|
||||||
assert article.title == "Article Title"
|
assert article.title == "Article Title"
|
||||||
assert article.author == "John Smith" # Should prefer content over meta
|
assert article.author == "John Smith" # Should prefer content over meta
|
||||||
assert article.published_date == "2023-01-15T00:00:00"
|
assert article.published_date == datetime(2023, 1, 15)
|
||||||
assert article.url == "https://example.com/article"
|
assert article.url == "https://example.com/article"
|
||||||
assert "This is the main content" in article.content
|
assert "This is the main content" in article.content
|
||||||
assert article.metadata["author"] == "Jane Doe"
|
assert article.metadata["author"] == "Jane Doe"
|
||||||
|
@ -5,14 +5,10 @@ from typing import cast
|
|||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings, collections
|
||||||
from memory.common.embedding import (
|
from memory.common.embedding import (
|
||||||
embed,
|
|
||||||
embed_file,
|
|
||||||
embed_mixed,
|
embed_mixed,
|
||||||
embed_page,
|
|
||||||
embed_text,
|
embed_text,
|
||||||
get_modality,
|
|
||||||
make_chunk,
|
make_chunk,
|
||||||
write_to_file,
|
write_to_file,
|
||||||
)
|
)
|
||||||
@ -50,7 +46,7 @@ def mock_embed(mock_voyage_client):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_get_modality(mime_type, expected_modality):
|
def test_get_modality(mime_type, expected_modality):
|
||||||
assert get_modality(mime_type) == expected_modality
|
assert collections.get_modality(mime_type) == expected_modality
|
||||||
|
|
||||||
|
|
||||||
def test_embed_text(mock_embed):
|
def test_embed_text(mock_embed):
|
||||||
@ -58,59 +54,11 @@ def test_embed_text(mock_embed):
|
|||||||
assert embed_text(texts) == [[0], [1]]
|
assert embed_text(texts) == [[0], [1]]
|
||||||
|
|
||||||
|
|
||||||
def test_embed_file(mock_embed, tmp_path):
|
|
||||||
mock_file = tmp_path / "test.txt"
|
|
||||||
mock_file.write_text("file content")
|
|
||||||
|
|
||||||
assert embed_file(mock_file) == [[0]]
|
|
||||||
|
|
||||||
|
|
||||||
def test_embed_mixed(mock_embed):
|
def test_embed_mixed(mock_embed):
|
||||||
items = ["text", {"type": "image", "data": "base64"}]
|
items = ["text", {"type": "image", "data": "base64"}]
|
||||||
assert embed_mixed(items) == [[0]]
|
assert embed_mixed(items) == [[0]]
|
||||||
|
|
||||||
|
|
||||||
def test_embed_page_text_only(mock_embed):
|
|
||||||
page = {"contents": ["text1", "text2"]}
|
|
||||||
assert embed_page(page) == [[0], [1]] # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def test_embed_page_mixed_content(mock_embed):
|
|
||||||
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
|
|
||||||
assert embed_page(page) == [[0]] # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def test_embed(mock_embed):
|
|
||||||
mime_type = "text/plain"
|
|
||||||
content = "sample content"
|
|
||||||
metadata = {"source": "test"}
|
|
||||||
|
|
||||||
with patch.object(uuid, "uuid4", return_value="id1"):
|
|
||||||
modality, chunks = embed(mime_type, content, metadata)
|
|
||||||
|
|
||||||
assert modality == "text"
|
|
||||||
assert [
|
|
||||||
{
|
|
||||||
"id": c.id, # type: ignore
|
|
||||||
"file_path": c.file_path, # type: ignore
|
|
||||||
"content": c.content, # type: ignore
|
|
||||||
"embedding_model": c.embedding_model, # type: ignore
|
|
||||||
"vector": c.vector, # type: ignore
|
|
||||||
"item_metadata": c.item_metadata, # type: ignore
|
|
||||||
}
|
|
||||||
for c in chunks
|
|
||||||
] == [
|
|
||||||
{
|
|
||||||
"content": "sample content",
|
|
||||||
"embedding_model": "voyage-3-large",
|
|
||||||
"file_path": None,
|
|
||||||
"id": "id1",
|
|
||||||
"item_metadata": {"source": "test"},
|
|
||||||
"vector": [0],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_to_file_text(mock_file_storage):
|
def test_write_to_file_text(mock_file_storage):
|
||||||
"""Test writing a string to a file."""
|
"""Test writing a string to a file."""
|
||||||
chunk_id = "test-chunk-id"
|
chunk_id = "test-chunk-id"
|
||||||
@ -160,17 +108,14 @@ def test_write_to_file_unsupported_type(mock_file_storage):
|
|||||||
|
|
||||||
def test_make_chunk_text_only(mock_file_storage, db_session):
|
def test_make_chunk_text_only(mock_file_storage, db_session):
|
||||||
"""Test creating a chunk from string content."""
|
"""Test creating a chunk from string content."""
|
||||||
page = {
|
contents = ["text content 1", "text content 2"]
|
||||||
"contents": ["text content 1", "text content 2"],
|
|
||||||
"metadata": {"source": "test"},
|
|
||||||
}
|
|
||||||
vector = [0.1, 0.2, 0.3]
|
vector = [0.1, 0.2, 0.3]
|
||||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
|
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||||
):
|
):
|
||||||
chunk = make_chunk(page, vector, metadata) # type: ignore
|
chunk = make_chunk(contents, vector, metadata) # type: ignore
|
||||||
|
|
||||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
|
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
|
||||||
assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
|
assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
|
||||||
@ -183,14 +128,14 @@ def test_make_chunk_text_only(mock_file_storage, db_session):
|
|||||||
def test_make_chunk_single_image(mock_file_storage, db_session):
|
def test_make_chunk_single_image(mock_file_storage, db_session):
|
||||||
"""Test creating a chunk from a single image."""
|
"""Test creating a chunk from a single image."""
|
||||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||||
page = {"contents": [img], "metadata": {"source": "test"}}
|
contents = [img]
|
||||||
vector = [0.1, 0.2, 0.3]
|
vector = [0.1, 0.2, 0.3]
|
||||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
|
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
|
||||||
):
|
):
|
||||||
chunk = make_chunk(page, vector, metadata) # type: ignore
|
chunk = make_chunk(contents, vector, metadata) # type: ignore
|
||||||
|
|
||||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
|
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
|
||||||
assert chunk.content is None
|
assert chunk.content is None
|
||||||
@ -208,14 +153,14 @@ def test_make_chunk_single_image(mock_file_storage, db_session):
|
|||||||
def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
||||||
"""Test creating a chunk from mixed content (string and image)."""
|
"""Test creating a chunk from mixed content (string and image)."""
|
||||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||||
page = {"contents": ["text content", img], "metadata": {"source": "test"}}
|
contents = ["text content", img]
|
||||||
vector = [0.1, 0.2, 0.3]
|
vector = [0.1, 0.2, 0.3]
|
||||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
|
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
|
||||||
):
|
):
|
||||||
chunk = make_chunk(page, vector, metadata) # type: ignore
|
chunk = make_chunk(contents, vector, metadata) # type: ignore
|
||||||
|
|
||||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
|
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
|
||||||
assert chunk.content is None
|
assert chunk.content is None
|
||||||
@ -233,3 +178,181 @@ def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
|||||||
assert (
|
assert (
|
||||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png"
|
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png"
|
||||||
).exists()
|
).exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data,embedding_model,collection,expected_model,expected_count,expected_has_content",
|
||||||
|
[
|
||||||
|
# Text-only with default model
|
||||||
|
(
|
||||||
|
["text content 1", "text content 2"],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
2,
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
# Text with explicit mixed model - but make_chunk still uses TEXT_EMBEDDING_MODEL for text-only content
|
||||||
|
(
|
||||||
|
["text content"],
|
||||||
|
settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
None,
|
||||||
|
settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
1,
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
# Text collection model selection - make_chunk uses TEXT_EMBEDDING_MODEL for text-only content
|
||||||
|
(["text content"], None, "mail", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||||
|
(["text content"], None, "photo", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||||
|
(["text content"], None, "doc", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||||
|
# Unknown collection falls back to default
|
||||||
|
(["text content"], None, "unknown", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||||
|
# Explicit model takes precedence over collection
|
||||||
|
(
|
||||||
|
["text content"],
|
||||||
|
settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
"photo",
|
||||||
|
settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
1,
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_embed_data_chunk_scenarios(
|
||||||
|
data,
|
||||||
|
embedding_model,
|
||||||
|
collection,
|
||||||
|
expected_model,
|
||||||
|
expected_count,
|
||||||
|
expected_has_content,
|
||||||
|
mock_embed,
|
||||||
|
mock_file_storage,
|
||||||
|
):
|
||||||
|
"""Test various embedding scenarios for data chunks."""
|
||||||
|
from memory.common.extract import DataChunk
|
||||||
|
from memory.common.embedding import embed_data_chunk
|
||||||
|
|
||||||
|
chunk = DataChunk(
|
||||||
|
data=data,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
collection=collection,
|
||||||
|
metadata={"source": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = embed_data_chunk(chunk, {"doc_type": "test"})
|
||||||
|
|
||||||
|
assert len(result) == expected_count
|
||||||
|
assert all(cast(str, c.embedding_model) == expected_model for c in result)
|
||||||
|
if expected_has_content:
|
||||||
|
assert all(c.content is not None for c in result)
|
||||||
|
assert all(c.file_path is None for c in result)
|
||||||
|
else:
|
||||||
|
assert all(c.content is None for c in result)
|
||||||
|
assert all(c.file_path is not None for c in result)
|
||||||
|
assert all(
|
||||||
|
c.item_metadata == {"source": "test", "doc_type": "test"} for c in result
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_data_chunk_mixed_content(mock_embed, mock_file_storage):
|
||||||
|
"""Test embedding mixed content (text and images)."""
|
||||||
|
from memory.common.extract import DataChunk
|
||||||
|
from memory.common.embedding import embed_data_chunk
|
||||||
|
|
||||||
|
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||||
|
chunk = DataChunk(
|
||||||
|
data=["text content", img],
|
||||||
|
embedding_model=settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
metadata={"source": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = embed_data_chunk(chunk)
|
||||||
|
|
||||||
|
assert len(result) == 1 # Mixed content returns single vector
|
||||||
|
assert result[0].content is None # Mixed content stored in files
|
||||||
|
assert result[0].file_path is not None
|
||||||
|
assert cast(str, result[0].embedding_model) == settings.MIXED_EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"chunk_max_size,chunk_size_param,expected_chunk_size",
|
||||||
|
[
|
||||||
|
(512, 1024, 512), # chunk.max_size takes precedence
|
||||||
|
(None, 2048, 2048), # chunk_size parameter used when max_size is None
|
||||||
|
(256, None, 256), # chunk.max_size used when parameter is None
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_embed_data_chunk_chunk_size_handling(
|
||||||
|
chunk_max_size, chunk_size_param, expected_chunk_size, mock_embed, mock_file_storage
|
||||||
|
):
|
||||||
|
"""Test chunk size parameter handling."""
|
||||||
|
from memory.common.extract import DataChunk
|
||||||
|
from memory.common.embedding import embed_data_chunk
|
||||||
|
|
||||||
|
chunk = DataChunk(
|
||||||
|
data=["text content"], max_size=chunk_max_size, metadata={"source": "test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("memory.common.embedding.embed_text") as mock_embed_text:
|
||||||
|
mock_embed_text.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
|
||||||
|
result = embed_data_chunk(chunk, chunk_size=chunk_size_param)
|
||||||
|
|
||||||
|
mock_embed_text.assert_called_once()
|
||||||
|
args, kwargs = mock_embed_text.call_args
|
||||||
|
assert kwargs["chunk_size"] == expected_chunk_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_data_chunk_metadata_merging(mock_embed, mock_file_storage):
|
||||||
|
"""Test that chunk metadata and parameter metadata are properly merged."""
|
||||||
|
from memory.common.extract import DataChunk
|
||||||
|
from memory.common.embedding import embed_data_chunk
|
||||||
|
|
||||||
|
chunk = DataChunk(
|
||||||
|
data=["text content"], metadata={"source": "test", "type": "chunk"}
|
||||||
|
)
|
||||||
|
metadata = {
|
||||||
|
"doc_type": "test",
|
||||||
|
"source": "override",
|
||||||
|
} # chunk.metadata takes precedence over parameter metadata
|
||||||
|
|
||||||
|
result = embed_data_chunk(chunk, metadata)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
expected_metadata = {
|
||||||
|
"source": "test",
|
||||||
|
"type": "chunk",
|
||||||
|
"doc_type": "test",
|
||||||
|
} # chunk source wins
|
||||||
|
assert result[0].item_metadata == expected_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_data_chunk_unsupported_model(mock_embed, mock_file_storage):
|
||||||
|
"""Test error handling for unsupported embedding model."""
|
||||||
|
from memory.common.extract import DataChunk
|
||||||
|
from memory.common.embedding import embed_data_chunk
|
||||||
|
|
||||||
|
chunk = DataChunk(
|
||||||
|
data=["text content"],
|
||||||
|
embedding_model="unsupported-model",
|
||||||
|
metadata={"source": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unsupported model: unsupported-model"):
|
||||||
|
embed_data_chunk(chunk)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_data_chunk_empty_data(mock_embed, mock_file_storage):
|
||||||
|
"""Test handling of empty data."""
|
||||||
|
from memory.common.extract import DataChunk
|
||||||
|
from memory.common.embedding import embed_data_chunk
|
||||||
|
|
||||||
|
chunk = DataChunk(data=[], metadata={"source": "test"})
|
||||||
|
|
||||||
|
# Should handle empty data gracefully
|
||||||
|
with patch("memory.common.embedding.embed_text") as mock_embed_text:
|
||||||
|
mock_embed_text.return_value = []
|
||||||
|
|
||||||
|
result = embed_data_chunk(chunk)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
@ -7,7 +7,6 @@ import shutil
|
|||||||
from memory.common.extract import (
|
from memory.common.extract import (
|
||||||
as_file,
|
as_file,
|
||||||
extract_text,
|
extract_text,
|
||||||
extract_content,
|
|
||||||
doc_to_images,
|
doc_to_images,
|
||||||
extract_image,
|
extract_image,
|
||||||
docx_to_pdf,
|
docx_to_pdf,
|
||||||
@ -107,52 +106,6 @@ def test_extract_image_with_str():
|
|||||||
extract_image("test")
|
extract_image("test")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"mime_type,content",
|
|
||||||
[
|
|
||||||
("text/plain", "Text content"),
|
|
||||||
("text/html", "<html>content</html>"),
|
|
||||||
("text/markdown", "# Heading"),
|
|
||||||
("text/csv", "a,b,c"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_extract_content_different_text_types(mime_type, content):
|
|
||||||
assert extract_content(mime_type, content) == [
|
|
||||||
{"contents": [content], "metadata": {}}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_content_pdf():
|
|
||||||
result = extract_content("application/pdf", REGULAMIN)
|
|
||||||
|
|
||||||
assert len(result) == 2
|
|
||||||
assert all(
|
|
||||||
isinstance(page["contents"], list)
|
|
||||||
and all(isinstance(c, Image.Image) for c in page["contents"])
|
|
||||||
for page in result
|
|
||||||
)
|
|
||||||
assert all("page" in page["metadata"] for page in result)
|
|
||||||
assert all("width" in page["metadata"] for page in result)
|
|
||||||
assert all("height" in page["metadata"] for page in result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_content_image(tmp_path):
|
|
||||||
# Create a test image
|
|
||||||
img = Image.new("RGB", (100, 100), color="red")
|
|
||||||
img_path = tmp_path / "test_img.png"
|
|
||||||
img.save(img_path)
|
|
||||||
|
|
||||||
(result,) = extract_content("image/png", img_path)
|
|
||||||
|
|
||||||
assert isinstance(result["contents"][0], Image.Image)
|
|
||||||
assert result["contents"][0].size == (100, 100)
|
|
||||||
assert result["metadata"] == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_content_unsupported_type():
|
|
||||||
assert extract_content("unsupported/type", "content") == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
|
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
|
||||||
def test_docx_to_pdf(tmp_path):
|
def test_docx_to_pdf(tmp_path):
|
||||||
output_path = tmp_path / "output.pdf"
|
output_path = tmp_path / "output.pdf"
|
||||||
|
431
tests/memory/workers/tasks/test_comic_tasks.py
Normal file
431
tests/memory/workers/tasks/test_comic_tasks.py
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from memory.common.db.models import Comic
|
||||||
|
from memory.workers.tasks import comic
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_comic_info():
|
||||||
|
"""Mock comic info data for testing."""
|
||||||
|
return {
|
||||||
|
"title": "Test Comic",
|
||||||
|
"image_url": "https://example.com/comic.png",
|
||||||
|
"url": "https://example.com/comic/1",
|
||||||
|
"published_date": "2024-01-01T12:00:00Z",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_feed_data():
|
||||||
|
"""Mock RSS feed data."""
|
||||||
|
return {
|
||||||
|
"entries": [
|
||||||
|
{"link": "https://example.com/comic/1", "id": None},
|
||||||
|
{"link": "https://example.com/comic/2", "id": None},
|
||||||
|
{"link": None, "id": "https://example.com/comic/3"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_image_response():
|
||||||
|
"""Mock HTTP response for comic image."""
|
||||||
|
# 1x1 PNG image (smallest valid PNG)
|
||||||
|
png_data = (
|
||||||
|
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01"
|
||||||
|
b"\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13"
|
||||||
|
b"\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```"
|
||||||
|
b"\x00\x00\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||||
|
)
|
||||||
|
response = Mock()
|
||||||
|
response.status_code = 200
|
||||||
|
response.content = png_data
|
||||||
|
with patch.object(requests, "get", return_value=response):
|
||||||
|
yield response
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||||
|
def test_find_new_urls_success(mock_parse, mock_feed_data, db_session):
|
||||||
|
"""Test successful URL discovery from RSS feed."""
|
||||||
|
mock_parse.return_value = Mock(entries=mock_feed_data["entries"])
|
||||||
|
|
||||||
|
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"https://example.com/comic/1",
|
||||||
|
"https://example.com/comic/2",
|
||||||
|
"https://example.com/comic/3",
|
||||||
|
}
|
||||||
|
mock_parse.assert_called_once_with("https://example.com/rss")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||||
|
def test_find_new_urls_with_existing_comics(mock_parse, mock_feed_data, db_session):
|
||||||
|
"""Test URL discovery when some comics already exist."""
|
||||||
|
mock_parse.return_value = Mock(entries=mock_feed_data["entries"])
|
||||||
|
|
||||||
|
# Add existing comic to database
|
||||||
|
existing_comic = Comic(
|
||||||
|
title="Existing Comic",
|
||||||
|
url="https://example.com/comic/1",
|
||||||
|
author="https://example.com",
|
||||||
|
filename="/test/path",
|
||||||
|
sha256=b"test_hash",
|
||||||
|
modality="comic",
|
||||||
|
tags=["comic"],
|
||||||
|
)
|
||||||
|
db_session.add(existing_comic)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
|
||||||
|
|
||||||
|
# Should only return URLs not in database
|
||||||
|
assert result == {
|
||||||
|
"https://example.com/comic/2",
|
||||||
|
"https://example.com/comic/3",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||||
|
def test_find_new_urls_parse_error(mock_parse):
|
||||||
|
"""Test handling of RSS feed parsing errors."""
|
||||||
|
mock_parse.side_effect = Exception("Parse error")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
comic.find_new_urls("https://example.com", "https://example.com/rss") == set()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||||
|
def test_find_new_urls_empty_feed(mock_parse):
|
||||||
|
"""Test handling of empty RSS feed."""
|
||||||
|
mock_parse.return_value = Mock(entries=[])
|
||||||
|
|
||||||
|
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
|
||||||
|
|
||||||
|
assert result == set()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.feedparser.parse")
|
||||||
|
def test_find_new_urls_malformed_entries(mock_parse):
|
||||||
|
"""Test handling of malformed RSS entries."""
|
||||||
|
mock_parse.return_value = Mock(
|
||||||
|
entries=[
|
||||||
|
{"link": None, "id": None}, # Both None
|
||||||
|
{}, # Missing keys
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = comic.find_new_urls("https://example.com", "https://example.com/rss")
|
||||||
|
|
||||||
|
assert result == set()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||||
|
@patch("memory.workers.tasks.comic.find_new_urls")
|
||||||
|
def test_fetch_new_comics_success(mock_find_urls, mock_sync_delay, mock_comic_info):
|
||||||
|
"""Test successful comic fetching."""
|
||||||
|
mock_find_urls.return_value = {"https://example.com/comic/1"}
|
||||||
|
mock_parser = Mock(return_value=mock_comic_info)
|
||||||
|
|
||||||
|
result = comic.fetch_new_comics(
|
||||||
|
"https://example.com", "https://example.com/rss", mock_parser
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"https://example.com/comic/1"}
|
||||||
|
mock_parser.assert_called_once_with("https://example.com/comic/1")
|
||||||
|
expected_call_args = {
|
||||||
|
**mock_comic_info,
|
||||||
|
"author": "https://example.com",
|
||||||
|
"url": "https://example.com/comic/1",
|
||||||
|
}
|
||||||
|
mock_sync_delay.assert_called_once_with(**expected_call_args)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||||
|
@patch("memory.workers.tasks.comic.find_new_urls")
|
||||||
|
def test_fetch_new_comics_no_new_urls(mock_find_urls, mock_sync_delay):
|
||||||
|
"""Test when no new URLs are found."""
|
||||||
|
mock_find_urls.return_value = set()
|
||||||
|
mock_parser = Mock()
|
||||||
|
|
||||||
|
result = comic.fetch_new_comics(
|
||||||
|
"https://example.com", "https://example.com/rss", mock_parser
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == set()
|
||||||
|
mock_parser.assert_not_called()
|
||||||
|
mock_sync_delay.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||||
|
@patch("memory.workers.tasks.comic.find_new_urls")
|
||||||
|
def test_fetch_new_comics_multiple_urls(
|
||||||
|
mock_find_urls, mock_sync_delay, mock_comic_info
|
||||||
|
):
|
||||||
|
"""Test fetching multiple new comics."""
|
||||||
|
urls = {"https://example.com/comic/1", "https://example.com/comic/2"}
|
||||||
|
mock_find_urls.return_value = urls
|
||||||
|
mock_parser = Mock(return_value=mock_comic_info)
|
||||||
|
|
||||||
|
result = comic.fetch_new_comics(
|
||||||
|
"https://example.com", "https://example.com/rss", mock_parser
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == urls
|
||||||
|
assert mock_parser.call_count == 2
|
||||||
|
assert mock_sync_delay.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.requests.get")
|
||||||
|
def test_sync_comic_success(mock_get, mock_image_response, db_session, qdrant):
|
||||||
|
"""Test successful comic synchronization."""
|
||||||
|
mock_get.return_value = mock_image_response
|
||||||
|
|
||||||
|
comic.sync_comic(
|
||||||
|
url="https://example.com/comic/1",
|
||||||
|
image_url="https://example.com/image.png",
|
||||||
|
title="Test Comic",
|
||||||
|
author="https://example.com",
|
||||||
|
published_date=datetime(2024, 1, 1, 12, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify comic was created in database
|
||||||
|
saved_comic = (
|
||||||
|
db_session.query(Comic)
|
||||||
|
.filter(Comic.url == "https://example.com/comic/1")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert saved_comic is not None
|
||||||
|
assert saved_comic.title == "Test Comic"
|
||||||
|
assert saved_comic.author == "https://example.com"
|
||||||
|
assert saved_comic.mime_type == "image/png"
|
||||||
|
assert saved_comic.size == len(mock_image_response.content)
|
||||||
|
assert "comic" in saved_comic.tags
|
||||||
|
assert "https://example.com" in saved_comic.tags
|
||||||
|
|
||||||
|
# Verify vectors were added to Qdrant
|
||||||
|
vectors, _ = qdrant.scroll(collection_name="comic")
|
||||||
|
expected_vectors = [
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"author": "https://example.com",
|
||||||
|
"published": "2024-01-01T12:00:00",
|
||||||
|
"tags": ["comic", "https://example.com"],
|
||||||
|
"title": "Test Comic",
|
||||||
|
"url": "https://example.com/comic/1",
|
||||||
|
"source_id": 1,
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
assert [
|
||||||
|
({**v.payload, "tags": sorted(v.payload["tags"])}, v.vector) for v in vectors
|
||||||
|
] == expected_vectors
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_comic_already_exists(db_session):
|
||||||
|
"""Test that duplicate comics are not processed."""
|
||||||
|
# Add existing comic
|
||||||
|
existing_comic = Comic(
|
||||||
|
title="Existing Comic",
|
||||||
|
url="https://example.com/comic/1",
|
||||||
|
author="https://example.com",
|
||||||
|
filename="/test/path",
|
||||||
|
sha256=b"test_hash",
|
||||||
|
modality="comic",
|
||||||
|
tags=["comic"],
|
||||||
|
)
|
||||||
|
db_session.add(existing_comic)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
with patch("memory.workers.tasks.comic.requests.get") as mock_get:
|
||||||
|
result = comic.sync_comic(
|
||||||
|
url="https://example.com/comic/1",
|
||||||
|
image_url="https://example.com/image.png",
|
||||||
|
title="Test Comic",
|
||||||
|
author="https://example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return early without making HTTP request
|
||||||
|
mock_get.assert_not_called()
|
||||||
|
assert result == {"comic_id": 1, "status": "already_exists"}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.requests.get")
|
||||||
|
def test_sync_comic_http_error(mock_get, db_session, qdrant):
|
||||||
|
"""Test handling of HTTP errors when downloading image."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 404
|
||||||
|
mock_response.content = b""
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
comic.sync_comic(
|
||||||
|
url="https://example.com/comic/1",
|
||||||
|
image_url="https://example.com/image.png",
|
||||||
|
title="Test Comic",
|
||||||
|
author="https://example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
db_session.query(Comic)
|
||||||
|
.filter(Comic.url == "https://example.com/comic/1")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.requests.get")
|
||||||
|
def test_sync_comic_no_published_date(
|
||||||
|
mock_get, mock_image_response, db_session, qdrant
|
||||||
|
):
|
||||||
|
"""Test comic sync without published date."""
|
||||||
|
mock_get.return_value = mock_image_response
|
||||||
|
|
||||||
|
comic.sync_comic(
|
||||||
|
url="https://example.com/comic/1",
|
||||||
|
image_url="https://example.com/image.png",
|
||||||
|
title="Test Comic",
|
||||||
|
author="https://example.com",
|
||||||
|
published_date=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_comic = (
|
||||||
|
db_session.query(Comic)
|
||||||
|
.filter(Comic.url == "https://example.com/comic/1")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert saved_comic is not None
|
||||||
|
assert saved_comic.published is None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.requests.get")
|
||||||
|
def test_sync_comic_special_characters_in_title(
|
||||||
|
mock_get, mock_image_response, db_session, qdrant
|
||||||
|
):
|
||||||
|
"""Test comic sync with special characters in title."""
|
||||||
|
mock_get.return_value = mock_image_response
|
||||||
|
|
||||||
|
comic.sync_comic(
|
||||||
|
url="https://example.com/comic/1",
|
||||||
|
image_url="https://example.com/image.png",
|
||||||
|
title="Test/Comic: With*Special?Characters",
|
||||||
|
author="https://example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify comic was created with cleaned title
|
||||||
|
saved_comic = (
|
||||||
|
db_session.query(Comic)
|
||||||
|
.filter(Comic.url == "https://example.com/comic/1")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert saved_comic is not None
|
||||||
|
assert saved_comic.title == "Test/Comic: With*Special?Characters"
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"memory.common.embedding.embed_source_item",
|
||||||
|
side_effect=Exception("Embedding failed"),
|
||||||
|
)
|
||||||
|
def test_sync_comic_embedding_failure(
|
||||||
|
mock_embed_source_item, mock_image_response, db_session, qdrant
|
||||||
|
):
|
||||||
|
"""Test handling of embedding failures."""
|
||||||
|
result = comic.sync_comic(
|
||||||
|
url="https://example.com/comic/1",
|
||||||
|
image_url="https://example.com/image.png",
|
||||||
|
title="Test Comic",
|
||||||
|
author="https://example.com",
|
||||||
|
)
|
||||||
|
assert result == {
|
||||||
|
"comic_id": 1,
|
||||||
|
"title": "Test Comic",
|
||||||
|
"status": "processed",
|
||||||
|
"chunks_count": 0,
|
||||||
|
"embed_status": "FAILED",
|
||||||
|
"content_length": 90,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.sync_xkcd.delay")
|
||||||
|
@patch("memory.workers.tasks.comic.sync_smbc.delay")
|
||||||
|
def test_sync_all_comics(mock_smbc_delay, mock_xkcd_delay):
|
||||||
|
"""Test synchronization of all comics."""
|
||||||
|
comic.sync_all_comics()
|
||||||
|
|
||||||
|
mock_smbc_delay.assert_called_once()
|
||||||
|
mock_xkcd_delay.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||||
|
@patch("memory.workers.tasks.comic.comics.extract_xkcd")
|
||||||
|
@patch("memory.workers.tasks.comic.comics.extract_smbc")
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_trigger_comic_sync_smbc_navigation(
|
||||||
|
mock_get, mock_extract_smbc, mock_extract_xkcd, mock_sync_delay, mock_comic_info
|
||||||
|
):
|
||||||
|
"""Test full SMBC comic sync with navigation."""
|
||||||
|
# Mock HTML responses for navigation
|
||||||
|
mock_responses = [
|
||||||
|
Mock(text='<a class="cc-prev" href="https://smbc.com/comic/2"></a>'),
|
||||||
|
Mock(text='<a class="cc-prev" href="https://smbc.com/comic/1"></a>'),
|
||||||
|
Mock(text="<div>No prev link</div>"), # End of navigation
|
||||||
|
]
|
||||||
|
mock_get.side_effect = mock_responses
|
||||||
|
mock_extract_smbc.return_value = mock_comic_info
|
||||||
|
mock_extract_xkcd.return_value = mock_comic_info
|
||||||
|
|
||||||
|
comic.trigger_comic_sync()
|
||||||
|
|
||||||
|
# Should have called extract_smbc for each discovered URL
|
||||||
|
assert mock_extract_smbc.call_count == 2
|
||||||
|
mock_extract_smbc.assert_any_call("https://smbc.com/comic/2")
|
||||||
|
mock_extract_smbc.assert_any_call("https://smbc.com/comic/1")
|
||||||
|
|
||||||
|
# Should have called extract_xkcd for range 1-307
|
||||||
|
assert mock_extract_xkcd.call_count == 307
|
||||||
|
|
||||||
|
# Should have queued sync tasks
|
||||||
|
assert mock_sync_delay.call_count == 2 + 307 # SMBC + XKCD
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||||
|
@patch("memory.workers.tasks.comic.comics.extract_smbc")
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_trigger_comic_sync_smbc_extraction_error(
|
||||||
|
mock_get, mock_extract_smbc, mock_sync_delay
|
||||||
|
):
|
||||||
|
"""Test handling of extraction errors during full sync."""
|
||||||
|
# Mock responses: first one has a prev link, second one doesn't
|
||||||
|
mock_responses = [
|
||||||
|
Mock(text='<a class="cc-prev" href="https://smbc.com/comic/1"></a>'),
|
||||||
|
Mock(text="<div>No prev link</div>"),
|
||||||
|
]
|
||||||
|
mock_get.side_effect = mock_responses
|
||||||
|
mock_extract_smbc.side_effect = Exception("Extraction failed")
|
||||||
|
|
||||||
|
# Should not raise exception, just log error
|
||||||
|
comic.trigger_comic_sync()
|
||||||
|
|
||||||
|
mock_extract_smbc.assert_called_once_with("https://smbc.com/comic/1")
|
||||||
|
mock_sync_delay.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.workers.tasks.comic.sync_comic.delay")
|
||||||
|
@patch("memory.workers.tasks.comic.comics.extract_xkcd")
|
||||||
|
@patch("requests.get")
|
||||||
|
def test_trigger_comic_sync_xkcd_extraction_error(
|
||||||
|
mock_get, mock_extract_xkcd, mock_sync_delay
|
||||||
|
):
|
||||||
|
"""Test handling of XKCD extraction errors during full sync."""
|
||||||
|
mock_get.return_value = Mock(text="<div>No prev link</div>")
|
||||||
|
mock_extract_xkcd.side_effect = Exception("XKCD extraction failed")
|
||||||
|
|
||||||
|
# Should not raise exception, just log errors
|
||||||
|
comic.trigger_comic_sync()
|
||||||
|
|
||||||
|
# Should attempt all 307 XKCD comics
|
||||||
|
assert mock_extract_xkcd.call_count == 307
|
||||||
|
mock_sync_delay.assert_not_called()
|
@ -51,24 +51,6 @@ def mock_ebook():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_embedding():
|
|
||||||
"""Mock the embedding function to return dummy vectors."""
|
|
||||||
with patch("memory.workers.tasks.ebook.embedding.embed") as mock:
|
|
||||||
mock.return_value = (
|
|
||||||
"book",
|
|
||||||
[
|
|
||||||
Chunk(
|
|
||||||
vector=[0.1] * 1024,
|
|
||||||
item_metadata={"test": "data"},
|
|
||||||
content="Test content",
|
|
||||||
embedding_model="model",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_qdrant():
|
def mock_qdrant():
|
||||||
"""Mock Qdrant operations."""
|
"""Mock Qdrant operations."""
|
||||||
@ -151,7 +133,7 @@ def test_create_book_and_sections(mock_ebook, db_session):
|
|||||||
assert getattr(chapter1, "parent_section_id") is None
|
assert getattr(chapter1, "parent_section_id") is None
|
||||||
|
|
||||||
|
|
||||||
def test_embed_sections(db_session, mock_embedding):
|
def test_embed_sections(db_session):
|
||||||
"""Test basic embedding sections workflow."""
|
"""Test basic embedding sections workflow."""
|
||||||
# Create a test book first
|
# Create a test book first
|
||||||
book = Book(
|
book = Book(
|
||||||
@ -187,31 +169,6 @@ def test_embed_sections(db_session, mock_embedding):
|
|||||||
assert hasattr(sections[0], "embed_status")
|
assert hasattr(sections[0], "embed_status")
|
||||||
|
|
||||||
|
|
||||||
def test_push_to_qdrant(qdrant):
|
|
||||||
"""Test pushing embeddings to Qdrant."""
|
|
||||||
# Create test sections with chunks
|
|
||||||
mock_chunk = Mock(
|
|
||||||
id="00000000-0000-0000-0000-000000000000",
|
|
||||||
vector=[0.1] * 1024,
|
|
||||||
item_metadata={"test": "data"},
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_section = Mock(spec=BookSection)
|
|
||||||
mock_section.embed_status = "QUEUED"
|
|
||||||
mock_section.chunks = [mock_chunk]
|
|
||||||
|
|
||||||
sections = [mock_section]
|
|
||||||
|
|
||||||
ebook.push_to_qdrant(sections) # type: ignore
|
|
||||||
|
|
||||||
assert {r.id: r.payload for r in qdrant.scroll(collection_name="book")[0]} == {
|
|
||||||
"00000000-0000-0000-0000-000000000000": {
|
|
||||||
"test": "data",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert mock_section.embed_status == "STORED"
|
|
||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.ebook.parse_ebook")
|
@patch("memory.workers.tasks.ebook.parse_ebook")
|
||||||
def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path):
|
def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path):
|
||||||
"""Test successful book synchronization."""
|
"""Test successful book synchronization."""
|
||||||
@ -229,7 +186,7 @@ def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path):
|
|||||||
"author": "Test Author",
|
"author": "Test Author",
|
||||||
"status": "processed",
|
"status": "processed",
|
||||||
"total_sections": 4,
|
"total_sections": 4,
|
||||||
"sections_embedded": 4,
|
"sections_embedded": 8,
|
||||||
}
|
}
|
||||||
|
|
||||||
book = db_session.query(Book).filter(Book.title == "Test Book").first()
|
book = db_session.query(Book).filter(Book.title == "Test Book").first()
|
||||||
@ -270,8 +227,9 @@ def test_sync_book_already_exists(mock_parse, mock_ebook, db_session, tmp_path):
|
|||||||
|
|
||||||
|
|
||||||
@patch("memory.workers.tasks.ebook.parse_ebook")
|
@patch("memory.workers.tasks.ebook.parse_ebook")
|
||||||
|
@patch("memory.common.embedding.embed_source_item")
|
||||||
def test_sync_book_embedding_failure(
|
def test_sync_book_embedding_failure(
|
||||||
mock_parse, mock_ebook, db_session, tmp_path, mock_embedding
|
mock_embedding, mock_parse, mock_ebook, db_session, tmp_path
|
||||||
):
|
):
|
||||||
"""Test handling of embedding failures."""
|
"""Test handling of embedding failures."""
|
||||||
book_file = tmp_path / "test.epub"
|
book_file = tmp_path / "test.epub"
|
||||||
@ -307,14 +265,18 @@ def test_sync_book_qdrant_failure(mock_parse, mock_ebook, db_session, tmp_path):
|
|||||||
# Since embedding is already failing, this test will complete without hitting Qdrant
|
# Since embedding is already failing, this test will complete without hitting Qdrant
|
||||||
# So let's just verify that the function completes without raising an exception
|
# So let's just verify that the function completes without raising an exception
|
||||||
with patch.object(ebook, "push_to_qdrant", side_effect=Exception("Qdrant failed")):
|
with patch.object(ebook, "push_to_qdrant", side_effect=Exception("Qdrant failed")):
|
||||||
with pytest.raises(Exception, match="Qdrant failed"):
|
assert ebook.sync_book(str(book_file)) == {
|
||||||
ebook.sync_book(str(book_file))
|
"status": "error",
|
||||||
|
"error": "Qdrant failed",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_sync_book_file_not_found():
|
def test_sync_book_file_not_found():
|
||||||
"""Test handling of missing files."""
|
"""Test handling of missing files."""
|
||||||
with pytest.raises(FileNotFoundError):
|
assert ebook.sync_book("/nonexistent/file.epub") == {
|
||||||
ebook.sync_book("/nonexistent/file.epub")
|
"status": "error",
|
||||||
|
"error": "Book file not found: /nonexistent/file.epub",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
||||||
@ -363,7 +325,7 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
|||||||
calls = mock_voyage_client.embed.call_args_list
|
calls = mock_voyage_client.embed.call_args_list
|
||||||
texts = [c[0][0] for c in calls]
|
texts = [c[0][0] for c in calls]
|
||||||
assert texts == [
|
assert texts == [
|
||||||
[large_section_content.strip()],
|
|
||||||
[large_page_1.strip()],
|
[large_page_1.strip()],
|
||||||
[large_page_2.strip()],
|
[large_page_2.strip()],
|
||||||
|
[large_section_content.strip()],
|
||||||
]
|
]
|
||||||
|
@ -69,13 +69,21 @@ def test_email_account(db_session):
|
|||||||
|
|
||||||
def test_process_simple_email(db_session, test_email_account, qdrant):
|
def test_process_simple_email(db_session, test_email_account, qdrant):
|
||||||
"""Test processing a simple email message."""
|
"""Test processing a simple email message."""
|
||||||
mail_message_id = process_message(
|
res = process_message(
|
||||||
account_id=test_email_account.id,
|
account_id=test_email_account.id,
|
||||||
message_id="101",
|
message_id="101",
|
||||||
folder="INBOX",
|
folder="INBOX",
|
||||||
raw_email=SIMPLE_EMAIL_RAW,
|
raw_email=SIMPLE_EMAIL_RAW,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mail_message_id = res["mail_message_id"]
|
||||||
|
assert res == {
|
||||||
|
"status": "processed",
|
||||||
|
"mail_message_id": mail_message_id,
|
||||||
|
"message_id": "101",
|
||||||
|
"chunks_count": 1,
|
||||||
|
"attachments_count": 0,
|
||||||
|
}
|
||||||
mail_message = (
|
mail_message = (
|
||||||
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
|
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
|
||||||
)
|
)
|
||||||
@ -98,7 +106,7 @@ def test_process_email_with_attachment(db_session, test_email_account, qdrant):
|
|||||||
message_id="302",
|
message_id="302",
|
||||||
folder="Archive",
|
folder="Archive",
|
||||||
raw_email=EMAIL_WITH_ATTACHMENT_RAW,
|
raw_email=EMAIL_WITH_ATTACHMENT_RAW,
|
||||||
)
|
)["mail_message_id"]
|
||||||
# Check mail message specifics
|
# Check mail message specifics
|
||||||
mail_message = (
|
mail_message = (
|
||||||
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
|
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
|
||||||
@ -125,25 +133,25 @@ def test_process_email_with_attachment(db_session, test_email_account, qdrant):
|
|||||||
|
|
||||||
def test_process_empty_message(db_session, test_email_account, qdrant):
|
def test_process_empty_message(db_session, test_email_account, qdrant):
|
||||||
"""Test processing an empty/invalid message."""
|
"""Test processing an empty/invalid message."""
|
||||||
source_id = process_message(
|
res = process_message(
|
||||||
account_id=test_email_account.id,
|
account_id=test_email_account.id,
|
||||||
message_id="999",
|
message_id="999",
|
||||||
folder="Archive",
|
folder="Archive",
|
||||||
raw_email="",
|
raw_email="",
|
||||||
)
|
)
|
||||||
|
assert res == {"reason": "empty_content", "status": "skipped"}
|
||||||
assert source_id is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_duplicate_message(db_session, test_email_account, qdrant):
|
def test_process_duplicate_message(db_session, test_email_account, qdrant):
|
||||||
"""Test that duplicate messages are detected and not stored again."""
|
"""Test that duplicate messages are detected and not stored again."""
|
||||||
# First call should succeed and create records
|
# First call should succeed and create records
|
||||||
source_id_1 = process_message(
|
res = process_message(
|
||||||
account_id=test_email_account.id,
|
account_id=test_email_account.id,
|
||||||
message_id="101",
|
message_id="101",
|
||||||
folder="INBOX",
|
folder="INBOX",
|
||||||
raw_email=SIMPLE_EMAIL_RAW,
|
raw_email=SIMPLE_EMAIL_RAW,
|
||||||
)
|
)
|
||||||
|
source_id_1 = res.get("mail_message_id")
|
||||||
|
|
||||||
assert source_id_1 is not None, "First call should return a source_id"
|
assert source_id_1 is not None, "First call should return a source_id"
|
||||||
|
|
||||||
@ -157,7 +165,7 @@ def test_process_duplicate_message(db_session, test_email_account, qdrant):
|
|||||||
message_id="101",
|
message_id="101",
|
||||||
folder="INBOX",
|
folder="INBOX",
|
||||||
raw_email=SIMPLE_EMAIL_RAW,
|
raw_email=SIMPLE_EMAIL_RAW,
|
||||||
)
|
).get("mail_message_id")
|
||||||
|
|
||||||
assert source_id_2 is None, "Second call should return None for duplicate message"
|
assert source_id_2 is None, "Second call should return None for duplicate message"
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from memory.common.db.models import (
|
|||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
MailMessage,
|
MailMessage,
|
||||||
)
|
)
|
||||||
from memory.common.parsers.email import Attachment
|
from memory.common.parsers.email import Attachment, parse_email_message
|
||||||
from memory.workers.email import (
|
from memory.workers.email import (
|
||||||
create_mail_message,
|
create_mail_message,
|
||||||
extract_email_uid,
|
extract_email_uid,
|
||||||
@ -226,14 +226,14 @@ def test_create_mail_message(db_session):
|
|||||||
"--boundary--"
|
"--boundary--"
|
||||||
)
|
)
|
||||||
folder = "INBOX"
|
folder = "INBOX"
|
||||||
|
parsed_email = parse_email_message(raw_email, "321")
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
mail_message = create_mail_message(
|
mail_message = create_mail_message(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
raw_email=raw_email,
|
|
||||||
folder=folder,
|
folder=folder,
|
||||||
tags=["test"],
|
tags=["test"],
|
||||||
message_id="123",
|
parsed_email=parsed_email,
|
||||||
)
|
)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
@ -344,6 +344,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
|||||||
recipients=["recipient@example.com"],
|
recipients=["recipient@example.com"],
|
||||||
content="This is a test email for vectorization",
|
content="This is a test email for vectorization",
|
||||||
folder="INBOX",
|
folder="INBOX",
|
||||||
|
modality="mail",
|
||||||
)
|
)
|
||||||
db_session.add(mail_message)
|
db_session.add(mail_message)
|
||||||
db_session.flush()
|
db_session.flush()
|
||||||
@ -373,6 +374,7 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
|||||||
recipients=["recipient@example.com"],
|
recipients=["recipient@example.com"],
|
||||||
content="This is a test email with attachments",
|
content="This is a test email with attachments",
|
||||||
folder="INBOX",
|
folder="INBOX",
|
||||||
|
modality="mail",
|
||||||
)
|
)
|
||||||
db_session.add(mail_message)
|
db_session.add(mail_message)
|
||||||
db_session.flush()
|
db_session.flush()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user