mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
move parsers
This commit is contained in:
parent
0f15e4e410
commit
f5c3e458d7
@ -1,8 +1,8 @@
|
||||
"""Rename rss feed
|
||||
|
||||
Revision ID: f8e6a7f80928
|
||||
Revision ID: 1b535e1b044e
|
||||
Revises: d897c6353a84
|
||||
Create Date: 2025-05-27 01:39:45.722077
|
||||
Create Date: 2025-05-27 01:51:38.553777
|
||||
|
||||
"""
|
||||
|
||||
@ -13,7 +13,7 @@ import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f8e6a7f80928"
|
||||
revision: str = "1b535e1b044e"
|
||||
down_revision: Union[str, None] = "d897c6353a84"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
@ -27,6 +27,9 @@ def upgrade() -> None:
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
|
||||
sa.Column(
|
||||
"check_interval", sa.Integer(), server_default="3600", nullable=False
|
||||
),
|
||||
sa.Column("last_checked_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("active", sa.Boolean(), server_default="true", nullable=False),
|
||||
sa.Column(
|
@ -76,6 +76,9 @@ services:
|
||||
mem_limit: 4g
|
||||
cpus: "1.5"
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
ports:
|
||||
# PostgreSQL port for local Celery result backend
|
||||
- "15432:5432"
|
||||
|
||||
rabbitmq:
|
||||
image: rabbitmq:3.13-management
|
||||
@ -97,7 +100,9 @@ services:
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
ports:
|
||||
# UI only on localhost
|
||||
- "127.0.0.1:15672:15672"
|
||||
- "15672:15672"
|
||||
# AMQP port for local Celery clients
|
||||
- "15673:5672"
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant:v1.14.0
|
||||
@ -114,6 +119,8 @@ services:
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
ports:
|
||||
- "6333:6333"
|
||||
mem_limit: 4g
|
||||
cpus: "2"
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
@ -181,7 +188,24 @@ services:
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "medium_embed"
|
||||
# deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
|
||||
|
||||
worker-ebook:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "ebooks"
|
||||
|
||||
worker-comic:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "comic"
|
||||
|
||||
worker-blogs:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "blogs"
|
||||
|
||||
worker-photo:
|
||||
<<: *worker-base
|
||||
|
@ -2,28 +2,33 @@ FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements files and setup
|
||||
COPY requirements-*.txt ./
|
||||
COPY setup.py ./
|
||||
COPY src/ ./src/
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libpq-dev gcc supervisor && \
|
||||
pip install -e ".[workers]" && \
|
||||
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements-*.txt ./
|
||||
RUN pip install --no-cache-dir -r requirements-common.txt
|
||||
RUN pip install --no-cache-dir -r requirements-parsers.txt
|
||||
RUN pip install --no-cache-dir -r requirements-workers.txt
|
||||
|
||||
# Copy requirements files and setup
|
||||
COPY setup.py ./
|
||||
COPY src/ ./src/
|
||||
|
||||
# Create and copy entrypoint script
|
||||
COPY docker/ingest_hub/supervisor.conf /etc/supervisor/conf.d/supervisor.conf
|
||||
COPY docker/workers/entry.sh ./entry.sh
|
||||
RUN chmod +x entry.sh
|
||||
|
||||
# Create required tmpfs directories for supervisor
|
||||
RUN mkdir -p /var/log/supervisor /var/run/supervisor
|
||||
|
||||
# Create storage directory
|
||||
RUN mkdir -p /app/memory_files
|
||||
|
||||
COPY docker/ingest_hub/supervisor.conf /etc/supervisor/conf.d/supervisor.conf
|
||||
|
||||
# Create required tmpfs directories for supervisor
|
||||
RUN mkdir -p /var/log/supervisor /var/run/supervisor
|
||||
|
||||
# Create user and set permissions
|
||||
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files
|
||||
USER kb
|
||||
|
@ -2,12 +2,6 @@ FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements files and setup
|
||||
COPY requirements-*.txt ./
|
||||
COPY setup.py ./
|
||||
COPY src/ ./src/
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libpq-dev gcc pandoc \
|
||||
texlive-xetex texlive-fonts-recommended texlive-plain-generic \
|
||||
@ -18,15 +12,25 @@ RUN apt-get update && apt-get install -y \
|
||||
# For optional LibreOffice support (uncomment if needed)
|
||||
# libreoffice-writer \
|
||||
&& apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements-*.txt ./
|
||||
RUN pip install --no-cache-dir -r requirements-common.txt
|
||||
RUN pip install --no-cache-dir -r requirements-parsers.txt
|
||||
RUN pip install --no-cache-dir -r requirements-workers.txt
|
||||
|
||||
# Install Python dependencies
|
||||
COPY setup.py ./
|
||||
COPY src/ ./src/
|
||||
RUN pip install -e ".[workers]"
|
||||
|
||||
# Create and copy entrypoint script
|
||||
# Copy entrypoint scripts and set permissions
|
||||
COPY docker/workers/entry.sh ./entry.sh
|
||||
COPY docker/workers/unnest-table.lua ./unnest-table.lua
|
||||
RUN chmod +x entry.sh
|
||||
|
||||
RUN mkdir -p /app/memory_files
|
||||
|
||||
COPY docker/workers/unnest-table.lua ./unnest-table.lua
|
||||
|
||||
# Create user and set permissions
|
||||
RUN useradd -m kb
|
||||
RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||
@ -35,7 +39,7 @@ RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
||||
USER kb
|
||||
|
||||
# Default queues to process
|
||||
ENV QUEUES="docs,email,maintenance"
|
||||
ENV QUEUES="ebooks,email,comic,blogs,photo_embed,maintenance"
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
ENTRYPOINT ["./entry.sh"]
|
@ -5,8 +5,3 @@ alembic==1.13.1
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
PyMuPDF==1.25.5
|
||||
ebooklib==0.18.0
|
||||
beautifulsoup4==4.13.4
|
||||
markdownify==0.13.1
|
||||
pillow==10.4.0
|
5
requirements-parsers.txt
Normal file
5
requirements-parsers.txt
Normal file
@ -0,0 +1,5 @@
|
||||
PyMuPDF==1.25.5
|
||||
ebooklib==0.18.0
|
||||
beautifulsoup4==4.13.4
|
||||
markdownify==0.13.1
|
||||
pillow==10.4.0
|
@ -1,6 +1,6 @@
|
||||
celery==5.3.6
|
||||
openai==1.25.0
|
||||
pillow==10.3.0
|
||||
pillow==10.4.0
|
||||
pypandoc==1.15.0
|
||||
beautifulsoup4==4.13.4
|
||||
feedparser==6.0.10
|
9
setup.py
9
setup.py
@ -14,6 +14,7 @@ def read_requirements(filename: str) -> list[str]:
|
||||
|
||||
# Read requirements files
|
||||
common_requires = read_requirements("requirements-common.txt")
|
||||
parsers_requires = read_requirements("requirements-parsers.txt")
|
||||
api_requires = read_requirements("requirements-api.txt")
|
||||
workers_requires = read_requirements("requirements-workers.txt")
|
||||
dev_requires = read_requirements("requirements-dev.txt")
|
||||
@ -26,9 +27,13 @@ setup(
|
||||
python_requires=">=3.10",
|
||||
extras_require={
|
||||
"api": api_requires + common_requires,
|
||||
"workers": workers_requires + common_requires,
|
||||
"workers": workers_requires + common_requires + parsers_requires,
|
||||
"common": common_requires,
|
||||
"dev": dev_requires,
|
||||
"all": api_requires + workers_requires + common_requires + dev_requires,
|
||||
"all": api_requires
|
||||
+ workers_requires
|
||||
+ common_requires
|
||||
+ dev_requires
|
||||
+ parsers_requires,
|
||||
},
|
||||
)
|
||||
|
@ -15,7 +15,7 @@ from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common import embedding, qdrant, extract, settings
|
||||
from memory.common.collections import get_modality, TEXT_COLLECTIONS
|
||||
from memory.common.collections import get_modality, TEXT_COLLECTIONS, ALL_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
|
||||
@ -95,8 +95,8 @@ def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[Search
|
||||
return [
|
||||
SearchResult(
|
||||
id=source.id,
|
||||
size=source.size,
|
||||
mime_type=source.mime_type,
|
||||
size=source.size or len(source.content),
|
||||
mime_type=source.mime_type or "text/plain",
|
||||
filename=source.filename
|
||||
and source.filename.replace(
|
||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||
@ -110,7 +110,7 @@ def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[Search
|
||||
|
||||
def query_chunks(
|
||||
client: qdrant_client.QdrantClient,
|
||||
upload_data: list[tuple[str, list[extract.Page]]],
|
||||
upload_data: list[extract.DataChunk],
|
||||
allowed_modalities: set[str],
|
||||
embedder: Callable,
|
||||
min_score: float = 0.0,
|
||||
@ -119,15 +119,9 @@ def query_chunks(
|
||||
if not upload_data:
|
||||
return {}
|
||||
|
||||
chunks = [
|
||||
chunk
|
||||
for content_type, pages in upload_data
|
||||
if get_modality(content_type) in allowed_modalities
|
||||
for page in pages
|
||||
for chunk in page["contents"]
|
||||
]
|
||||
|
||||
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
|
||||
if not chunks:
|
||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||
return {}
|
||||
|
||||
vector = embedder(chunks, input_type="query")[0]
|
||||
@ -143,18 +137,18 @@ def query_chunks(
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
for collection in embedding.ALL_COLLECTIONS
|
||||
for collection in allowed_modalities
|
||||
}
|
||||
|
||||
|
||||
async def input_type(item: str | UploadFile) -> tuple[str, list[extract.Page]]:
|
||||
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||
if not item:
|
||||
return "text/plain", []
|
||||
return []
|
||||
|
||||
if isinstance(item, str):
|
||||
return "text/plain", extract.extract_text(item)
|
||||
return extract.extract_text(item)
|
||||
content_type = item.content_type or "application/octet-stream"
|
||||
return content_type, extract.extract_content(content_type, await item.read())
|
||||
return extract.extract_data_chunks(content_type, await item.read())
|
||||
|
||||
|
||||
@app.post("/search", response_model=list[SearchResult])
|
||||
@ -179,13 +173,15 @@ async def search(
|
||||
Returns:
|
||||
- List of search results sorted by score
|
||||
"""
|
||||
upload_data = [await input_type(item) for item in [query, *files]]
|
||||
upload_data = [
|
||||
chunk for item in [query, *files] for chunk in await input_type(item)
|
||||
]
|
||||
logger.error(
|
||||
f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}"
|
||||
)
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
allowed_modalities = set(modalities or embedding.ALL_COLLECTIONS.keys())
|
||||
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
||||
text_results = query_chunks(
|
||||
client,
|
||||
upload_data,
|
||||
@ -212,6 +208,7 @@ async def search(
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
logger.error(f"Found chunks: {chunks}")
|
||||
|
||||
results = group_chunks(
|
||||
[
|
||||
@ -245,3 +242,16 @@ def get_file_by_path(path: str):
|
||||
raise HTTPException(status_code=404, detail=f"File not found at path: {path}")
|
||||
|
||||
return FileResponse(path=file_path, filename=file_path.name)
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the FastAPI server in debug mode with auto-reloading."""
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"memory.api.app:app", host="0.0.0.0", port=8000, reload=True, log_level="debug"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -33,7 +33,6 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.parsers.email import EmailMessage, parse_email_message
|
||||
import memory.common.extract as extract
|
||||
import memory.common.collections as collections
|
||||
import memory.common.chunker as chunker
|
||||
@ -126,8 +125,8 @@ class Chunk(Base):
|
||||
embedding_model = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
vector: ClassVar[list[float] | None] = None
|
||||
item_metadata: ClassVar[dict[str, Any] | None] = None
|
||||
vector: list[float] = []
|
||||
item_metadata: dict[str, Any] = {}
|
||||
images: list[Image.Image] = []
|
||||
|
||||
# One of file_path or content must be populated
|
||||
@ -275,7 +274,7 @@ class MailMessage(SourceItem):
|
||||
def attachments_path(self) -> pathlib.Path:
|
||||
clean_sender = clean_filename(cast(str, self.sender))
|
||||
clean_folder = clean_filename(cast(str | None, self.folder) or "INBOX")
|
||||
return pathlib.Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder
|
||||
return pathlib.Path(settings.EMAIL_STORAGE_DIR) / clean_sender / clean_folder
|
||||
|
||||
def safe_filename(self, filename: str) -> pathlib.Path:
|
||||
suffix = pathlib.Path(filename).suffix
|
||||
@ -297,7 +296,9 @@ class MailMessage(SourceItem):
|
||||
}
|
||||
|
||||
@property
|
||||
def parsed_content(self) -> EmailMessage:
|
||||
def parsed_content(self):
|
||||
from memory.parsers.email import parse_email_message
|
||||
|
||||
return parse_email_message(cast(str, self.content), cast(str, self.message_id))
|
||||
|
||||
@property
|
||||
@ -563,6 +564,8 @@ class BookSection(SourceItem):
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
vals = {
|
||||
"title": self.book.title,
|
||||
"author": self.book.author,
|
||||
"source_id": self.id,
|
||||
"book_id": self.book_id,
|
||||
"section_title": self.section_title,
|
||||
@ -636,7 +639,9 @@ class BlogPost(SourceItem):
|
||||
return {k: v for k, v in payload.items() if v}
|
||||
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
images = [Image.open(image) for image in self.images]
|
||||
images = [
|
||||
Image.open(settings.FILE_STORAGE_DIR / image) for image in self.images
|
||||
]
|
||||
|
||||
content = cast(str, self.content)
|
||||
full_text = [content.strip(), *images]
|
||||
@ -705,6 +710,9 @@ class ArticleFeed(Base):
|
||||
title = Column(Text)
|
||||
description = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
check_interval = Column(
|
||||
Integer, nullable=False, server_default="3600", doc="Seconds between checks"
|
||||
)
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
|
@ -20,6 +20,7 @@ def embed_chunks(
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]}")
|
||||
vo = voyageai.Client() # type: ignore
|
||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||
return vo.multimodal_embed(
|
||||
@ -71,18 +72,24 @@ def embed_mixed(
|
||||
return embed_chunks([chunks], model, input_type)
|
||||
|
||||
|
||||
def embed_chunk(chunk: Chunk) -> Chunk:
|
||||
model = cast(str, chunk.embedding_model)
|
||||
if model == settings.TEXT_EMBEDDING_MODEL:
|
||||
content = cast(str, chunk.content)
|
||||
elif model == settings.MIXED_EMBEDDING_MODEL:
|
||||
content = [cast(str, chunk.content)] + chunk.images
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {chunk.embedding_model}")
|
||||
vectors = embed_chunks([content], model) # type: ignore
|
||||
chunk.vector = vectors[0] # type: ignore
|
||||
return chunk
|
||||
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
||||
model_chunks = [
|
||||
chunk for chunk in chunks if cast(str, chunk.embedding_model) == model
|
||||
]
|
||||
if not model_chunks:
|
||||
return []
|
||||
|
||||
vectors = embed_chunks([chunk.content for chunk in model_chunks], model) # type: ignore
|
||||
for chunk, vector in zip(model_chunks, vectors):
|
||||
chunk.vector = vector
|
||||
return model_chunks
|
||||
|
||||
|
||||
def embed_source_item(item: SourceItem) -> list[Chunk]:
|
||||
return [embed_chunk(chunk) for chunk in item.data_chunks()]
|
||||
chunks = list(item.data_chunks())
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
text_chunks = embed_by_model(chunks, settings.TEXT_EMBEDDING_MODEL)
|
||||
mixed_chunks = embed_by_model(chunks, settings.MIXED_EMBEDDING_MODEL)
|
||||
return text_chunks + mixed_chunks
|
||||
|
@ -20,6 +20,7 @@ MulitmodalChunk = Image.Image | str
|
||||
class DataChunk:
|
||||
data: Sequence[MulitmodalChunk]
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
mime_type: str = "text/plain"
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -36,7 +37,9 @@ def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None
|
||||
|
||||
def page_to_image(page: pymupdf.Page) -> Image.Image:
|
||||
pix = page.get_pixmap() # type: ignore
|
||||
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
image.format = "jpeg"
|
||||
return image
|
||||
|
||||
|
||||
def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
@ -50,6 +53,7 @@ def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
"width": page.rect.width,
|
||||
"height": page.rect.height,
|
||||
},
|
||||
mime_type="image/jpeg",
|
||||
)
|
||||
for page in pdf.pages()
|
||||
]
|
||||
@ -99,7 +103,8 @@ def extract_image(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
image = Image.open(io.BytesIO(content))
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
return [DataChunk(data=[image])]
|
||||
image_format = image.format or "jpeg"
|
||||
return [DataChunk(data=[image], mime_type=f"image/{image_format.lower()}")]
|
||||
|
||||
|
||||
def extract_text(
|
||||
@ -112,7 +117,7 @@ def extract_text(
|
||||
|
||||
content = cast(str, content)
|
||||
chunks = chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
|
||||
return [DataChunk(data=[c]) for c in chunks]
|
||||
return [DataChunk(data=[c], mime_type="text/plain") for c in chunks if c.strip()]
|
||||
|
||||
|
||||
def extract_data_chunks(
|
||||
|
@ -39,20 +39,32 @@ CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||
|
||||
# File storage settings
|
||||
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
|
||||
EBOOK_STORAGE_DIR = pathlib.Path(
|
||||
os.getenv("EBOOK_STORAGE_DIR", FILE_STORAGE_DIR / "ebooks")
|
||||
)
|
||||
EMAIL_STORAGE_DIR = pathlib.Path(
|
||||
os.getenv("EMAIL_STORAGE_DIR", FILE_STORAGE_DIR / "emails")
|
||||
)
|
||||
CHUNK_STORAGE_DIR = pathlib.Path(
|
||||
os.getenv("CHUNK_STORAGE_DIR", FILE_STORAGE_DIR / "chunks")
|
||||
)
|
||||
COMIC_STORAGE_DIR = pathlib.Path(
|
||||
os.getenv("COMIC_STORAGE_DIR", FILE_STORAGE_DIR / "comics")
|
||||
)
|
||||
PHOTO_STORAGE_DIR = pathlib.Path(
|
||||
os.getenv("PHOTO_STORAGE_DIR", FILE_STORAGE_DIR / "photos")
|
||||
)
|
||||
WEBPAGE_STORAGE_DIR = pathlib.Path(
|
||||
os.getenv("WEBPAGE_STORAGE_DIR", FILE_STORAGE_DIR / "webpages")
|
||||
)
|
||||
|
||||
storage_dirs = [
|
||||
FILE_STORAGE_DIR,
|
||||
EBOOK_STORAGE_DIR,
|
||||
EMAIL_STORAGE_DIR,
|
||||
CHUNK_STORAGE_DIR,
|
||||
COMIC_STORAGE_DIR,
|
||||
PHOTO_STORAGE_DIR,
|
||||
WEBPAGE_STORAGE_DIR,
|
||||
]
|
||||
for dir in storage_dirs:
|
||||
@ -83,3 +95,4 @@ CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60
|
||||
# Embedding settings
|
||||
TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large")
|
||||
MIXED_EMBEDDING_MODEL = os.getenv("MIXED_EMBEDDING_MODEL", "voyage-multimodal-3")
|
||||
EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", 50))
|
||||
|
@ -1,26 +1,22 @@
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generator, cast
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from memory.common.parsers.blogs import is_substack
|
||||
|
||||
from memory.common.parsers.feeds import (
|
||||
from memory.parsers.blogs import is_substack
|
||||
from memory.parsers.feeds import (
|
||||
DanluuParser,
|
||||
HTMLListParser,
|
||||
RiftersParser,
|
||||
FeedItem,
|
||||
FeedParser,
|
||||
HTMLListParser,
|
||||
RiftersParser,
|
||||
SubstackAPIParser,
|
||||
)
|
||||
from memory.common.parsers.html import (
|
||||
fetch_html,
|
||||
extract_url,
|
||||
get_base_url,
|
||||
)
|
||||
from memory.parsers.html import extract_url, fetch_html, get_base_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -6,7 +6,7 @@ from typing import cast
|
||||
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
from memory.common.parsers.html import (
|
||||
from memory.parsers.html import (
|
||||
BaseHTMLParser,
|
||||
Article,
|
||||
parse_date,
|
@ -1,10 +1,10 @@
|
||||
import email
|
||||
import hashlib
|
||||
import logging
|
||||
import pathlib
|
||||
from datetime import datetime
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import TypedDict
|
||||
import pathlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -10,7 +10,7 @@ import feedparser
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
import requests
|
||||
|
||||
from memory.common.parsers.html import (
|
||||
from memory.parsers.html import (
|
||||
get_base_url,
|
||||
to_absolute_url,
|
||||
extract_title,
|
@ -1,11 +1,11 @@
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
import logging
|
||||
import pathlib
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
import pathlib
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
import hashlib
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
@ -153,6 +153,7 @@ def process_image(url: str, image_dir: pathlib.Path) -> PILImage.Image | None:
|
||||
ext = pathlib.Path(urlparse(url).path).suffix or ".jpg"
|
||||
filename = f"{url_hash}{ext}"
|
||||
local_path = image_dir / filename
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download if not already cached
|
||||
if not local_path.exists():
|
@ -21,12 +21,11 @@ app.conf.update(
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
task_routes={
|
||||
"memory.workers.tasks.text.*": {"queue": "medium_embed"},
|
||||
"memory.workers.tasks.email.*": {"queue": "email"},
|
||||
"memory.workers.tasks.photo.*": {"queue": "photo_embed"},
|
||||
"memory.workers.tasks.comic.*": {"queue": "comic"},
|
||||
"memory.workers.tasks.ebook.*": {"queue": "ebooks"},
|
||||
"memory.workers.tasks.blogs.*": {"queue": "blogs"},
|
||||
"memory.workers.tasks.docs.*": {"queue": "docs"},
|
||||
"memory.workers.tasks.maintenance.*": {"queue": "maintenance"},
|
||||
},
|
||||
)
|
||||
|
@ -1,7 +1,6 @@
|
||||
import hashlib
|
||||
import imaplib
|
||||
import logging
|
||||
import pathlib
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
@ -16,7 +15,7 @@ from memory.common.db.models import (
|
||||
EmailAttachment,
|
||||
MailMessage,
|
||||
)
|
||||
from memory.common.parsers.email import (
|
||||
from memory.parsers.email import (
|
||||
Attachment,
|
||||
EmailMessage,
|
||||
RawEmailResponse,
|
||||
|
@ -2,26 +2,42 @@
|
||||
Import sub-modules so Celery can register their @app.task decorators.
|
||||
"""
|
||||
|
||||
from memory.workers.tasks import docs, email, comic, blogs # noqa
|
||||
from memory.workers.tasks.blogs import SYNC_WEBPAGE
|
||||
from memory.workers.tasks import email, comic, blogs, ebook # noqa
|
||||
from memory.workers.tasks.blogs import (
|
||||
SYNC_WEBPAGE,
|
||||
SYNC_ARTICLE_FEED,
|
||||
SYNC_ALL_ARTICLE_FEEDS,
|
||||
SYNC_WEBSITE_ARCHIVE,
|
||||
)
|
||||
from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_SMBC, SYNC_XKCD
|
||||
from memory.workers.tasks.ebook import SYNC_BOOK
|
||||
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
|
||||
from memory.workers.tasks.maintenance import (
|
||||
CLEAN_ALL_COLLECTIONS,
|
||||
CLEAN_COLLECTION,
|
||||
REINGEST_MISSING_CHUNKS,
|
||||
REINGEST_CHUNK,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"docs",
|
||||
"email",
|
||||
"comic",
|
||||
"blogs",
|
||||
"ebook",
|
||||
"SYNC_WEBPAGE",
|
||||
"SYNC_ARTICLE_FEED",
|
||||
"SYNC_ALL_ARTICLE_FEEDS",
|
||||
"SYNC_WEBSITE_ARCHIVE",
|
||||
"SYNC_ALL_COMICS",
|
||||
"SYNC_SMBC",
|
||||
"SYNC_XKCD",
|
||||
"SYNC_BOOK",
|
||||
"SYNC_ACCOUNT",
|
||||
"SYNC_ALL_ACCOUNTS",
|
||||
"PROCESS_EMAIL",
|
||||
"CLEAN_ALL_COLLECTIONS",
|
||||
"CLEAN_COLLECTION",
|
||||
"REINGEST_MISSING_CHUNKS",
|
||||
"REINGEST_CHUNK",
|
||||
]
|
||||
|
@ -1,9 +1,12 @@
|
||||
import logging
|
||||
from typing import Iterable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Iterable, cast
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import BlogPost
|
||||
from memory.common.parsers.blogs import parse_webpage
|
||||
from memory.common.db.models import ArticleFeed, BlogPost
|
||||
from memory.parsers.blogs import parse_webpage
|
||||
from memory.parsers.feeds import get_feed_parser
|
||||
from memory.parsers.archives import get_archive_fetcher
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
@ -16,6 +19,9 @@ from memory.workers.tasks.content_processing import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_WEBPAGE = "memory.workers.tasks.blogs.sync_webpage"
|
||||
SYNC_ARTICLE_FEED = "memory.workers.tasks.blogs.sync_article_feed"
|
||||
SYNC_ALL_ARTICLE_FEEDS = "memory.workers.tasks.blogs.sync_all_article_feeds"
|
||||
SYNC_WEBSITE_ARCHIVE = "memory.workers.tasks.blogs.sync_website_archive"
|
||||
|
||||
|
||||
@app.task(name=SYNC_WEBPAGE)
|
||||
@ -31,7 +37,9 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
|
||||
Returns:
|
||||
dict: Summary of what was processed
|
||||
"""
|
||||
logger.info(f"Syncing webpage: {url}")
|
||||
article = parse_webpage(url)
|
||||
logger.debug(f"Article: {article.title} - {article.url}")
|
||||
|
||||
if not article.content:
|
||||
logger.warning(f"Article content too short or empty: {url}")
|
||||
@ -57,10 +65,158 @@ def sync_webpage(url: str, tags: Iterable[str] = []) -> dict:
|
||||
|
||||
with make_session() as session:
|
||||
existing_post = check_content_exists(
|
||||
session, BlogPost, url=url, sha256=create_content_hash(article.content)
|
||||
session, BlogPost, url=article.url, sha256=blog_post.sha256
|
||||
)
|
||||
if existing_post:
|
||||
logger.info(f"Blog post already exists: {existing_post.title}")
|
||||
return create_task_result(existing_post, "already_exists", url=url)
|
||||
return create_task_result(existing_post, "already_exists", url=article.url)
|
||||
|
||||
return process_content_item(blog_post, "blog", session, tags)
|
||||
|
||||
|
||||
@app.task(name=SYNC_ARTICLE_FEED)
|
||||
@safe_task_execution
|
||||
def sync_article_feed(feed_id: int) -> dict:
|
||||
"""
|
||||
Synchronize articles from a specific ArticleFeed.
|
||||
|
||||
Args:
|
||||
feed_id: ID of the ArticleFeed to sync
|
||||
|
||||
Returns:
|
||||
dict: Summary of sync operation including stats
|
||||
"""
|
||||
with make_session() as session:
|
||||
feed = session.query(ArticleFeed).filter(ArticleFeed.id == feed_id).first()
|
||||
if not feed or not cast(bool, feed.active):
|
||||
logger.warning(f"Feed {feed_id} not found or inactive")
|
||||
return {"status": "error", "error": "Feed not found or inactive"}
|
||||
|
||||
last_checked_at = cast(datetime | None, feed.last_checked_at)
|
||||
if last_checked_at and datetime.now() - last_checked_at < timedelta(
|
||||
seconds=cast(int, feed.check_interval)
|
||||
):
|
||||
logger.info(f"Feed {feed_id} checked too recently, skipping")
|
||||
return {"status": "skipped_recent_check", "feed_id": feed_id}
|
||||
|
||||
logger.info(f"Syncing feed: {feed.title} ({feed.url})")
|
||||
|
||||
parser = get_feed_parser(cast(str, feed.url), last_checked_at)
|
||||
if not parser:
|
||||
logger.error(f"No parser available for feed: {feed.url}")
|
||||
return {"status": "error", "error": "No parser available for feed"}
|
||||
|
||||
articles_found = 0
|
||||
new_articles = 0
|
||||
errors = 0
|
||||
task_ids = []
|
||||
|
||||
try:
|
||||
for feed_item in parser.parse_feed():
|
||||
articles_found += 1
|
||||
|
||||
existing = check_content_exists(session, BlogPost, url=feed_item.url)
|
||||
if existing:
|
||||
continue
|
||||
|
||||
feed_tags = cast(list[str] | None, feed.tags) or []
|
||||
task_ids.append(sync_webpage.delay(feed_item.url, feed_tags).id)
|
||||
new_articles += 1
|
||||
|
||||
logger.info(f"Scheduled sync for: {feed_item.title} ({feed_item.url})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing feed {feed.url}: {e}")
|
||||
errors += 1
|
||||
|
||||
feed.last_checked_at = datetime.now() # type: ignore
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"feed_id": feed_id,
|
||||
"feed_title": feed.title,
|
||||
"feed_url": feed.url,
|
||||
"articles_found": articles_found,
|
||||
"new_articles": new_articles,
|
||||
"errors": errors,
|
||||
"task_ids": task_ids,
|
||||
}
|
||||
|
||||
|
||||
@app.task(name=SYNC_ALL_ARTICLE_FEEDS)
|
||||
def sync_all_article_feeds() -> list[dict]:
|
||||
"""
|
||||
Trigger sync for all active ArticleFeeds.
|
||||
|
||||
Returns:
|
||||
List of task results for each feed sync
|
||||
"""
|
||||
with make_session() as session:
|
||||
active_feeds = session.query(ArticleFeed).filter(ArticleFeed.active).all()
|
||||
|
||||
results = [
|
||||
{
|
||||
"feed_id": feed.id,
|
||||
"feed_title": feed.title,
|
||||
"feed_url": feed.url,
|
||||
"task_id": sync_article_feed.delay(feed.id).id,
|
||||
}
|
||||
for feed in active_feeds
|
||||
]
|
||||
logger.info(f"Scheduled sync for {len(results)} active feeds")
|
||||
return results
|
||||
|
||||
|
||||
@app.task(name=SYNC_WEBSITE_ARCHIVE)
|
||||
@safe_task_execution
|
||||
def sync_website_archive(
|
||||
url: str, tags: Iterable[str] = [], max_pages: int = 100
|
||||
) -> dict:
|
||||
"""
|
||||
Synchronize all articles from a website's archive.
|
||||
|
||||
Args:
|
||||
url: Base URL of the website to sync
|
||||
tags: Additional tags to apply to all articles
|
||||
max_pages: Maximum number of pages to process
|
||||
|
||||
Returns:
|
||||
dict: Summary of archive sync operation
|
||||
"""
|
||||
logger.info(f"Starting archive sync for: {url}")
|
||||
|
||||
# Get archive fetcher for the website
|
||||
fetcher = get_archive_fetcher(url)
|
||||
if not fetcher:
|
||||
logger.error(f"No archive fetcher available for: {url}")
|
||||
return {"status": "error", "error": "No archive fetcher available"}
|
||||
|
||||
# Override max_pages if provided
|
||||
fetcher.max_pages = max_pages
|
||||
|
||||
articles_found = 0
|
||||
new_articles = 0
|
||||
task_ids = []
|
||||
|
||||
for feed_item in fetcher.fetch_all_items():
|
||||
articles_found += 1
|
||||
|
||||
with make_session() as session:
|
||||
existing = check_content_exists(session, BlogPost, url=feed_item.url)
|
||||
if existing:
|
||||
continue
|
||||
|
||||
task_ids.append(sync_webpage.delay(feed_item.url, list(tags)).id)
|
||||
new_articles += 1
|
||||
|
||||
logger.info(f"Scheduled sync for: {feed_item.title} ({feed_item.url})")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"website_url": url,
|
||||
"articles_found": articles_found,
|
||||
"new_articles": new_articles,
|
||||
"task_ids": task_ids,
|
||||
"max_pages_processed": fetcher.max_pages,
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import requests
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Comic, clean_filename
|
||||
from memory.common.parsers import comics
|
||||
from memory.parsers import comics
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
|
@ -221,6 +221,7 @@ def process_content_item(
|
||||
try:
|
||||
push_to_qdrant([item], collection_name)
|
||||
status = "processed"
|
||||
item.embed_status = "STORED" # type: ignore
|
||||
logger.info(
|
||||
f"Successfully processed {type(item).__name__}: {getattr(item, 'title', 'unknown')} ({chunks_count} chunks embedded)"
|
||||
)
|
||||
@ -228,7 +229,6 @@ def process_content_item(
|
||||
logger.error(f"Failed to push embeddings to Qdrant: {e}")
|
||||
item.embed_status = "FAILED" # type: ignore
|
||||
status = "failed"
|
||||
|
||||
session.commit()
|
||||
|
||||
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
||||
|
@ -1,5 +0,0 @@
|
||||
from memory.workers.celery_app import app
|
||||
|
||||
@app.task(name="kb.text.ping")
|
||||
def ping():
|
||||
return "pong"
|
@ -1,9 +1,10 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import pathlib
|
||||
from typing import Iterable, cast
|
||||
|
||||
import memory.common.settings as settings
|
||||
from memory.parsers.ebook import Ebook, parse_ebook, Section
|
||||
from memory.common.db.models import Book, BookSection
|
||||
from memory.common.parsers.ebook import Ebook, parse_ebook, Section
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.tasks.content_processing import (
|
||||
@ -16,7 +17,7 @@ from memory.workers.tasks.content_processing import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_BOOK = "memory.workers.tasks.book.sync_book"
|
||||
SYNC_BOOK = "memory.workers.tasks.ebook.sync_book"
|
||||
|
||||
# Minimum section length to embed (avoid noise from very short sections)
|
||||
MIN_SECTION_LENGTH = 100
|
||||
@ -59,6 +60,8 @@ def section_processor(
|
||||
end_page=section.end_page,
|
||||
parent_section_id=None, # Will be set after flush
|
||||
content=content,
|
||||
size=len(content),
|
||||
mime_type="text/plain",
|
||||
sha256=create_content_hash(
|
||||
f"{book.id}:{section.title}:{section.start_page}"
|
||||
),
|
||||
@ -94,7 +97,13 @@ def create_all_sections(
|
||||
|
||||
def validate_and_parse_book(file_path: str) -> Ebook:
|
||||
"""Validate file exists and parse the ebook."""
|
||||
path = Path(file_path)
|
||||
logger.info(f"Validating and parsing book from {file_path}")
|
||||
path = pathlib.Path(file_path)
|
||||
if not path.is_absolute():
|
||||
path = settings.EBOOK_STORAGE_DIR / path
|
||||
|
||||
logger.info(f"Resolved path: {path}")
|
||||
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Book file not found: {path}")
|
||||
|
||||
@ -145,6 +154,7 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
||||
dict: Summary of what was processed
|
||||
"""
|
||||
ebook = validate_and_parse_book(file_path)
|
||||
logger.info(f"Ebook parsed: {ebook.title}")
|
||||
|
||||
with make_session() as session:
|
||||
# Check for existing book
|
||||
@ -161,14 +171,22 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
||||
"sections_processed": 0,
|
||||
}
|
||||
|
||||
logger.info("Creating book and sections with relationships")
|
||||
# Create book and sections with relationships
|
||||
book, all_sections = create_book_and_sections(ebook, session, tags)
|
||||
|
||||
# Embed sections
|
||||
logger.info("Embedding sections")
|
||||
embedded_count = sum(embed_source_item(section) for section in all_sections)
|
||||
session.flush()
|
||||
|
||||
for section in all_sections:
|
||||
logger.info(
|
||||
f"Embedded section: {section.section_title} - {section.content[:100]}"
|
||||
)
|
||||
logger.info("Pushing to Qdrant")
|
||||
push_to_qdrant(all_sections, "book")
|
||||
logger.info("Committing session")
|
||||
|
||||
session.commit()
|
||||
|
||||
|
@ -10,7 +10,7 @@ from memory.workers.email import (
|
||||
process_folder,
|
||||
vectorize_email,
|
||||
)
|
||||
from memory.common.parsers.email import parse_email_message
|
||||
from memory.parsers.email import parse_email_message
|
||||
from memory.workers.tasks.content_processing import (
|
||||
check_content_exists,
|
||||
safe_task_execution,
|
||||
|
@ -23,6 +23,10 @@ REINGEST_CHUNK = "memory.workers.tasks.maintenance.reingest_chunk"
|
||||
@app.task(name=CLEAN_COLLECTION)
|
||||
def clean_collection(collection: str) -> dict[str, int]:
|
||||
logger.info(f"Cleaning collection {collection}")
|
||||
|
||||
if collection not in collections.ALL_COLLECTIONS:
|
||||
raise ValueError(f"Unsupported collection {collection}")
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
batches, deleted, checked = 0, 0, 0
|
||||
for batch in qdrant.batch_ids(client, collection):
|
||||
@ -47,7 +51,7 @@ def clean_collection(collection: str) -> dict[str, int]:
|
||||
@app.task(name=CLEAN_ALL_COLLECTIONS)
|
||||
def clean_all_collections():
|
||||
logger.info("Cleaning all collections")
|
||||
for collection in embedding.ALL_COLLECTIONS:
|
||||
for collection in collections.ALL_COLLECTIONS:
|
||||
clean_collection.delay(collection) # type: ignore
|
||||
|
||||
|
||||
@ -111,10 +115,12 @@ def check_batch(batch: Sequence[Chunk]) -> dict:
|
||||
|
||||
|
||||
@app.task(name=REINGEST_MISSING_CHUNKS)
|
||||
def reingest_missing_chunks(batch_size: int = 1000):
|
||||
def reingest_missing_chunks(
|
||||
batch_size: int = 1000, minutes_ago: int = settings.CHUNK_REINGEST_SINCE_MINUTES
|
||||
):
|
||||
logger.info("Reingesting missing chunks")
|
||||
total_stats = defaultdict(lambda: {"missing": 0, "correct": 0, "total": 0})
|
||||
since = datetime.now() - timedelta(minutes=settings.CHUNK_REINGEST_SINCE_MINUTES)
|
||||
since = datetime.now() - timedelta(minutes=minutes_ago)
|
||||
|
||||
with make_session() as session:
|
||||
total_count = session.query(Chunk).filter(Chunk.checked_at < since).count()
|
||||
|
@ -753,7 +753,65 @@ def test_email_attachment_cascade_delete(db_session: Session):
|
||||
assert deleted_attachment is None
|
||||
|
||||
|
||||
# BookSection tests
|
||||
def test_subclass_deletion_cascades_to_source_item(db_session: Session):
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test_email_cascade",
|
||||
content="test email content",
|
||||
message_id="<cascade_test@example.com>",
|
||||
subject="Cascade Test",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
folder="INBOX",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.commit()
|
||||
|
||||
source_item_id = mail_message.id
|
||||
mail_message_id = mail_message.id
|
||||
|
||||
# Verify both records exist
|
||||
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is not None
|
||||
assert (
|
||||
db_session.query(MailMessage).filter_by(id=mail_message_id).first() is not None
|
||||
)
|
||||
|
||||
# Delete the MailMessage subclass
|
||||
db_session.delete(mail_message)
|
||||
db_session.commit()
|
||||
|
||||
# Verify both the MailMessage and SourceItem records are deleted
|
||||
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
|
||||
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
|
||||
|
||||
|
||||
def test_subclass_deletion_cascades_from_source_item(db_session: Session):
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test_email_cascade",
|
||||
content="test email content",
|
||||
message_id="<cascade_test@example.com>",
|
||||
subject="Cascade Test",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
folder="INBOX",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.commit()
|
||||
|
||||
source_item_id = mail_message.id
|
||||
mail_message_id = mail_message.id
|
||||
|
||||
# Verify both records exist
|
||||
source_item = db_session.query(SourceItem).get(source_item_id)
|
||||
assert source_item
|
||||
assert db_session.query(MailMessage).get(mail_message_id)
|
||||
|
||||
# Delete the MailMessage subclass
|
||||
db_session.delete(source_item)
|
||||
db_session.commit()
|
||||
|
||||
# Verify both the MailMessage and SourceItem records are deleted
|
||||
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
|
||||
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -3,7 +3,7 @@ from urllib.parse import urlparse, parse_qs
|
||||
|
||||
import pytest
|
||||
|
||||
from memory.common.parsers.archives import (
|
||||
from memory.parsers.archives import (
|
||||
ArchiveFetcher,
|
||||
LinkFetcher,
|
||||
HTMLArchiveFetcher,
|
||||
@ -14,7 +14,7 @@ from memory.common.parsers.archives import (
|
||||
get_archive_fetcher,
|
||||
FETCHER_REGISTRY,
|
||||
)
|
||||
from memory.common.parsers.feeds import (
|
||||
from memory.parsers.feeds import (
|
||||
FeedItem,
|
||||
FeedParser,
|
||||
HTMLListParser,
|
||||
@ -56,7 +56,7 @@ def test_archive_fetcher_find_next_page_base():
|
||||
assert fetcher._find_next_page(parser, 0) is None
|
||||
|
||||
|
||||
@patch("memory.common.parsers.archives.time.sleep")
|
||||
@patch("memory.parsers.archives.time.sleep")
|
||||
def test_archive_fetcher_fetch_all_items_single_page(mock_sleep):
|
||||
items = [
|
||||
FeedItem(title="Item 1", url="https://example.com/1"),
|
||||
@ -80,7 +80,7 @@ def test_archive_fetcher_fetch_all_items_single_page(mock_sleep):
|
||||
mock_sleep.assert_not_called() # No delay for single page
|
||||
|
||||
|
||||
@patch("memory.common.parsers.archives.time.sleep")
|
||||
@patch("memory.parsers.archives.time.sleep")
|
||||
def test_archive_fetcher_fetch_all_items_multiple_pages(mock_sleep):
|
||||
page1_items = [FeedItem(title="Item 1", url="https://example.com/1")]
|
||||
page2_items = [FeedItem(title="Item 2", url="https://example.com/2")]
|
||||
@ -258,7 +258,7 @@ def test_html_archive_fetcher_find_next_page(html, selectors, expected_url):
|
||||
)
|
||||
parser = MockParser("https://example.com", content=html)
|
||||
|
||||
with patch("memory.common.parsers.archives.extract_url") as mock_extract:
|
||||
with patch("memory.parsers.archives.extract_url") as mock_extract:
|
||||
mock_extract.return_value = expected_url
|
||||
|
||||
result = fetcher._find_next_page(parser)
|
||||
@ -308,7 +308,7 @@ def test_html_parser_factory():
|
||||
],
|
||||
)
|
||||
def test_substack_archive_fetcher_post_init(start_url, expected_api_url):
|
||||
with patch("memory.common.parsers.archives.get_base_url") as mock_get_base:
|
||||
with patch("memory.parsers.archives.get_base_url") as mock_get_base:
|
||||
mock_get_base.return_value = "https://example.substack.com"
|
||||
|
||||
fetcher = SubstackArchiveFetcher(SubstackAPIParser, start_url)
|
||||
@ -413,10 +413,10 @@ def test_html_next_url_archive_fetcher_find_next_page():
|
||||
],
|
||||
)
|
||||
def test_get_archive_fetcher_registry_matches(url, expected_fetcher_type):
|
||||
with patch("memory.common.parsers.archives.fetch_html") as mock_fetch:
|
||||
with patch("memory.parsers.archives.fetch_html") as mock_fetch:
|
||||
mock_fetch.return_value = "<html><body>Not substack</body></html>"
|
||||
|
||||
with patch("memory.common.parsers.archives.is_substack") as mock_is_substack:
|
||||
with patch("memory.parsers.archives.is_substack") as mock_is_substack:
|
||||
mock_is_substack.return_value = False
|
||||
|
||||
fetcher = get_archive_fetcher(url)
|
||||
@ -430,7 +430,7 @@ def test_get_archive_fetcher_registry_matches(url, expected_fetcher_type):
|
||||
def test_get_archive_fetcher_tuple_registry():
|
||||
url = "https://putanumonit.com"
|
||||
|
||||
with patch("memory.common.parsers.archives.fetch_html") as mock_fetch:
|
||||
with patch("memory.parsers.archives.fetch_html") as mock_fetch:
|
||||
mock_fetch.return_value = "<html><body>Not substack</body></html>"
|
||||
|
||||
fetcher = get_archive_fetcher(url)
|
||||
@ -442,7 +442,7 @@ def test_get_archive_fetcher_tuple_registry():
|
||||
def test_get_archive_fetcher_direct_parser_registry():
|
||||
url = "https://danluu.com"
|
||||
|
||||
with patch("memory.common.parsers.archives.fetch_html") as mock_fetch:
|
||||
with patch("memory.parsers.archives.fetch_html") as mock_fetch:
|
||||
mock_fetch.return_value = "<html><body>Not substack</body></html>"
|
||||
|
||||
fetcher = get_archive_fetcher(url)
|
||||
@ -455,10 +455,10 @@ def test_get_archive_fetcher_direct_parser_registry():
|
||||
def test_get_archive_fetcher_substack():
|
||||
url = "https://example.substack.com"
|
||||
|
||||
with patch("memory.common.parsers.archives.fetch_html") as mock_fetch:
|
||||
with patch("memory.parsers.archives.fetch_html") as mock_fetch:
|
||||
mock_fetch.return_value = "<html><body>Substack content</body></html>"
|
||||
|
||||
with patch("memory.common.parsers.archives.is_substack") as mock_is_substack:
|
||||
with patch("memory.parsers.archives.is_substack") as mock_is_substack:
|
||||
mock_is_substack.return_value = True
|
||||
|
||||
fetcher = get_archive_fetcher(url)
|
||||
@ -470,10 +470,10 @@ def test_get_archive_fetcher_substack():
|
||||
def test_get_archive_fetcher_no_match():
|
||||
url = "https://unknown.com"
|
||||
|
||||
with patch("memory.common.parsers.archives.fetch_html") as mock_fetch:
|
||||
with patch("memory.parsers.archives.fetch_html") as mock_fetch:
|
||||
mock_fetch.return_value = "<html><body>Regular website</body></html>"
|
||||
|
||||
with patch("memory.common.parsers.archives.is_substack") as mock_is_substack:
|
||||
with patch("memory.parsers.archives.is_substack") as mock_is_substack:
|
||||
mock_is_substack.return_value = False
|
||||
|
||||
fetcher = get_archive_fetcher(url)
|
@ -2,7 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from memory.common.parsers.comics import extract_smbc, extract_xkcd
|
||||
from memory.parsers.comics import extract_smbc, extract_xkcd
|
||||
|
||||
MOCK_SMBC_HTML = """
|
||||
<!DOCTYPE html>
|
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
import fitz
|
||||
|
||||
from memory.common.parsers.ebook import (
|
||||
from memory.parsers.ebook import (
|
||||
Peekable,
|
||||
extract_epub_metadata,
|
||||
get_pages,
|
||||
@ -381,7 +381,7 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path):
|
||||
section2.pages = ["Content of section 2"]
|
||||
|
||||
# Mock extract_sections to return our sections
|
||||
with patch("memory.common.parsers.ebook.extract_sections") as mock_extract:
|
||||
with patch("memory.parsers.ebook.extract_sections") as mock_extract:
|
||||
mock_extract.return_value = [section1, section2]
|
||||
|
||||
mock_open.return_value = mock_doc
|
@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from email.utils import formatdate
|
||||
from unittest.mock import ANY
|
||||
import pytest
|
||||
from memory.common.parsers.email import (
|
||||
from memory.parsers.email import (
|
||||
compute_message_hash,
|
||||
extract_attachments,
|
||||
extract_body,
|
@ -6,7 +6,7 @@ import json
|
||||
import pytest
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
from memory.common.parsers.feeds import (
|
||||
from memory.parsers.feeds import (
|
||||
FeedItem,
|
||||
FeedParser,
|
||||
RSSAtomParser,
|
||||
@ -61,7 +61,7 @@ def test_select_in(data, path, expected):
|
||||
assert select_in(data, path) == expected
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.fetch_html")
|
||||
@patch("memory.parsers.feeds.fetch_html")
|
||||
def test_json_parser_fetch_items_with_content(mock_fetch_html):
|
||||
content = json.dumps(
|
||||
[
|
||||
@ -80,7 +80,7 @@ def test_json_parser_fetch_items_with_content(mock_fetch_html):
|
||||
mock_fetch_html.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.fetch_html")
|
||||
@patch("memory.parsers.feeds.fetch_html")
|
||||
def test_json_parser_fetch_items_without_content(mock_fetch_html):
|
||||
content = json.dumps([{"title": "Article", "url": "https://example.com/1"}])
|
||||
mock_fetch_html.return_value = content
|
||||
@ -92,7 +92,7 @@ def test_json_parser_fetch_items_without_content(mock_fetch_html):
|
||||
mock_fetch_html.assert_called_once_with("https://example.com/feed.json")
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.fetch_html")
|
||||
@patch("memory.parsers.feeds.fetch_html")
|
||||
def test_json_parser_fetch_items_invalid_json(mock_fetch_html):
|
||||
mock_fetch_html.return_value = "invalid json content"
|
||||
|
||||
@ -220,7 +220,7 @@ def test_feed_parser_parse_feed_with_invalid_items():
|
||||
]
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.feedparser.parse")
|
||||
@patch("memory.parsers.feeds.feedparser.parse")
|
||||
@pytest.mark.parametrize("since_date", [None, datetime(2023, 1, 1)])
|
||||
def test_rss_atom_parser_fetch_items(mock_parse, since_date):
|
||||
mock_feed = MagicMock()
|
||||
@ -239,7 +239,7 @@ def test_rss_atom_parser_fetch_items(mock_parse, since_date):
|
||||
assert items == ["entry1", "entry2"]
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.feedparser.parse")
|
||||
@patch("memory.parsers.feeds.feedparser.parse")
|
||||
def test_rss_atom_parser_fetch_items_with_content(mock_parse):
|
||||
mock_feed = MagicMock()
|
||||
mock_feed.entries = ["entry1"]
|
||||
@ -411,7 +411,7 @@ def test_rss_atom_parser_extract_metadata():
|
||||
}
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.fetch_html")
|
||||
@patch("memory.parsers.feeds.fetch_html")
|
||||
def test_html_list_parser_fetch_items_with_content(mock_fetch_html):
|
||||
html = """
|
||||
<ul>
|
||||
@ -430,7 +430,7 @@ def test_html_list_parser_fetch_items_with_content(mock_fetch_html):
|
||||
mock_fetch_html.assert_not_called()
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.fetch_html")
|
||||
@patch("memory.parsers.feeds.fetch_html")
|
||||
def test_html_list_parser_fetch_items_without_content(mock_fetch_html):
|
||||
html = """
|
||||
<ul>
|
||||
@ -502,7 +502,7 @@ def test_html_list_parser_extract_title(html, title_selector, expected):
|
||||
parser.title_selector = title_selector
|
||||
|
||||
if expected and title_selector:
|
||||
with patch("memory.common.parsers.feeds.extract_title") as mock_extract:
|
||||
with patch("memory.parsers.feeds.extract_title") as mock_extract:
|
||||
mock_extract.return_value = expected
|
||||
title = parser.extract_title(item)
|
||||
mock_extract.assert_called_once_with(item, title_selector)
|
||||
@ -555,7 +555,7 @@ def test_html_list_parser_extract_date_with_selector():
|
||||
parser = HTMLListParser(url="https://example.com")
|
||||
parser.date_selector = ".date"
|
||||
|
||||
with patch("memory.common.parsers.feeds.extract_date") as mock_extract:
|
||||
with patch("memory.parsers.feeds.extract_date") as mock_extract:
|
||||
mock_extract.return_value = datetime(2023, 1, 15)
|
||||
date = parser.extract_date(item)
|
||||
mock_extract.assert_called_once_with(item, ".date", "%Y-%m-%d")
|
||||
@ -787,7 +787,7 @@ def test_get_feed_parser_registry(url, expected_parser_class):
|
||||
assert parser.url == url
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.fetch_html")
|
||||
@patch("memory.parsers.feeds.fetch_html")
|
||||
def test_get_feed_parser_rss_content(mock_fetch_html):
|
||||
mock_fetch_html.return_value = "<?xml version='1.0'?><rss>"
|
||||
|
||||
@ -796,7 +796,7 @@ def test_get_feed_parser_rss_content(mock_fetch_html):
|
||||
assert parser.url == "https://example.com/unknown"
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.fetch_html")
|
||||
@patch("memory.parsers.feeds.fetch_html")
|
||||
def test_get_feed_parser_with_feed_link(mock_fetch_html):
|
||||
html = """
|
||||
<html>
|
||||
@ -812,19 +812,19 @@ def test_get_feed_parser_with_feed_link(mock_fetch_html):
|
||||
assert parser.url == "https://example.com/feed.xml"
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.fetch_html")
|
||||
@patch("memory.parsers.feeds.fetch_html")
|
||||
def test_get_feed_parser_recursive_paths(mock_fetch_html):
|
||||
# Mock the initial call to return HTML without feed links
|
||||
html = "<html><body>No feed links</body></html>"
|
||||
mock_fetch_html.return_value = html
|
||||
|
||||
# Mock the recursive calls to avoid actual HTTP requests
|
||||
with patch("memory.common.parsers.feeds.get_feed_parser") as mock_recursive:
|
||||
with patch("memory.parsers.feeds.get_feed_parser") as mock_recursive:
|
||||
# Set up the mock to return None for recursive calls
|
||||
mock_recursive.return_value = None
|
||||
|
||||
# Call the original function directly
|
||||
from memory.common.parsers.feeds import (
|
||||
from memory.parsers.feeds import (
|
||||
get_feed_parser as original_get_feed_parser,
|
||||
)
|
||||
|
||||
@ -833,13 +833,13 @@ def test_get_feed_parser_recursive_paths(mock_fetch_html):
|
||||
assert parser is None
|
||||
|
||||
|
||||
@patch("memory.common.parsers.feeds.fetch_html")
|
||||
@patch("memory.parsers.feeds.fetch_html")
|
||||
def test_get_feed_parser_no_match(mock_fetch_html):
|
||||
html = "<html><body>No feed links</body></html>"
|
||||
mock_fetch_html.return_value = html
|
||||
|
||||
# Mock the recursive calls to avoid actual HTTP requests
|
||||
with patch("memory.common.parsers.feeds.get_feed_parser") as mock_recursive:
|
||||
with patch("memory.parsers.feeds.get_feed_parser") as mock_recursive:
|
||||
mock_recursive.return_value = None
|
||||
parser = get_feed_parser("https://unknown.com")
|
||||
|
@ -1,11 +1,10 @@
|
||||
import hashlib
|
||||
import pathlib
|
||||
import re
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import urlparse
|
||||
import re
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
@ -13,8 +12,7 @@ from bs4 import BeautifulSoup, Tag
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.parsers.html import (
|
||||
Article,
|
||||
from memory.parsers.html import (
|
||||
BaseHTMLParser,
|
||||
convert_to_markdown,
|
||||
extract_author,
|
||||
@ -311,8 +309,8 @@ def test_extract_metadata():
|
||||
assert isinstance(metadata, dict)
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.requests.get")
|
||||
@patch("memory.common.parsers.html.PILImage.open")
|
||||
@patch("memory.parsers.html.requests.get")
|
||||
@patch("memory.parsers.html.PILImage.open")
|
||||
def test_process_image_success(mock_pil_open, mock_requests_get):
|
||||
# Setup mocks
|
||||
mock_response = MagicMock()
|
||||
@ -345,7 +343,7 @@ def test_process_image_success(mock_pil_open, mock_requests_get):
|
||||
assert result == mock_image
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.requests.get")
|
||||
@patch("memory.parsers.html.requests.get")
|
||||
def test_process_image_http_error(mock_requests_get):
|
||||
# Setup mock to raise HTTP error
|
||||
mock_requests_get.side_effect = requests.RequestException("Network error")
|
||||
@ -359,8 +357,8 @@ def test_process_image_http_error(mock_requests_get):
|
||||
process_image(url, image_dir)
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.requests.get")
|
||||
@patch("memory.common.parsers.html.PILImage.open")
|
||||
@patch("memory.parsers.html.requests.get")
|
||||
@patch("memory.parsers.html.PILImage.open")
|
||||
def test_process_image_pil_error(mock_pil_open, mock_requests_get):
|
||||
# Setup mocks
|
||||
mock_response = MagicMock()
|
||||
@ -378,8 +376,8 @@ def test_process_image_pil_error(mock_pil_open, mock_requests_get):
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.requests.get")
|
||||
@patch("memory.common.parsers.html.PILImage.open")
|
||||
@patch("memory.parsers.html.requests.get")
|
||||
@patch("memory.parsers.html.PILImage.open")
|
||||
def test_process_image_cached(mock_pil_open, mock_requests_get):
|
||||
# Create a temporary file to simulate cached image
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
@ -404,7 +402,7 @@ def test_process_image_cached(mock_pil_open, mock_requests_get):
|
||||
assert result == mock_image
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.process_image")
|
||||
@patch("memory.parsers.html.process_image")
|
||||
def test_process_images_basic(mock_process_image):
|
||||
html = """
|
||||
<div>
|
||||
@ -461,7 +459,7 @@ def test_process_images_empty():
|
||||
assert result_images == {}
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.process_image")
|
||||
@patch("memory.parsers.html.process_image")
|
||||
def test_process_images_with_failures(mock_process_image):
|
||||
html = """
|
||||
<div>
|
||||
@ -488,7 +486,7 @@ def test_process_images_with_failures(mock_process_image):
|
||||
assert images == {"images/good.jpg": mock_good_image}
|
||||
|
||||
|
||||
@patch("memory.common.parsers.html.process_image")
|
||||
@patch("memory.parsers.html.process_image")
|
||||
def test_process_images_no_filename(mock_process_image):
|
||||
html = '<div><img src="test.jpg" alt="Test"></div>'
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
@ -769,7 +767,7 @@ class TestBaseHTMLParser:
|
||||
|
||||
assert article.author == "Fixed Author"
|
||||
|
||||
@patch("memory.common.parsers.html.process_images")
|
||||
@patch("memory.parsers.html.process_images")
|
||||
def test_parse_with_images(self, mock_process_images):
|
||||
# Mock the image processing to return test data
|
||||
mock_image = MagicMock(spec=PILImage.Image)
|
566
tests/memory/workers/tasks/test_blogs_tasks.py
Normal file
566
tests/memory/workers/tasks/test_blogs_tasks.py
Normal file
@ -0,0 +1,566 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from memory.common.db.models import ArticleFeed, BlogPost
|
||||
from memory.workers.tasks import blogs
|
||||
from memory.parsers.blogs import Article
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_article():
|
||||
"""Mock article data for testing."""
|
||||
return Article(
|
||||
title="Test Article",
|
||||
url="https://example.com/article/1",
|
||||
content="This is test article content with enough text to be processed.",
|
||||
published_date=datetime(2024, 1, 1, 12, 0, 0),
|
||||
images={}, # Article.images is dict[str, PILImage.Image]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_empty_article():
|
||||
"""Mock article with empty content."""
|
||||
return Article(
|
||||
title="Empty Article",
|
||||
url="https://example.com/empty",
|
||||
content="",
|
||||
published_date=datetime(2024, 1, 1, 12, 0, 0),
|
||||
images={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_article_feed(db_session):
|
||||
"""Create a sample ArticleFeed for testing."""
|
||||
feed = ArticleFeed(
|
||||
url="https://example.com/feed.xml",
|
||||
title="Test Feed",
|
||||
description="A test RSS feed",
|
||||
tags=["test", "blog"],
|
||||
check_interval=3600,
|
||||
active=True,
|
||||
last_checked_at=None, # Avoid timezone issues
|
||||
)
|
||||
db_session.add(feed)
|
||||
db_session.commit()
|
||||
return feed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inactive_article_feed(db_session):
|
||||
"""Create an inactive ArticleFeed for testing."""
|
||||
feed = ArticleFeed(
|
||||
url="https://example.com/inactive.xml",
|
||||
title="Inactive Feed",
|
||||
description="An inactive RSS feed",
|
||||
tags=["test"],
|
||||
check_interval=3600,
|
||||
active=False,
|
||||
)
|
||||
db_session.add(feed)
|
||||
db_session.commit()
|
||||
return feed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recently_checked_feed(db_session):
|
||||
"""Create a recently checked ArticleFeed."""
|
||||
from sqlalchemy import text
|
||||
|
||||
# Use a very recent timestamp that will trigger the "recently checked" condition
|
||||
# The check_interval is 3600 seconds, so 30 seconds ago should be "recent"
|
||||
recent_time = datetime.now() - timedelta(seconds=30)
|
||||
|
||||
feed = ArticleFeed(
|
||||
url="https://example.com/recent.xml",
|
||||
title="Recent Feed",
|
||||
description="A recently checked feed",
|
||||
tags=["test"],
|
||||
check_interval=3600,
|
||||
active=True,
|
||||
)
|
||||
db_session.add(feed)
|
||||
db_session.flush() # Get the ID
|
||||
|
||||
# Manually set the last_checked_at to avoid timezone issues
|
||||
db_session.execute(
|
||||
text(
|
||||
"UPDATE article_feeds SET last_checked_at = :timestamp WHERE id = :feed_id"
|
||||
),
|
||||
{"timestamp": recent_time, "feed_id": feed.id},
|
||||
)
|
||||
db_session.commit()
|
||||
return feed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feed_item():
|
||||
"""Mock feed item for testing."""
|
||||
item = Mock()
|
||||
item.url = "https://example.com/article/1"
|
||||
item.title = "Test Article"
|
||||
return item
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feed_parser():
|
||||
"""Mock feed parser for testing."""
|
||||
parser = Mock()
|
||||
parser.parse_feed.return_value = [
|
||||
Mock(url="https://example.com/article/1", title="Test Article")
|
||||
]
|
||||
return parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_archive_fetcher():
|
||||
"""Mock archive fetcher for testing."""
|
||||
fetcher = Mock()
|
||||
fetcher.max_pages = 100
|
||||
fetcher.fetch_all_items.return_value = [
|
||||
Mock(url="https://example.com/archive/1", title="Archive Article 1"),
|
||||
Mock(url="https://example.com/archive/2", title="Archive Article 2"),
|
||||
]
|
||||
return fetcher
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.parse_webpage")
|
||||
def test_sync_webpage_success(mock_parse, mock_article, db_session, qdrant):
|
||||
"""Test successful webpage synchronization."""
|
||||
mock_parse.return_value = mock_article
|
||||
|
||||
result = blogs.sync_webpage("https://example.com/article/1", ["test", "blog"])
|
||||
|
||||
mock_parse.assert_called_once_with("https://example.com/article/1")
|
||||
|
||||
# Verify the BlogPost was created in the database
|
||||
blog_post = (
|
||||
db_session.query(BlogPost)
|
||||
.filter_by(url="https://example.com/article/1")
|
||||
.first()
|
||||
)
|
||||
assert blog_post is not None
|
||||
assert blog_post.title == "Test Article"
|
||||
assert (
|
||||
blog_post.content
|
||||
== "This is test article content with enough text to be processed."
|
||||
)
|
||||
assert blog_post.modality == "blog"
|
||||
assert blog_post.mime_type == "text/markdown"
|
||||
assert blog_post.images == [] # Empty because mock article.images is {}
|
||||
assert "test" in blog_post.tags
|
||||
assert "blog" in blog_post.tags
|
||||
|
||||
# Verify the result
|
||||
assert result["status"] == "processed"
|
||||
assert result["blogpost_id"] == blog_post.id
|
||||
assert result["title"] == "Test Article"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.parse_webpage")
|
||||
def test_sync_webpage_empty_content(mock_parse, mock_empty_article, db_session):
|
||||
"""Test webpage sync with empty content."""
|
||||
mock_parse.return_value = mock_empty_article
|
||||
|
||||
result = blogs.sync_webpage("https://example.com/empty")
|
||||
|
||||
assert result == {
|
||||
"url": "https://example.com/empty",
|
||||
"title": "Empty Article",
|
||||
"status": "skipped_short_content",
|
||||
"content_length": 0,
|
||||
}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.parse_webpage")
|
||||
def test_sync_webpage_already_exists(mock_parse, mock_article, db_session):
|
||||
"""Test webpage sync when content already exists."""
|
||||
mock_parse.return_value = mock_article
|
||||
|
||||
# Add existing blog post with same content hash
|
||||
from memory.workers.tasks.content_processing import create_content_hash
|
||||
|
||||
existing_post = BlogPost(
|
||||
url="https://example.com/article/1",
|
||||
title="Test Article",
|
||||
content="This is test article content with enough text to be processed.",
|
||||
sha256=create_content_hash(
|
||||
"This is test article content with enough text to be processed."
|
||||
),
|
||||
modality="blog",
|
||||
tags=["test"],
|
||||
mime_type="text/markdown",
|
||||
size=65,
|
||||
)
|
||||
db_session.add(existing_post)
|
||||
db_session.commit()
|
||||
|
||||
result = blogs.sync_webpage("https://example.com/article/1", ["test"])
|
||||
|
||||
assert result["status"] == "already_exists"
|
||||
assert result["blogpost_id"] == existing_post.id
|
||||
|
||||
# Verify no duplicate was created
|
||||
blog_posts = (
|
||||
db_session.query(BlogPost).filter_by(url="https://example.com/article/1").all()
|
||||
)
|
||||
assert len(blog_posts) == 1
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.get_feed_parser")
|
||||
def test_sync_article_feed_success(
|
||||
mock_get_parser, sample_article_feed, mock_feed_parser, db_session
|
||||
):
|
||||
"""Test successful article feed synchronization."""
|
||||
mock_get_parser.return_value = mock_feed_parser
|
||||
|
||||
with patch("memory.workers.tasks.blogs.sync_webpage") as mock_sync_webpage:
|
||||
mock_sync_webpage.delay.return_value = Mock(id="task-123")
|
||||
|
||||
result = blogs.sync_article_feed(sample_article_feed.id)
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["feed_id"] == sample_article_feed.id
|
||||
assert result["feed_title"] == "Test Feed"
|
||||
assert result["feed_url"] == "https://example.com/feed.xml"
|
||||
assert result["articles_found"] == 1
|
||||
assert result["new_articles"] == 1
|
||||
assert result["errors"] == 0
|
||||
assert result["task_ids"] == ["task-123"]
|
||||
|
||||
# Verify sync_webpage was called with correct arguments
|
||||
mock_sync_webpage.delay.assert_called_once_with(
|
||||
"https://example.com/article/1", ["test", "blog"]
|
||||
)
|
||||
|
||||
# Verify last_checked_at was updated
|
||||
db_session.refresh(sample_article_feed)
|
||||
assert sample_article_feed.last_checked_at is not None
|
||||
|
||||
|
||||
def test_sync_article_feed_not_found(db_session):
|
||||
"""Test sync with non-existent feed ID."""
|
||||
result = blogs.sync_article_feed(99999)
|
||||
|
||||
assert result == {"status": "error", "error": "Feed not found or inactive"}
|
||||
|
||||
|
||||
def test_sync_article_feed_inactive(inactive_article_feed, db_session):
|
||||
"""Test sync with inactive feed."""
|
||||
result = blogs.sync_article_feed(inactive_article_feed.id)
|
||||
|
||||
assert result == {"status": "error", "error": "Feed not found or inactive"}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.get_feed_parser")
|
||||
def test_sync_article_feed_no_parser(mock_get_parser, sample_article_feed, db_session):
|
||||
"""Test sync when no parser is available."""
|
||||
mock_get_parser.return_value = None
|
||||
|
||||
result = blogs.sync_article_feed(sample_article_feed.id)
|
||||
|
||||
assert result == {"status": "error", "error": "No parser available for feed"}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.get_feed_parser")
|
||||
def test_sync_article_feed_with_existing_articles(
|
||||
mock_get_parser, sample_article_feed, db_session
|
||||
):
|
||||
"""Test sync when some articles already exist."""
|
||||
# Create existing blog post
|
||||
existing_post = BlogPost(
|
||||
url="https://example.com/article/1",
|
||||
title="Existing Article",
|
||||
content="Existing content",
|
||||
sha256=b"existing_hash" + bytes(24),
|
||||
modality="blog",
|
||||
tags=["test"],
|
||||
mime_type="text/markdown",
|
||||
size=100,
|
||||
)
|
||||
db_session.add(existing_post)
|
||||
db_session.commit()
|
||||
|
||||
# Mock parser with multiple items
|
||||
mock_parser = Mock()
|
||||
mock_parser.parse_feed.return_value = [
|
||||
Mock(url="https://example.com/article/1", title="Existing Article"),
|
||||
Mock(url="https://example.com/article/2", title="New Article"),
|
||||
]
|
||||
mock_get_parser.return_value = mock_parser
|
||||
|
||||
with patch("memory.workers.tasks.blogs.sync_webpage") as mock_sync_webpage:
|
||||
mock_sync_webpage.delay.return_value = Mock(id="task-456")
|
||||
|
||||
result = blogs.sync_article_feed(sample_article_feed.id)
|
||||
|
||||
assert result["articles_found"] == 2
|
||||
assert result["new_articles"] == 1 # Only one new article
|
||||
assert result["task_ids"] == ["task-456"]
|
||||
|
||||
# Verify sync_webpage was only called for the new article
|
||||
mock_sync_webpage.delay.assert_called_once_with(
|
||||
"https://example.com/article/2", ["test", "blog"]
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.get_feed_parser")
|
||||
def test_sync_article_feed_parser_error(
|
||||
mock_get_parser, sample_article_feed, db_session
|
||||
):
|
||||
"""Test sync when parser raises an exception."""
|
||||
mock_parser = Mock()
|
||||
mock_parser.parse_feed.side_effect = Exception("Parser error")
|
||||
mock_get_parser.return_value = mock_parser
|
||||
|
||||
result = blogs.sync_article_feed(sample_article_feed.id)
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["articles_found"] == 0
|
||||
assert result["new_articles"] == 0
|
||||
assert result["errors"] == 1
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.sync_article_feed")
|
||||
def test_sync_all_article_feeds(mock_sync_delay, db_session):
|
||||
"""Test synchronization of all active feeds."""
|
||||
# Create multiple feeds
|
||||
feed1 = ArticleFeed(
|
||||
url="https://example.com/feed1.xml",
|
||||
title="Feed 1",
|
||||
active=True,
|
||||
check_interval=3600,
|
||||
)
|
||||
feed2 = ArticleFeed(
|
||||
url="https://example.com/feed2.xml",
|
||||
title="Feed 2",
|
||||
active=True,
|
||||
check_interval=3600,
|
||||
)
|
||||
feed3 = ArticleFeed(
|
||||
url="https://example.com/feed3.xml",
|
||||
title="Feed 3",
|
||||
active=False, # Inactive
|
||||
check_interval=3600,
|
||||
)
|
||||
|
||||
db_session.add_all([feed1, feed2, feed3])
|
||||
db_session.commit()
|
||||
|
||||
mock_sync_delay.delay.side_effect = [Mock(id="task-1"), Mock(id="task-2")]
|
||||
|
||||
result = blogs.sync_all_article_feeds()
|
||||
|
||||
assert len(result) == 2 # Only active feeds
|
||||
assert result[0]["feed_id"] == feed1.id
|
||||
assert result[0]["task_id"] == "task-1"
|
||||
assert result[1]["feed_id"] == feed2.id
|
||||
assert result[1]["task_id"] == "task-2"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.get_archive_fetcher")
|
||||
def test_sync_website_archive_success(
|
||||
mock_get_fetcher, mock_archive_fetcher, db_session
|
||||
):
|
||||
"""Test successful website archive synchronization."""
|
||||
mock_get_fetcher.return_value = mock_archive_fetcher
|
||||
|
||||
with patch("memory.workers.tasks.blogs.sync_webpage") as mock_sync_webpage:
|
||||
mock_sync_webpage.delay.side_effect = [Mock(id="task-1"), Mock(id="task-2")]
|
||||
|
||||
result = blogs.sync_website_archive("https://example.com", ["archive"], 50)
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["website_url"] == "https://example.com"
|
||||
assert result["articles_found"] == 2
|
||||
assert result["new_articles"] == 2
|
||||
assert result["task_ids"] == ["task-1", "task-2"]
|
||||
assert result["max_pages_processed"] == 50
|
||||
assert mock_archive_fetcher.max_pages == 50
|
||||
|
||||
# Verify sync_webpage was called for both articles
|
||||
assert mock_sync_webpage.delay.call_count == 2
|
||||
mock_sync_webpage.delay.assert_any_call(
|
||||
"https://example.com/archive/1", ["archive"]
|
||||
)
|
||||
mock_sync_webpage.delay.assert_any_call(
|
||||
"https://example.com/archive/2", ["archive"]
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.get_archive_fetcher")
|
||||
def test_sync_website_archive_no_fetcher(mock_get_fetcher, db_session):
|
||||
"""Test archive sync when no fetcher is available."""
|
||||
mock_get_fetcher.return_value = None
|
||||
|
||||
result = blogs.sync_website_archive("https://example.com")
|
||||
|
||||
assert result == {"status": "error", "error": "No archive fetcher available"}
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.get_archive_fetcher")
|
||||
def test_sync_website_archive_with_existing_articles(mock_get_fetcher, db_session):
|
||||
"""Test archive sync when some articles already exist."""
|
||||
# Create existing blog post
|
||||
existing_post = BlogPost(
|
||||
url="https://example.com/archive/1",
|
||||
title="Existing Archive Article",
|
||||
content="Existing content",
|
||||
sha256=b"existing_hash" + bytes(24),
|
||||
modality="blog",
|
||||
tags=["archive"],
|
||||
mime_type="text/markdown",
|
||||
size=100,
|
||||
)
|
||||
db_session.add(existing_post)
|
||||
db_session.commit()
|
||||
|
||||
# Mock fetcher
|
||||
mock_fetcher = Mock()
|
||||
mock_fetcher.max_pages = 100
|
||||
mock_fetcher.fetch_all_items.return_value = [
|
||||
Mock(url="https://example.com/archive/1", title="Existing Archive Article"),
|
||||
Mock(url="https://example.com/archive/2", title="New Archive Article"),
|
||||
]
|
||||
mock_get_fetcher.return_value = mock_fetcher
|
||||
|
||||
with patch("memory.workers.tasks.blogs.sync_webpage") as mock_sync_webpage:
|
||||
mock_sync_webpage.delay.return_value = Mock(id="task-new")
|
||||
|
||||
result = blogs.sync_website_archive("https://example.com", ["archive"])
|
||||
|
||||
assert result["articles_found"] == 2
|
||||
assert result["new_articles"] == 1 # Only one new article
|
||||
assert result["task_ids"] == ["task-new"]
|
||||
|
||||
# Verify sync_webpage was only called for the new article
|
||||
mock_sync_webpage.delay.assert_called_once_with(
|
||||
"https://example.com/archive/2", ["archive"]
|
||||
)
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.parse_webpage")
|
||||
def test_sync_webpage_with_tags(mock_parse, mock_article, db_session, qdrant):
|
||||
"""Test webpage sync with custom tags."""
|
||||
mock_parse.return_value = mock_article
|
||||
|
||||
result = blogs.sync_webpage("https://example.com/article/1", ["custom", "tags"])
|
||||
|
||||
# Verify the BlogPost was created with custom tags
|
||||
blog_post = (
|
||||
db_session.query(BlogPost)
|
||||
.filter_by(url="https://example.com/article/1")
|
||||
.first()
|
||||
)
|
||||
assert blog_post is not None
|
||||
assert "custom" in blog_post.tags
|
||||
assert "tags" in blog_post.tags
|
||||
assert result["status"] == "processed"
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.parse_webpage")
|
||||
def test_sync_webpage_parse_error(mock_parse, db_session):
|
||||
"""Test webpage sync when parsing fails."""
|
||||
mock_parse.side_effect = Exception("Parse error")
|
||||
|
||||
# The safe_task_execution decorator should catch this
|
||||
result = blogs.sync_webpage("https://example.com/error")
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "Parse error" in result["error"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"feed_tags,expected_tags",
|
||||
[
|
||||
(["feed", "tag"], ["feed", "tag"]),
|
||||
(None, []),
|
||||
([], []),
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.blogs.get_feed_parser")
|
||||
def test_sync_article_feed_tags_handling(
|
||||
mock_get_parser, feed_tags, expected_tags, db_session
|
||||
):
|
||||
"""Test that feed tags are properly passed to sync_webpage."""
|
||||
# Create feed with specific tags
|
||||
feed = ArticleFeed(
|
||||
url="https://example.com/feed.xml",
|
||||
title="Test Feed",
|
||||
tags=feed_tags,
|
||||
check_interval=3600,
|
||||
active=True,
|
||||
last_checked_at=None, # Avoid timezone issues
|
||||
)
|
||||
db_session.add(feed)
|
||||
db_session.commit()
|
||||
|
||||
mock_parser = Mock()
|
||||
mock_parser.parse_feed.return_value = [
|
||||
Mock(url="https://example.com/article/1", title="Test")
|
||||
]
|
||||
mock_get_parser.return_value = mock_parser
|
||||
|
||||
with patch("memory.workers.tasks.blogs.sync_webpage") as mock_sync_webpage:
|
||||
mock_sync_webpage.delay.return_value = Mock(id="task-123")
|
||||
|
||||
blogs.sync_article_feed(feed.id)
|
||||
|
||||
# Verify sync_webpage was called with correct tags
|
||||
mock_sync_webpage.delay.assert_called_once_with(
|
||||
"https://example.com/article/1", expected_tags
|
||||
)
|
||||
|
||||
|
||||
def test_sync_all_article_feeds_no_active_feeds(db_session):
|
||||
"""Test sync_all_article_feeds when no active feeds exist."""
|
||||
# Create only inactive feeds
|
||||
inactive_feed = ArticleFeed(
|
||||
url="https://example.com/inactive.xml",
|
||||
title="Inactive Feed",
|
||||
active=False,
|
||||
check_interval=3600,
|
||||
)
|
||||
db_session.add(inactive_feed)
|
||||
db_session.commit()
|
||||
|
||||
result = blogs.sync_all_article_feeds()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.sync_webpage")
|
||||
@patch("memory.workers.tasks.blogs.get_archive_fetcher")
|
||||
def test_sync_website_archive_default_max_pages(
|
||||
mock_get_fetcher, mock_sync_delay, db_session
|
||||
):
|
||||
"""Test that default max_pages is used when not specified."""
|
||||
mock_fetcher = Mock()
|
||||
mock_fetcher.max_pages = 100 # Default value
|
||||
mock_fetcher.fetch_all_items.return_value = []
|
||||
mock_get_fetcher.return_value = mock_fetcher
|
||||
|
||||
result = blogs.sync_website_archive("https://example.com")
|
||||
|
||||
assert result["max_pages_processed"] == 100
|
||||
assert mock_fetcher.max_pages == 100 # Should be set to default
|
||||
|
||||
|
||||
@patch("memory.workers.tasks.blogs.sync_webpage")
|
||||
@patch("memory.workers.tasks.blogs.get_archive_fetcher")
|
||||
def test_sync_website_archive_empty_results(
|
||||
mock_get_fetcher, mock_sync_delay, db_session
|
||||
):
|
||||
"""Test archive sync when no articles are found."""
|
||||
mock_fetcher = Mock()
|
||||
mock_fetcher.max_pages = 100
|
||||
mock_fetcher.fetch_all_items.return_value = []
|
||||
mock_get_fetcher.return_value = mock_fetcher
|
||||
|
||||
result = blogs.sync_website_archive("https://example.com")
|
||||
|
||||
assert result["articles_found"] == 0
|
||||
assert result["new_articles"] == 0
|
||||
assert result["task_ids"] == []
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from memory.common.db.models import Book, BookSection, Chunk
|
||||
from memory.common.parsers.ebook import Ebook, Section
|
||||
from memory.parsers.ebook import Ebook, Section
|
||||
from memory.workers.tasks import ebook
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ from memory.common.db.models import (
|
||||
EmailAttachment,
|
||||
MailMessage,
|
||||
)
|
||||
from memory.common.parsers.email import Attachment, parse_email_message
|
||||
from memory.parsers.email import Attachment, parse_email_message
|
||||
from memory.workers.email import (
|
||||
create_mail_message,
|
||||
extract_email_uid,
|
||||
|
Loading…
x
Reference in New Issue
Block a user