From 8cfaeaea72cf3d5ec0c79e738dc7cb1dbc8cd5f3 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Tue, 20 May 2025 21:28:26 +0200 Subject: [PATCH] add commics --- db/migrations/env.py | 26 +- .../20250503_140056_initial_structure.py | 10 +- ...20250504_125653_update_field_for_chunks.py | 40 +++ .../versions/20250504_234552_add_comics.py | 43 +++ docker-compose.yaml | 36 +- docker/ingest_hub/Dockerfile | 5 +- docker/workers/Dockerfile | 2 +- requirements-workers.txt | 4 +- src/memory/api/app.py | 10 +- src/memory/common/db/connection.py | 5 +- src/memory/common/db/models.py | 76 ++++- src/memory/common/embedding.py | 44 ++- src/memory/common/parsers/comics.py | 133 ++++++++ src/memory/common/qdrant.py | 37 ++- src/memory/common/settings.py | 18 + src/memory/workers/celery_app.py | 16 +- src/memory/workers/ingest.py | 22 +- src/memory/workers/tasks/__init__.py | 20 +- src/memory/workers/tasks/comic.py | 172 ++++++++++ src/memory/workers/tasks/maintenance.py | 157 +++++++++ tests/memory/common/parsers/test_comic.py | 199 +++++++++++ tests/memory/common/test_qdrant.py | 51 ++- .../memory/workers/tasks/test_maintenance.py | 310 ++++++++++++++++++ 23 files changed, 1335 insertions(+), 101 deletions(-) create mode 100644 db/migrations/versions/20250504_125653_update_field_for_chunks.py create mode 100644 db/migrations/versions/20250504_234552_add_comics.py create mode 100644 src/memory/common/parsers/comics.py create mode 100644 src/memory/workers/tasks/comic.py create mode 100644 src/memory/workers/tasks/maintenance.py create mode 100644 tests/memory/common/parsers/test_comic.py create mode 100644 tests/memory/workers/tasks/test_maintenance.py diff --git a/db/migrations/env.py b/db/migrations/env.py index 796789b..c1ab961 100644 --- a/db/migrations/env.py +++ b/db/migrations/env.py @@ -1,6 +1,7 @@ """ Alembic environment configuration. """ + from logging.config import fileConfig from sqlalchemy import engine_from_config @@ -29,16 +30,28 @@ target_metadata = Base.metadata # can be acquired: # my_important_option = config.get_main_option("my_important_option") +# Tables to exclude from auto-generation - these are managed by Celery +excluded_tables = { + "celery_taskmeta", + "celery_tasksetmeta", +} + + +def include_object(object, name, type_, reflected, compare_to): + if type_ == "table" and name in excluded_tables: + return False + return True + def run_migrations_offline() -> None: """ Run migrations in 'offline' mode - creates SQL scripts. - + This configures the context with just a URL and not an Engine, though an Engine is acceptable here as well. By skipping the Engine creation we don't even need a DBAPI to be available. - + Calls to context.execute() here emit the given string to the script output. """ @@ -48,6 +61,7 @@ def run_migrations_offline() -> None: target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, + include_object=include_object, ) with context.begin_transaction(): @@ -57,7 +71,7 @@ def run_migrations_offline() -> None: def run_migrations_online() -> None: """ Run migrations in 'online' mode - directly to the database. - + In this scenario we need to create an Engine and associate a connection with the context. """ @@ -69,7 +83,9 @@ def run_migrations_online() -> None: with connectable.connect() as connection: context.configure( - connection=connection, target_metadata=target_metadata + connection=connection, + target_metadata=target_metadata, + include_object=include_object, ) with context.begin_transaction(): @@ -79,4 +95,4 @@ def run_migrations_online() -> None: if context.is_offline_mode(): run_migrations_offline() else: - run_migrations_online() \ No newline at end of file + run_migrations_online() diff --git a/db/migrations/versions/20250503_140056_initial_structure.py b/db/migrations/versions/20250503_140056_initial_structure.py index 697220a..72479cd 100644 --- a/db/migrations/versions/20250503_140056_initial_structure.py +++ b/db/migrations/versions/20250503_140056_initial_structure.py @@ -21,7 +21,7 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto") - op.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"") + op.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"') op.create_table( "email_accounts", sa.Column("id", sa.BigInteger(), nullable=False), @@ -82,12 +82,6 @@ def upgrade() -> None: server_default=sa.text("now()"), nullable=False, ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - server_default=sa.text("now()"), - nullable=False, - ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("url"), ) @@ -269,7 +263,7 @@ def upgrade() -> None: "misc_doc", sa.Column("id", sa.BigInteger(), nullable=False), sa.Column("path", sa.Text(), nullable=True), - sa.Column("mime_type", sa.Text(), nullable=True), + sa.Column("mime_type", sa.TEXT(), autoincrement=False, nullable=True), sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) diff --git a/db/migrations/versions/20250504_125653_update_field_for_chunks.py b/db/migrations/versions/20250504_125653_update_field_for_chunks.py new file mode 100644 index 0000000..7654d82 --- /dev/null +++ b/db/migrations/versions/20250504_125653_update_field_for_chunks.py @@ -0,0 +1,40 @@ +"""Update field for chunks + +Revision ID: d292d48ec74e +Revises: 4684845ca51e +Create Date: 2025-05-04 12:56:53.231393 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "d292d48ec74e" +down_revision: Union[str, None] = "4684845ca51e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "chunk", + sa.Column( + "checked_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + ) + op.drop_column("misc_doc", "mime_type") + + +def downgrade() -> None: + op.add_column( + "misc_doc", + sa.Column("mime_type", sa.TEXT(), autoincrement=False, nullable=True), + ) + op.drop_column("chunk", "checked_at") diff --git a/db/migrations/versions/20250504_234552_add_comics.py b/db/migrations/versions/20250504_234552_add_comics.py new file mode 100644 index 0000000..4bb55bf --- /dev/null +++ b/db/migrations/versions/20250504_234552_add_comics.py @@ -0,0 +1,43 @@ +"""Add comics + +Revision ID: b78b1fff9974 +Revises: d292d48ec74e +Create Date: 2025-05-04 23:45:52.733301 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b78b1fff9974' +down_revision: Union[str, None] = 'd292d48ec74e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('comic', + sa.Column('id', sa.BigInteger(), nullable=False), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('author', sa.Text(), nullable=True), + sa.Column('published', sa.DateTime(timezone=True), nullable=True), + sa.Column('volume', sa.Text(), nullable=True), + sa.Column('issue', sa.Text(), nullable=True), + sa.Column('page', sa.Integer(), nullable=True), + sa.Column('url', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['id'], ['source_item.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('comic_author_idx', 'comic', ['author'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('comic_author_idx', table_name='comic') + op.drop_table('comic') + # ### end Alembic commands ### \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml index 25a16e4..f9bd21a 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -21,6 +21,7 @@ volumes: # ------------------------------ X-templates ---------------------------- x-common-env: &env RABBITMQ_USER: kb + RABBITMQ_HOST: rabbitmq QDRANT_HOST: qdrant DB_HOST: postgres FILE_STORAGE_DIR: /app/memory_files @@ -173,49 +174,28 @@ services: environment: <<: *worker-env QUEUES: "email" - deploy: { resources: { limits: { cpus: "2", memory: 3g } } } + # deploy: { resources: { limits: { cpus: "2", memory: 3g } } } worker-text: <<: *worker-base environment: <<: *worker-env QUEUES: "medium_embed" - deploy: { resources: { limits: { cpus: "2", memory: 3g } } } + # deploy: { resources: { limits: { cpus: "2", memory: 3g } } } worker-photo: <<: *worker-base environment: <<: *worker-env - QUEUES: "photo_embed" - deploy: { resources: { limits: { cpus: "4", memory: 4g } } } + QUEUES: "photo_embed,comic" + # deploy: { resources: { limits: { cpus: "4", memory: 4g } } } - worker-ocr: + worker-maintenance: <<: *worker-base environment: <<: *worker-env - QUEUES: "low_ocr" - deploy: { resources: { limits: { cpus: "4", memory: 4g } } } - - worker-git: - <<: *worker-base - environment: - <<: *worker-env - QUEUES: "git_summary" - deploy: { resources: { limits: { cpus: "1", memory: 1g } } } - - worker-rss: - <<: *worker-base - environment: - <<: *worker-env - QUEUES: "rss" - deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } } - - worker-docs: - <<: *worker-base - environment: - <<: *worker-env - QUEUES: "docs" - deploy: { resources: { limits: { cpus: "1", memory: 1g } } } + QUEUES: "maintenance" + # deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } } ingest-hub: <<: *worker-base diff --git a/docker/ingest_hub/Dockerfile b/docker/ingest_hub/Dockerfile index 88d8981..18659de 100644 --- a/docker/ingest_hub/Dockerfile +++ b/docker/ingest_hub/Dockerfile @@ -9,7 +9,7 @@ COPY src/ ./src/ # Install dependencies RUN apt-get update && apt-get install -y \ - libpq-dev gcc supervisor && \ + libpq-dev gcc supervisor && \ pip install -e ".[workers]" && \ apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* @@ -28,8 +28,7 @@ RUN mkdir -p /app/memory_files RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files USER kb -# Default queues to process -ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email" +ENV QUEUES="docs,email,maintenance" ENV PYTHONPATH="/app" ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"] \ No newline at end of file diff --git a/docker/workers/Dockerfile b/docker/workers/Dockerfile index be84e43..07336e4 100644 --- a/docker/workers/Dockerfile +++ b/docker/workers/Dockerfile @@ -35,7 +35,7 @@ RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \ USER kb # Default queues to process -ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email" +ENV QUEUES="docs,email,maintenance" ENV PYTHONPATH="/app" ENTRYPOINT ["./entry.sh"] \ No newline at end of file diff --git a/requirements-workers.txt b/requirements-workers.txt index bed6f1a..ba9dfc8 100644 --- a/requirements-workers.txt +++ b/requirements-workers.txt @@ -1,4 +1,6 @@ celery==5.3.6 openai==1.25.0 pillow==10.3.0 -pypandoc==1.15.0 \ No newline at end of file +pypandoc==1.15.0 +beautifulsoup4==4.13.4 +feedparser==6.0.10 \ No newline at end of file diff --git a/src/memory/api/app.py b/src/memory/api/app.py index f40b194..41f2b79 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -101,7 +101,7 @@ def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[Search and source.filename.replace( str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files" ), - content=source.content, + content=source.display_contents, chunks=sorted(chunks, key=lambda x: x.score, reverse=True), ) for source, chunks in items.items() @@ -143,7 +143,7 @@ def query_chunks( ) if r.score >= min_score ] - for collection in embedding.DEFAULT_COLLECTIONS + for collection in embedding.ALL_COLLECTIONS } @@ -164,7 +164,7 @@ async def search( modalities: Annotated[list[str], Query()] = [], files: list[UploadFile] = File([]), limit: int = Query(10, ge=1, le=100), - min_text_score: float = Query(0.5, ge=0.0, le=1.0), + min_text_score: float = Query(0.3, ge=0.0, le=1.0), min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0), ): """ @@ -181,11 +181,11 @@ async def search( """ upload_data = [await input_type(item) for item in [query, *files]] logger.error( - f"Querying chunks for {modalities}, query: {query}, previews: {previews}" + f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}" ) client = qdrant.get_qdrant_client() - allowed_modalities = set(modalities or embedding.DEFAULT_COLLECTIONS.keys()) + allowed_modalities = set(modalities or embedding.ALL_COLLECTIONS.keys()) text_results = query_chunks( client, upload_data, diff --git a/src/memory/common/db/connection.py b/src/memory/common/db/connection.py index 76f2d4e..6941776 100644 --- a/src/memory/common/db/connection.py +++ b/src/memory/common/db/connection.py @@ -1,6 +1,7 @@ """ Database connection utilities. """ + from contextlib import contextmanager from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, scoped_session @@ -24,14 +25,14 @@ def get_scoped_session(): """Create a thread-local scoped session factory""" engine = get_engine() session_factory = sessionmaker(bind=engine) - return scoped_session(session_factory) + return scoped_session(session_factory) @contextmanager def make_session(): """ Context manager for database sessions. - + Yields: SQLAlchemy session that will be automatically closed """ diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index b73df0d..34c3168 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -6,7 +6,8 @@ import pathlib import re from email.message import EmailMessage from pathlib import Path -from typing import Any +import textwrap +from typing import Any, ClassVar from PIL import Image from sqlalchemy import ( ARRAY, @@ -99,8 +100,9 @@ class Chunk(Base): content = Column(Text) # Direct content storage embedding_model = Column(Text) created_at = Column(DateTime(timezone=True), server_default=func.now()) - vector = list[float] | None # the vector generated by the embedding model - item_metadata = dict[str, Any] | None + checked_at = Column(DateTime(timezone=True), server_default=func.now()) + vector: ClassVar[list[float] | None] = None + item_metadata: ClassVar[dict[str, Any] | None] = None # One of file_path or content must be populated __table_args__ = ( @@ -121,7 +123,7 @@ class Chunk(Base): items = [] for file_path in files: - if file_path.suffix == ".png": + if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}: if file_path.exists(): items.append(Image.open(file_path)) elif file_path.suffix == ".bin": @@ -172,6 +174,16 @@ class SourceItem(Base): """Get vector IDs from associated chunks.""" return [chunk.id for chunk in self.chunks] + def as_payload(self) -> dict: + return { + "source_id": self.id, + "tags": self.tags, + } + + @property + def display_contents(self) -> str | None: + return self.content or self.filename + class MailMessage(SourceItem): __tablename__ = "mail_message" @@ -236,6 +248,26 @@ class MailMessage(SourceItem): def body(self) -> str: return self.parsed_content["body"] + @property + def display_contents(self) -> str | None: + content = self.parsed_content + return textwrap.dedent( + """ + Subject: {subject} + From: {sender} + To: {recipients} + Date: {date} + Body: + {body} + """ + ).format( + subject=content.get("subject", ""), + sender=content.get("from", ""), + recipients=content.get("to", ""), + date=content.get("date", ""), + body=content.get("body", ""), + ) + # Add indexes __table_args__ = ( Index("mail_sent_idx", "sent_at"), @@ -341,6 +373,41 @@ class Photo(SourceItem): __table_args__ = (Index("photo_taken_idx", "exif_taken_at"),) +class Comic(SourceItem): + __tablename__ = "comic" + + id = Column( + BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True + ) + title = Column(Text) + author = Column(Text, nullable=True) + published = Column(DateTime(timezone=True), nullable=True) + volume = Column(Text, nullable=True) + issue = Column(Text, nullable=True) + page = Column(Integer, nullable=True) + url = Column(Text, nullable=True) + + __mapper_args__ = { + "polymorphic_identity": "comic", + } + + __table_args__ = (Index("comic_author_idx", "author"),) + + def as_payload(self) -> dict: + payload = { + "source_id": self.id, + "tags": self.tags, + "title": self.title, + "author": self.author, + "published": self.published, + "volume": self.volume, + "issue": self.issue, + "page": self.page, + "url": self.url, + } + return {k: v for k, v in payload.items() if v is not None} + + class BookDoc(SourceItem): __tablename__ = "book_doc" @@ -379,7 +446,6 @@ class MiscDoc(SourceItem): BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True ) path = Column(Text) - mime_type = Column(Text) __mapper_args__ = { "polymorphic_identity": "misc_doc", diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index e575f0a..2b68dff 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -32,7 +32,7 @@ class Collection(TypedDict): shards: NotRequired[int] -DEFAULT_COLLECTIONS: dict[str, Collection] = { +ALL_COLLECTIONS: dict[str, Collection] = { "mail": { "dimension": 1024, "distance": "Cosine", @@ -69,6 +69,11 @@ DEFAULT_COLLECTIONS: dict[str, Collection] = { "distance": "Cosine", "model": settings.MIXED_EMBEDDING_MODEL, }, + "comic": { + "dimension": 1024, + "distance": "Cosine", + "model": settings.MIXED_EMBEDDING_MODEL, + }, "doc": { "dimension": 1024, "distance": "Cosine", @@ -77,12 +82,12 @@ DEFAULT_COLLECTIONS: dict[str, Collection] = { } TEXT_COLLECTIONS = { coll - for coll, params in DEFAULT_COLLECTIONS.items() + for coll, params in ALL_COLLECTIONS.items() if params["model"] == settings.TEXT_EMBEDDING_MODEL } MULTIMODAL_COLLECTIONS = { coll - for coll, params in DEFAULT_COLLECTIONS.items() + for coll, params in ALL_COLLECTIONS.items() if params["model"] == settings.MIXED_EMBEDDING_MODEL } @@ -99,6 +104,22 @@ TYPES = { } +def get_mimetype(image: Image.Image) -> str | None: + format_to_mime = { + "JPEG": "image/jpeg", + "PNG": "image/png", + "GIF": "image/gif", + "BMP": "image/bmp", + "TIFF": "image/tiff", + "WEBP": "image/webp", + } + + if not image.format: + return None + + return format_to_mime.get(image.format.upper(), f"image/{image.format.lower()}") + + def get_modality(mime_type: str) -> str: for type, mime_types in TYPES.items(): if mime_type in mime_types: @@ -237,3 +258,20 @@ def embed( for vector in embed_page(page) ] return modality, chunks + + +def embed_image(file_path: pathlib.Path, texts: list[str]) -> Chunk: + image = Image.open(file_path) + mime_type = get_mimetype(image) + if mime_type is None: + raise ValueError("Unsupported image format") + + vector = embed_mixed([image] + texts)[0] + + return Chunk( + id=str(uuid.uuid4()), + file_path=file_path.absolute().as_posix(), + content=None, + embedding_model=settings.MIXED_EMBEDDING_MODEL, + vector=vector, + ) diff --git a/src/memory/common/parsers/comics.py b/src/memory/common/parsers/comics.py new file mode 100644 index 0000000..12f0231 --- /dev/null +++ b/src/memory/common/parsers/comics.py @@ -0,0 +1,133 @@ +import logging +from typing import TypedDict, NotRequired + +from bs4 import BeautifulSoup, Tag +import requests +import json + +logger = logging.getLogger(__name__) + + +class ComicInfo(TypedDict): + title: str + image_url: str + published_date: NotRequired[str] + url: str + + +def extract_smbc(url: str) -> ComicInfo: + """ + Extract the title, published date, and image URL from a SMBC comic. + + Returns: + ComicInfo with title, image_url, published_date, and comic_url + """ + response = requests.get(url) + response.raise_for_status() + soup = BeautifulSoup(response.text, "html.parser") + + comic_img = soup.find("img", id="cc-comic") + title = "" + image_url = "" + + if comic_img and isinstance(comic_img, Tag): + if comic_img.has_attr("src"): + image_url = str(comic_img["src"]) + if comic_img.has_attr("title"): + title = str(comic_img["title"]) + + published_date = "" + comic_url = "" + + script_ld = soup.find("script", type="application/ld+json") + if script_ld and isinstance(script_ld, Tag) and script_ld.string: + try: + data = json.loads(script_ld.string) + published_date = data.get("datePublished", "") + + # Use JSON-LD URL if available + title = title or data.get("name", "") + comic_url = data.get("url") + except (json.JSONDecodeError, AttributeError): + pass + + permalink_input = soup.find("input", id="permalinktext") + if not comic_url and permalink_input and isinstance(permalink_input, Tag): + comic_url = permalink_input.get("value", "") + + return { + "title": title, + "image_url": image_url, + "published_date": published_date, + "url": comic_url or url, + } + + +def extract_xkcd(url: str) -> ComicInfo: + """ + Extract comic information from an XKCD comic. + + This function parses an XKCD comic page to extract the title from the hover text, + the image URL, and the permanent URL of the comic. + + Args: + url: The URL of the XKCD comic to parse + + Returns: + ComicInfo with title, image_url, and comic_url + """ + response = requests.get(url) + response.raise_for_status() + soup = BeautifulSoup(response.text, "html.parser") + + def get_comic_img() -> Tag | None: + """Extract the comic image tag.""" + comic_div = soup.find("div", id="comic") + if comic_div and isinstance(comic_div, Tag): + img = comic_div.find("img") + return img if isinstance(img, Tag) else None + return None + + def get_title() -> str: + """Extract title from image title attribute with fallbacks.""" + # Primary source: hover text from the image (most informative) + img = get_comic_img() + if img and img.has_attr("title"): + return str(img["title"]) + + # Secondary source: og:title meta tag + og_title = soup.find("meta", property="og:title") + if og_title and isinstance(og_title, Tag) and og_title.has_attr("content"): + return str(og_title["content"]) + + # Last resort: page title div + title_div = soup.find("div", id="ctitle") + return title_div.text.strip() if title_div else "" + + def get_image_url() -> str: + """Extract and normalize the image URL.""" + img = get_comic_img() + if not img or not img.has_attr("src"): + return "" + + image_src = str(img["src"]) + return f"https:{image_src}" if image_src.startswith("//") else image_src + + def get_permanent_url() -> str: + """Extract the permanent URL to the comic.""" + og_url = soup.find("meta", property="og:url") + if og_url and isinstance(og_url, Tag) and og_url.has_attr("content"): + return str(og_url["content"]) + + # Fallback: look for permanent link text + for a_tag in soup.find_all("a"): + text = a_tag.get_text() + if text.startswith("https://xkcd.com/") and text.strip().endswith("/"): + return str(text.strip()) + return url + + return { + "title": get_title(), + "image_url": get_image_url(), + "url": get_permanent_url(), + } diff --git a/src/memory/common/qdrant.py b/src/memory/common/qdrant.py index 282f5f1..03e4105 100644 --- a/src/memory/common/qdrant.py +++ b/src/memory/common/qdrant.py @@ -1,5 +1,5 @@ import logging -from typing import Any, cast +from typing import Any, cast, Iterator, Sequence import qdrant_client from qdrant_client.http import models as qdrant_models @@ -7,7 +7,7 @@ from qdrant_client.http.exceptions import UnexpectedResponse from memory.common import settings from memory.common.embedding import ( Collection, - DEFAULT_COLLECTIONS, + ALL_COLLECTIONS, DistanceType, Vector, ) @@ -91,7 +91,7 @@ def initialize_collections( If None, defaults to the DEFAULT_COLLECTIONS. """ if collections is None: - collections = DEFAULT_COLLECTIONS + collections = ALL_COLLECTIONS logger.info(f"Initializing collections:") for name, params in collections.items(): @@ -184,13 +184,13 @@ def search_vectors( ) -def delete_vectors( +def delete_points( client: qdrant_client.QdrantClient, collection_name: str, ids: list[str], ) -> None: """ - Delete vectors from a collection. + Delete points from a collection. Args: client: Qdrant client @@ -222,3 +222,30 @@ def get_collection_info( """ info = client.get_collection(collection_name) return info.model_dump() + + +def batch_ids( + client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000 +) -> Iterator[list[str]]: + """Iterate over all IDs in a collection.""" + offset = None + while resp := client.scroll( + collection_name=collection_name, + with_vectors=False, + offset=offset, + limit=batch_size, + ): + points, offset = resp + yield [point.id for point in points] + + if not offset: + return + + +def find_missing_points( + client: qdrant_client.QdrantClient, collection_name: str, ids: Sequence[str] +) -> set[str]: + found = client.retrieve( + collection_name, ids=ids, with_payload=False, with_vectors=False + ) + return set(ids) - {str(r.id) for r in found} diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index b1d3715..80481b5 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -29,13 +29,27 @@ def make_db_url( DB_URL = os.getenv("DATABASE_URL", make_db_url()) +# Celery settings +RABBITMQ_USER = os.getenv("RABBITMQ_USER", "kb") +RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", "kb") +RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq") +CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}") + + +# File storage settings FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files")) FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True) CHUNK_STORAGE_DIR = pathlib.Path( os.getenv("CHUNK_STORAGE_DIR", FILE_STORAGE_DIR / "chunks") ) CHUNK_STORAGE_DIR.mkdir(parents=True, exist_ok=True) + +COMIC_STORAGE_DIR = pathlib.Path( + os.getenv("COMIC_STORAGE_DIR", FILE_STORAGE_DIR / "comics") +) +COMIC_STORAGE_DIR.mkdir(parents=True, exist_ok=True) + # Maximum attachment size to store directly in the database (10MB) MAX_INLINE_ATTACHMENT_SIZE = int( os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024) @@ -51,8 +65,12 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60")) # Worker settings +# Intervals are in seconds EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600)) +CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 86400)) +CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 3600)) +CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24)) # Embedding settings TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large") diff --git a/src/memory/workers/celery_app.py b/src/memory/workers/celery_app.py index 7879b77..f730bc8 100644 --- a/src/memory/workers/celery_app.py +++ b/src/memory/workers/celery_app.py @@ -1,18 +1,15 @@ -import os from celery import Celery from memory.common import settings def rabbit_url() -> str: - user = os.getenv("RABBITMQ_USER", "guest") - password = os.getenv("RABBITMQ_PASSWORD", "guest") - return f"amqp://{user}:{password}@rabbitmq:5672//" + return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:5672//" app = Celery( "memory", broker=rabbit_url(), - backend=os.getenv("CELERY_RESULT_BACKEND", f"db+{settings.DB_URL}") + backend=settings.CELERY_RESULT_BACKEND, ) @@ -27,16 +24,15 @@ app.conf.update( "memory.workers.tasks.text.*": {"queue": "medium_embed"}, "memory.workers.tasks.email.*": {"queue": "email"}, "memory.workers.tasks.photo.*": {"queue": "photo_embed"}, - "memory.workers.tasks.ocr.*": {"queue": "low_ocr"}, - "memory.workers.tasks.git.*": {"queue": "git_summary"}, - "memory.workers.tasks.rss.*": {"queue": "rss"}, + "memory.workers.tasks.comic.*": {"queue": "comic"}, "memory.workers.tasks.docs.*": {"queue": "docs"}, + "memory.workers.tasks.maintenance.*": {"queue": "maintenance"}, }, ) - @app.on_after_configure.connect def ensure_qdrant_initialised(sender, **_): from memory.common import qdrant - qdrant.setup_qdrant() \ No newline at end of file + + qdrant.setup_qdrant() diff --git a/src/memory/workers/ingest.py b/src/memory/workers/ingest.py index 8e63034..98dcd9f 100644 --- a/src/memory/workers/ingest.py +++ b/src/memory/workers/ingest.py @@ -1,11 +1,23 @@ +import logging -from memory.workers.celery_app import app from memory.common import settings +from memory.workers.celery_app import app +from memory.workers.tasks import CLEAN_ALL_COLLECTIONS, REINGEST_MISSING_CHUNKS + +logger = logging.getLogger(__name__) app.conf.beat_schedule = { - 'sync-mail-all': { - 'task': 'memory.workers.tasks.email.sync_all_accounts', - 'schedule': settings.EMAIL_SYNC_INTERVAL, + "sync-mail-all": { + "task": "memory.workers.tasks.email.sync_all_accounts", + "schedule": settings.EMAIL_SYNC_INTERVAL, }, -} \ No newline at end of file + "clean-all-collections": { + "task": CLEAN_ALL_COLLECTIONS, + "schedule": settings.CLEAN_COLLECTION_INTERVAL, + }, + "reingest-missing-chunks": { + "task": REINGEST_MISSING_CHUNKS, + "schedule": settings.CHUNK_REINGEST_INTERVAL, + }, +} diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py index bdd287d..cba4b0c 100644 --- a/src/memory/workers/tasks/__init__.py +++ b/src/memory/workers/tasks/__init__.py @@ -1,8 +1,24 @@ """ Import sub-modules so Celery can register their @app.task decorators. """ -from memory.workers.tasks import docs, email # noqa + +from memory.workers.tasks import docs, email, comic # noqa from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL +from memory.workers.tasks.maintenance import ( + CLEAN_ALL_COLLECTIONS, + CLEAN_COLLECTION, + REINGEST_MISSING_CHUNKS, +) -__all__ = ["docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"] \ No newline at end of file +__all__ = [ + "docs", + "email", + "comic", + "SYNC_ACCOUNT", + "SYNC_ALL_ACCOUNTS", + "PROCESS_EMAIL", + "CLEAN_ALL_COLLECTIONS", + "CLEAN_COLLECTION", + "REINGEST_MISSING_CHUNKS", +] diff --git a/src/memory/workers/tasks/comic.py b/src/memory/workers/tasks/comic.py new file mode 100644 index 0000000..549bd99 --- /dev/null +++ b/src/memory/workers/tasks/comic.py @@ -0,0 +1,172 @@ +import hashlib +import logging +from datetime import datetime +from typing import Callable + +import feedparser +import requests + +from memory.common import embedding, qdrant, settings +from memory.common.db.connection import make_session +from memory.common.db.models import Comic, clean_filename +from memory.common.parsers import comics +from memory.workers.celery_app import app + +logger = logging.getLogger(__name__) + + +SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics" +SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc" +SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd" +SYNC_COMIC = "memory.workers.tasks.comic.sync_comic" + + +BASE_SMBC_URL = "https://www.smbc-comics.com/" +SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss" + +BASE_XKCD_URL = "https://xkcd.com/" +XKCD_RSS_URL = "https://xkcd.com/atom.xml" + + +def find_new_urls(base_url: str, rss_url: str) -> set[str]: + try: + feed = feedparser.parse(rss_url) + except Exception as e: + logger.error(f"Failed to fetch or parse {rss_url}: {e}") + return set() + + urls = {item.get("link") or item.get("id") for item in feed.entries} + + with make_session() as session: + known = { + c.url + for c in session.query(Comic.url).filter( + Comic.author == base_url, + Comic.url.in_(urls), + ) + } + + return urls - known + + +def fetch_new_comics( + base_url: str, rss_url: str, parser: Callable[[str], comics.ComicInfo] +) -> set[str]: + new_urls = find_new_urls(base_url, rss_url) + + for url in new_urls: + data = parser(url) | {"author": base_url, "url": url} + sync_comic.delay(**data) + return new_urls + + +@app.task(name=SYNC_COMIC) +def sync_comic( + url: str, + image_url: str, + title: str, + author: str, + published_date: datetime | None = None, +): + """Synchronize a comic from a URL.""" + with make_session() as session: + if session.query(Comic).filter(Comic.url == url).first(): + return + + response = requests.get(image_url) + file_type = image_url.split(".")[-1] + mime_type = f"image/{file_type}" + filename = ( + settings.COMIC_STORAGE_DIR / clean_filename(author) / f"{title}.{file_type}" + ) + if response.status_code == 200: + filename.parent.mkdir(parents=True, exist_ok=True) + filename.write_bytes(response.content) + + sha256 = hashlib.sha256(f"{image_url}{published_date}".encode()).digest() + comic = Comic( + title=title, + url=url, + published=published_date, + author=author, + filename=filename.resolve().as_posix(), + mime_type=mime_type, + size=len(response.content), + sha256=sha256, + tags={"comic", author}, + modality="comic", + ) + chunk = embedding.embed_image(filename, [title, author]) + comic.chunks = [chunk] + + with make_session() as session: + session.add(comic) + session.add(chunk) + session.flush() + + qdrant.upsert_vectors( + client=qdrant.get_qdrant_client(), + collection_name="comic", + ids=[str(chunk.id)], + vectors=[chunk.vector], + payloads=[comic.as_payload()], + ) + + session.commit() + + +@app.task(name=SYNC_SMBC) +def sync_smbc() -> set[str]: + """Synchronize SMBC comics from RSS feed.""" + return fetch_new_comics(BASE_SMBC_URL, SMBC_RSS_URL, comics.extract_smbc) + + +@app.task(name=SYNC_XKCD) +def sync_xkcd() -> set[str]: + """Synchronize XKCD comics from RSS feed.""" + return fetch_new_comics(BASE_XKCD_URL, XKCD_RSS_URL, comics.extract_xkcd) + + +@app.task(name=SYNC_ALL_COMICS) +def sync_all_comics(): + """Synchronize all active comics.""" + sync_smbc.delay() + sync_xkcd.delay() + + +@app.task(name="memory.workers.tasks.comic.full_sync_comic") +def trigger_comic_sync(): + def prev_smbc_comic(url: str) -> str | None: + from bs4 import BeautifulSoup + + response = requests.get(url) + soup = BeautifulSoup(response.text, "html.parser") + if link := soup.find("a", attrs={"class", "cc-prev"}): + return link.attrs["href"] + return None + + next_url = "https://www.smbc-comics.com" + urls = [] + logger.info(f"syncing {next_url}") + while next_url := prev_smbc_comic(next_url): + if len(urls) % 10 == 0: + logger.info(f"got {len(urls)}") + try: + data = comics.extract_smbc(next_url) | { + "author": "https://www.smbc-comics.com/" + } + sync_comic.delay(**data) + except Exception as e: + logger.error(f"failed to sync {next_url}: {e}") + urls.append(next_url) + + logger.info(f"syncing {BASE_XKCD_URL}") + for i in range(1, 308): + if i % 10 == 0: + logger.info(f"got {i}") + url = f"{BASE_XKCD_URL}/{i}" + try: + data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"} + sync_comic.delay(**data) + except Exception as e: + logger.error(f"failed to sync {url}: {e}") diff --git a/src/memory/workers/tasks/maintenance.py b/src/memory/workers/tasks/maintenance.py new file mode 100644 index 0000000..243c13b --- /dev/null +++ b/src/memory/workers/tasks/maintenance.py @@ -0,0 +1,157 @@ +import logging +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import contains_eager + +from memory.common import embedding, qdrant, settings +from memory.common.db.connection import make_session +from memory.common.db.models import Chunk, SourceItem +from memory.workers.celery_app import app + +logger = logging.getLogger(__name__) + + +CLEAN_ALL_COLLECTIONS = "memory.workers.tasks.maintenance.clean_all_collections" +CLEAN_COLLECTION = "memory.workers.tasks.maintenance.clean_collection" +REINGEST_MISSING_CHUNKS = "memory.workers.tasks.maintenance.reingest_missing_chunks" +REINGEST_CHUNK = "memory.workers.tasks.maintenance.reingest_chunk" + + +@app.task(name=CLEAN_COLLECTION) +def clean_collection(collection: str) -> dict[str, int]: + logger.info(f"Cleaning collection {collection}") + client = qdrant.get_qdrant_client() + batches, deleted, checked = 0, 0, 0 + for batch in qdrant.batch_ids(client, collection): + batches += 1 + batch_ids = set(batch) + with make_session() as session: + db_ids = { + str(c.id) for c in session.query(Chunk).filter(Chunk.id.in_(batch_ids)) + } + ids_to_delete = batch_ids - db_ids + checked += len(batch_ids) + if ids_to_delete: + qdrant.delete_points(client, collection, list(ids_to_delete)) + deleted += len(ids_to_delete) + return { + "batches": batches, + "deleted": deleted, + "checked": checked, + } + + +@app.task(name=CLEAN_ALL_COLLECTIONS) +def clean_all_collections(): + logger.info("Cleaning all collections") + for collection in embedding.ALL_COLLECTIONS: + clean_collection.delay(collection) + + +@app.task(name=REINGEST_CHUNK) +def reingest_chunk(chunk_id: str, collection: str): + logger.info(f"Reingesting chunk {chunk_id}") + with make_session() as session: + chunk = session.query(Chunk).get(chunk_id) + if not chunk: + logger.error(f"Chunk {chunk_id} not found") + return + + if collection not in embedding.ALL_COLLECTIONS: + raise ValueError(f"Unsupported collection {collection}") + + data = chunk.data + if collection in embedding.MULTIMODAL_COLLECTIONS: + vector = embedding.embed_mixed(data)[0] + elif len(data) == 1 and isinstance(data[0], str): + vector = embedding.embed_text([data[0]])[0] + else: + raise ValueError(f"Unsupported data type for collection {collection}") + + client = qdrant.get_qdrant_client() + qdrant.upsert_vectors( + client, + collection, + [chunk_id], + [vector], + [chunk.source.as_payload()], + ) + chunk.checked_at = datetime.now() + session.commit() + + +def check_batch(batch: Sequence[Chunk]) -> dict: + client = qdrant.get_qdrant_client() + by_collection = defaultdict(list) + for chunk in batch: + by_collection[chunk.source.modality].append(chunk) + + stats = {} + for collection, chunks in by_collection.items(): + missing = qdrant.find_missing_points( + client, collection, [str(c.id) for c in chunks] + ) + + for chunk in chunks: + if str(chunk.id) in missing: + reingest_chunk.delay(str(chunk.id), collection) + else: + chunk.checked_at = datetime.now() + + stats[collection] = { + "missing": len(missing), + "correct": len(chunks) - len(missing), + "total": len(chunks), + } + + return stats + + +@app.task(name=REINGEST_MISSING_CHUNKS) +def reingest_missing_chunks(batch_size: int = 1000): + logger.info("Reingesting missing chunks") + total_stats = defaultdict(lambda: {"missing": 0, "correct": 0, "total": 0}) + since = datetime.now() - timedelta(minutes=settings.CHUNK_REINGEST_SINCE_MINUTES) + + with make_session() as session: + total_count = session.query(Chunk).filter(Chunk.checked_at < since).count() + + logger.info( + f"Found {total_count} chunks to check, processing in batches of {batch_size}" + ) + + num_batches = (total_count + batch_size - 1) // batch_size + + for batch_num in range(num_batches): + stmt = ( + select(Chunk) + .join(SourceItem, Chunk.source_id == SourceItem.id) + .filter(Chunk.checked_at < since) + .options( + contains_eager(Chunk.source).load_only( + SourceItem.id, SourceItem.modality, SourceItem.tags + ) + ) + .order_by(Chunk.id) + .limit(batch_size) + ) + chunks = session.execute(stmt).scalars().all() + + if not chunks: + break + + logger.info( + f"Processing batch {batch_num + 1}/{num_batches} with {len(chunks)} chunks" + ) + batch_stats = check_batch(chunks) + session.commit() + + for collection, stats in batch_stats.items(): + total_stats[collection]["missing"] += stats["missing"] + total_stats[collection]["correct"] += stats["correct"] + total_stats[collection]["total"] += stats["total"] + + return dict(total_stats) diff --git a/tests/memory/common/parsers/test_comic.py b/tests/memory/common/parsers/test_comic.py new file mode 100644 index 0000000..79f22e6 --- /dev/null +++ b/tests/memory/common/parsers/test_comic.py @@ -0,0 +1,199 @@ +import textwrap +from memory.common.parsers.comics import extract_smbc, extract_xkcd +import pytest +from unittest.mock import patch, Mock +import requests + + +MOCK_SMBC_HTML = """ + + + + Saturday Morning Breakfast Cereal - Time + + + + + +
+ +
+ + + +""" + + +@pytest.mark.parametrize( + "to_remove, overrides", + [ + # Normal case - all data present + ("", {}), + # Missing title attribute on image + ( + 'title="I don\'t know why either, but it was fun to draw."', + {"title": "Saturday Morning Breakfast Cereal - Time"}, + ), + # Missing src attribute on image + ( + 'src="https://www.smbc-comics.com/comics/1746375102-20250504.webp"', + {"image_url": ""}, + ), + # # Missing entire img tag + ( + '', + {"title": "Saturday Morning Breakfast Cereal - Time", "image_url": ""}, + ), + # # Corrupt JSON-LD data + ( + '"datePublished":"2025-05-04T12:11:21-04:00"', + {"published_date": "", "url": "http://www.smbc-comics.com/comic/time-6"}, + ), + # # Missing JSON-LD script entirely + ( + '', + {"published_date": "", "url": "http://www.smbc-comics.com/comic/time-6"}, + ), + # # Missing permalink input + ( + '', + {}, + ), + # Missing URL in JSON-LD + ( + '"url": "https://www.smbc-comics.com/comic/time-6",', + {"url": "http://www.smbc-comics.com/comic/time-6"}, + ), + ], +) +def test_extract_smbc(to_remove, overrides): + """Test successful extraction of comic info from SMBC.""" + expected = { + "title": "I don't know why either, but it was fun to draw.", + "image_url": "https://www.smbc-comics.com/comics/1746375102-20250504.webp", + "published_date": "2025-05-04T12:11:21-04:00", + "url": "https://www.smbc-comics.com/comic/time-6", + } + with patch("requests.get") as mock_get: + mock_get.return_value.status_code = 200 + mock_get.return_value.text = MOCK_SMBC_HTML.replace(to_remove, "") + assert extract_smbc("https://www.smbc-comics.com/") == expected | overrides + + +# Create a stripped-down version of the XKCD HTML +MOCK_XKCD_HTML = """ + + + +xkcd: Unstoppable Force and Immovable Object + + + + + +
Unstoppable Force and Immovable Object
+ +
+Unstoppable Force and Immovable Object +
+ +Permanent link to this comic: https://xkcd.com/3084/
+Image URL (for hotlinking/embedding): +https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png + + +""" + + +@pytest.mark.parametrize( + "to_remove, overrides", + [ + # Normal case - all data present + ("", {}), + # Missing title attribute on image + ( + 'title="Unstoppable force-carrying particles can't interact with immovable matter by definition."', + { + "title": "Unstoppable Force and Immovable Object" + }, # Falls back to og:title + ), + # Missing og:title meta tag - falls back to ctitle + ( + '', + {}, # Still gets title from image title + ), + # Missing both title and og:title - falls back to ctitle + ( + 'title="Unstoppable force-carrying particles can't interact with immovable matter by definition."\n', + {"title": "Unstoppable Force and Immovable Object"}, # Falls back to ctitle + ), + # Missing image src attribute + ( + 'src="//imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png"', + {"image_url": ""}, + ), + # Missing entire img tag + ( + 'Unstoppable Force and Immovable Object', + { + "image_url": "", + "title": "Unstoppable Force and Immovable Object", + }, # Falls back to og:title + ), + # Missing entire comic div + ( + '
\nUnstoppable Force and Immovable Object\n
', + {"image_url": "", "title": "Unstoppable Force and Immovable Object"}, + ), + # Missing og:url tag + ( + '', + {}, # Should fallback to permalink link + ), + # Missing permanent link + ( + 'Permanent link to this comic: https://xkcd.com/3084/
', + {"url": "https://xkcd.com/3084/"}, # Should still get URL from og:url + ), + # Missing both og:url and permanent link + ( + '\nPermanent link to this comic: https://xkcd.com/3084/
', + {"url": "https://xkcd.com/test"}, # Falls back to original URL + ), + ], +) +def test_extract_xkcd(to_remove, overrides): + """Test successful extraction of comic info from XKCD.""" + expected = { + "title": "Unstoppable force-carrying particles can't interact with immovable matter by definition.", + "image_url": "https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png", + "url": "https://xkcd.com/3084/", + } + + with patch("requests.get") as mock_get: + mock_get.return_value.status_code = 200 + modified_html = MOCK_XKCD_HTML + for item in to_remove.split("\n"): + modified_html = modified_html.replace(item, "") + + mock_get.return_value.text = modified_html + result = extract_xkcd("https://xkcd.com/test") + assert result == expected | overrides diff --git a/tests/memory/common/test_qdrant.py b/tests/memory/common/test_qdrant.py index 8cf34dc..0cd037a 100644 --- a/tests/memory/common/test_qdrant.py +++ b/tests/memory/common/test_qdrant.py @@ -1,40 +1,43 @@ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, Mock import qdrant_client from qdrant_client.http import models as qdrant_models from qdrant_client.http.exceptions import UnexpectedResponse from memory.common.qdrant import ( - DEFAULT_COLLECTIONS, + ALL_COLLECTIONS, ensure_collection_exists, initialize_collections, upsert_vectors, - delete_vectors, + delete_points, + batch_ids, ) @pytest.fixture def mock_qdrant_client(): - with patch.object(qdrant_client, "QdrantClient", return_value=MagicMock()) as mock_client: + with patch.object( + qdrant_client, "QdrantClient", return_value=MagicMock() + ) as mock_client: yield mock_client def test_ensure_collection_exists_existing(mock_qdrant_client): mock_qdrant_client.get_collection.return_value = {} assert not ensure_collection_exists(mock_qdrant_client, "test_collection", 128) - + mock_qdrant_client.get_collection.assert_called_once_with("test_collection") mock_qdrant_client.create_collection.assert_not_called() def test_ensure_collection_exists_new(mock_qdrant_client): mock_qdrant_client.get_collection.side_effect = UnexpectedResponse( - status_code=404, reason_phrase='asd', content=b'asd', headers=None + status_code=404, reason_phrase="asd", content=b"asd", headers=None ) - + assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128) - + mock_qdrant_client.get_collection.assert_called_once_with("test_collection") mock_qdrant_client.create_collection.assert_called_once() mock_qdrant_client.create_payload_index.assert_called_once() @@ -42,22 +45,22 @@ def test_ensure_collection_exists_new(mock_qdrant_client): def test_initialize_collections(mock_qdrant_client): initialize_collections(mock_qdrant_client) - - assert mock_qdrant_client.get_collection.call_count == len(DEFAULT_COLLECTIONS) + + assert mock_qdrant_client.get_collection.call_count == len(ALL_COLLECTIONS) def test_upsert_vectors(mock_qdrant_client): ids = ["1", "2"] vectors = [[0.1, 0.2], [0.3, 0.4]] payloads = [{"tag": "test1"}, {"tag": "test2"}] - + upsert_vectors(mock_qdrant_client, "test_collection", ids, vectors, payloads) - + mock_qdrant_client.upsert.assert_called_once() args, kwargs = mock_qdrant_client.upsert.call_args assert kwargs["collection_name"] == "test_collection" assert len(kwargs["points"]) == 2 - + # Check points were created correctly points = kwargs["points"] assert points[0].id == "1" @@ -70,12 +73,24 @@ def test_upsert_vectors(mock_qdrant_client): def test_delete_vectors(mock_qdrant_client): ids = ["1", "2"] - - delete_vectors(mock_qdrant_client, "test_collection", ids) - + + delete_points(mock_qdrant_client, "test_collection", ids) + mock_qdrant_client.delete.assert_called_once() args, kwargs = mock_qdrant_client.delete.call_args - + assert kwargs["collection_name"] == "test_collection" assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList) - assert kwargs["points_selector"].points == ids \ No newline at end of file + assert kwargs["points_selector"].points == ids + + +def test_batch_ids(mock_qdrant_client): + mock_qdrant_client.scroll.side_effect = [ + ([Mock(id="1"), Mock(id="2")], "3"), + ([Mock(id="3"), Mock(id="4")], None), + ] + + assert list(batch_ids(mock_qdrant_client, "test_collection")) == [ + ["1", "2"], + ["3", "4"], + ] diff --git a/tests/memory/workers/tasks/test_maintenance.py b/tests/memory/workers/tasks/test_maintenance.py new file mode 100644 index 0000000..ca43183 --- /dev/null +++ b/tests/memory/workers/tasks/test_maintenance.py @@ -0,0 +1,310 @@ +import uuid +from datetime import datetime, timedelta +from unittest.mock import patch, call + +import pytest +from PIL import Image + +from memory.common import qdrant as qd +from memory.common import embedding, settings +from memory.common.db.models import Chunk, SourceItem +from memory.workers.tasks.maintenance import ( + clean_collection, + reingest_chunk, + check_batch, + reingest_missing_chunks, +) + + +@pytest.fixture +def source(db_session): + s = SourceItem(id=1, modality="text", sha256=b"123") + db_session.add(s) + db_session.commit() + return s + + +@pytest.fixture +def mock_uuid4(): + i = 0 + + def uuid4(): + nonlocal i + i += 1 + return f"00000000-0000-0000-0000-00000000000{i}" + + with patch("uuid.uuid4", side_effect=uuid4): + yield + + +@pytest.fixture +def test_image(mock_file_storage): + img = Image.new("RGB", (100, 100), color=(73, 109, 137)) + img_path = settings.CHUNK_STORAGE_DIR / "test.png" + img.save(img_path) + return img_path + + +@pytest.fixture(params=["text", "photo"]) +def chunk(request, test_image, db_session): + """Parametrized fixture for chunk configuration""" + collection = request.param + if collection == "photo": + content = None + file_path = str(test_image) + else: + content = "Test content for reingestion" + file_path = None + + chunk = Chunk( + id=str(uuid.uuid4()), + source=SourceItem(id=1, modality=collection, sha256=b"123"), + content=content, + file_path=file_path, + embedding_model="test-model", + checked_at=datetime(2025, 1, 1), + ) + db_session.add(chunk) + db_session.commit() + return chunk + + +def test_clean_collection_no_mismatches(db_session, qdrant, source): + """Test when all Qdrant points exist in the database - nothing should be deleted.""" + # Create chunks in the database + chunk_ids = [str(uuid.uuid4()) for _ in range(3000)] + collection = "text" + + # Add chunks to the database + for chunk_id in chunk_ids: + db_session.add( + Chunk( + id=chunk_id, + source=source, + content="Test content", + embedding_model="test-model", + ) + ) + db_session.commit() + qd.ensure_collection_exists(qdrant, collection, 1024) + qd.upsert_vectors(qdrant, collection, chunk_ids, [[1] * 1024] * len(chunk_ids)) + + assert set(chunk_ids) == { + str(i) for batch in qd.batch_ids(qdrant, collection) for i in batch + } + + clean_collection(collection) + + # Check that the chunks are still in the database - no points were deleted + assert set(chunk_ids) == { + str(i) for batch in qd.batch_ids(qdrant, collection) for i in batch + } + + +def test_clean_collection_with_orphaned_vectors(db_session, qdrant, source): + """Test when there are vectors in Qdrant that don't exist in the database.""" + existing_ids = [str(uuid.uuid4()) for _ in range(3000)] + orphaned_ids = [str(uuid.uuid4()) for _ in range(3000)] + all_ids = existing_ids + orphaned_ids + collection = "text" + + # Add only the existing chunks to the database + for chunk_id in existing_ids: + db_session.add( + Chunk( + id=chunk_id, + source=source, + content="Test content", + embedding_model="test-model", + ) + ) + db_session.commit() + qd.ensure_collection_exists(qdrant, collection, 1024) + qd.upsert_vectors(qdrant, collection, all_ids, [[1] * 1024] * len(all_ids)) + + clean_collection(collection) + + # The orphaned vectors should be deleted + assert set(existing_ids) == { + str(i) for batch in qd.batch_ids(qdrant, collection) for i in batch + } + + +def test_clean_collection_empty_batches(db_session, qdrant): + collection = "text" + qd.ensure_collection_exists(qdrant, collection, 1024) + + clean_collection(collection) + + assert not [i for b in qd.batch_ids(qdrant, collection) for i in b] + + +def test_reingest_chunk(db_session, qdrant, chunk): + """Test reingesting a chunk using parameterized fixtures""" + collection = chunk.source.modality + qd.ensure_collection_exists(qdrant, collection, 1024) + + start = datetime.now() + test_vector = [0.1] * 1024 + + with patch.object(embedding, "embed_chunks", return_value=[test_vector]): + reingest_chunk(str(chunk.id), collection) + + vectors = qd.search_vectors(qdrant, collection, test_vector, limit=1) + assert len(vectors) == 1 + assert str(vectors[0].id) == str(chunk.id) + assert vectors[0].payload == chunk.source.as_payload() + db_session.refresh(chunk) + assert chunk.checked_at.isoformat() > start.isoformat() + + +def test_reingest_chunk_not_found(db_session, qdrant): + """Test when the chunk to reingest doesn't exist.""" + non_existent_id = str(uuid.uuid4()) + collection = "text" + + reingest_chunk(non_existent_id, collection) + + +def test_reingest_chunk_unsupported_collection(db_session, qdrant, source): + """Test reingesting with an unsupported collection type.""" + chunk_id = str(uuid.uuid4()) + chunk = Chunk( + id=chunk_id, + source=source, + content="Test content", + embedding_model="test-model", + ) + db_session.add(chunk) + db_session.commit() + + unsupported_collection = "unsupported" + qd.ensure_collection_exists(qdrant, unsupported_collection, 1024) + + with pytest.raises( + ValueError, match=f"Unsupported collection {unsupported_collection}" + ): + reingest_chunk(chunk_id, unsupported_collection) + + +def test_check_batch_empty(db_session, qdrant): + assert check_batch([]) == {} + + +def test_check_batch(db_session, qdrant): + modalities = ["text", "photo", "mail"] + chunks = [ + Chunk( + id=f"00000000-0000-0000-0000-0000000000{i:02d}", + source=SourceItem(modality=modality, sha256=f"123{i}".encode()), + content="Test content", + file_path=None, + embedding_model="test-model", + checked_at=datetime(2025, 1, 1), + ) + for modality in modalities + for i in range(5) + ] + db_session.add_all(chunks) + db_session.commit() + start_time = datetime.now() + + for modality in modalities: + qd.ensure_collection_exists(qdrant, modality, 1024) + + for chunk in chunks[::2]: + qd.upsert_vectors(qdrant, chunk.source.modality, [str(chunk.id)], [[1] * 1024]) + + with patch.object(reingest_chunk, "delay") as mock_reingest: + stats = check_batch(chunks) + + assert mock_reingest.call_args_list == [ + call("00000000-0000-0000-0000-000000000001", "text"), + call("00000000-0000-0000-0000-000000000003", "text"), + call("00000000-0000-0000-0000-000000000000", "photo"), + call("00000000-0000-0000-0000-000000000002", "photo"), + call("00000000-0000-0000-0000-000000000004", "photo"), + call("00000000-0000-0000-0000-000000000001", "mail"), + call("00000000-0000-0000-0000-000000000003", "mail"), + ] + assert stats == { + "mail": {"correct": 3, "missing": 2, "total": 5}, + "text": {"correct": 3, "missing": 2, "total": 5}, + "photo": {"correct": 2, "missing": 3, "total": 5}, + } + db_session.commit() + for chunk in chunks[::2]: + assert chunk.checked_at.isoformat() > start_time.isoformat() + for chunk in chunks[1::2]: + assert chunk.checked_at.isoformat()[:19] == "2025-01-01T00:00:00" + + +@pytest.mark.parametrize("batch_size", [4, 10, 100]) +def test_reingest_missing_chunks(db_session, qdrant, batch_size): + now = datetime.now() + old_time = now - timedelta(minutes=120) # Older than the threshold + + modalities = ["text", "photo", "mail"] + ids_generator = (f"00000000-0000-0000-0000-00000000{i:04d}" for i in range(1000)) + + old_chunks = [ + Chunk( + id=next(ids_generator), + source=SourceItem(modality=modality, sha256=f"{modality}-{i}".encode()), + content="Old content", + file_path=None, + embedding_model="test-model", + checked_at=old_time, + ) + for modality in modalities + for i in range(20) + ] + + recent_chunks = [ + Chunk( + id=next(ids_generator), + source=SourceItem( + modality=modality, sha256=f"recent-{modality}-{i}".encode() + ), + content="Recent content", + file_path=None, + embedding_model="test-model", + checked_at=now, + ) + for modality in modalities + for i in range(5) + ] + + db_session.add_all(old_chunks + recent_chunks) + db_session.commit() + + for modality in modalities: + qd.ensure_collection_exists(qdrant, modality, 1024) + + for chunk in old_chunks[::2]: + qd.upsert_vectors(qdrant, chunk.source.modality, [str(chunk.id)], [[1] * 1024]) + + with patch.object(reingest_chunk, "delay", reingest_chunk): + with patch.object(settings, "CHUNK_REINGEST_SINCE_MINUTES", 60): + with patch.object(embedding, "embed_chunks", return_value=[[1] * 1024]): + result = reingest_missing_chunks(batch_size=batch_size) + + assert result == { + "photo": {"correct": 10, "missing": 10, "total": 20}, + "mail": {"correct": 10, "missing": 10, "total": 20}, + "text": {"correct": 10, "missing": 10, "total": 20}, + } + + db_session.commit() + # All the old chunks should be reingested + client = qd.get_qdrant_client() + for modality in modalities: + qdrant_ids = [ + i for b in qd.batch_ids(client, modality, batch_size=1000) for i in b + ] + db_ids = [str(c.id) for c in old_chunks if c.source.modality == modality] + assert set(qdrant_ids) == set(db_ids) + + +def test_reingest_missing_chunks_no_chunks(db_session): + assert reingest_missing_chunks() == {}