add commics

This commit is contained in:
Daniel O'Connell 2025-05-20 21:28:26 +02:00
parent 743a76c3d1
commit 8cfaeaea72
23 changed files with 1335 additions and 101 deletions

View File

@ -1,6 +1,7 @@
""" """
Alembic environment configuration. Alembic environment configuration.
""" """
from logging.config import fileConfig from logging.config import fileConfig
from sqlalchemy import engine_from_config from sqlalchemy import engine_from_config
@ -29,6 +30,18 @@ target_metadata = Base.metadata
# can be acquired: # can be acquired:
# my_important_option = config.get_main_option("my_important_option") # my_important_option = config.get_main_option("my_important_option")
# Tables to exclude from auto-generation - these are managed by Celery
excluded_tables = {
"celery_taskmeta",
"celery_tasksetmeta",
}
def include_object(object, name, type_, reflected, compare_to):
if type_ == "table" and name in excluded_tables:
return False
return True
def run_migrations_offline() -> None: def run_migrations_offline() -> None:
""" """
@ -48,6 +61,7 @@ def run_migrations_offline() -> None:
target_metadata=target_metadata, target_metadata=target_metadata,
literal_binds=True, literal_binds=True,
dialect_opts={"paramstyle": "named"}, dialect_opts={"paramstyle": "named"},
include_object=include_object,
) )
with context.begin_transaction(): with context.begin_transaction():
@ -69,7 +83,9 @@ def run_migrations_online() -> None:
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(
connection=connection, target_metadata=target_metadata connection=connection,
target_metadata=target_metadata,
include_object=include_object,
) )
with context.begin_transaction(): with context.begin_transaction():

View File

@ -21,7 +21,7 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto") op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
op.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"") op.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"')
op.create_table( op.create_table(
"email_accounts", "email_accounts",
sa.Column("id", sa.BigInteger(), nullable=False), sa.Column("id", sa.BigInteger(), nullable=False),
@ -82,12 +82,6 @@ def upgrade() -> None:
server_default=sa.text("now()"), server_default=sa.text("now()"),
nullable=False, nullable=False,
), ),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("url"), sa.UniqueConstraint("url"),
) )
@ -269,7 +263,7 @@ def upgrade() -> None:
"misc_doc", "misc_doc",
sa.Column("id", sa.BigInteger(), nullable=False), sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("path", sa.Text(), nullable=True), sa.Column("path", sa.Text(), nullable=True),
sa.Column("mime_type", sa.Text(), nullable=True), sa.Column("mime_type", sa.TEXT(), autoincrement=False, nullable=True),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"), sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )

View File

@ -0,0 +1,40 @@
"""Update field for chunks
Revision ID: d292d48ec74e
Revises: 4684845ca51e
Create Date: 2025-05-04 12:56:53.231393
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "d292d48ec74e"
down_revision: Union[str, None] = "4684845ca51e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"chunk",
sa.Column(
"checked_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
)
op.drop_column("misc_doc", "mime_type")
def downgrade() -> None:
op.add_column(
"misc_doc",
sa.Column("mime_type", sa.TEXT(), autoincrement=False, nullable=True),
)
op.drop_column("chunk", "checked_at")

View File

@ -0,0 +1,43 @@
"""Add comics
Revision ID: b78b1fff9974
Revises: d292d48ec74e
Create Date: 2025-05-04 23:45:52.733301
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b78b1fff9974'
down_revision: Union[str, None] = 'd292d48ec74e'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('comic',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('author', sa.Text(), nullable=True),
sa.Column('published', sa.DateTime(timezone=True), nullable=True),
sa.Column('volume', sa.Text(), nullable=True),
sa.Column('issue', sa.Text(), nullable=True),
sa.Column('page', sa.Integer(), nullable=True),
sa.Column('url', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['id'], ['source_item.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('comic_author_idx', 'comic', ['author'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('comic_author_idx', table_name='comic')
op.drop_table('comic')
# ### end Alembic commands ###

View File

@ -21,6 +21,7 @@ volumes:
# ------------------------------ X-templates ---------------------------- # ------------------------------ X-templates ----------------------------
x-common-env: &env x-common-env: &env
RABBITMQ_USER: kb RABBITMQ_USER: kb
RABBITMQ_HOST: rabbitmq
QDRANT_HOST: qdrant QDRANT_HOST: qdrant
DB_HOST: postgres DB_HOST: postgres
FILE_STORAGE_DIR: /app/memory_files FILE_STORAGE_DIR: /app/memory_files
@ -173,49 +174,28 @@ services:
environment: environment:
<<: *worker-env <<: *worker-env
QUEUES: "email" QUEUES: "email"
deploy: { resources: { limits: { cpus: "2", memory: 3g } } } # deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
worker-text: worker-text:
<<: *worker-base <<: *worker-base
environment: environment:
<<: *worker-env <<: *worker-env
QUEUES: "medium_embed" QUEUES: "medium_embed"
deploy: { resources: { limits: { cpus: "2", memory: 3g } } } # deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
worker-photo: worker-photo:
<<: *worker-base <<: *worker-base
environment: environment:
<<: *worker-env <<: *worker-env
QUEUES: "photo_embed" QUEUES: "photo_embed,comic"
deploy: { resources: { limits: { cpus: "4", memory: 4g } } } # deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
worker-ocr: worker-maintenance:
<<: *worker-base <<: *worker-base
environment: environment:
<<: *worker-env <<: *worker-env
QUEUES: "low_ocr" QUEUES: "maintenance"
deploy: { resources: { limits: { cpus: "4", memory: 4g } } } # deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
worker-git:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "git_summary"
deploy: { resources: { limits: { cpus: "1", memory: 1g } } }
worker-rss:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "rss"
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
worker-docs:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "docs"
deploy: { resources: { limits: { cpus: "1", memory: 1g } } }
ingest-hub: ingest-hub:
<<: *worker-base <<: *worker-base

View File

@ -9,7 +9,7 @@ COPY src/ ./src/
# Install dependencies # Install dependencies
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
libpq-dev gcc supervisor && \ libpq-dev gcc supervisor && \
pip install -e ".[workers]" && \ pip install -e ".[workers]" && \
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
@ -28,8 +28,7 @@ RUN mkdir -p /app/memory_files
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files
USER kb USER kb
# Default queues to process ENV QUEUES="docs,email,maintenance"
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email"
ENV PYTHONPATH="/app" ENV PYTHONPATH="/app"
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"] ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]

View File

@ -35,7 +35,7 @@ RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
USER kb USER kb
# Default queues to process # Default queues to process
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email" ENV QUEUES="docs,email,maintenance"
ENV PYTHONPATH="/app" ENV PYTHONPATH="/app"
ENTRYPOINT ["./entry.sh"] ENTRYPOINT ["./entry.sh"]

View File

@ -2,3 +2,5 @@ celery==5.3.6
openai==1.25.0 openai==1.25.0
pillow==10.3.0 pillow==10.3.0
pypandoc==1.15.0 pypandoc==1.15.0
beautifulsoup4==4.13.4
feedparser==6.0.10

View File

@ -101,7 +101,7 @@ def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[Search
and source.filename.replace( and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files" str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
), ),
content=source.content, content=source.display_contents,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True), chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
) )
for source, chunks in items.items() for source, chunks in items.items()
@ -143,7 +143,7 @@ def query_chunks(
) )
if r.score >= min_score if r.score >= min_score
] ]
for collection in embedding.DEFAULT_COLLECTIONS for collection in embedding.ALL_COLLECTIONS
} }
@ -164,7 +164,7 @@ async def search(
modalities: Annotated[list[str], Query()] = [], modalities: Annotated[list[str], Query()] = [],
files: list[UploadFile] = File([]), files: list[UploadFile] = File([]),
limit: int = Query(10, ge=1, le=100), limit: int = Query(10, ge=1, le=100),
min_text_score: float = Query(0.5, ge=0.0, le=1.0), min_text_score: float = Query(0.3, ge=0.0, le=1.0),
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0), min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
): ):
""" """
@ -181,11 +181,11 @@ async def search(
""" """
upload_data = [await input_type(item) for item in [query, *files]] upload_data = [await input_type(item) for item in [query, *files]]
logger.error( logger.error(
f"Querying chunks for {modalities}, query: {query}, previews: {previews}" f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}"
) )
client = qdrant.get_qdrant_client() client = qdrant.get_qdrant_client()
allowed_modalities = set(modalities or embedding.DEFAULT_COLLECTIONS.keys()) allowed_modalities = set(modalities or embedding.ALL_COLLECTIONS.keys())
text_results = query_chunks( text_results = query_chunks(
client, client,
upload_data, upload_data,

View File

@ -1,6 +1,7 @@
""" """
Database connection utilities. Database connection utilities.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session

View File

@ -6,7 +6,8 @@ import pathlib
import re import re
from email.message import EmailMessage from email.message import EmailMessage
from pathlib import Path from pathlib import Path
from typing import Any import textwrap
from typing import Any, ClassVar
from PIL import Image from PIL import Image
from sqlalchemy import ( from sqlalchemy import (
ARRAY, ARRAY,
@ -99,8 +100,9 @@ class Chunk(Base):
content = Column(Text) # Direct content storage content = Column(Text) # Direct content storage
embedding_model = Column(Text) embedding_model = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
vector = list[float] | None # the vector generated by the embedding model checked_at = Column(DateTime(timezone=True), server_default=func.now())
item_metadata = dict[str, Any] | None vector: ClassVar[list[float] | None] = None
item_metadata: ClassVar[dict[str, Any] | None] = None
# One of file_path or content must be populated # One of file_path or content must be populated
__table_args__ = ( __table_args__ = (
@ -121,7 +123,7 @@ class Chunk(Base):
items = [] items = []
for file_path in files: for file_path in files:
if file_path.suffix == ".png": if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
if file_path.exists(): if file_path.exists():
items.append(Image.open(file_path)) items.append(Image.open(file_path))
elif file_path.suffix == ".bin": elif file_path.suffix == ".bin":
@ -172,6 +174,16 @@ class SourceItem(Base):
"""Get vector IDs from associated chunks.""" """Get vector IDs from associated chunks."""
return [chunk.id for chunk in self.chunks] return [chunk.id for chunk in self.chunks]
def as_payload(self) -> dict:
return {
"source_id": self.id,
"tags": self.tags,
}
@property
def display_contents(self) -> str | None:
return self.content or self.filename
class MailMessage(SourceItem): class MailMessage(SourceItem):
__tablename__ = "mail_message" __tablename__ = "mail_message"
@ -236,6 +248,26 @@ class MailMessage(SourceItem):
def body(self) -> str: def body(self) -> str:
return self.parsed_content["body"] return self.parsed_content["body"]
@property
def display_contents(self) -> str | None:
content = self.parsed_content
return textwrap.dedent(
"""
Subject: {subject}
From: {sender}
To: {recipients}
Date: {date}
Body:
{body}
"""
).format(
subject=content.get("subject", ""),
sender=content.get("from", ""),
recipients=content.get("to", ""),
date=content.get("date", ""),
body=content.get("body", ""),
)
# Add indexes # Add indexes
__table_args__ = ( __table_args__ = (
Index("mail_sent_idx", "sent_at"), Index("mail_sent_idx", "sent_at"),
@ -341,6 +373,41 @@ class Photo(SourceItem):
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),) __table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
class Comic(SourceItem):
__tablename__ = "comic"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
title = Column(Text)
author = Column(Text, nullable=True)
published = Column(DateTime(timezone=True), nullable=True)
volume = Column(Text, nullable=True)
issue = Column(Text, nullable=True)
page = Column(Integer, nullable=True)
url = Column(Text, nullable=True)
__mapper_args__ = {
"polymorphic_identity": "comic",
}
__table_args__ = (Index("comic_author_idx", "author"),)
def as_payload(self) -> dict:
payload = {
"source_id": self.id,
"tags": self.tags,
"title": self.title,
"author": self.author,
"published": self.published,
"volume": self.volume,
"issue": self.issue,
"page": self.page,
"url": self.url,
}
return {k: v for k, v in payload.items() if v is not None}
class BookDoc(SourceItem): class BookDoc(SourceItem):
__tablename__ = "book_doc" __tablename__ = "book_doc"
@ -379,7 +446,6 @@ class MiscDoc(SourceItem):
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
) )
path = Column(Text) path = Column(Text)
mime_type = Column(Text)
__mapper_args__ = { __mapper_args__ = {
"polymorphic_identity": "misc_doc", "polymorphic_identity": "misc_doc",

View File

@ -32,7 +32,7 @@ class Collection(TypedDict):
shards: NotRequired[int] shards: NotRequired[int]
DEFAULT_COLLECTIONS: dict[str, Collection] = { ALL_COLLECTIONS: dict[str, Collection] = {
"mail": { "mail": {
"dimension": 1024, "dimension": 1024,
"distance": "Cosine", "distance": "Cosine",
@ -69,6 +69,11 @@ DEFAULT_COLLECTIONS: dict[str, Collection] = {
"distance": "Cosine", "distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL, "model": settings.MIXED_EMBEDDING_MODEL,
}, },
"comic": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
},
"doc": { "doc": {
"dimension": 1024, "dimension": 1024,
"distance": "Cosine", "distance": "Cosine",
@ -77,12 +82,12 @@ DEFAULT_COLLECTIONS: dict[str, Collection] = {
} }
TEXT_COLLECTIONS = { TEXT_COLLECTIONS = {
coll coll
for coll, params in DEFAULT_COLLECTIONS.items() for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.TEXT_EMBEDDING_MODEL if params["model"] == settings.TEXT_EMBEDDING_MODEL
} }
MULTIMODAL_COLLECTIONS = { MULTIMODAL_COLLECTIONS = {
coll coll
for coll, params in DEFAULT_COLLECTIONS.items() for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.MIXED_EMBEDDING_MODEL if params["model"] == settings.MIXED_EMBEDDING_MODEL
} }
@ -99,6 +104,22 @@ TYPES = {
} }
def get_mimetype(image: Image.Image) -> str | None:
format_to_mime = {
"JPEG": "image/jpeg",
"PNG": "image/png",
"GIF": "image/gif",
"BMP": "image/bmp",
"TIFF": "image/tiff",
"WEBP": "image/webp",
}
if not image.format:
return None
return format_to_mime.get(image.format.upper(), f"image/{image.format.lower()}")
def get_modality(mime_type: str) -> str: def get_modality(mime_type: str) -> str:
for type, mime_types in TYPES.items(): for type, mime_types in TYPES.items():
if mime_type in mime_types: if mime_type in mime_types:
@ -237,3 +258,20 @@ def embed(
for vector in embed_page(page) for vector in embed_page(page)
] ]
return modality, chunks return modality, chunks
def embed_image(file_path: pathlib.Path, texts: list[str]) -> Chunk:
image = Image.open(file_path)
mime_type = get_mimetype(image)
if mime_type is None:
raise ValueError("Unsupported image format")
vector = embed_mixed([image] + texts)[0]
return Chunk(
id=str(uuid.uuid4()),
file_path=file_path.absolute().as_posix(),
content=None,
embedding_model=settings.MIXED_EMBEDDING_MODEL,
vector=vector,
)

View File

@ -0,0 +1,133 @@
import logging
from typing import TypedDict, NotRequired
from bs4 import BeautifulSoup, Tag
import requests
import json
logger = logging.getLogger(__name__)
class ComicInfo(TypedDict):
title: str
image_url: str
published_date: NotRequired[str]
url: str
def extract_smbc(url: str) -> ComicInfo:
"""
Extract the title, published date, and image URL from a SMBC comic.
Returns:
ComicInfo with title, image_url, published_date, and comic_url
"""
response = requests.get(url)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
comic_img = soup.find("img", id="cc-comic")
title = ""
image_url = ""
if comic_img and isinstance(comic_img, Tag):
if comic_img.has_attr("src"):
image_url = str(comic_img["src"])
if comic_img.has_attr("title"):
title = str(comic_img["title"])
published_date = ""
comic_url = ""
script_ld = soup.find("script", type="application/ld+json")
if script_ld and isinstance(script_ld, Tag) and script_ld.string:
try:
data = json.loads(script_ld.string)
published_date = data.get("datePublished", "")
# Use JSON-LD URL if available
title = title or data.get("name", "")
comic_url = data.get("url")
except (json.JSONDecodeError, AttributeError):
pass
permalink_input = soup.find("input", id="permalinktext")
if not comic_url and permalink_input and isinstance(permalink_input, Tag):
comic_url = permalink_input.get("value", "")
return {
"title": title,
"image_url": image_url,
"published_date": published_date,
"url": comic_url or url,
}
def extract_xkcd(url: str) -> ComicInfo:
"""
Extract comic information from an XKCD comic.
This function parses an XKCD comic page to extract the title from the hover text,
the image URL, and the permanent URL of the comic.
Args:
url: The URL of the XKCD comic to parse
Returns:
ComicInfo with title, image_url, and comic_url
"""
response = requests.get(url)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
def get_comic_img() -> Tag | None:
"""Extract the comic image tag."""
comic_div = soup.find("div", id="comic")
if comic_div and isinstance(comic_div, Tag):
img = comic_div.find("img")
return img if isinstance(img, Tag) else None
return None
def get_title() -> str:
"""Extract title from image title attribute with fallbacks."""
# Primary source: hover text from the image (most informative)
img = get_comic_img()
if img and img.has_attr("title"):
return str(img["title"])
# Secondary source: og:title meta tag
og_title = soup.find("meta", property="og:title")
if og_title and isinstance(og_title, Tag) and og_title.has_attr("content"):
return str(og_title["content"])
# Last resort: page title div
title_div = soup.find("div", id="ctitle")
return title_div.text.strip() if title_div else ""
def get_image_url() -> str:
"""Extract and normalize the image URL."""
img = get_comic_img()
if not img or not img.has_attr("src"):
return ""
image_src = str(img["src"])
return f"https:{image_src}" if image_src.startswith("//") else image_src
def get_permanent_url() -> str:
"""Extract the permanent URL to the comic."""
og_url = soup.find("meta", property="og:url")
if og_url and isinstance(og_url, Tag) and og_url.has_attr("content"):
return str(og_url["content"])
# Fallback: look for permanent link text
for a_tag in soup.find_all("a"):
text = a_tag.get_text()
if text.startswith("https://xkcd.com/") and text.strip().endswith("/"):
return str(text.strip())
return url
return {
"title": get_title(),
"image_url": get_image_url(),
"url": get_permanent_url(),
}

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, cast from typing import Any, cast, Iterator, Sequence
import qdrant_client import qdrant_client
from qdrant_client.http import models as qdrant_models from qdrant_client.http import models as qdrant_models
@ -7,7 +7,7 @@ from qdrant_client.http.exceptions import UnexpectedResponse
from memory.common import settings from memory.common import settings
from memory.common.embedding import ( from memory.common.embedding import (
Collection, Collection,
DEFAULT_COLLECTIONS, ALL_COLLECTIONS,
DistanceType, DistanceType,
Vector, Vector,
) )
@ -91,7 +91,7 @@ def initialize_collections(
If None, defaults to the DEFAULT_COLLECTIONS. If None, defaults to the DEFAULT_COLLECTIONS.
""" """
if collections is None: if collections is None:
collections = DEFAULT_COLLECTIONS collections = ALL_COLLECTIONS
logger.info(f"Initializing collections:") logger.info(f"Initializing collections:")
for name, params in collections.items(): for name, params in collections.items():
@ -184,13 +184,13 @@ def search_vectors(
) )
def delete_vectors( def delete_points(
client: qdrant_client.QdrantClient, client: qdrant_client.QdrantClient,
collection_name: str, collection_name: str,
ids: list[str], ids: list[str],
) -> None: ) -> None:
""" """
Delete vectors from a collection. Delete points from a collection.
Args: Args:
client: Qdrant client client: Qdrant client
@ -222,3 +222,30 @@ def get_collection_info(
""" """
info = client.get_collection(collection_name) info = client.get_collection(collection_name)
return info.model_dump() return info.model_dump()
def batch_ids(
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
) -> Iterator[list[str]]:
"""Iterate over all IDs in a collection."""
offset = None
while resp := client.scroll(
collection_name=collection_name,
with_vectors=False,
offset=offset,
limit=batch_size,
):
points, offset = resp
yield [point.id for point in points]
if not offset:
return
def find_missing_points(
client: qdrant_client.QdrantClient, collection_name: str, ids: Sequence[str]
) -> set[str]:
found = client.retrieve(
collection_name, ids=ids, with_payload=False, with_vectors=False
)
return set(ids) - {str(r.id) for r in found}

View File

@ -29,13 +29,27 @@ def make_db_url(
DB_URL = os.getenv("DATABASE_URL", make_db_url()) DB_URL = os.getenv("DATABASE_URL", make_db_url())
# Celery settings
RABBITMQ_USER = os.getenv("RABBITMQ_USER", "kb")
RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", "kb")
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
# File storage settings
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files")) FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True) FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
CHUNK_STORAGE_DIR = pathlib.Path( CHUNK_STORAGE_DIR = pathlib.Path(
os.getenv("CHUNK_STORAGE_DIR", FILE_STORAGE_DIR / "chunks") os.getenv("CHUNK_STORAGE_DIR", FILE_STORAGE_DIR / "chunks")
) )
CHUNK_STORAGE_DIR.mkdir(parents=True, exist_ok=True) CHUNK_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
COMIC_STORAGE_DIR = pathlib.Path(
os.getenv("COMIC_STORAGE_DIR", FILE_STORAGE_DIR / "comics")
)
COMIC_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
# Maximum attachment size to store directly in the database (10MB) # Maximum attachment size to store directly in the database (10MB)
MAX_INLINE_ATTACHMENT_SIZE = int( MAX_INLINE_ATTACHMENT_SIZE = int(
os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024) os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024)
@ -51,8 +65,12 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
# Worker settings # Worker settings
# Intervals are in seconds
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600)) EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 86400))
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 3600))
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
# Embedding settings # Embedding settings
TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large") TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large")

View File

@ -1,18 +1,15 @@
import os
from celery import Celery from celery import Celery
from memory.common import settings from memory.common import settings
def rabbit_url() -> str: def rabbit_url() -> str:
user = os.getenv("RABBITMQ_USER", "guest") return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:5672//"
password = os.getenv("RABBITMQ_PASSWORD", "guest")
return f"amqp://{user}:{password}@rabbitmq:5672//"
app = Celery( app = Celery(
"memory", "memory",
broker=rabbit_url(), broker=rabbit_url(),
backend=os.getenv("CELERY_RESULT_BACKEND", f"db+{settings.DB_URL}") backend=settings.CELERY_RESULT_BACKEND,
) )
@ -27,16 +24,15 @@ app.conf.update(
"memory.workers.tasks.text.*": {"queue": "medium_embed"}, "memory.workers.tasks.text.*": {"queue": "medium_embed"},
"memory.workers.tasks.email.*": {"queue": "email"}, "memory.workers.tasks.email.*": {"queue": "email"},
"memory.workers.tasks.photo.*": {"queue": "photo_embed"}, "memory.workers.tasks.photo.*": {"queue": "photo_embed"},
"memory.workers.tasks.ocr.*": {"queue": "low_ocr"}, "memory.workers.tasks.comic.*": {"queue": "comic"},
"memory.workers.tasks.git.*": {"queue": "git_summary"},
"memory.workers.tasks.rss.*": {"queue": "rss"},
"memory.workers.tasks.docs.*": {"queue": "docs"}, "memory.workers.tasks.docs.*": {"queue": "docs"},
"memory.workers.tasks.maintenance.*": {"queue": "maintenance"},
}, },
) )
@app.on_after_configure.connect @app.on_after_configure.connect
def ensure_qdrant_initialised(sender, **_): def ensure_qdrant_initialised(sender, **_):
from memory.common import qdrant from memory.common import qdrant
qdrant.setup_qdrant() qdrant.setup_qdrant()

View File

@ -1,11 +1,23 @@
import logging
from memory.workers.celery_app import app
from memory.common import settings from memory.common import settings
from memory.workers.celery_app import app
from memory.workers.tasks import CLEAN_ALL_COLLECTIONS, REINGEST_MISSING_CHUNKS
logger = logging.getLogger(__name__)
app.conf.beat_schedule = { app.conf.beat_schedule = {
'sync-mail-all': { "sync-mail-all": {
'task': 'memory.workers.tasks.email.sync_all_accounts', "task": "memory.workers.tasks.email.sync_all_accounts",
'schedule': settings.EMAIL_SYNC_INTERVAL, "schedule": settings.EMAIL_SYNC_INTERVAL,
},
"clean-all-collections": {
"task": CLEAN_ALL_COLLECTIONS,
"schedule": settings.CLEAN_COLLECTION_INTERVAL,
},
"reingest-missing-chunks": {
"task": REINGEST_MISSING_CHUNKS,
"schedule": settings.CHUNK_REINGEST_INTERVAL,
}, },
} }

View File

@ -1,8 +1,24 @@
""" """
Import sub-modules so Celery can register their @app.task decorators. Import sub-modules so Celery can register their @app.task decorators.
""" """
from memory.workers.tasks import docs, email # noqa
from memory.workers.tasks import docs, email, comic # noqa
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
from memory.workers.tasks.maintenance import (
CLEAN_ALL_COLLECTIONS,
CLEAN_COLLECTION,
REINGEST_MISSING_CHUNKS,
)
__all__ = ["docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"] __all__ = [
"docs",
"email",
"comic",
"SYNC_ACCOUNT",
"SYNC_ALL_ACCOUNTS",
"PROCESS_EMAIL",
"CLEAN_ALL_COLLECTIONS",
"CLEAN_COLLECTION",
"REINGEST_MISSING_CHUNKS",
]

View File

@ -0,0 +1,172 @@
import hashlib
import logging
from datetime import datetime
from typing import Callable
import feedparser
import requests
from memory.common import embedding, qdrant, settings
from memory.common.db.connection import make_session
from memory.common.db.models import Comic, clean_filename
from memory.common.parsers import comics
from memory.workers.celery_app import app
logger = logging.getLogger(__name__)
SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics"
SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc"
SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd"
SYNC_COMIC = "memory.workers.tasks.comic.sync_comic"
BASE_SMBC_URL = "https://www.smbc-comics.com/"
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
BASE_XKCD_URL = "https://xkcd.com/"
XKCD_RSS_URL = "https://xkcd.com/atom.xml"
def find_new_urls(base_url: str, rss_url: str) -> set[str]:
try:
feed = feedparser.parse(rss_url)
except Exception as e:
logger.error(f"Failed to fetch or parse {rss_url}: {e}")
return set()
urls = {item.get("link") or item.get("id") for item in feed.entries}
with make_session() as session:
known = {
c.url
for c in session.query(Comic.url).filter(
Comic.author == base_url,
Comic.url.in_(urls),
)
}
return urls - known
def fetch_new_comics(
base_url: str, rss_url: str, parser: Callable[[str], comics.ComicInfo]
) -> set[str]:
new_urls = find_new_urls(base_url, rss_url)
for url in new_urls:
data = parser(url) | {"author": base_url, "url": url}
sync_comic.delay(**data)
return new_urls
@app.task(name=SYNC_COMIC)
def sync_comic(
url: str,
image_url: str,
title: str,
author: str,
published_date: datetime | None = None,
):
"""Synchronize a comic from a URL."""
with make_session() as session:
if session.query(Comic).filter(Comic.url == url).first():
return
response = requests.get(image_url)
file_type = image_url.split(".")[-1]
mime_type = f"image/{file_type}"
filename = (
settings.COMIC_STORAGE_DIR / clean_filename(author) / f"{title}.{file_type}"
)
if response.status_code == 200:
filename.parent.mkdir(parents=True, exist_ok=True)
filename.write_bytes(response.content)
sha256 = hashlib.sha256(f"{image_url}{published_date}".encode()).digest()
comic = Comic(
title=title,
url=url,
published=published_date,
author=author,
filename=filename.resolve().as_posix(),
mime_type=mime_type,
size=len(response.content),
sha256=sha256,
tags={"comic", author},
modality="comic",
)
chunk = embedding.embed_image(filename, [title, author])
comic.chunks = [chunk]
with make_session() as session:
session.add(comic)
session.add(chunk)
session.flush()
qdrant.upsert_vectors(
client=qdrant.get_qdrant_client(),
collection_name="comic",
ids=[str(chunk.id)],
vectors=[chunk.vector],
payloads=[comic.as_payload()],
)
session.commit()
@app.task(name=SYNC_SMBC)
def sync_smbc() -> set[str]:
"""Synchronize SMBC comics from RSS feed."""
return fetch_new_comics(BASE_SMBC_URL, SMBC_RSS_URL, comics.extract_smbc)
@app.task(name=SYNC_XKCD)
def sync_xkcd() -> set[str]:
"""Synchronize XKCD comics from RSS feed."""
return fetch_new_comics(BASE_XKCD_URL, XKCD_RSS_URL, comics.extract_xkcd)
@app.task(name=SYNC_ALL_COMICS)
def sync_all_comics():
"""Synchronize all active comics."""
sync_smbc.delay()
sync_xkcd.delay()
@app.task(name="memory.workers.tasks.comic.full_sync_comic")
def trigger_comic_sync():
def prev_smbc_comic(url: str) -> str | None:
from bs4 import BeautifulSoup
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")
if link := soup.find("a", attrs={"class", "cc-prev"}):
return link.attrs["href"]
return None
next_url = "https://www.smbc-comics.com"
urls = []
logger.info(f"syncing {next_url}")
while next_url := prev_smbc_comic(next_url):
if len(urls) % 10 == 0:
logger.info(f"got {len(urls)}")
try:
data = comics.extract_smbc(next_url) | {
"author": "https://www.smbc-comics.com/"
}
sync_comic.delay(**data)
except Exception as e:
logger.error(f"failed to sync {next_url}: {e}")
urls.append(next_url)
logger.info(f"syncing {BASE_XKCD_URL}")
for i in range(1, 308):
if i % 10 == 0:
logger.info(f"got {i}")
url = f"{BASE_XKCD_URL}/{i}"
try:
data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"}
sync_comic.delay(**data)
except Exception as e:
logger.error(f"failed to sync {url}: {e}")

View File

@ -0,0 +1,157 @@
import logging
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Sequence
from sqlalchemy import select
from sqlalchemy.orm import contains_eager
from memory.common import embedding, qdrant, settings
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
from memory.workers.celery_app import app
logger = logging.getLogger(__name__)
CLEAN_ALL_COLLECTIONS = "memory.workers.tasks.maintenance.clean_all_collections"
CLEAN_COLLECTION = "memory.workers.tasks.maintenance.clean_collection"
REINGEST_MISSING_CHUNKS = "memory.workers.tasks.maintenance.reingest_missing_chunks"
REINGEST_CHUNK = "memory.workers.tasks.maintenance.reingest_chunk"
@app.task(name=CLEAN_COLLECTION)
def clean_collection(collection: str) -> dict[str, int]:
logger.info(f"Cleaning collection {collection}")
client = qdrant.get_qdrant_client()
batches, deleted, checked = 0, 0, 0
for batch in qdrant.batch_ids(client, collection):
batches += 1
batch_ids = set(batch)
with make_session() as session:
db_ids = {
str(c.id) for c in session.query(Chunk).filter(Chunk.id.in_(batch_ids))
}
ids_to_delete = batch_ids - db_ids
checked += len(batch_ids)
if ids_to_delete:
qdrant.delete_points(client, collection, list(ids_to_delete))
deleted += len(ids_to_delete)
return {
"batches": batches,
"deleted": deleted,
"checked": checked,
}
@app.task(name=CLEAN_ALL_COLLECTIONS)
def clean_all_collections():
logger.info("Cleaning all collections")
for collection in embedding.ALL_COLLECTIONS:
clean_collection.delay(collection)
@app.task(name=REINGEST_CHUNK)
def reingest_chunk(chunk_id: str, collection: str):
logger.info(f"Reingesting chunk {chunk_id}")
with make_session() as session:
chunk = session.query(Chunk).get(chunk_id)
if not chunk:
logger.error(f"Chunk {chunk_id} not found")
return
if collection not in embedding.ALL_COLLECTIONS:
raise ValueError(f"Unsupported collection {collection}")
data = chunk.data
if collection in embedding.MULTIMODAL_COLLECTIONS:
vector = embedding.embed_mixed(data)[0]
elif len(data) == 1 and isinstance(data[0], str):
vector = embedding.embed_text([data[0]])[0]
else:
raise ValueError(f"Unsupported data type for collection {collection}")
client = qdrant.get_qdrant_client()
qdrant.upsert_vectors(
client,
collection,
[chunk_id],
[vector],
[chunk.source.as_payload()],
)
chunk.checked_at = datetime.now()
session.commit()
def check_batch(batch: Sequence[Chunk]) -> dict:
client = qdrant.get_qdrant_client()
by_collection = defaultdict(list)
for chunk in batch:
by_collection[chunk.source.modality].append(chunk)
stats = {}
for collection, chunks in by_collection.items():
missing = qdrant.find_missing_points(
client, collection, [str(c.id) for c in chunks]
)
for chunk in chunks:
if str(chunk.id) in missing:
reingest_chunk.delay(str(chunk.id), collection)
else:
chunk.checked_at = datetime.now()
stats[collection] = {
"missing": len(missing),
"correct": len(chunks) - len(missing),
"total": len(chunks),
}
return stats
@app.task(name=REINGEST_MISSING_CHUNKS)
def reingest_missing_chunks(batch_size: int = 1000):
logger.info("Reingesting missing chunks")
total_stats = defaultdict(lambda: {"missing": 0, "correct": 0, "total": 0})
since = datetime.now() - timedelta(minutes=settings.CHUNK_REINGEST_SINCE_MINUTES)
with make_session() as session:
total_count = session.query(Chunk).filter(Chunk.checked_at < since).count()
logger.info(
f"Found {total_count} chunks to check, processing in batches of {batch_size}"
)
num_batches = (total_count + batch_size - 1) // batch_size
for batch_num in range(num_batches):
stmt = (
select(Chunk)
.join(SourceItem, Chunk.source_id == SourceItem.id)
.filter(Chunk.checked_at < since)
.options(
contains_eager(Chunk.source).load_only(
SourceItem.id, SourceItem.modality, SourceItem.tags
)
)
.order_by(Chunk.id)
.limit(batch_size)
)
chunks = session.execute(stmt).scalars().all()
if not chunks:
break
logger.info(
f"Processing batch {batch_num + 1}/{num_batches} with {len(chunks)} chunks"
)
batch_stats = check_batch(chunks)
session.commit()
for collection, stats in batch_stats.items():
total_stats[collection]["missing"] += stats["missing"]
total_stats[collection]["correct"] += stats["correct"]
total_stats[collection]["total"] += stats["total"]
return dict(total_stats)

View File

@ -0,0 +1,199 @@
import textwrap
from memory.common.parsers.comics import extract_smbc, extract_xkcd
import pytest
from unittest.mock import patch, Mock
import requests
MOCK_SMBC_HTML = """
<!DOCTYPE html>
<html>
<head>
<title>Saturday Morning Breakfast Cereal - Time</title>
<meta property="og:image" content="https://www.smbc-comics.com/comics/1746375102-20250504.webp" />
<script type='application/ld+json'>
{
"@context": "http://www.schema.org",
"@type": "ComicStory",
"name": "Saturday Morning Breakfast Cereal - Time",
"url": "https://www.smbc-comics.com/comic/time-6",
"author":"Zach Weinersmith",
"about":"Saturday Morning Breakfast Cereal - Time",
"image":"https://www.smbc-comics.com/comics/1746375102-20250504.webp",
"thumbnailUrl":"https://www.smbc-comics.com/comicsthumbs/1746375102-20250504.webp",
"datePublished":"2025-05-04T12:11:21-04:00"
}
</script>
</head>
<body>
<div id="cc-comicbody">
<img title="I don't know why either, but it was fun to draw."
src="https://www.smbc-comics.com/comics/1746375102-20250504.webp"
id="cc-comic" />
</div>
<div id="permalink">
<input id="permalinktext" type="text" value="http://www.smbc-comics.com/comic/time-6" />
</div>
</body>
</html>
"""
@pytest.mark.parametrize(
"to_remove, overrides",
[
# Normal case - all data present
("", {}),
# Missing title attribute on image
(
'title="I don\'t know why either, but it was fun to draw."',
{"title": "Saturday Morning Breakfast Cereal - Time"},
),
# Missing src attribute on image
(
'src="https://www.smbc-comics.com/comics/1746375102-20250504.webp"',
{"image_url": ""},
),
# # Missing entire img tag
(
'<img title="I don\'t know why either, but it was fun to draw." \n src="https://www.smbc-comics.com/comics/1746375102-20250504.webp" \n id="cc-comic" />',
{"title": "Saturday Morning Breakfast Cereal - Time", "image_url": ""},
),
# # Corrupt JSON-LD data
(
'"datePublished":"2025-05-04T12:11:21-04:00"',
{"published_date": "", "url": "http://www.smbc-comics.com/comic/time-6"},
),
# # Missing JSON-LD script entirely
(
'<script type=\'application/ld+json\'>\n {\n "@context": "http://www.schema.org",\n "@type": "ComicStory",\n "name": "Saturday Morning Breakfast Cereal - Time",\n "url": "https://www.smbc-comics.com/comic/time-6",\n "author":"Zach Weinersmith",\n "about":"Saturday Morning Breakfast Cereal - Time",\n "image":"https://www.smbc-comics.com/comics/1746375102-20250504.webp",\n "thumbnailUrl":"https://www.smbc-comics.com/comicsthumbs/1746375102-20250504.webp",\n "datePublished":"2025-05-04T12:11:21-04:00"\n }\n </script>',
{"published_date": "", "url": "http://www.smbc-comics.com/comic/time-6"},
),
# # Missing permalink input
(
'<div id="permalink">\n <input id="permalinktext" type="text" value="http://www.smbc-comics.com/comic/time-6" />\n </div>',
{},
),
# Missing URL in JSON-LD
(
'"url": "https://www.smbc-comics.com/comic/time-6",',
{"url": "http://www.smbc-comics.com/comic/time-6"},
),
],
)
def test_extract_smbc(to_remove, overrides):
"""Test successful extraction of comic info from SMBC."""
expected = {
"title": "I don't know why either, but it was fun to draw.",
"image_url": "https://www.smbc-comics.com/comics/1746375102-20250504.webp",
"published_date": "2025-05-04T12:11:21-04:00",
"url": "https://www.smbc-comics.com/comic/time-6",
}
with patch("requests.get") as mock_get:
mock_get.return_value.status_code = 200
mock_get.return_value.text = MOCK_SMBC_HTML.replace(to_remove, "")
assert extract_smbc("https://www.smbc-comics.com/") == expected | overrides
# Create a stripped-down version of the XKCD HTML
MOCK_XKCD_HTML = """
<!DOCTYPE html>
<html>
<head>
<title>xkcd: Unstoppable Force and Immovable Object</title>
<meta property="og:title" content="Unstoppable Force and Immovable Object">
<meta property="og:url" content="https://xkcd.com/3084/">
<meta property="og:image" content="https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object_2x.png">
</head>
<body>
<div id="ctitle">Unstoppable Force and Immovable Object</div>
<div id="comic">
<img src="//imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png"
title="Unstoppable force-carrying particles can&#39;t interact with immovable matter by definition."
alt="Unstoppable Force and Immovable Object" />
</div>
Permanent link to this comic: <a href="https://xkcd.com/3084">https://xkcd.com/3084/</a><br />
Image URL (for hotlinking/embedding): <a href="https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png">
https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png</a>
</body>
</html>
"""
@pytest.mark.parametrize(
"to_remove, overrides",
[
# Normal case - all data present
("", {}),
# Missing title attribute on image
(
'title="Unstoppable force-carrying particles can&#39;t interact with immovable matter by definition."',
{
"title": "Unstoppable Force and Immovable Object"
}, # Falls back to og:title
),
# Missing og:title meta tag - falls back to ctitle
(
'<meta property="og:title" content="Unstoppable Force and Immovable Object">',
{}, # Still gets title from image title
),
# Missing both title and og:title - falls back to ctitle
(
'title="Unstoppable force-carrying particles can&#39;t interact with immovable matter by definition."\n<meta property="og:title" content="Unstoppable Force and Immovable Object">',
{"title": "Unstoppable Force and Immovable Object"}, # Falls back to ctitle
),
# Missing image src attribute
(
'src="//imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png"',
{"image_url": ""},
),
# Missing entire img tag
(
'<img src="//imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png" \n title="Unstoppable force-carrying particles can&#39;t interact with immovable matter by definition." \n alt="Unstoppable Force and Immovable Object" />',
{
"image_url": "",
"title": "Unstoppable Force and Immovable Object",
}, # Falls back to og:title
),
# Missing entire comic div
(
'<div id="comic">\n<img src="//imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png" \n title="Unstoppable force-carrying particles can&#39;t interact with immovable matter by definition." \n alt="Unstoppable Force and Immovable Object" />\n</div>',
{"image_url": "", "title": "Unstoppable Force and Immovable Object"},
),
# Missing og:url tag
(
'<meta property="og:url" content="https://xkcd.com/3084/">',
{}, # Should fallback to permalink link
),
# Missing permanent link
(
'Permanent link to this comic: <a href="https://xkcd.com/3084">https://xkcd.com/3084/</a><br />',
{"url": "https://xkcd.com/3084/"}, # Should still get URL from og:url
),
# Missing both og:url and permanent link
(
'<meta property="og:url" content="https://xkcd.com/3084/">\nPermanent link to this comic: <a href="https://xkcd.com/3084">https://xkcd.com/3084/</a><br />',
{"url": "https://xkcd.com/test"}, # Falls back to original URL
),
],
)
def test_extract_xkcd(to_remove, overrides):
"""Test successful extraction of comic info from XKCD."""
expected = {
"title": "Unstoppable force-carrying particles can't interact with immovable matter by definition.",
"image_url": "https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png",
"url": "https://xkcd.com/3084/",
}
with patch("requests.get") as mock_get:
mock_get.return_value.status_code = 200
modified_html = MOCK_XKCD_HTML
for item in to_remove.split("\n"):
modified_html = modified_html.replace(item, "")
mock_get.return_value.text = modified_html
result = extract_xkcd("https://xkcd.com/test")
assert result == expected | overrides

View File

@ -1,22 +1,25 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch, Mock
import qdrant_client import qdrant_client
from qdrant_client.http import models as qdrant_models from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse
from memory.common.qdrant import ( from memory.common.qdrant import (
DEFAULT_COLLECTIONS, ALL_COLLECTIONS,
ensure_collection_exists, ensure_collection_exists,
initialize_collections, initialize_collections,
upsert_vectors, upsert_vectors,
delete_vectors, delete_points,
batch_ids,
) )
@pytest.fixture @pytest.fixture
def mock_qdrant_client(): def mock_qdrant_client():
with patch.object(qdrant_client, "QdrantClient", return_value=MagicMock()) as mock_client: with patch.object(
qdrant_client, "QdrantClient", return_value=MagicMock()
) as mock_client:
yield mock_client yield mock_client
@ -30,7 +33,7 @@ def test_ensure_collection_exists_existing(mock_qdrant_client):
def test_ensure_collection_exists_new(mock_qdrant_client): def test_ensure_collection_exists_new(mock_qdrant_client):
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse( mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
status_code=404, reason_phrase='asd', content=b'asd', headers=None status_code=404, reason_phrase="asd", content=b"asd", headers=None
) )
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128) assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
@ -43,7 +46,7 @@ def test_ensure_collection_exists_new(mock_qdrant_client):
def test_initialize_collections(mock_qdrant_client): def test_initialize_collections(mock_qdrant_client):
initialize_collections(mock_qdrant_client) initialize_collections(mock_qdrant_client)
assert mock_qdrant_client.get_collection.call_count == len(DEFAULT_COLLECTIONS) assert mock_qdrant_client.get_collection.call_count == len(ALL_COLLECTIONS)
def test_upsert_vectors(mock_qdrant_client): def test_upsert_vectors(mock_qdrant_client):
@ -71,7 +74,7 @@ def test_upsert_vectors(mock_qdrant_client):
def test_delete_vectors(mock_qdrant_client): def test_delete_vectors(mock_qdrant_client):
ids = ["1", "2"] ids = ["1", "2"]
delete_vectors(mock_qdrant_client, "test_collection", ids) delete_points(mock_qdrant_client, "test_collection", ids)
mock_qdrant_client.delete.assert_called_once() mock_qdrant_client.delete.assert_called_once()
args, kwargs = mock_qdrant_client.delete.call_args args, kwargs = mock_qdrant_client.delete.call_args
@ -79,3 +82,15 @@ def test_delete_vectors(mock_qdrant_client):
assert kwargs["collection_name"] == "test_collection" assert kwargs["collection_name"] == "test_collection"
assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList) assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList)
assert kwargs["points_selector"].points == ids assert kwargs["points_selector"].points == ids
def test_batch_ids(mock_qdrant_client):
mock_qdrant_client.scroll.side_effect = [
([Mock(id="1"), Mock(id="2")], "3"),
([Mock(id="3"), Mock(id="4")], None),
]
assert list(batch_ids(mock_qdrant_client, "test_collection")) == [
["1", "2"],
["3", "4"],
]

View File

@ -0,0 +1,310 @@
import uuid
from datetime import datetime, timedelta
from unittest.mock import patch, call
import pytest
from PIL import Image
from memory.common import qdrant as qd
from memory.common import embedding, settings
from memory.common.db.models import Chunk, SourceItem
from memory.workers.tasks.maintenance import (
clean_collection,
reingest_chunk,
check_batch,
reingest_missing_chunks,
)
@pytest.fixture
def source(db_session):
s = SourceItem(id=1, modality="text", sha256=b"123")
db_session.add(s)
db_session.commit()
return s
@pytest.fixture
def mock_uuid4():
i = 0
def uuid4():
nonlocal i
i += 1
return f"00000000-0000-0000-0000-00000000000{i}"
with patch("uuid.uuid4", side_effect=uuid4):
yield
@pytest.fixture
def test_image(mock_file_storage):
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
img_path = settings.CHUNK_STORAGE_DIR / "test.png"
img.save(img_path)
return img_path
@pytest.fixture(params=["text", "photo"])
def chunk(request, test_image, db_session):
"""Parametrized fixture for chunk configuration"""
collection = request.param
if collection == "photo":
content = None
file_path = str(test_image)
else:
content = "Test content for reingestion"
file_path = None
chunk = Chunk(
id=str(uuid.uuid4()),
source=SourceItem(id=1, modality=collection, sha256=b"123"),
content=content,
file_path=file_path,
embedding_model="test-model",
checked_at=datetime(2025, 1, 1),
)
db_session.add(chunk)
db_session.commit()
return chunk
def test_clean_collection_no_mismatches(db_session, qdrant, source):
"""Test when all Qdrant points exist in the database - nothing should be deleted."""
# Create chunks in the database
chunk_ids = [str(uuid.uuid4()) for _ in range(3000)]
collection = "text"
# Add chunks to the database
for chunk_id in chunk_ids:
db_session.add(
Chunk(
id=chunk_id,
source=source,
content="Test content",
embedding_model="test-model",
)
)
db_session.commit()
qd.ensure_collection_exists(qdrant, collection, 1024)
qd.upsert_vectors(qdrant, collection, chunk_ids, [[1] * 1024] * len(chunk_ids))
assert set(chunk_ids) == {
str(i) for batch in qd.batch_ids(qdrant, collection) for i in batch
}
clean_collection(collection)
# Check that the chunks are still in the database - no points were deleted
assert set(chunk_ids) == {
str(i) for batch in qd.batch_ids(qdrant, collection) for i in batch
}
def test_clean_collection_with_orphaned_vectors(db_session, qdrant, source):
"""Test when there are vectors in Qdrant that don't exist in the database."""
existing_ids = [str(uuid.uuid4()) for _ in range(3000)]
orphaned_ids = [str(uuid.uuid4()) for _ in range(3000)]
all_ids = existing_ids + orphaned_ids
collection = "text"
# Add only the existing chunks to the database
for chunk_id in existing_ids:
db_session.add(
Chunk(
id=chunk_id,
source=source,
content="Test content",
embedding_model="test-model",
)
)
db_session.commit()
qd.ensure_collection_exists(qdrant, collection, 1024)
qd.upsert_vectors(qdrant, collection, all_ids, [[1] * 1024] * len(all_ids))
clean_collection(collection)
# The orphaned vectors should be deleted
assert set(existing_ids) == {
str(i) for batch in qd.batch_ids(qdrant, collection) for i in batch
}
def test_clean_collection_empty_batches(db_session, qdrant):
collection = "text"
qd.ensure_collection_exists(qdrant, collection, 1024)
clean_collection(collection)
assert not [i for b in qd.batch_ids(qdrant, collection) for i in b]
def test_reingest_chunk(db_session, qdrant, chunk):
"""Test reingesting a chunk using parameterized fixtures"""
collection = chunk.source.modality
qd.ensure_collection_exists(qdrant, collection, 1024)
start = datetime.now()
test_vector = [0.1] * 1024
with patch.object(embedding, "embed_chunks", return_value=[test_vector]):
reingest_chunk(str(chunk.id), collection)
vectors = qd.search_vectors(qdrant, collection, test_vector, limit=1)
assert len(vectors) == 1
assert str(vectors[0].id) == str(chunk.id)
assert vectors[0].payload == chunk.source.as_payload()
db_session.refresh(chunk)
assert chunk.checked_at.isoformat() > start.isoformat()
def test_reingest_chunk_not_found(db_session, qdrant):
"""Test when the chunk to reingest doesn't exist."""
non_existent_id = str(uuid.uuid4())
collection = "text"
reingest_chunk(non_existent_id, collection)
def test_reingest_chunk_unsupported_collection(db_session, qdrant, source):
"""Test reingesting with an unsupported collection type."""
chunk_id = str(uuid.uuid4())
chunk = Chunk(
id=chunk_id,
source=source,
content="Test content",
embedding_model="test-model",
)
db_session.add(chunk)
db_session.commit()
unsupported_collection = "unsupported"
qd.ensure_collection_exists(qdrant, unsupported_collection, 1024)
with pytest.raises(
ValueError, match=f"Unsupported collection {unsupported_collection}"
):
reingest_chunk(chunk_id, unsupported_collection)
def test_check_batch_empty(db_session, qdrant):
assert check_batch([]) == {}
def test_check_batch(db_session, qdrant):
modalities = ["text", "photo", "mail"]
chunks = [
Chunk(
id=f"00000000-0000-0000-0000-0000000000{i:02d}",
source=SourceItem(modality=modality, sha256=f"123{i}".encode()),
content="Test content",
file_path=None,
embedding_model="test-model",
checked_at=datetime(2025, 1, 1),
)
for modality in modalities
for i in range(5)
]
db_session.add_all(chunks)
db_session.commit()
start_time = datetime.now()
for modality in modalities:
qd.ensure_collection_exists(qdrant, modality, 1024)
for chunk in chunks[::2]:
qd.upsert_vectors(qdrant, chunk.source.modality, [str(chunk.id)], [[1] * 1024])
with patch.object(reingest_chunk, "delay") as mock_reingest:
stats = check_batch(chunks)
assert mock_reingest.call_args_list == [
call("00000000-0000-0000-0000-000000000001", "text"),
call("00000000-0000-0000-0000-000000000003", "text"),
call("00000000-0000-0000-0000-000000000000", "photo"),
call("00000000-0000-0000-0000-000000000002", "photo"),
call("00000000-0000-0000-0000-000000000004", "photo"),
call("00000000-0000-0000-0000-000000000001", "mail"),
call("00000000-0000-0000-0000-000000000003", "mail"),
]
assert stats == {
"mail": {"correct": 3, "missing": 2, "total": 5},
"text": {"correct": 3, "missing": 2, "total": 5},
"photo": {"correct": 2, "missing": 3, "total": 5},
}
db_session.commit()
for chunk in chunks[::2]:
assert chunk.checked_at.isoformat() > start_time.isoformat()
for chunk in chunks[1::2]:
assert chunk.checked_at.isoformat()[:19] == "2025-01-01T00:00:00"
@pytest.mark.parametrize("batch_size", [4, 10, 100])
def test_reingest_missing_chunks(db_session, qdrant, batch_size):
now = datetime.now()
old_time = now - timedelta(minutes=120) # Older than the threshold
modalities = ["text", "photo", "mail"]
ids_generator = (f"00000000-0000-0000-0000-00000000{i:04d}" for i in range(1000))
old_chunks = [
Chunk(
id=next(ids_generator),
source=SourceItem(modality=modality, sha256=f"{modality}-{i}".encode()),
content="Old content",
file_path=None,
embedding_model="test-model",
checked_at=old_time,
)
for modality in modalities
for i in range(20)
]
recent_chunks = [
Chunk(
id=next(ids_generator),
source=SourceItem(
modality=modality, sha256=f"recent-{modality}-{i}".encode()
),
content="Recent content",
file_path=None,
embedding_model="test-model",
checked_at=now,
)
for modality in modalities
for i in range(5)
]
db_session.add_all(old_chunks + recent_chunks)
db_session.commit()
for modality in modalities:
qd.ensure_collection_exists(qdrant, modality, 1024)
for chunk in old_chunks[::2]:
qd.upsert_vectors(qdrant, chunk.source.modality, [str(chunk.id)], [[1] * 1024])
with patch.object(reingest_chunk, "delay", reingest_chunk):
with patch.object(settings, "CHUNK_REINGEST_SINCE_MINUTES", 60):
with patch.object(embedding, "embed_chunks", return_value=[[1] * 1024]):
result = reingest_missing_chunks(batch_size=batch_size)
assert result == {
"photo": {"correct": 10, "missing": 10, "total": 20},
"mail": {"correct": 10, "missing": 10, "total": 20},
"text": {"correct": 10, "missing": 10, "total": 20},
}
db_session.commit()
# All the old chunks should be reingested
client = qd.get_qdrant_client()
for modality in modalities:
qdrant_ids = [
i for b in qd.batch_ids(client, modality, batch_size=1000) for i in b
]
db_ids = [str(c.id) for c in old_chunks if c.source.modality == modality]
assert set(qdrant_ids) == set(db_ids)
def test_reingest_missing_chunks_no_chunks(db_session):
assert reingest_missing_chunks() == {}