mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
add commics
This commit is contained in:
parent
743a76c3d1
commit
8cfaeaea72
@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Alembic environment configuration.
|
Alembic environment configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
from sqlalchemy import engine_from_config
|
||||||
@ -29,6 +30,18 @@ target_metadata = Base.metadata
|
|||||||
# can be acquired:
|
# can be acquired:
|
||||||
# my_important_option = config.get_main_option("my_important_option")
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
|
||||||
|
# Tables to exclude from auto-generation - these are managed by Celery
|
||||||
|
excluded_tables = {
|
||||||
|
"celery_taskmeta",
|
||||||
|
"celery_tasksetmeta",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def include_object(object, name, type_, reflected, compare_to):
|
||||||
|
if type_ == "table" and name in excluded_tables:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline() -> None:
|
def run_migrations_offline() -> None:
|
||||||
"""
|
"""
|
||||||
@ -48,6 +61,7 @@ def run_migrations_offline() -> None:
|
|||||||
target_metadata=target_metadata,
|
target_metadata=target_metadata,
|
||||||
literal_binds=True,
|
literal_binds=True,
|
||||||
dialect_opts={"paramstyle": "named"},
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
include_object=include_object,
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
@ -69,7 +83,9 @@ def run_migrations_online() -> None:
|
|||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(
|
||||||
connection=connection, target_metadata=target_metadata
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
include_object=include_object,
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
|
@ -21,7 +21,7 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
|
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
|
||||||
op.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")
|
op.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"')
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"email_accounts",
|
"email_accounts",
|
||||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
@ -82,12 +82,6 @@ def upgrade() -> None:
|
|||||||
server_default=sa.text("now()"),
|
server_default=sa.text("now()"),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column(
|
|
||||||
"updated_at",
|
|
||||||
sa.DateTime(timezone=True),
|
|
||||||
server_default=sa.text("now()"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
sa.UniqueConstraint("url"),
|
sa.UniqueConstraint("url"),
|
||||||
)
|
)
|
||||||
@ -269,7 +263,7 @@ def upgrade() -> None:
|
|||||||
"misc_doc",
|
"misc_doc",
|
||||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
sa.Column("path", sa.Text(), nullable=True),
|
sa.Column("path", sa.Text(), nullable=True),
|
||||||
sa.Column("mime_type", sa.Text(), nullable=True),
|
sa.Column("mime_type", sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
"""Update field for chunks
|
||||||
|
|
||||||
|
Revision ID: d292d48ec74e
|
||||||
|
Revises: 4684845ca51e
|
||||||
|
Create Date: 2025-05-04 12:56:53.231393
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "d292d48ec74e"
|
||||||
|
down_revision: Union[str, None] = "4684845ca51e"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"chunk",
|
||||||
|
sa.Column(
|
||||||
|
"checked_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.drop_column("misc_doc", "mime_type")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"misc_doc",
|
||||||
|
sa.Column("mime_type", sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
)
|
||||||
|
op.drop_column("chunk", "checked_at")
|
43
db/migrations/versions/20250504_234552_add_comics.py
Normal file
43
db/migrations/versions/20250504_234552_add_comics.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
"""Add comics
|
||||||
|
|
||||||
|
Revision ID: b78b1fff9974
|
||||||
|
Revises: d292d48ec74e
|
||||||
|
Create Date: 2025-05-04 23:45:52.733301
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'b78b1fff9974'
|
||||||
|
down_revision: Union[str, None] = 'd292d48ec74e'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('comic',
|
||||||
|
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column('title', sa.Text(), nullable=True),
|
||||||
|
sa.Column('author', sa.Text(), nullable=True),
|
||||||
|
sa.Column('published', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('volume', sa.Text(), nullable=True),
|
||||||
|
sa.Column('issue', sa.Text(), nullable=True),
|
||||||
|
sa.Column('page', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('url', sa.Text(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['id'], ['source_item.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index('comic_author_idx', 'comic', ['author'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index('comic_author_idx', table_name='comic')
|
||||||
|
op.drop_table('comic')
|
||||||
|
# ### end Alembic commands ###
|
@ -21,6 +21,7 @@ volumes:
|
|||||||
# ------------------------------ X-templates ----------------------------
|
# ------------------------------ X-templates ----------------------------
|
||||||
x-common-env: &env
|
x-common-env: &env
|
||||||
RABBITMQ_USER: kb
|
RABBITMQ_USER: kb
|
||||||
|
RABBITMQ_HOST: rabbitmq
|
||||||
QDRANT_HOST: qdrant
|
QDRANT_HOST: qdrant
|
||||||
DB_HOST: postgres
|
DB_HOST: postgres
|
||||||
FILE_STORAGE_DIR: /app/memory_files
|
FILE_STORAGE_DIR: /app/memory_files
|
||||||
@ -173,49 +174,28 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "email"
|
QUEUES: "email"
|
||||||
deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
|
# deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
|
||||||
|
|
||||||
worker-text:
|
worker-text:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "medium_embed"
|
QUEUES: "medium_embed"
|
||||||
deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
|
# deploy: { resources: { limits: { cpus: "2", memory: 3g } } }
|
||||||
|
|
||||||
worker-photo:
|
worker-photo:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "photo_embed"
|
QUEUES: "photo_embed,comic"
|
||||||
deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
|
# deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
|
||||||
|
|
||||||
worker-ocr:
|
worker-maintenance:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
<<: *worker-env
|
<<: *worker-env
|
||||||
QUEUES: "low_ocr"
|
QUEUES: "maintenance"
|
||||||
deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
|
# deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
||||||
|
|
||||||
worker-git:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "git_summary"
|
|
||||||
deploy: { resources: { limits: { cpus: "1", memory: 1g } } }
|
|
||||||
|
|
||||||
worker-rss:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "rss"
|
|
||||||
deploy: { resources: { limits: { cpus: "0.5", memory: 512m } } }
|
|
||||||
|
|
||||||
worker-docs:
|
|
||||||
<<: *worker-base
|
|
||||||
environment:
|
|
||||||
<<: *worker-env
|
|
||||||
QUEUES: "docs"
|
|
||||||
deploy: { resources: { limits: { cpus: "1", memory: 1g } } }
|
|
||||||
|
|
||||||
ingest-hub:
|
ingest-hub:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
|
@ -9,7 +9,7 @@ COPY src/ ./src/
|
|||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
libpq-dev gcc supervisor && \
|
libpq-dev gcc supervisor && \
|
||||||
pip install -e ".[workers]" && \
|
pip install -e ".[workers]" && \
|
||||||
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
@ -28,8 +28,7 @@ RUN mkdir -p /app/memory_files
|
|||||||
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files
|
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor /app/memory_files
|
||||||
USER kb
|
USER kb
|
||||||
|
|
||||||
# Default queues to process
|
ENV QUEUES="docs,email,maintenance"
|
||||||
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email"
|
|
||||||
ENV PYTHONPATH="/app"
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]
|
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]
|
@ -35,7 +35,7 @@ RUN mkdir -p /var/cache/fontconfig /home/kb/.cache/fontconfig && \
|
|||||||
USER kb
|
USER kb
|
||||||
|
|
||||||
# Default queues to process
|
# Default queues to process
|
||||||
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email"
|
ENV QUEUES="docs,email,maintenance"
|
||||||
ENV PYTHONPATH="/app"
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
ENTRYPOINT ["./entry.sh"]
|
ENTRYPOINT ["./entry.sh"]
|
@ -2,3 +2,5 @@ celery==5.3.6
|
|||||||
openai==1.25.0
|
openai==1.25.0
|
||||||
pillow==10.3.0
|
pillow==10.3.0
|
||||||
pypandoc==1.15.0
|
pypandoc==1.15.0
|
||||||
|
beautifulsoup4==4.13.4
|
||||||
|
feedparser==6.0.10
|
@ -101,7 +101,7 @@ def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[Search
|
|||||||
and source.filename.replace(
|
and source.filename.replace(
|
||||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||||
),
|
),
|
||||||
content=source.content,
|
content=source.display_contents,
|
||||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||||
)
|
)
|
||||||
for source, chunks in items.items()
|
for source, chunks in items.items()
|
||||||
@ -143,7 +143,7 @@ def query_chunks(
|
|||||||
)
|
)
|
||||||
if r.score >= min_score
|
if r.score >= min_score
|
||||||
]
|
]
|
||||||
for collection in embedding.DEFAULT_COLLECTIONS
|
for collection in embedding.ALL_COLLECTIONS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -164,7 +164,7 @@ async def search(
|
|||||||
modalities: Annotated[list[str], Query()] = [],
|
modalities: Annotated[list[str], Query()] = [],
|
||||||
files: list[UploadFile] = File([]),
|
files: list[UploadFile] = File([]),
|
||||||
limit: int = Query(10, ge=1, le=100),
|
limit: int = Query(10, ge=1, le=100),
|
||||||
min_text_score: float = Query(0.5, ge=0.0, le=1.0),
|
min_text_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||||
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -181,11 +181,11 @@ async def search(
|
|||||||
"""
|
"""
|
||||||
upload_data = [await input_type(item) for item in [query, *files]]
|
upload_data = [await input_type(item) for item in [query, *files]]
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Querying chunks for {modalities}, query: {query}, previews: {previews}"
|
f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}"
|
||||||
)
|
)
|
||||||
|
|
||||||
client = qdrant.get_qdrant_client()
|
client = qdrant.get_qdrant_client()
|
||||||
allowed_modalities = set(modalities or embedding.DEFAULT_COLLECTIONS.keys())
|
allowed_modalities = set(modalities or embedding.ALL_COLLECTIONS.keys())
|
||||||
text_results = query_chunks(
|
text_results = query_chunks(
|
||||||
client,
|
client,
|
||||||
upload_data,
|
upload_data,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Database connection utilities.
|
Database connection utilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||||
|
@ -6,7 +6,8 @@ import pathlib
|
|||||||
import re
|
import re
|
||||||
from email.message import EmailMessage
|
from email.message import EmailMessage
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
import textwrap
|
||||||
|
from typing import Any, ClassVar
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
ARRAY,
|
ARRAY,
|
||||||
@ -99,8 +100,9 @@ class Chunk(Base):
|
|||||||
content = Column(Text) # Direct content storage
|
content = Column(Text) # Direct content storage
|
||||||
embedding_model = Column(Text)
|
embedding_model = Column(Text)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
vector = list[float] | None # the vector generated by the embedding model
|
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
item_metadata = dict[str, Any] | None
|
vector: ClassVar[list[float] | None] = None
|
||||||
|
item_metadata: ClassVar[dict[str, Any] | None] = None
|
||||||
|
|
||||||
# One of file_path or content must be populated
|
# One of file_path or content must be populated
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
@ -121,7 +123,7 @@ class Chunk(Base):
|
|||||||
|
|
||||||
items = []
|
items = []
|
||||||
for file_path in files:
|
for file_path in files:
|
||||||
if file_path.suffix == ".png":
|
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
||||||
if file_path.exists():
|
if file_path.exists():
|
||||||
items.append(Image.open(file_path))
|
items.append(Image.open(file_path))
|
||||||
elif file_path.suffix == ".bin":
|
elif file_path.suffix == ".bin":
|
||||||
@ -172,6 +174,16 @@ class SourceItem(Base):
|
|||||||
"""Get vector IDs from associated chunks."""
|
"""Get vector IDs from associated chunks."""
|
||||||
return [chunk.id for chunk in self.chunks]
|
return [chunk.id for chunk in self.chunks]
|
||||||
|
|
||||||
|
def as_payload(self) -> dict:
|
||||||
|
return {
|
||||||
|
"source_id": self.id,
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_contents(self) -> str | None:
|
||||||
|
return self.content or self.filename
|
||||||
|
|
||||||
|
|
||||||
class MailMessage(SourceItem):
|
class MailMessage(SourceItem):
|
||||||
__tablename__ = "mail_message"
|
__tablename__ = "mail_message"
|
||||||
@ -236,6 +248,26 @@ class MailMessage(SourceItem):
|
|||||||
def body(self) -> str:
|
def body(self) -> str:
|
||||||
return self.parsed_content["body"]
|
return self.parsed_content["body"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_contents(self) -> str | None:
|
||||||
|
content = self.parsed_content
|
||||||
|
return textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Subject: {subject}
|
||||||
|
From: {sender}
|
||||||
|
To: {recipients}
|
||||||
|
Date: {date}
|
||||||
|
Body:
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
).format(
|
||||||
|
subject=content.get("subject", ""),
|
||||||
|
sender=content.get("from", ""),
|
||||||
|
recipients=content.get("to", ""),
|
||||||
|
date=content.get("date", ""),
|
||||||
|
body=content.get("body", ""),
|
||||||
|
)
|
||||||
|
|
||||||
# Add indexes
|
# Add indexes
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("mail_sent_idx", "sent_at"),
|
Index("mail_sent_idx", "sent_at"),
|
||||||
@ -341,6 +373,41 @@ class Photo(SourceItem):
|
|||||||
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
|
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
|
||||||
|
|
||||||
|
|
||||||
|
class Comic(SourceItem):
|
||||||
|
__tablename__ = "comic"
|
||||||
|
|
||||||
|
id = Column(
|
||||||
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
title = Column(Text)
|
||||||
|
author = Column(Text, nullable=True)
|
||||||
|
published = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
volume = Column(Text, nullable=True)
|
||||||
|
issue = Column(Text, nullable=True)
|
||||||
|
page = Column(Integer, nullable=True)
|
||||||
|
url = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "comic",
|
||||||
|
}
|
||||||
|
|
||||||
|
__table_args__ = (Index("comic_author_idx", "author"),)
|
||||||
|
|
||||||
|
def as_payload(self) -> dict:
|
||||||
|
payload = {
|
||||||
|
"source_id": self.id,
|
||||||
|
"tags": self.tags,
|
||||||
|
"title": self.title,
|
||||||
|
"author": self.author,
|
||||||
|
"published": self.published,
|
||||||
|
"volume": self.volume,
|
||||||
|
"issue": self.issue,
|
||||||
|
"page": self.page,
|
||||||
|
"url": self.url,
|
||||||
|
}
|
||||||
|
return {k: v for k, v in payload.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
class BookDoc(SourceItem):
|
class BookDoc(SourceItem):
|
||||||
__tablename__ = "book_doc"
|
__tablename__ = "book_doc"
|
||||||
|
|
||||||
@ -379,7 +446,6 @@ class MiscDoc(SourceItem):
|
|||||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
path = Column(Text)
|
path = Column(Text)
|
||||||
mime_type = Column(Text)
|
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
"polymorphic_identity": "misc_doc",
|
"polymorphic_identity": "misc_doc",
|
||||||
|
@ -32,7 +32,7 @@ class Collection(TypedDict):
|
|||||||
shards: NotRequired[int]
|
shards: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
ALL_COLLECTIONS: dict[str, Collection] = {
|
||||||
"mail": {
|
"mail": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
@ -69,6 +69,11 @@ DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
|||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||||
},
|
},
|
||||||
|
"comic": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
},
|
||||||
"doc": {
|
"doc": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
@ -77,12 +82,12 @@ DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
|||||||
}
|
}
|
||||||
TEXT_COLLECTIONS = {
|
TEXT_COLLECTIONS = {
|
||||||
coll
|
coll
|
||||||
for coll, params in DEFAULT_COLLECTIONS.items()
|
for coll, params in ALL_COLLECTIONS.items()
|
||||||
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
||||||
}
|
}
|
||||||
MULTIMODAL_COLLECTIONS = {
|
MULTIMODAL_COLLECTIONS = {
|
||||||
coll
|
coll
|
||||||
for coll, params in DEFAULT_COLLECTIONS.items()
|
for coll, params in ALL_COLLECTIONS.items()
|
||||||
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,6 +104,22 @@ TYPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_mimetype(image: Image.Image) -> str | None:
|
||||||
|
format_to_mime = {
|
||||||
|
"JPEG": "image/jpeg",
|
||||||
|
"PNG": "image/png",
|
||||||
|
"GIF": "image/gif",
|
||||||
|
"BMP": "image/bmp",
|
||||||
|
"TIFF": "image/tiff",
|
||||||
|
"WEBP": "image/webp",
|
||||||
|
}
|
||||||
|
|
||||||
|
if not image.format:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return format_to_mime.get(image.format.upper(), f"image/{image.format.lower()}")
|
||||||
|
|
||||||
|
|
||||||
def get_modality(mime_type: str) -> str:
|
def get_modality(mime_type: str) -> str:
|
||||||
for type, mime_types in TYPES.items():
|
for type, mime_types in TYPES.items():
|
||||||
if mime_type in mime_types:
|
if mime_type in mime_types:
|
||||||
@ -237,3 +258,20 @@ def embed(
|
|||||||
for vector in embed_page(page)
|
for vector in embed_page(page)
|
||||||
]
|
]
|
||||||
return modality, chunks
|
return modality, chunks
|
||||||
|
|
||||||
|
|
||||||
|
def embed_image(file_path: pathlib.Path, texts: list[str]) -> Chunk:
|
||||||
|
image = Image.open(file_path)
|
||||||
|
mime_type = get_mimetype(image)
|
||||||
|
if mime_type is None:
|
||||||
|
raise ValueError("Unsupported image format")
|
||||||
|
|
||||||
|
vector = embed_mixed([image] + texts)[0]
|
||||||
|
|
||||||
|
return Chunk(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
file_path=file_path.absolute().as_posix(),
|
||||||
|
content=None,
|
||||||
|
embedding_model=settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
133
src/memory/common/parsers/comics.py
Normal file
133
src/memory/common/parsers/comics.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import logging
|
||||||
|
from typing import TypedDict, NotRequired
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup, Tag
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ComicInfo(TypedDict):
|
||||||
|
title: str
|
||||||
|
image_url: str
|
||||||
|
published_date: NotRequired[str]
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
def extract_smbc(url: str) -> ComicInfo:
|
||||||
|
"""
|
||||||
|
Extract the title, published date, and image URL from a SMBC comic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ComicInfo with title, image_url, published_date, and comic_url
|
||||||
|
"""
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
|
||||||
|
comic_img = soup.find("img", id="cc-comic")
|
||||||
|
title = ""
|
||||||
|
image_url = ""
|
||||||
|
|
||||||
|
if comic_img and isinstance(comic_img, Tag):
|
||||||
|
if comic_img.has_attr("src"):
|
||||||
|
image_url = str(comic_img["src"])
|
||||||
|
if comic_img.has_attr("title"):
|
||||||
|
title = str(comic_img["title"])
|
||||||
|
|
||||||
|
published_date = ""
|
||||||
|
comic_url = ""
|
||||||
|
|
||||||
|
script_ld = soup.find("script", type="application/ld+json")
|
||||||
|
if script_ld and isinstance(script_ld, Tag) and script_ld.string:
|
||||||
|
try:
|
||||||
|
data = json.loads(script_ld.string)
|
||||||
|
published_date = data.get("datePublished", "")
|
||||||
|
|
||||||
|
# Use JSON-LD URL if available
|
||||||
|
title = title or data.get("name", "")
|
||||||
|
comic_url = data.get("url")
|
||||||
|
except (json.JSONDecodeError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
permalink_input = soup.find("input", id="permalinktext")
|
||||||
|
if not comic_url and permalink_input and isinstance(permalink_input, Tag):
|
||||||
|
comic_url = permalink_input.get("value", "")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"title": title,
|
||||||
|
"image_url": image_url,
|
||||||
|
"published_date": published_date,
|
||||||
|
"url": comic_url or url,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_xkcd(url: str) -> ComicInfo:
|
||||||
|
"""
|
||||||
|
Extract comic information from an XKCD comic.
|
||||||
|
|
||||||
|
This function parses an XKCD comic page to extract the title from the hover text,
|
||||||
|
the image URL, and the permanent URL of the comic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL of the XKCD comic to parse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ComicInfo with title, image_url, and comic_url
|
||||||
|
"""
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
|
||||||
|
def get_comic_img() -> Tag | None:
|
||||||
|
"""Extract the comic image tag."""
|
||||||
|
comic_div = soup.find("div", id="comic")
|
||||||
|
if comic_div and isinstance(comic_div, Tag):
|
||||||
|
img = comic_div.find("img")
|
||||||
|
return img if isinstance(img, Tag) else None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_title() -> str:
|
||||||
|
"""Extract title from image title attribute with fallbacks."""
|
||||||
|
# Primary source: hover text from the image (most informative)
|
||||||
|
img = get_comic_img()
|
||||||
|
if img and img.has_attr("title"):
|
||||||
|
return str(img["title"])
|
||||||
|
|
||||||
|
# Secondary source: og:title meta tag
|
||||||
|
og_title = soup.find("meta", property="og:title")
|
||||||
|
if og_title and isinstance(og_title, Tag) and og_title.has_attr("content"):
|
||||||
|
return str(og_title["content"])
|
||||||
|
|
||||||
|
# Last resort: page title div
|
||||||
|
title_div = soup.find("div", id="ctitle")
|
||||||
|
return title_div.text.strip() if title_div else ""
|
||||||
|
|
||||||
|
def get_image_url() -> str:
|
||||||
|
"""Extract and normalize the image URL."""
|
||||||
|
img = get_comic_img()
|
||||||
|
if not img or not img.has_attr("src"):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
image_src = str(img["src"])
|
||||||
|
return f"https:{image_src}" if image_src.startswith("//") else image_src
|
||||||
|
|
||||||
|
def get_permanent_url() -> str:
|
||||||
|
"""Extract the permanent URL to the comic."""
|
||||||
|
og_url = soup.find("meta", property="og:url")
|
||||||
|
if og_url and isinstance(og_url, Tag) and og_url.has_attr("content"):
|
||||||
|
return str(og_url["content"])
|
||||||
|
|
||||||
|
# Fallback: look for permanent link text
|
||||||
|
for a_tag in soup.find_all("a"):
|
||||||
|
text = a_tag.get_text()
|
||||||
|
if text.startswith("https://xkcd.com/") and text.strip().endswith("/"):
|
||||||
|
return str(text.strip())
|
||||||
|
return url
|
||||||
|
|
||||||
|
return {
|
||||||
|
"title": get_title(),
|
||||||
|
"image_url": get_image_url(),
|
||||||
|
"url": get_permanent_url(),
|
||||||
|
}
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, cast
|
from typing import Any, cast, Iterator, Sequence
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from qdrant_client.http import models as qdrant_models
|
from qdrant_client.http import models as qdrant_models
|
||||||
@ -7,7 +7,7 @@ from qdrant_client.http.exceptions import UnexpectedResponse
|
|||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.embedding import (
|
from memory.common.embedding import (
|
||||||
Collection,
|
Collection,
|
||||||
DEFAULT_COLLECTIONS,
|
ALL_COLLECTIONS,
|
||||||
DistanceType,
|
DistanceType,
|
||||||
Vector,
|
Vector,
|
||||||
)
|
)
|
||||||
@ -91,7 +91,7 @@ def initialize_collections(
|
|||||||
If None, defaults to the DEFAULT_COLLECTIONS.
|
If None, defaults to the DEFAULT_COLLECTIONS.
|
||||||
"""
|
"""
|
||||||
if collections is None:
|
if collections is None:
|
||||||
collections = DEFAULT_COLLECTIONS
|
collections = ALL_COLLECTIONS
|
||||||
|
|
||||||
logger.info(f"Initializing collections:")
|
logger.info(f"Initializing collections:")
|
||||||
for name, params in collections.items():
|
for name, params in collections.items():
|
||||||
@ -184,13 +184,13 @@ def search_vectors(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def delete_vectors(
|
def delete_points(
|
||||||
client: qdrant_client.QdrantClient,
|
client: qdrant_client.QdrantClient,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
ids: list[str],
|
ids: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Delete vectors from a collection.
|
Delete points from a collection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: Qdrant client
|
client: Qdrant client
|
||||||
@ -222,3 +222,30 @@ def get_collection_info(
|
|||||||
"""
|
"""
|
||||||
info = client.get_collection(collection_name)
|
info = client.get_collection(collection_name)
|
||||||
return info.model_dump()
|
return info.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
def batch_ids(
|
||||||
|
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
|
||||||
|
) -> Iterator[list[str]]:
|
||||||
|
"""Iterate over all IDs in a collection."""
|
||||||
|
offset = None
|
||||||
|
while resp := client.scroll(
|
||||||
|
collection_name=collection_name,
|
||||||
|
with_vectors=False,
|
||||||
|
offset=offset,
|
||||||
|
limit=batch_size,
|
||||||
|
):
|
||||||
|
points, offset = resp
|
||||||
|
yield [point.id for point in points]
|
||||||
|
|
||||||
|
if not offset:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def find_missing_points(
|
||||||
|
client: qdrant_client.QdrantClient, collection_name: str, ids: Sequence[str]
|
||||||
|
) -> set[str]:
|
||||||
|
found = client.retrieve(
|
||||||
|
collection_name, ids=ids, with_payload=False, with_vectors=False
|
||||||
|
)
|
||||||
|
return set(ids) - {str(r.id) for r in found}
|
||||||
|
@ -29,13 +29,27 @@ def make_db_url(
|
|||||||
|
|
||||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||||
|
|
||||||
|
# Celery settings
|
||||||
|
RABBITMQ_USER = os.getenv("RABBITMQ_USER", "kb")
|
||||||
|
RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", "kb")
|
||||||
|
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
|
||||||
|
|
||||||
|
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
|
||||||
|
|
||||||
|
|
||||||
|
# File storage settings
|
||||||
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
|
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
|
||||||
FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
CHUNK_STORAGE_DIR = pathlib.Path(
|
CHUNK_STORAGE_DIR = pathlib.Path(
|
||||||
os.getenv("CHUNK_STORAGE_DIR", FILE_STORAGE_DIR / "chunks")
|
os.getenv("CHUNK_STORAGE_DIR", FILE_STORAGE_DIR / "chunks")
|
||||||
)
|
)
|
||||||
CHUNK_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
CHUNK_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
COMIC_STORAGE_DIR = pathlib.Path(
|
||||||
|
os.getenv("COMIC_STORAGE_DIR", FILE_STORAGE_DIR / "comics")
|
||||||
|
)
|
||||||
|
COMIC_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Maximum attachment size to store directly in the database (10MB)
|
# Maximum attachment size to store directly in the database (10MB)
|
||||||
MAX_INLINE_ATTACHMENT_SIZE = int(
|
MAX_INLINE_ATTACHMENT_SIZE = int(
|
||||||
os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024)
|
os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024)
|
||||||
@ -51,8 +65,12 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
|||||||
|
|
||||||
|
|
||||||
# Worker settings
|
# Worker settings
|
||||||
|
# Intervals are in seconds
|
||||||
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
|
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
|
||||||
|
CLEAN_COLLECTION_INTERVAL = int(os.getenv("CLEAN_COLLECTION_INTERVAL", 86400))
|
||||||
|
CHUNK_REINGEST_INTERVAL = int(os.getenv("CHUNK_REINGEST_INTERVAL", 3600))
|
||||||
|
|
||||||
|
CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 * 24))
|
||||||
|
|
||||||
# Embedding settings
|
# Embedding settings
|
||||||
TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large")
|
TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large")
|
||||||
|
@ -1,18 +1,15 @@
|
|||||||
import os
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
|
|
||||||
def rabbit_url() -> str:
|
def rabbit_url() -> str:
|
||||||
user = os.getenv("RABBITMQ_USER", "guest")
|
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:5672//"
|
||||||
password = os.getenv("RABBITMQ_PASSWORD", "guest")
|
|
||||||
return f"amqp://{user}:{password}@rabbitmq:5672//"
|
|
||||||
|
|
||||||
|
|
||||||
app = Celery(
|
app = Celery(
|
||||||
"memory",
|
"memory",
|
||||||
broker=rabbit_url(),
|
broker=rabbit_url(),
|
||||||
backend=os.getenv("CELERY_RESULT_BACKEND", f"db+{settings.DB_URL}")
|
backend=settings.CELERY_RESULT_BACKEND,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -27,16 +24,15 @@ app.conf.update(
|
|||||||
"memory.workers.tasks.text.*": {"queue": "medium_embed"},
|
"memory.workers.tasks.text.*": {"queue": "medium_embed"},
|
||||||
"memory.workers.tasks.email.*": {"queue": "email"},
|
"memory.workers.tasks.email.*": {"queue": "email"},
|
||||||
"memory.workers.tasks.photo.*": {"queue": "photo_embed"},
|
"memory.workers.tasks.photo.*": {"queue": "photo_embed"},
|
||||||
"memory.workers.tasks.ocr.*": {"queue": "low_ocr"},
|
"memory.workers.tasks.comic.*": {"queue": "comic"},
|
||||||
"memory.workers.tasks.git.*": {"queue": "git_summary"},
|
|
||||||
"memory.workers.tasks.rss.*": {"queue": "rss"},
|
|
||||||
"memory.workers.tasks.docs.*": {"queue": "docs"},
|
"memory.workers.tasks.docs.*": {"queue": "docs"},
|
||||||
|
"memory.workers.tasks.maintenance.*": {"queue": "maintenance"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.on_after_configure.connect
|
@app.on_after_configure.connect
|
||||||
def ensure_qdrant_initialised(sender, **_):
|
def ensure_qdrant_initialised(sender, **_):
|
||||||
from memory.common import qdrant
|
from memory.common import qdrant
|
||||||
|
|
||||||
qdrant.setup_qdrant()
|
qdrant.setup_qdrant()
|
@ -1,11 +1,23 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
from memory.workers.celery_app import app
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
from memory.workers.celery_app import app
|
||||||
|
from memory.workers.tasks import CLEAN_ALL_COLLECTIONS, REINGEST_MISSING_CHUNKS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
app.conf.beat_schedule = {
|
app.conf.beat_schedule = {
|
||||||
'sync-mail-all': {
|
"sync-mail-all": {
|
||||||
'task': 'memory.workers.tasks.email.sync_all_accounts',
|
"task": "memory.workers.tasks.email.sync_all_accounts",
|
||||||
'schedule': settings.EMAIL_SYNC_INTERVAL,
|
"schedule": settings.EMAIL_SYNC_INTERVAL,
|
||||||
|
},
|
||||||
|
"clean-all-collections": {
|
||||||
|
"task": CLEAN_ALL_COLLECTIONS,
|
||||||
|
"schedule": settings.CLEAN_COLLECTION_INTERVAL,
|
||||||
|
},
|
||||||
|
"reingest-missing-chunks": {
|
||||||
|
"task": REINGEST_MISSING_CHUNKS,
|
||||||
|
"schedule": settings.CHUNK_REINGEST_INTERVAL,
|
||||||
},
|
},
|
||||||
}
|
}
|
@ -1,8 +1,24 @@
|
|||||||
"""
|
"""
|
||||||
Import sub-modules so Celery can register their @app.task decorators.
|
Import sub-modules so Celery can register their @app.task decorators.
|
||||||
"""
|
"""
|
||||||
from memory.workers.tasks import docs, email # noqa
|
|
||||||
|
from memory.workers.tasks import docs, email, comic # noqa
|
||||||
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
|
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
|
||||||
|
from memory.workers.tasks.maintenance import (
|
||||||
|
CLEAN_ALL_COLLECTIONS,
|
||||||
|
CLEAN_COLLECTION,
|
||||||
|
REINGEST_MISSING_CHUNKS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"]
|
__all__ = [
|
||||||
|
"docs",
|
||||||
|
"email",
|
||||||
|
"comic",
|
||||||
|
"SYNC_ACCOUNT",
|
||||||
|
"SYNC_ALL_ACCOUNTS",
|
||||||
|
"PROCESS_EMAIL",
|
||||||
|
"CLEAN_ALL_COLLECTIONS",
|
||||||
|
"CLEAN_COLLECTION",
|
||||||
|
"REINGEST_MISSING_CHUNKS",
|
||||||
|
]
|
||||||
|
172
src/memory/workers/tasks/comic.py
Normal file
172
src/memory/workers/tasks/comic.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import feedparser
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from memory.common import embedding, qdrant, settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import Comic, clean_filename
|
||||||
|
from memory.common.parsers import comics
|
||||||
|
from memory.workers.celery_app import app
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SYNC_ALL_COMICS = "memory.workers.tasks.comic.sync_all_comics"
|
||||||
|
SYNC_SMBC = "memory.workers.tasks.comic.sync_smbc"
|
||||||
|
SYNC_XKCD = "memory.workers.tasks.comic.sync_xkcd"
|
||||||
|
SYNC_COMIC = "memory.workers.tasks.comic.sync_comic"
|
||||||
|
|
||||||
|
|
||||||
|
BASE_SMBC_URL = "https://www.smbc-comics.com/"
|
||||||
|
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"
|
||||||
|
|
||||||
|
BASE_XKCD_URL = "https://xkcd.com/"
|
||||||
|
XKCD_RSS_URL = "https://xkcd.com/atom.xml"
|
||||||
|
|
||||||
|
|
||||||
|
def find_new_urls(base_url: str, rss_url: str) -> set[str]:
|
||||||
|
try:
|
||||||
|
feed = feedparser.parse(rss_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to fetch or parse {rss_url}: {e}")
|
||||||
|
return set()
|
||||||
|
|
||||||
|
urls = {item.get("link") or item.get("id") for item in feed.entries}
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
known = {
|
||||||
|
c.url
|
||||||
|
for c in session.query(Comic.url).filter(
|
||||||
|
Comic.author == base_url,
|
||||||
|
Comic.url.in_(urls),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return urls - known
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_new_comics(
|
||||||
|
base_url: str, rss_url: str, parser: Callable[[str], comics.ComicInfo]
|
||||||
|
) -> set[str]:
|
||||||
|
new_urls = find_new_urls(base_url, rss_url)
|
||||||
|
|
||||||
|
for url in new_urls:
|
||||||
|
data = parser(url) | {"author": base_url, "url": url}
|
||||||
|
sync_comic.delay(**data)
|
||||||
|
return new_urls
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=SYNC_COMIC)
|
||||||
|
def sync_comic(
|
||||||
|
url: str,
|
||||||
|
image_url: str,
|
||||||
|
title: str,
|
||||||
|
author: str,
|
||||||
|
published_date: datetime | None = None,
|
||||||
|
):
|
||||||
|
"""Synchronize a comic from a URL."""
|
||||||
|
with make_session() as session:
|
||||||
|
if session.query(Comic).filter(Comic.url == url).first():
|
||||||
|
return
|
||||||
|
|
||||||
|
response = requests.get(image_url)
|
||||||
|
file_type = image_url.split(".")[-1]
|
||||||
|
mime_type = f"image/{file_type}"
|
||||||
|
filename = (
|
||||||
|
settings.COMIC_STORAGE_DIR / clean_filename(author) / f"{title}.{file_type}"
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
filename.write_bytes(response.content)
|
||||||
|
|
||||||
|
sha256 = hashlib.sha256(f"{image_url}{published_date}".encode()).digest()
|
||||||
|
comic = Comic(
|
||||||
|
title=title,
|
||||||
|
url=url,
|
||||||
|
published=published_date,
|
||||||
|
author=author,
|
||||||
|
filename=filename.resolve().as_posix(),
|
||||||
|
mime_type=mime_type,
|
||||||
|
size=len(response.content),
|
||||||
|
sha256=sha256,
|
||||||
|
tags={"comic", author},
|
||||||
|
modality="comic",
|
||||||
|
)
|
||||||
|
chunk = embedding.embed_image(filename, [title, author])
|
||||||
|
comic.chunks = [chunk]
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
session.add(comic)
|
||||||
|
session.add(chunk)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
qdrant.upsert_vectors(
|
||||||
|
client=qdrant.get_qdrant_client(),
|
||||||
|
collection_name="comic",
|
||||||
|
ids=[str(chunk.id)],
|
||||||
|
vectors=[chunk.vector],
|
||||||
|
payloads=[comic.as_payload()],
|
||||||
|
)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=SYNC_SMBC)
|
||||||
|
def sync_smbc() -> set[str]:
|
||||||
|
"""Synchronize SMBC comics from RSS feed."""
|
||||||
|
return fetch_new_comics(BASE_SMBC_URL, SMBC_RSS_URL, comics.extract_smbc)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=SYNC_XKCD)
|
||||||
|
def sync_xkcd() -> set[str]:
|
||||||
|
"""Synchronize XKCD comics from RSS feed."""
|
||||||
|
return fetch_new_comics(BASE_XKCD_URL, XKCD_RSS_URL, comics.extract_xkcd)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=SYNC_ALL_COMICS)
|
||||||
|
def sync_all_comics():
|
||||||
|
"""Synchronize all active comics."""
|
||||||
|
sync_smbc.delay()
|
||||||
|
sync_xkcd.delay()
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name="memory.workers.tasks.comic.full_sync_comic")
|
||||||
|
def trigger_comic_sync():
|
||||||
|
def prev_smbc_comic(url: str) -> str | None:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
response = requests.get(url)
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
if link := soup.find("a", attrs={"class", "cc-prev"}):
|
||||||
|
return link.attrs["href"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
next_url = "https://www.smbc-comics.com"
|
||||||
|
urls = []
|
||||||
|
logger.info(f"syncing {next_url}")
|
||||||
|
while next_url := prev_smbc_comic(next_url):
|
||||||
|
if len(urls) % 10 == 0:
|
||||||
|
logger.info(f"got {len(urls)}")
|
||||||
|
try:
|
||||||
|
data = comics.extract_smbc(next_url) | {
|
||||||
|
"author": "https://www.smbc-comics.com/"
|
||||||
|
}
|
||||||
|
sync_comic.delay(**data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"failed to sync {next_url}: {e}")
|
||||||
|
urls.append(next_url)
|
||||||
|
|
||||||
|
logger.info(f"syncing {BASE_XKCD_URL}")
|
||||||
|
for i in range(1, 308):
|
||||||
|
if i % 10 == 0:
|
||||||
|
logger.info(f"got {i}")
|
||||||
|
url = f"{BASE_XKCD_URL}/{i}"
|
||||||
|
try:
|
||||||
|
data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"}
|
||||||
|
sync_comic.delay(**data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"failed to sync {url}: {e}")
|
157
src/memory/workers/tasks/maintenance.py
Normal file
157
src/memory/workers/tasks/maintenance.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import contains_eager
|
||||||
|
|
||||||
|
from memory.common import embedding, qdrant, settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
|
from memory.workers.celery_app import app
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CLEAN_ALL_COLLECTIONS = "memory.workers.tasks.maintenance.clean_all_collections"
|
||||||
|
CLEAN_COLLECTION = "memory.workers.tasks.maintenance.clean_collection"
|
||||||
|
REINGEST_MISSING_CHUNKS = "memory.workers.tasks.maintenance.reingest_missing_chunks"
|
||||||
|
REINGEST_CHUNK = "memory.workers.tasks.maintenance.reingest_chunk"
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=CLEAN_COLLECTION)
|
||||||
|
def clean_collection(collection: str) -> dict[str, int]:
|
||||||
|
logger.info(f"Cleaning collection {collection}")
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
batches, deleted, checked = 0, 0, 0
|
||||||
|
for batch in qdrant.batch_ids(client, collection):
|
||||||
|
batches += 1
|
||||||
|
batch_ids = set(batch)
|
||||||
|
with make_session() as session:
|
||||||
|
db_ids = {
|
||||||
|
str(c.id) for c in session.query(Chunk).filter(Chunk.id.in_(batch_ids))
|
||||||
|
}
|
||||||
|
ids_to_delete = batch_ids - db_ids
|
||||||
|
checked += len(batch_ids)
|
||||||
|
if ids_to_delete:
|
||||||
|
qdrant.delete_points(client, collection, list(ids_to_delete))
|
||||||
|
deleted += len(ids_to_delete)
|
||||||
|
return {
|
||||||
|
"batches": batches,
|
||||||
|
"deleted": deleted,
|
||||||
|
"checked": checked,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=CLEAN_ALL_COLLECTIONS)
|
||||||
|
def clean_all_collections():
|
||||||
|
logger.info("Cleaning all collections")
|
||||||
|
for collection in embedding.ALL_COLLECTIONS:
|
||||||
|
clean_collection.delay(collection)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=REINGEST_CHUNK)
|
||||||
|
def reingest_chunk(chunk_id: str, collection: str):
|
||||||
|
logger.info(f"Reingesting chunk {chunk_id}")
|
||||||
|
with make_session() as session:
|
||||||
|
chunk = session.query(Chunk).get(chunk_id)
|
||||||
|
if not chunk:
|
||||||
|
logger.error(f"Chunk {chunk_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
if collection not in embedding.ALL_COLLECTIONS:
|
||||||
|
raise ValueError(f"Unsupported collection {collection}")
|
||||||
|
|
||||||
|
data = chunk.data
|
||||||
|
if collection in embedding.MULTIMODAL_COLLECTIONS:
|
||||||
|
vector = embedding.embed_mixed(data)[0]
|
||||||
|
elif len(data) == 1 and isinstance(data[0], str):
|
||||||
|
vector = embedding.embed_text([data[0]])[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported data type for collection {collection}")
|
||||||
|
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
qdrant.upsert_vectors(
|
||||||
|
client,
|
||||||
|
collection,
|
||||||
|
[chunk_id],
|
||||||
|
[vector],
|
||||||
|
[chunk.source.as_payload()],
|
||||||
|
)
|
||||||
|
chunk.checked_at = datetime.now()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def check_batch(batch: Sequence[Chunk]) -> dict:
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
by_collection = defaultdict(list)
|
||||||
|
for chunk in batch:
|
||||||
|
by_collection[chunk.source.modality].append(chunk)
|
||||||
|
|
||||||
|
stats = {}
|
||||||
|
for collection, chunks in by_collection.items():
|
||||||
|
missing = qdrant.find_missing_points(
|
||||||
|
client, collection, [str(c.id) for c in chunks]
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
if str(chunk.id) in missing:
|
||||||
|
reingest_chunk.delay(str(chunk.id), collection)
|
||||||
|
else:
|
||||||
|
chunk.checked_at = datetime.now()
|
||||||
|
|
||||||
|
stats[collection] = {
|
||||||
|
"missing": len(missing),
|
||||||
|
"correct": len(chunks) - len(missing),
|
||||||
|
"total": len(chunks),
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(name=REINGEST_MISSING_CHUNKS)
|
||||||
|
def reingest_missing_chunks(batch_size: int = 1000):
|
||||||
|
logger.info("Reingesting missing chunks")
|
||||||
|
total_stats = defaultdict(lambda: {"missing": 0, "correct": 0, "total": 0})
|
||||||
|
since = datetime.now() - timedelta(minutes=settings.CHUNK_REINGEST_SINCE_MINUTES)
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
total_count = session.query(Chunk).filter(Chunk.checked_at < since).count()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {total_count} chunks to check, processing in batches of {batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_batches = (total_count + batch_size - 1) // batch_size
|
||||||
|
|
||||||
|
for batch_num in range(num_batches):
|
||||||
|
stmt = (
|
||||||
|
select(Chunk)
|
||||||
|
.join(SourceItem, Chunk.source_id == SourceItem.id)
|
||||||
|
.filter(Chunk.checked_at < since)
|
||||||
|
.options(
|
||||||
|
contains_eager(Chunk.source).load_only(
|
||||||
|
SourceItem.id, SourceItem.modality, SourceItem.tags
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(Chunk.id)
|
||||||
|
.limit(batch_size)
|
||||||
|
)
|
||||||
|
chunks = session.execute(stmt).scalars().all()
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Processing batch {batch_num + 1}/{num_batches} with {len(chunks)} chunks"
|
||||||
|
)
|
||||||
|
batch_stats = check_batch(chunks)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
for collection, stats in batch_stats.items():
|
||||||
|
total_stats[collection]["missing"] += stats["missing"]
|
||||||
|
total_stats[collection]["correct"] += stats["correct"]
|
||||||
|
total_stats[collection]["total"] += stats["total"]
|
||||||
|
|
||||||
|
return dict(total_stats)
|
199
tests/memory/common/parsers/test_comic.py
Normal file
199
tests/memory/common/parsers/test_comic.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
import textwrap
|
||||||
|
from memory.common.parsers.comics import extract_smbc, extract_xkcd
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
MOCK_SMBC_HTML = """
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Saturday Morning Breakfast Cereal - Time</title>
|
||||||
|
<meta property="og:image" content="https://www.smbc-comics.com/comics/1746375102-20250504.webp" />
|
||||||
|
|
||||||
|
<script type='application/ld+json'>
|
||||||
|
{
|
||||||
|
"@context": "http://www.schema.org",
|
||||||
|
"@type": "ComicStory",
|
||||||
|
"name": "Saturday Morning Breakfast Cereal - Time",
|
||||||
|
"url": "https://www.smbc-comics.com/comic/time-6",
|
||||||
|
"author":"Zach Weinersmith",
|
||||||
|
"about":"Saturday Morning Breakfast Cereal - Time",
|
||||||
|
"image":"https://www.smbc-comics.com/comics/1746375102-20250504.webp",
|
||||||
|
"thumbnailUrl":"https://www.smbc-comics.com/comicsthumbs/1746375102-20250504.webp",
|
||||||
|
"datePublished":"2025-05-04T12:11:21-04:00"
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="cc-comicbody">
|
||||||
|
<img title="I don't know why either, but it was fun to draw."
|
||||||
|
src="https://www.smbc-comics.com/comics/1746375102-20250504.webp"
|
||||||
|
id="cc-comic" />
|
||||||
|
</div>
|
||||||
|
<div id="permalink">
|
||||||
|
<input id="permalinktext" type="text" value="http://www.smbc-comics.com/comic/time-6" />
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"to_remove, overrides",
|
||||||
|
[
|
||||||
|
# Normal case - all data present
|
||||||
|
("", {}),
|
||||||
|
# Missing title attribute on image
|
||||||
|
(
|
||||||
|
'title="I don\'t know why either, but it was fun to draw."',
|
||||||
|
{"title": "Saturday Morning Breakfast Cereal - Time"},
|
||||||
|
),
|
||||||
|
# Missing src attribute on image
|
||||||
|
(
|
||||||
|
'src="https://www.smbc-comics.com/comics/1746375102-20250504.webp"',
|
||||||
|
{"image_url": ""},
|
||||||
|
),
|
||||||
|
# # Missing entire img tag
|
||||||
|
(
|
||||||
|
'<img title="I don\'t know why either, but it was fun to draw." \n src="https://www.smbc-comics.com/comics/1746375102-20250504.webp" \n id="cc-comic" />',
|
||||||
|
{"title": "Saturday Morning Breakfast Cereal - Time", "image_url": ""},
|
||||||
|
),
|
||||||
|
# # Corrupt JSON-LD data
|
||||||
|
(
|
||||||
|
'"datePublished":"2025-05-04T12:11:21-04:00"',
|
||||||
|
{"published_date": "", "url": "http://www.smbc-comics.com/comic/time-6"},
|
||||||
|
),
|
||||||
|
# # Missing JSON-LD script entirely
|
||||||
|
(
|
||||||
|
'<script type=\'application/ld+json\'>\n {\n "@context": "http://www.schema.org",\n "@type": "ComicStory",\n "name": "Saturday Morning Breakfast Cereal - Time",\n "url": "https://www.smbc-comics.com/comic/time-6",\n "author":"Zach Weinersmith",\n "about":"Saturday Morning Breakfast Cereal - Time",\n "image":"https://www.smbc-comics.com/comics/1746375102-20250504.webp",\n "thumbnailUrl":"https://www.smbc-comics.com/comicsthumbs/1746375102-20250504.webp",\n "datePublished":"2025-05-04T12:11:21-04:00"\n }\n </script>',
|
||||||
|
{"published_date": "", "url": "http://www.smbc-comics.com/comic/time-6"},
|
||||||
|
),
|
||||||
|
# # Missing permalink input
|
||||||
|
(
|
||||||
|
'<div id="permalink">\n <input id="permalinktext" type="text" value="http://www.smbc-comics.com/comic/time-6" />\n </div>',
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
# Missing URL in JSON-LD
|
||||||
|
(
|
||||||
|
'"url": "https://www.smbc-comics.com/comic/time-6",',
|
||||||
|
{"url": "http://www.smbc-comics.com/comic/time-6"},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_smbc(to_remove, overrides):
|
||||||
|
"""Test successful extraction of comic info from SMBC."""
|
||||||
|
expected = {
|
||||||
|
"title": "I don't know why either, but it was fun to draw.",
|
||||||
|
"image_url": "https://www.smbc-comics.com/comics/1746375102-20250504.webp",
|
||||||
|
"published_date": "2025-05-04T12:11:21-04:00",
|
||||||
|
"url": "https://www.smbc-comics.com/comic/time-6",
|
||||||
|
}
|
||||||
|
with patch("requests.get") as mock_get:
|
||||||
|
mock_get.return_value.status_code = 200
|
||||||
|
mock_get.return_value.text = MOCK_SMBC_HTML.replace(to_remove, "")
|
||||||
|
assert extract_smbc("https://www.smbc-comics.com/") == expected | overrides
|
||||||
|
|
||||||
|
|
||||||
|
# Create a stripped-down version of the XKCD HTML
|
||||||
|
MOCK_XKCD_HTML = """
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>xkcd: Unstoppable Force and Immovable Object</title>
|
||||||
|
<meta property="og:title" content="Unstoppable Force and Immovable Object">
|
||||||
|
<meta property="og:url" content="https://xkcd.com/3084/">
|
||||||
|
<meta property="og:image" content="https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object_2x.png">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="ctitle">Unstoppable Force and Immovable Object</div>
|
||||||
|
|
||||||
|
<div id="comic">
|
||||||
|
<img src="//imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png"
|
||||||
|
title="Unstoppable force-carrying particles can't interact with immovable matter by definition."
|
||||||
|
alt="Unstoppable Force and Immovable Object" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
Permanent link to this comic: <a href="https://xkcd.com/3084">https://xkcd.com/3084/</a><br />
|
||||||
|
Image URL (for hotlinking/embedding): <a href="https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png">
|
||||||
|
https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png</a>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"to_remove, overrides",
|
||||||
|
[
|
||||||
|
# Normal case - all data present
|
||||||
|
("", {}),
|
||||||
|
# Missing title attribute on image
|
||||||
|
(
|
||||||
|
'title="Unstoppable force-carrying particles can't interact with immovable matter by definition."',
|
||||||
|
{
|
||||||
|
"title": "Unstoppable Force and Immovable Object"
|
||||||
|
}, # Falls back to og:title
|
||||||
|
),
|
||||||
|
# Missing og:title meta tag - falls back to ctitle
|
||||||
|
(
|
||||||
|
'<meta property="og:title" content="Unstoppable Force and Immovable Object">',
|
||||||
|
{}, # Still gets title from image title
|
||||||
|
),
|
||||||
|
# Missing both title and og:title - falls back to ctitle
|
||||||
|
(
|
||||||
|
'title="Unstoppable force-carrying particles can't interact with immovable matter by definition."\n<meta property="og:title" content="Unstoppable Force and Immovable Object">',
|
||||||
|
{"title": "Unstoppable Force and Immovable Object"}, # Falls back to ctitle
|
||||||
|
),
|
||||||
|
# Missing image src attribute
|
||||||
|
(
|
||||||
|
'src="//imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png"',
|
||||||
|
{"image_url": ""},
|
||||||
|
),
|
||||||
|
# Missing entire img tag
|
||||||
|
(
|
||||||
|
'<img src="//imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png" \n title="Unstoppable force-carrying particles can't interact with immovable matter by definition." \n alt="Unstoppable Force and Immovable Object" />',
|
||||||
|
{
|
||||||
|
"image_url": "",
|
||||||
|
"title": "Unstoppable Force and Immovable Object",
|
||||||
|
}, # Falls back to og:title
|
||||||
|
),
|
||||||
|
# Missing entire comic div
|
||||||
|
(
|
||||||
|
'<div id="comic">\n<img src="//imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png" \n title="Unstoppable force-carrying particles can't interact with immovable matter by definition." \n alt="Unstoppable Force and Immovable Object" />\n</div>',
|
||||||
|
{"image_url": "", "title": "Unstoppable Force and Immovable Object"},
|
||||||
|
),
|
||||||
|
# Missing og:url tag
|
||||||
|
(
|
||||||
|
'<meta property="og:url" content="https://xkcd.com/3084/">',
|
||||||
|
{}, # Should fallback to permalink link
|
||||||
|
),
|
||||||
|
# Missing permanent link
|
||||||
|
(
|
||||||
|
'Permanent link to this comic: <a href="https://xkcd.com/3084">https://xkcd.com/3084/</a><br />',
|
||||||
|
{"url": "https://xkcd.com/3084/"}, # Should still get URL from og:url
|
||||||
|
),
|
||||||
|
# Missing both og:url and permanent link
|
||||||
|
(
|
||||||
|
'<meta property="og:url" content="https://xkcd.com/3084/">\nPermanent link to this comic: <a href="https://xkcd.com/3084">https://xkcd.com/3084/</a><br />',
|
||||||
|
{"url": "https://xkcd.com/test"}, # Falls back to original URL
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_xkcd(to_remove, overrides):
|
||||||
|
"""Test successful extraction of comic info from XKCD."""
|
||||||
|
expected = {
|
||||||
|
"title": "Unstoppable force-carrying particles can't interact with immovable matter by definition.",
|
||||||
|
"image_url": "https://imgs.xkcd.com/comics/unstoppable_force_and_immovable_object.png",
|
||||||
|
"url": "https://xkcd.com/3084/",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("requests.get") as mock_get:
|
||||||
|
mock_get.return_value.status_code = 200
|
||||||
|
modified_html = MOCK_XKCD_HTML
|
||||||
|
for item in to_remove.split("\n"):
|
||||||
|
modified_html = modified_html.replace(item, "")
|
||||||
|
|
||||||
|
mock_get.return_value.text = modified_html
|
||||||
|
result = extract_xkcd("https://xkcd.com/test")
|
||||||
|
assert result == expected | overrides
|
@ -1,22 +1,25 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch, Mock
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from qdrant_client.http import models as qdrant_models
|
from qdrant_client.http import models as qdrant_models
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
|
||||||
from memory.common.qdrant import (
|
from memory.common.qdrant import (
|
||||||
DEFAULT_COLLECTIONS,
|
ALL_COLLECTIONS,
|
||||||
ensure_collection_exists,
|
ensure_collection_exists,
|
||||||
initialize_collections,
|
initialize_collections,
|
||||||
upsert_vectors,
|
upsert_vectors,
|
||||||
delete_vectors,
|
delete_points,
|
||||||
|
batch_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_qdrant_client():
|
def mock_qdrant_client():
|
||||||
with patch.object(qdrant_client, "QdrantClient", return_value=MagicMock()) as mock_client:
|
with patch.object(
|
||||||
|
qdrant_client, "QdrantClient", return_value=MagicMock()
|
||||||
|
) as mock_client:
|
||||||
yield mock_client
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
@ -30,7 +33,7 @@ def test_ensure_collection_exists_existing(mock_qdrant_client):
|
|||||||
|
|
||||||
def test_ensure_collection_exists_new(mock_qdrant_client):
|
def test_ensure_collection_exists_new(mock_qdrant_client):
|
||||||
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
|
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
|
||||||
status_code=404, reason_phrase='asd', content=b'asd', headers=None
|
status_code=404, reason_phrase="asd", content=b"asd", headers=None
|
||||||
)
|
)
|
||||||
|
|
||||||
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
|
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
|
||||||
@ -43,7 +46,7 @@ def test_ensure_collection_exists_new(mock_qdrant_client):
|
|||||||
def test_initialize_collections(mock_qdrant_client):
|
def test_initialize_collections(mock_qdrant_client):
|
||||||
initialize_collections(mock_qdrant_client)
|
initialize_collections(mock_qdrant_client)
|
||||||
|
|
||||||
assert mock_qdrant_client.get_collection.call_count == len(DEFAULT_COLLECTIONS)
|
assert mock_qdrant_client.get_collection.call_count == len(ALL_COLLECTIONS)
|
||||||
|
|
||||||
|
|
||||||
def test_upsert_vectors(mock_qdrant_client):
|
def test_upsert_vectors(mock_qdrant_client):
|
||||||
@ -71,7 +74,7 @@ def test_upsert_vectors(mock_qdrant_client):
|
|||||||
def test_delete_vectors(mock_qdrant_client):
|
def test_delete_vectors(mock_qdrant_client):
|
||||||
ids = ["1", "2"]
|
ids = ["1", "2"]
|
||||||
|
|
||||||
delete_vectors(mock_qdrant_client, "test_collection", ids)
|
delete_points(mock_qdrant_client, "test_collection", ids)
|
||||||
|
|
||||||
mock_qdrant_client.delete.assert_called_once()
|
mock_qdrant_client.delete.assert_called_once()
|
||||||
args, kwargs = mock_qdrant_client.delete.call_args
|
args, kwargs = mock_qdrant_client.delete.call_args
|
||||||
@ -79,3 +82,15 @@ def test_delete_vectors(mock_qdrant_client):
|
|||||||
assert kwargs["collection_name"] == "test_collection"
|
assert kwargs["collection_name"] == "test_collection"
|
||||||
assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList)
|
assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList)
|
||||||
assert kwargs["points_selector"].points == ids
|
assert kwargs["points_selector"].points == ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_ids(mock_qdrant_client):
|
||||||
|
mock_qdrant_client.scroll.side_effect = [
|
||||||
|
([Mock(id="1"), Mock(id="2")], "3"),
|
||||||
|
([Mock(id="3"), Mock(id="4")], None),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert list(batch_ids(mock_qdrant_client, "test_collection")) == [
|
||||||
|
["1", "2"],
|
||||||
|
["3", "4"],
|
||||||
|
]
|
||||||
|
310
tests/memory/workers/tasks/test_maintenance.py
Normal file
310
tests/memory/workers/tasks/test_maintenance.py
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import patch, call
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common import qdrant as qd
|
||||||
|
from memory.common import embedding, settings
|
||||||
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
|
from memory.workers.tasks.maintenance import (
|
||||||
|
clean_collection,
|
||||||
|
reingest_chunk,
|
||||||
|
check_batch,
|
||||||
|
reingest_missing_chunks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def source(db_session):
|
||||||
|
s = SourceItem(id=1, modality="text", sha256=b"123")
|
||||||
|
db_session.add(s)
|
||||||
|
db_session.commit()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_uuid4():
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
def uuid4():
|
||||||
|
nonlocal i
|
||||||
|
i += 1
|
||||||
|
return f"00000000-0000-0000-0000-00000000000{i}"
|
||||||
|
|
||||||
|
with patch("uuid.uuid4", side_effect=uuid4):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_image(mock_file_storage):
|
||||||
|
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||||
|
img_path = settings.CHUNK_STORAGE_DIR / "test.png"
|
||||||
|
img.save(img_path)
|
||||||
|
return img_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["text", "photo"])
|
||||||
|
def chunk(request, test_image, db_session):
|
||||||
|
"""Parametrized fixture for chunk configuration"""
|
||||||
|
collection = request.param
|
||||||
|
if collection == "photo":
|
||||||
|
content = None
|
||||||
|
file_path = str(test_image)
|
||||||
|
else:
|
||||||
|
content = "Test content for reingestion"
|
||||||
|
file_path = None
|
||||||
|
|
||||||
|
chunk = Chunk(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
source=SourceItem(id=1, modality=collection, sha256=b"123"),
|
||||||
|
content=content,
|
||||||
|
file_path=file_path,
|
||||||
|
embedding_model="test-model",
|
||||||
|
checked_at=datetime(2025, 1, 1),
|
||||||
|
)
|
||||||
|
db_session.add(chunk)
|
||||||
|
db_session.commit()
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
def test_clean_collection_no_mismatches(db_session, qdrant, source):
|
||||||
|
"""Test when all Qdrant points exist in the database - nothing should be deleted."""
|
||||||
|
# Create chunks in the database
|
||||||
|
chunk_ids = [str(uuid.uuid4()) for _ in range(3000)]
|
||||||
|
collection = "text"
|
||||||
|
|
||||||
|
# Add chunks to the database
|
||||||
|
for chunk_id in chunk_ids:
|
||||||
|
db_session.add(
|
||||||
|
Chunk(
|
||||||
|
id=chunk_id,
|
||||||
|
source=source,
|
||||||
|
content="Test content",
|
||||||
|
embedding_model="test-model",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db_session.commit()
|
||||||
|
qd.ensure_collection_exists(qdrant, collection, 1024)
|
||||||
|
qd.upsert_vectors(qdrant, collection, chunk_ids, [[1] * 1024] * len(chunk_ids))
|
||||||
|
|
||||||
|
assert set(chunk_ids) == {
|
||||||
|
str(i) for batch in qd.batch_ids(qdrant, collection) for i in batch
|
||||||
|
}
|
||||||
|
|
||||||
|
clean_collection(collection)
|
||||||
|
|
||||||
|
# Check that the chunks are still in the database - no points were deleted
|
||||||
|
assert set(chunk_ids) == {
|
||||||
|
str(i) for batch in qd.batch_ids(qdrant, collection) for i in batch
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_clean_collection_with_orphaned_vectors(db_session, qdrant, source):
|
||||||
|
"""Test when there are vectors in Qdrant that don't exist in the database."""
|
||||||
|
existing_ids = [str(uuid.uuid4()) for _ in range(3000)]
|
||||||
|
orphaned_ids = [str(uuid.uuid4()) for _ in range(3000)]
|
||||||
|
all_ids = existing_ids + orphaned_ids
|
||||||
|
collection = "text"
|
||||||
|
|
||||||
|
# Add only the existing chunks to the database
|
||||||
|
for chunk_id in existing_ids:
|
||||||
|
db_session.add(
|
||||||
|
Chunk(
|
||||||
|
id=chunk_id,
|
||||||
|
source=source,
|
||||||
|
content="Test content",
|
||||||
|
embedding_model="test-model",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db_session.commit()
|
||||||
|
qd.ensure_collection_exists(qdrant, collection, 1024)
|
||||||
|
qd.upsert_vectors(qdrant, collection, all_ids, [[1] * 1024] * len(all_ids))
|
||||||
|
|
||||||
|
clean_collection(collection)
|
||||||
|
|
||||||
|
# The orphaned vectors should be deleted
|
||||||
|
assert set(existing_ids) == {
|
||||||
|
str(i) for batch in qd.batch_ids(qdrant, collection) for i in batch
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_clean_collection_empty_batches(db_session, qdrant):
|
||||||
|
collection = "text"
|
||||||
|
qd.ensure_collection_exists(qdrant, collection, 1024)
|
||||||
|
|
||||||
|
clean_collection(collection)
|
||||||
|
|
||||||
|
assert not [i for b in qd.batch_ids(qdrant, collection) for i in b]
|
||||||
|
|
||||||
|
|
||||||
|
def test_reingest_chunk(db_session, qdrant, chunk):
|
||||||
|
"""Test reingesting a chunk using parameterized fixtures"""
|
||||||
|
collection = chunk.source.modality
|
||||||
|
qd.ensure_collection_exists(qdrant, collection, 1024)
|
||||||
|
|
||||||
|
start = datetime.now()
|
||||||
|
test_vector = [0.1] * 1024
|
||||||
|
|
||||||
|
with patch.object(embedding, "embed_chunks", return_value=[test_vector]):
|
||||||
|
reingest_chunk(str(chunk.id), collection)
|
||||||
|
|
||||||
|
vectors = qd.search_vectors(qdrant, collection, test_vector, limit=1)
|
||||||
|
assert len(vectors) == 1
|
||||||
|
assert str(vectors[0].id) == str(chunk.id)
|
||||||
|
assert vectors[0].payload == chunk.source.as_payload()
|
||||||
|
db_session.refresh(chunk)
|
||||||
|
assert chunk.checked_at.isoformat() > start.isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def test_reingest_chunk_not_found(db_session, qdrant):
|
||||||
|
"""Test when the chunk to reingest doesn't exist."""
|
||||||
|
non_existent_id = str(uuid.uuid4())
|
||||||
|
collection = "text"
|
||||||
|
|
||||||
|
reingest_chunk(non_existent_id, collection)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reingest_chunk_unsupported_collection(db_session, qdrant, source):
|
||||||
|
"""Test reingesting with an unsupported collection type."""
|
||||||
|
chunk_id = str(uuid.uuid4())
|
||||||
|
chunk = Chunk(
|
||||||
|
id=chunk_id,
|
||||||
|
source=source,
|
||||||
|
content="Test content",
|
||||||
|
embedding_model="test-model",
|
||||||
|
)
|
||||||
|
db_session.add(chunk)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
unsupported_collection = "unsupported"
|
||||||
|
qd.ensure_collection_exists(qdrant, unsupported_collection, 1024)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=f"Unsupported collection {unsupported_collection}"
|
||||||
|
):
|
||||||
|
reingest_chunk(chunk_id, unsupported_collection)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_batch_empty(db_session, qdrant):
|
||||||
|
assert check_batch([]) == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_batch(db_session, qdrant):
|
||||||
|
modalities = ["text", "photo", "mail"]
|
||||||
|
chunks = [
|
||||||
|
Chunk(
|
||||||
|
id=f"00000000-0000-0000-0000-0000000000{i:02d}",
|
||||||
|
source=SourceItem(modality=modality, sha256=f"123{i}".encode()),
|
||||||
|
content="Test content",
|
||||||
|
file_path=None,
|
||||||
|
embedding_model="test-model",
|
||||||
|
checked_at=datetime(2025, 1, 1),
|
||||||
|
)
|
||||||
|
for modality in modalities
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
db_session.add_all(chunks)
|
||||||
|
db_session.commit()
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
for modality in modalities:
|
||||||
|
qd.ensure_collection_exists(qdrant, modality, 1024)
|
||||||
|
|
||||||
|
for chunk in chunks[::2]:
|
||||||
|
qd.upsert_vectors(qdrant, chunk.source.modality, [str(chunk.id)], [[1] * 1024])
|
||||||
|
|
||||||
|
with patch.object(reingest_chunk, "delay") as mock_reingest:
|
||||||
|
stats = check_batch(chunks)
|
||||||
|
|
||||||
|
assert mock_reingest.call_args_list == [
|
||||||
|
call("00000000-0000-0000-0000-000000000001", "text"),
|
||||||
|
call("00000000-0000-0000-0000-000000000003", "text"),
|
||||||
|
call("00000000-0000-0000-0000-000000000000", "photo"),
|
||||||
|
call("00000000-0000-0000-0000-000000000002", "photo"),
|
||||||
|
call("00000000-0000-0000-0000-000000000004", "photo"),
|
||||||
|
call("00000000-0000-0000-0000-000000000001", "mail"),
|
||||||
|
call("00000000-0000-0000-0000-000000000003", "mail"),
|
||||||
|
]
|
||||||
|
assert stats == {
|
||||||
|
"mail": {"correct": 3, "missing": 2, "total": 5},
|
||||||
|
"text": {"correct": 3, "missing": 2, "total": 5},
|
||||||
|
"photo": {"correct": 2, "missing": 3, "total": 5},
|
||||||
|
}
|
||||||
|
db_session.commit()
|
||||||
|
for chunk in chunks[::2]:
|
||||||
|
assert chunk.checked_at.isoformat() > start_time.isoformat()
|
||||||
|
for chunk in chunks[1::2]:
|
||||||
|
assert chunk.checked_at.isoformat()[:19] == "2025-01-01T00:00:00"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [4, 10, 100])
|
||||||
|
def test_reingest_missing_chunks(db_session, qdrant, batch_size):
|
||||||
|
now = datetime.now()
|
||||||
|
old_time = now - timedelta(minutes=120) # Older than the threshold
|
||||||
|
|
||||||
|
modalities = ["text", "photo", "mail"]
|
||||||
|
ids_generator = (f"00000000-0000-0000-0000-00000000{i:04d}" for i in range(1000))
|
||||||
|
|
||||||
|
old_chunks = [
|
||||||
|
Chunk(
|
||||||
|
id=next(ids_generator),
|
||||||
|
source=SourceItem(modality=modality, sha256=f"{modality}-{i}".encode()),
|
||||||
|
content="Old content",
|
||||||
|
file_path=None,
|
||||||
|
embedding_model="test-model",
|
||||||
|
checked_at=old_time,
|
||||||
|
)
|
||||||
|
for modality in modalities
|
||||||
|
for i in range(20)
|
||||||
|
]
|
||||||
|
|
||||||
|
recent_chunks = [
|
||||||
|
Chunk(
|
||||||
|
id=next(ids_generator),
|
||||||
|
source=SourceItem(
|
||||||
|
modality=modality, sha256=f"recent-{modality}-{i}".encode()
|
||||||
|
),
|
||||||
|
content="Recent content",
|
||||||
|
file_path=None,
|
||||||
|
embedding_model="test-model",
|
||||||
|
checked_at=now,
|
||||||
|
)
|
||||||
|
for modality in modalities
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
|
||||||
|
db_session.add_all(old_chunks + recent_chunks)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
for modality in modalities:
|
||||||
|
qd.ensure_collection_exists(qdrant, modality, 1024)
|
||||||
|
|
||||||
|
for chunk in old_chunks[::2]:
|
||||||
|
qd.upsert_vectors(qdrant, chunk.source.modality, [str(chunk.id)], [[1] * 1024])
|
||||||
|
|
||||||
|
with patch.object(reingest_chunk, "delay", reingest_chunk):
|
||||||
|
with patch.object(settings, "CHUNK_REINGEST_SINCE_MINUTES", 60):
|
||||||
|
with patch.object(embedding, "embed_chunks", return_value=[[1] * 1024]):
|
||||||
|
result = reingest_missing_chunks(batch_size=batch_size)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"photo": {"correct": 10, "missing": 10, "total": 20},
|
||||||
|
"mail": {"correct": 10, "missing": 10, "total": 20},
|
||||||
|
"text": {"correct": 10, "missing": 10, "total": 20},
|
||||||
|
}
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
# All the old chunks should be reingested
|
||||||
|
client = qd.get_qdrant_client()
|
||||||
|
for modality in modalities:
|
||||||
|
qdrant_ids = [
|
||||||
|
i for b in qd.batch_ids(client, modality, batch_size=1000) for i in b
|
||||||
|
]
|
||||||
|
db_ids = [str(c.id) for c in old_chunks if c.source.modality == modality]
|
||||||
|
assert set(qdrant_ids) == set(db_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reingest_missing_chunks_no_chunks(db_session):
|
||||||
|
assert reingest_missing_chunks() == {}
|
Loading…
x
Reference in New Issue
Block a user