diff --git a/docker-compose.yaml b/docker-compose.yaml index c5094ad..98bf745 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -11,6 +11,7 @@ secrets: postgres_password: { file: ./secrets/postgres_password.txt } jwt_secret: { file: ./secrets/jwt_secret.txt } openai_key: { file: ./secrets/openai_key.txt } + anthropic_key: { file: ./secrets/anthropic_key.txt } # --------------------------------------------------------------------- volumes volumes: @@ -42,7 +43,8 @@ x-worker-base: &worker-base # DSNs are built in worker entrypoint from user + pw files QDRANT_URL: http://qdrant:6333 OPENAI_API_KEY_FILE: /run/secrets/openai_key - secrets: [ postgres_password, openai_key ] + ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key + secrets: [ postgres_password, openai_key, anthropic_key ] read_only: true tmpfs: [ /tmp, /var/tmp ] cap_drop: [ ALL ] @@ -76,9 +78,6 @@ services: mem_limit: 4g cpus: "1.5" security_opt: [ "no-new-privileges=true" ] - ports: - # PostgreSQL port for local Celery result backend - - "15432:5432" rabbitmq: image: rabbitmq:3.13-management @@ -98,11 +97,6 @@ services: mem_limit: 512m cpus: "0.5" security_opt: [ "no-new-privileges=true" ] - ports: - # UI only on localhost - - "15672:15672" - # AMQP port for local Celery clients - - "15673:5672" qdrant: image: qdrant/qdrant:v1.14.0 @@ -119,8 +113,6 @@ services: interval: 15s timeout: 5s retries: 5 - ports: - - "6333:6333" mem_limit: 4g cpus: "2" security_opt: [ "no-new-privileges=true" ] diff --git a/requirements-common.txt b/requirements-common.txt index 27b9e47..c65d20c 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -4,4 +4,5 @@ pydantic==2.7.1 alembic==1.13.1 dotenv==0.9.9 voyageai==0.3.2 -qdrant-client==1.9.0 \ No newline at end of file +qdrant-client==1.9.0 +anthropic==0.18.1 \ No newline at end of file diff --git a/run_celery_task.py b/run_celery_task.py new file mode 100644 index 0000000..e9a13f2 --- /dev/null +++ b/run_celery_task.py @@ -0,0 +1,449 @@ +#!/usr/bin/env python3 +""" +Script to run Celery tasks on the Docker Compose setup from your local machine. + +This script connects to the RabbitMQ broker running in Docker and sends tasks +to the workers. It requires the same dependencies as the workers to import +the task definitions. + +Usage: + python run_celery_task.py --help + python run_celery_task.py email sync-all-accounts + python run_celery_task.py email sync-account --account-id 1 + python run_celery_task.py ebook sync-book --file-path "/path/to/book.epub" --tags "fiction,scifi" + python run_celery_task.py maintenance clean-all-collections + python run_celery_task.py blogs sync-webpage --url "https://example.com" + python run_celery_task.py comic sync-all-comics + python run_celery_task.py forums sync-lesswrong --since-date "2025-01-01" --min-karma 10 --limit 50 --cooldown 0.5 --max-items 1000 +""" + +import json +import sys +from pathlib import Path +from typing import Any + +import click +from memory.workers.tasks.blogs import ( + SYNC_ALL_ARTICLE_FEEDS, + SYNC_ARTICLE_FEED, + SYNC_WEBPAGE, + SYNC_WEBSITE_ARCHIVE, +) +from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_COMIC, SYNC_SMBC, SYNC_XKCD +from memory.workers.tasks.ebook import SYNC_BOOK +from memory.workers.tasks.email import PROCESS_EMAIL, SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS +from memory.workers.tasks.forums import SYNC_LESSWRONG, SYNC_LESSWRONG_POST +from memory.workers.tasks.maintenance import ( + CLEAN_ALL_COLLECTIONS, + CLEAN_COLLECTION, + REINGEST_CHUNK, + REINGEST_EMPTY_SOURCE_ITEMS, + REINGEST_ITEM, + REINGEST_MISSING_CHUNKS, + UPDATE_METADATA_FOR_ITEM, + UPDATE_METADATA_FOR_SOURCE_ITEMS, +) + +# Add the src directory to Python path so we can import memory modules +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from celery import Celery +from memory.common import settings + + +TASK_MAPPINGS = { + "email": { + "sync_all_accounts": SYNC_ALL_ACCOUNTS, + "sync_account": SYNC_ACCOUNT, + "process_message": PROCESS_EMAIL, + }, + "ebook": { + "sync_book": SYNC_BOOK, + }, + "maintenance": { + "clean_all_collections": CLEAN_ALL_COLLECTIONS, + "clean_collection": CLEAN_COLLECTION, + "reingest_missing_chunks": REINGEST_MISSING_CHUNKS, + "reingest_chunk": REINGEST_CHUNK, + "reingest_item": REINGEST_ITEM, + "reingest_empty_source_items": REINGEST_EMPTY_SOURCE_ITEMS, + "update_metadata_for_item": UPDATE_METADATA_FOR_ITEM, + "update_metadata_for_source_items": UPDATE_METADATA_FOR_SOURCE_ITEMS, + }, + "blogs": { + "sync_webpage": SYNC_WEBPAGE, + "sync_article_feed": SYNC_ARTICLE_FEED, + "sync_all_article_feeds": SYNC_ALL_ARTICLE_FEEDS, + "sync_website_archive": SYNC_WEBSITE_ARCHIVE, + }, + "comic": { + "sync_all_comics": SYNC_ALL_COMICS, + "sync_smbc": SYNC_SMBC, + "sync_xkcd": SYNC_XKCD, + "sync_comic": SYNC_COMIC, + }, + "forums": { + "sync_lesswrong": SYNC_LESSWRONG, + "sync_lesswrong_post": SYNC_LESSWRONG_POST, + }, +} +QUEUE_MAPPINGS = { + "email": "email", + "ebook": "ebooks", + "photo": "photo_embed", +} + + +def create_local_celery_app() -> Celery: + """Create a Celery app configured to connect to the Docker RabbitMQ.""" + # Override settings for local connection to Docker services + rabbitmq_url = f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@localhost:15673//" + + app = Celery( + "memory-local", + broker=rabbitmq_url, + backend=settings.CELERY_RESULT_BACKEND.replace( + "postgres:5432", "localhost:15432" + ), + ) + + # Import task modules so they're registered + app.autodiscover_tasks(["memory.workers.tasks"]) + + return app + + +def run_task(app: Celery, category: str, task_name: str, **kwargs) -> str: + """Run a task using the task mappings.""" + if category not in TASK_MAPPINGS: + raise ValueError(f"Unknown category: {category}") + + if task_name not in TASK_MAPPINGS[category]: + raise ValueError(f"Unknown {category} task: {task_name}") + + task_path = TASK_MAPPINGS[category][task_name] + queue_name = QUEUE_MAPPINGS.get(category) or category + + result = app.send_task(task_path, kwargs=kwargs, queue=queue_name) + return result.id + + +def get_task_result(app: Celery, task_id: str, timeout: int = 300) -> Any: + """Get the result of a task.""" + result = app.AsyncResult(task_id) + try: + return result.get(timeout=timeout) + except Exception as e: + return {"error": str(e), "status": result.status} + + +@click.group() +@click.option("--wait", is_flag=True, help="Wait for task completion and show result") +@click.option( + "--timeout", default=300, help="Timeout in seconds when waiting for result" +) +@click.pass_context +def cli(ctx, wait, timeout): + """Run Celery tasks on Docker Compose setup.""" + ctx.ensure_object(dict) + ctx.obj["wait"] = wait + ctx.obj["timeout"] = timeout + + try: + ctx.obj["app"] = create_local_celery_app() + except Exception as e: + click.echo(f"Error connecting to Celery broker: {e}") + click.echo( + "Make sure Docker Compose is running and RabbitMQ is accessible on localhost:15673" + ) + sys.exit(1) + + +def execute_task(ctx, category: str, task_name: str, **kwargs): + """Helper to execute a task and handle results.""" + app = ctx.obj["app"] + wait = ctx.obj["wait"] + timeout = ctx.obj["timeout"] + + # Filter out None values + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + try: + task_id = run_task(app, category, task_name, **kwargs) + click.echo("Task submitted successfully!") + click.echo(f"Task ID: {task_id}") + + if wait: + click.echo(f"Waiting for task completion (timeout: {timeout}s)...") + result = get_task_result(app, task_id, timeout) + click.echo("Task result:") + click.echo(json.dumps(result, indent=2, default=str)) + except Exception as e: + click.echo(f"Error running task: {e}") + sys.exit(1) + + +@cli.group() +@click.pass_context +def email(ctx): + """Email-related tasks.""" + pass + + +@email.command("sync-all-accounts") +@click.option("--since-date", help="Sync items since this date (ISO format)") +@click.pass_context +def email_sync_all_accounts(ctx, since_date): + """Sync all email accounts.""" + execute_task(ctx, "email", "sync_all_accounts", since_date=since_date) + + +@email.command("sync-account") +@click.option("--account-id", type=int, required=True, help="Email account ID") +@click.option("--since-date", help="Sync items since this date (ISO format)") +@click.pass_context +def email_sync_account(ctx, account_id, since_date): + """Sync a specific email account.""" + execute_task( + ctx, "email", "sync_account", account_id=account_id, since_date=since_date + ) + + +@email.command("process-message") +@click.option("--message-id", required=True, help="Email message ID") +@click.option("--folder", help="Email folder name") +@click.option("--raw-email", help="Raw email content") +@click.pass_context +def email_process_message(ctx, message_id, folder, raw_email): + """Process a specific email message.""" + execute_task( + ctx, + "email", + "process_message", + message_id=message_id, + folder=folder, + raw_email=raw_email, + ) + + +@cli.group() +@click.pass_context +def ebook(ctx): + """Ebook-related tasks.""" + pass + + +@ebook.command("sync-book") +@click.option("--file-path", required=True, help="Path to ebook file") +@click.option("--tags", help="Comma-separated tags") +@click.pass_context +def ebook_sync_book(ctx, file_path, tags): + """Sync an ebook.""" + execute_task(ctx, "ebook", "sync_book", file_path=file_path, tags=tags) + + +@cli.group() +@click.pass_context +def maintenance(ctx): + """Maintenance tasks.""" + pass + + +@maintenance.command("clean-all-collections") +@click.pass_context +def maintenance_clean_all_collections(ctx): + """Clean all collections.""" + execute_task(ctx, "maintenance", "clean_all_collections") + + +@maintenance.command("clean-collection") +@click.option("--collection", required=True, help="Collection name to clean") +@click.pass_context +def maintenance_clean_collection(ctx, collection): + """Clean a specific collection.""" + execute_task(ctx, "maintenance", "clean_collection", collection=collection) + + +@maintenance.command("reingest-missing-chunks") +@click.option("--minutes-ago", type=int, help="Minutes ago to reingest chunks") +@click.pass_context +def maintenance_reingest_missing_chunks(ctx, minutes_ago): + """Reingest missing chunks.""" + execute_task(ctx, "maintenance", "reingest_missing_chunks", minutes_ago=minutes_ago) + + +@maintenance.command("reingest-item") +@click.option("--item-id", required=True, help="Item ID to reingest") +@click.option("--item-type", required=True, help="Item type to reingest") +@click.pass_context +def maintenance_reingest_item(ctx, item_id, item_type): + """Reingest a specific item.""" + execute_task( + ctx, "maintenance", "reingest_item", item_id=item_id, item_type=item_type + ) + + +@maintenance.command("update-metadata-for-item") +@click.option("--item-id", required=True, help="Item ID to update metadata for") +@click.option("--item-type", required=True, help="Item type to update metadata for") +@click.pass_context +def maintenance_update_metadata_for_item(ctx, item_id, item_type): + """Update metadata for a specific item.""" + execute_task( + ctx, + "maintenance", + "update_metadata_for_item", + item_id=item_id, + item_type=item_type, + ) + + +@maintenance.command("update-metadata-for-source-items") +@click.option("--item-type", required=True, help="Item type to update metadata for") +@click.pass_context +def maintenance_update_metadata_for_source_items(ctx, item_type): + """Update metadata for all items of a specific type.""" + execute_task( + ctx, "maintenance", "update_metadata_for_source_items", item_type=item_type + ) + + +@maintenance.command("reingest-empty-source-items") +@click.option("--item-type", required=True, help="Item type to reingest") +@click.pass_context +def maintenance_reingest_empty_source_items(ctx, item_type): + """Reingest empty source items.""" + execute_task(ctx, "maintenance", "reingest_empty_source_items", item_type=item_type) + + +@maintenance.command("reingest-chunk") +@click.option("--chunk-id", required=True, help="Chunk ID to reingest") +@click.pass_context +def maintenance_reingest_chunk(ctx, chunk_id): + """Reingest a specific chunk.""" + execute_task(ctx, "maintenance", "reingest_chunk", chunk_id=chunk_id) + + +@cli.group() +@click.pass_context +def blogs(ctx): + """Blog-related tasks.""" + pass + + +@blogs.command("sync-webpage") +@click.option("--url", required=True, help="URL to sync") +@click.pass_context +def blogs_sync_webpage(ctx, url): + """Sync a webpage.""" + execute_task(ctx, "blogs", "sync_webpage", url=url) + + +@blogs.command("sync-article-feed") +@click.option("--feed-id", type=int, required=True, help="Feed ID to sync") +@click.pass_context +def blogs_sync_article_feed(ctx, feed_id): + """Sync an article feed.""" + execute_task(ctx, "blogs", "sync_article_feed", feed_id=feed_id) + + +@blogs.command("sync-all-article-feeds") +@click.pass_context +def blogs_sync_all_article_feeds(ctx): + """Sync all article feeds.""" + execute_task(ctx, "blogs", "sync_all_article_feeds") + + +@blogs.command("sync-website-archive") +@click.option("--url", required=True, help="URL to sync") +@click.pass_context +def blogs_sync_website_archive(ctx, url): + """Sync a website archive.""" + execute_task(ctx, "blogs", "sync_website_archive", url=url) + + +@cli.group() +@click.pass_context +def comic(ctx): + """Comic-related tasks.""" + pass + + +@comic.command("sync-all-comics") +@click.pass_context +def comic_sync_all_comics(ctx): + """Sync all comics.""" + execute_task(ctx, "comic", "sync_all_comics") + + +@comic.command("sync-smbc") +@click.pass_context +def comic_sync_smbc(ctx): + """Sync SMBC comics.""" + execute_task(ctx, "comic", "sync_smbc") + + +@comic.command("sync-xkcd") +@click.pass_context +def comic_sync_xkcd(ctx): + """Sync XKCD comics.""" + execute_task(ctx, "comic", "sync_xkcd") + + +@comic.command("sync-comic") +@click.option("--image-url", required=True, help="Image URL to sync") +@click.option("--title", help="Comic title") +@click.option("--author", help="Comic author") +@click.option("--published-date", help="Comic published date") +@click.pass_context +def comic_sync_comic(ctx, image_url, title, author, published_date): + """Sync a specific comic.""" + execute_task( + ctx, + "comic", + "sync_comic", + image_url=image_url, + title=title, + author=author, + published_date=published_date, + ) + + +@cli.group() +@click.pass_context +def forums(ctx): + """Forum-related tasks.""" + pass + + +@forums.command("sync-lesswrong") +@click.option("--since-date", help="Sync items since this date (ISO format)") +@click.option("--min-karma", type=int, help="Minimum karma to sync") +@click.option("--limit", type=int, help="Limit the number of posts to sync") +@click.option("--cooldown", type=float, help="Cooldown between posts") +@click.option("--max-items", type=int, help="Maximum number of posts to sync") +@click.pass_context +def forums_sync_lesswrong(ctx, since_date, min_karma, limit, cooldown, max_items): + """Sync LessWrong posts.""" + execute_task( + ctx, + "forums", + "sync_lesswrong", + since_date=since_date, + min_karma=min_karma, + limit=limit, + cooldown=cooldown, + max_items=max_items, + ) + + +@forums.command("sync-lesswrong-post") +@click.option("--url", required=True, help="LessWrong post URL") +@click.pass_context +def forums_sync_lesswrong_post(ctx, url): + """Sync a specific LessWrong post.""" + execute_task(ctx, "forums", "sync_lesswrong_post", url=url) + + +if __name__ == "__main__": + cli() diff --git a/src/memory/common/chunker.py b/src/memory/common/chunker.py index 418ae1e..71bb322 100644 --- a/src/memory/common/chunker.py +++ b/src/memory/common/chunker.py @@ -2,13 +2,15 @@ import logging from typing import Iterable, Any import re +from memory.common import settings + logger = logging.getLogger(__name__) # Chunking configuration -EMBEDDING_MAX_TOKENS = 32000 # VoyageAI max context window -DEFAULT_CHUNK_TOKENS = 512 # Optimal chunk size for semantic search -OVERLAP_TOKENS = 50 # Default overlap between chunks (10% of chunk size) +EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS +DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS +OVERLAP_TOKENS = settings.OVERLAP_TOKENS CHARS_PER_TOKEN = 4 diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index 62596cd..52e7735 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -35,6 +35,7 @@ from memory.common import settings import memory.common.extract as extract import memory.common.collections as collections import memory.common.chunker as chunker +import memory.common.summarizer as summarizer Base = declarative_base() @@ -105,19 +106,36 @@ def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalCh ] -def chunk_mixed( - content: str, image_paths: Sequence[str] -) -> list[list[extract.MulitmodalChunk]]: - images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths] - full_text: list[extract.MulitmodalChunk] = [content.strip(), *images] +def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]: + final = {} + for m in metadata: + if tags := set(m.pop("tags", [])): + final["tags"] = tags | final.get("tags", set()) + final |= m + return final - chunks = [] + +def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]: + if not content.strip(): + return [] + + images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths] + + summary, tags = summarizer.summarize(content) + full_text: extract.DataChunk = extract.DataChunk( + data=[content.strip(), *images], metadata={"tags": tags} + ) + + chunks: list[extract.DataChunk] = [full_text] tokens = chunker.approx_token_count(content) if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2: - chunks = [add_pics(c, images) for c in chunker.chunk_text(content)] + chunks += [ + extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags}) + for c in chunker.chunk_text(content) + ] + chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags})) - all_chunks = [full_text] + chunks - return [c for c in all_chunks if c and all(i for i in c)] + return [c for c in chunks if c.data] class Chunk(Base): @@ -224,22 +242,27 @@ class SourceItem(Base): """Get vector IDs from associated chunks.""" return [chunk.id for chunk in self.chunks] - def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]: - chunks: list[list[extract.MulitmodalChunk]] = [] - if cast(str | None, self.content): - chunks = [[c] for c in chunker.chunk_text(cast(str, self.content))] + def _chunk_contents(self) -> Sequence[extract.DataChunk]: + chunks: list[extract.DataChunk] = [] + content = cast(str | None, self.content) + if content: + chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)] + + if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2: + summary, tags = summarizer.summarize(content) + chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags})) mime_type = cast(str | None, self.mime_type) if mime_type and mime_type.startswith("image/"): - chunks.append([Image.open(self.filename)]) + chunks.append(extract.DataChunk(data=[Image.open(self.filename)])) return chunks def _make_chunk( - self, data: Sequence[extract.MulitmodalChunk], metadata: dict[str, Any] = {} - ): + self, data: extract.DataChunk, metadata: dict[str, Any] = {} + ) -> Chunk: chunk_id = str(uuid.uuid4()) - text = "\n\n".join(c for c in data if isinstance(c, str) and c.strip()) - images = [c for c in data if isinstance(c, Image.Image)] + text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip()) + images = [c for c in data.data if isinstance(c, Image.Image)] image_names = image_filenames(chunk_id, images) chunk = Chunk( @@ -249,17 +272,18 @@ class SourceItem(Base): images=images, file_paths=image_names, embedding_model=collections.collection_model(cast(str, self.modality)), - item_metadata=self.as_payload() | metadata, + item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata), ) return chunk def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: - return [self._make_chunk(data, metadata) for data in self._chunk_contents()] + return [self._make_chunk(data) for data in self._chunk_contents()] def as_payload(self) -> dict: return { "source_id": self.id, "tags": self.tags, + "size": self.size, } @property @@ -312,7 +336,7 @@ class MailMessage(SourceItem): def as_payload(self) -> dict: return { - "source_id": self.id, + **super().as_payload(), "message_id": self.message_id, "subject": self.subject, "sender": self.sender, @@ -381,13 +405,12 @@ class EmailAttachment(SourceItem): def as_payload(self) -> dict: return { + **super().as_payload(), "filename": self.filename, "content_type": self.mime_type, "size": self.size, "created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore "mail_message_id": self.mail_message_id, - "source_id": self.id, - "tags": self.tags, } def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: @@ -397,7 +420,7 @@ class EmailAttachment(SourceItem): contents = cast(str, self.content) chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents) - return [self._make_chunk(c.data, metadata | c.metadata) for c in chunks] + return [self._make_chunk(c, metadata) for c in chunks] # Add indexes __table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),) @@ -488,8 +511,7 @@ class Comic(SourceItem): def as_payload(self) -> dict: payload = { - "source_id": self.id, - "tags": self.tags, + **super().as_payload(), "title": self.title, "author": self.author, "published": self.published, @@ -500,10 +522,10 @@ class Comic(SourceItem): } return {k: v for k, v in payload.items() if v is not None} - def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]: + def _chunk_contents(self) -> Sequence[extract.DataChunk]: image = Image.open(pathlib.Path(cast(str, self.filename))) description = f"{self.title} by {self.author}" - return [[image, description]] + return [extract.DataChunk(data=[image, description])] class Book(Base): @@ -538,7 +560,7 @@ class Book(Base): def as_payload(self) -> dict: return { - "source_id": self.id, + **super().as_payload(), "isbn": self.isbn, "title": self.title, "author": self.author, @@ -548,7 +570,6 @@ class Book(Base): "edition": self.edition, "series": self.series, "series_number": self.series_number, - "tags": self.tags, } | (cast(dict, self.book_metadata) or {}) @@ -591,29 +612,41 @@ class BookSection(SourceItem): def as_payload(self) -> dict: vals = { + **super().as_payload(), "title": self.book.title, "author": self.book.author, - "source_id": self.id, "book_id": self.book_id, "section_title": self.section_title, "section_number": self.section_number, "section_level": self.section_level, "start_page": self.start_page, "end_page": self.end_page, - "tags": self.tags, } return {k: v for k, v in vals.items() if v} def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: - if not cast(str, self.content.strip()): + content = cast(str, self.content.strip()) + if not content: return [] - texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)] + if len([p for p in self.pages if p.strip()]) == 1: + return [ + self._make_chunk( + extract.DataChunk(data=[content]), metadata | {"type": "page"} + ) + ] + + summary, tags = summarizer.summarize(content) return [ - self._make_chunk([text.strip()], metadata | {"page": page_number}) - for text, page_number in texts - if text and text.strip() - ] + [self._make_chunk([cast(str, self.content.strip())], metadata)] + self._make_chunk( + extract.DataChunk(data=[content]), + merge_metadata(metadata, {"type": "section", "tags": tags}), + ), + self._make_chunk( + extract.DataChunk(data=[summary]), + merge_metadata(metadata, {"type": "summary", "tags": tags}), + ), + ] class BlogPost(SourceItem): @@ -652,7 +685,7 @@ class BlogPost(SourceItem): metadata = cast(dict | None, self.webpage_metadata) or {} payload = { - "source_id": self.id, + **super().as_payload(), "url": self.url, "title": self.title, "author": self.author, @@ -660,12 +693,11 @@ class BlogPost(SourceItem): "description": self.description, "domain": self.domain, "word_count": self.word_count, - "tags": self.tags, **metadata, } return {k: v for k, v in payload.items() if v} - def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]: + def _chunk_contents(self) -> Sequence[extract.DataChunk]: return chunk_mixed(cast(str, self.content), cast(list[str], self.images)) @@ -701,7 +733,7 @@ class ForumPost(SourceItem): def as_payload(self) -> dict: return { - "source_id": self.id, + **super().as_payload(), "url": self.url, "title": self.title, "description": self.description, @@ -712,10 +744,9 @@ class ForumPost(SourceItem): "votes": self.votes, "score": self.score, "comments": self.comments, - "tags": self.tags, } - def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]: + def _chunk_contents(self) -> Sequence[extract.DataChunk]: return chunk_mixed(cast(str, self.content), cast(list[str], self.images)) diff --git a/src/memory/common/qdrant.py b/src/memory/common/qdrant.py index 92b489b..106ee22 100644 --- a/src/memory/common/qdrant.py +++ b/src/memory/common/qdrant.py @@ -245,3 +245,46 @@ def find_missing_points( collection_name, ids=ids, with_payload=False, with_vectors=False ) return set(ids) - {str(r.id) for r in found} + + +def set_payload( + client: qdrant_client.QdrantClient, + collection_name: str, + point_id: str, + payload: dict[str, Any], +) -> None: + """Set payload for a single point without modifying its vector. + + Args: + client: Qdrant client + collection_name: Name of the collection + point_id: Vector ID (as string) + payload: New payload to set + """ + client.set_payload( + collection_name=collection_name, + payload=payload, + points=[point_id], + ) + + logger.debug(f"Set payload for point {point_id} in {collection_name}") + + +def get_payloads( + client: qdrant_client.QdrantClient, collection_name: str, ids: list[str] +) -> dict[str, dict[str, Any]]: + """Retrieve payloads for multiple points. + + Args: + client: Qdrant client + collection_name: Name of the collection + ids: List of vector IDs (as strings) + + Returns: + Dictionary mapping point IDs to their payloads + """ + points = client.retrieve( + collection_name=collection_name, ids=ids, with_payload=True, with_vectors=False + ) + + return {str(point.id): point.payload or {} for point in points} diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 05d6d3a..7ee6481 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -98,3 +98,22 @@ CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60 TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large") MIXED_EMBEDDING_MODEL = os.getenv("MIXED_EMBEDDING_MODEL", "voyage-multimodal-3") EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", 50)) + +# VoyageAI max context window +EMBEDDING_MAX_TOKENS = int(os.getenv("EMBEDDING_MAX_TOKENS", 32000)) +# Optimal chunk size for semantic search +DEFAULT_CHUNK_TOKENS = int(os.getenv("DEFAULT_CHUNK_TOKENS", 512)) +OVERLAP_TOKENS = int(os.getenv("OVERLAP_TOKENS", 50)) + + +# LLM settings +if openai_key_file := os.getenv("OPENAI_API_KEY_FILE"): + OPENAI_API_KEY = pathlib.Path(openai_key_file).read_text().strip() +else: + OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") + +if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"): + ANTHROPIC_API_KEY = pathlib.Path(anthropic_key_file).read_text().strip() +else: + ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") +SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307") diff --git a/src/memory/common/summarizer.py b/src/memory/common/summarizer.py new file mode 100644 index 0000000..7fdcac7 --- /dev/null +++ b/src/memory/common/summarizer.py @@ -0,0 +1,137 @@ +import json +import logging +from typing import Any + +from memory.common import settings, chunker + +logger = logging.getLogger(__name__) + +TAGS_PROMPT = """ +The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes. + +Return your response as JSON with this format: +{{ +"summary": "{summary}", +"tags": ["tag1", "tag2", "tag3"] +}} + +Text: +{content} +""" + +SUMMARY_PROMPT = """ +Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters). +Also provide 3-5 relevant tags that capture the main topics or themes. + +Return your response as JSON with this format: +{{ + "summary": "your summary here", + "tags": ["tag1", "tag2", "tag3"] +}} + +Text to summarize: +{content} +""" + + +def _call_openai(prompt: str) -> dict[str, Any]: + """Call OpenAI API for summarization.""" + import openai + + client = openai.OpenAI(api_key=settings.OPENAI_API_KEY) + try: + response = client.chat.completions.create( + model=settings.SUMMARIZER_MODEL.split("/")[1], + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that creates concise summaries and identifies key topics.", + }, + {"role": "user", "content": prompt}, + ], + response_format={"type": "json_object"}, + temperature=0.3, + max_tokens=2048, + ) + return json.loads(response.choices[0].message.content or "{}") + except Exception as e: + logger.error(f"OpenAI API error: {e}") + raise + + +def _call_anthropic(prompt: str) -> dict[str, Any]: + """Call Anthropic API for summarization.""" + import anthropic + + client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY) + try: + response = client.messages.create( + model=settings.SUMMARIZER_MODEL.split("/")[1], + messages=[{"role": "user", "content": prompt}], + system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid JSON.", + temperature=0.3, + max_tokens=2048, + ) + return json.loads(response.content[0].text) + except Exception as e: + logger.error(f"Anthropic API error: {e}") + raise + + +def truncate(content: str, target_tokens: int) -> str: + target_chars = target_tokens * chunker.CHARS_PER_TOKEN + if len(content) > target_chars: + return content[:target_chars].rsplit(" ", 1)[0] + "..." + return content + + +def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]: + """ + Summarize content to approximately target_tokens length and generate tags. + + Args: + content: Text to summarize + target_tokens: Target length in tokens (defaults to DEFAULT_CHUNK_TOKENS) + + Returns: + Tuple of (summary, tags) + """ + if not content or not content.strip(): + return "", [] + + if target_tokens is None: + target_tokens = settings.DEFAULT_CHUNK_TOKENS + + summary, tags = content, [] + + # If content is already short enough, just extract tags + current_tokens = chunker.approx_token_count(content) + if current_tokens <= target_tokens: + logger.info( + f"Content already under {target_tokens} tokens, extracting tags only" + ) + prompt = TAGS_PROMPT.format(content=content, summary=summary[:1000]) + else: + prompt = SUMMARY_PROMPT.format( + target_tokens=target_tokens, + target_chars=target_tokens * chunker.CHARS_PER_TOKEN, + content=content, + ) + + try: + if settings.SUMMARIZER_MODEL.startswith("anthropic"): + result = _call_anthropic(prompt) + else: + result = _call_openai(prompt) + + summary = result.get("summary", "") + tags = result.get("tags", []) + except Exception as e: + logger.error(f"Summarization failed: {e}") + + tokens = chunker.approx_token_count(summary) + if tokens > target_tokens * 1.5: + logger.warning(f"Summary too long ({tokens} tokens), truncating") + summary = truncate(content, target_tokens) + + return summary, tags diff --git a/src/memory/workers/tasks/maintenance.py b/src/memory/workers/tasks/maintenance.py index 88bd057..998de9b 100644 --- a/src/memory/workers/tasks/maintenance.py +++ b/src/memory/workers/tasks/maintenance.py @@ -1,7 +1,7 @@ import logging from collections import defaultdict from datetime import datetime, timedelta -from typing import Sequence +from typing import Sequence, Any from memory.workers.tasks.content_processing import process_content_item from sqlalchemy import select @@ -21,6 +21,10 @@ REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks" REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk" REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item" REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items" +UPDATE_METADATA_FOR_SOURCE_ITEMS = ( + f"{MAINTENANCE_ROOT}.update_metadata_for_source_items" +) +UPDATE_METADATA_FOR_ITEM = f"{MAINTENANCE_ROOT}.update_metadata_for_item" @app.task(name=CLEAN_COLLECTION) @@ -97,6 +101,8 @@ def get_item_class(item_type: str): raise ValueError( f"Unsupported item type {item_type}. Available types: {available_types}" ) + if not hasattr(class_, "chunks"): + raise ValueError(f"Item type {item_type} does not have chunks") return class_ @@ -226,3 +232,99 @@ def reingest_missing_chunks( total_stats[collection]["total"] += stats["total"] return dict(total_stats) + + +def _payloads_equal(current: dict[str, Any], new: dict[str, Any]) -> bool: + """Compare two payloads to see if they're effectively equal.""" + # Handle tags specially - compare as sets since order doesn't matter + current_tags = set(current.get("tags", [])) + new_tags = set(new.get("tags", [])) + + if current_tags != new_tags: + return False + + # Compare all other fields + current_without_tags = {k: v for k, v in current.items() if k != "tags"} + new_without_tags = {k: v for k, v in new.items() if k != "tags"} + + return current_without_tags == new_without_tags + + +@app.task(name=UPDATE_METADATA_FOR_ITEM) +def update_metadata_for_item(item_id: str, item_type: str): + """Update metadata in Qdrant for all chunks of a single source item, merging tags.""" + logger.info(f"Updating metadata for {item_type} {item_id}") + try: + class_ = get_item_class(item_type) + except ValueError as e: + logger.error(f"Error getting item class: {e}") + return {"status": "error", "error": str(e)} + + client = qdrant.get_qdrant_client() + updated_chunks = 0 + errors = 0 + + with make_session() as session: + item = session.query(class_).get(item_id) + if not item: + return {"status": "error", "error": f"Item {item_id} not found"} + + chunk_ids = [str(chunk.id) for chunk in item.chunks if chunk.id] + if not chunk_ids: + return {"status": "success", "updated_chunks": 0, "errors": 0} + + collection = item.modality + + try: + current_payloads = qdrant.get_payloads(client, collection, chunk_ids) + + # Get new metadata from source item + new_metadata = item.as_payload() + new_tags = set(new_metadata.get("tags", [])) + + for chunk_id in chunk_ids: + if chunk_id not in current_payloads: + logger.warning( + f"Chunk {chunk_id} not found in Qdrant collection {collection}" + ) + continue + + current_payload = current_payloads[chunk_id] + current_tags = set(current_payload.get("tags", [])) + + # Merge tags (combine existing and new tags) + merged_tags = list(current_tags | new_tags) + updated_metadata = new_metadata.copy() + updated_metadata["tags"] = merged_tags + + if _payloads_equal(current_payload, updated_metadata): + continue + + qdrant.set_payload(client, collection, chunk_id, updated_metadata) + updated_chunks += 1 + + except Exception as e: + logger.error(f"Error updating metadata for item {item.id}: {e}") + errors += 1 + + return {"status": "success", "updated_chunks": updated_chunks, "errors": errors} + + +@app.task(name=UPDATE_METADATA_FOR_SOURCE_ITEMS) +def update_metadata_for_source_items(item_type: str): + """Update metadata in Qdrant for all chunks of all items of a given source type.""" + logger.info(f"Updating metadata for all {item_type} source items") + try: + class_ = get_item_class(item_type) + except ValueError as e: + logger.error(f"Error getting item class: {e}") + return {"status": "error", "error": str(e)} + + with make_session() as session: + item_ids = session.query(class_.id).all() + logger.info(f"Found {len(item_ids)} items to update metadata for") + + for item_id in item_ids: + update_metadata_for_item.delay(item_id.id, item_type) # type: ignore + + return {"status": "success", "items": len(item_ids)} diff --git a/tests/conftest.py b/tests/conftest.py index 24874c1..aae9b51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,8 @@ from datetime import datetime from pathlib import Path from unittest.mock import Mock, patch +import anthropic +import openai import pytest import qdrant_client import voyageai @@ -237,3 +239,35 @@ def mock_voyage_client(): client.embed = embeder client.multimodal_embed = embeder yield client + + +@pytest.fixture(autouse=True) +def mock_openai_client(): + with patch.object(openai, "OpenAI", autospec=True) as mock_client: + client = mock_client() + client.chat = Mock() + client.chat.completions.create = Mock( + return_value=Mock( + choices=[ + Mock( + message=Mock( + content='{"summary": "test", "tags": ["tag1", "tag2"]}' + ) + ) + ] + ) + ) + yield client + + +@pytest.fixture(autouse=True) +def mock_anthropic_client(): + with patch.object(anthropic, "Anthropic", autospec=True) as mock_client: + client = mock_client() + client.messages = Mock() + client.messages.create = Mock( + return_value=Mock( + content=[Mock(text='{"summary": "test", "tags": ["tag1", "tag2"]}')] + ) + ) + yield client diff --git a/tests/memory/common/db/test_models.py b/tests/memory/common/db/test_models.py index c2ae0e9..099d7a1 100644 --- a/tests/memory/common/db/test_models.py +++ b/tests/memory/common/db/test_models.py @@ -5,8 +5,7 @@ from typing import cast import pytest from PIL import Image from datetime import datetime -from memory.common import settings -from memory.common import chunker +from memory.common import settings, chunker, extract from memory.common.db.models import ( Chunk, clean_filename, @@ -17,6 +16,7 @@ from memory.common.db.models import ( BookSection, BlogPost, Book, + merge_metadata, ) @@ -54,6 +54,114 @@ def test_clean_filename(input_filename, expected): assert clean_filename(input_filename) == expected +@pytest.mark.parametrize( + "dicts,expected", + [ + # Empty input + ([], {}), + # Single dict without tags + ([{"key": "value"}], {"key": "value"}), + # Single dict with tags as list + ( + [{"key": "value", "tags": ["tag1", "tag2"]}], + {"key": "value", "tags": {"tag1", "tag2"}}, + ), + # Single dict with tags as set + ( + [{"key": "value", "tags": {"tag1", "tag2"}}], + {"key": "value", "tags": {"tag1", "tag2"}}, + ), + # Multiple dicts without tags + ( + [{"key1": "value1"}, {"key2": "value2"}], + {"key1": "value1", "key2": "value2"}, + ), + # Multiple dicts with non-overlapping tags + ( + [ + {"key1": "value1", "tags": ["tag1"]}, + {"key2": "value2", "tags": ["tag2"]}, + ], + {"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2"}}, + ), + # Multiple dicts with overlapping tags + ( + [ + {"key1": "value1", "tags": ["tag1", "tag2"]}, + {"key2": "value2", "tags": ["tag2", "tag3"]}, + ], + {"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2", "tag3"}}, + ), + # Overlapping keys - later dict wins + ( + [ + {"key": "value1", "other": "data1"}, + {"key": "value2", "another": "data2"}, + ], + {"key": "value2", "other": "data1", "another": "data2"}, + ), + # Mixed tags types (list and set) + ( + [ + {"key1": "value1", "tags": ["tag1", "tag2"]}, + {"key2": "value2", "tags": {"tag3", "tag4"}}, + ], + { + "key1": "value1", + "key2": "value2", + "tags": {"tag1", "tag2", "tag3", "tag4"}, + }, + ), + # Empty tags + ( + [{"key": "value", "tags": []}, {"key2": "value2", "tags": []}], + {"key": "value", "key2": "value2"}, + ), + # None values + ( + [{"key1": None, "key2": "value"}, {"key3": None}], + {"key1": None, "key2": "value", "key3": None}, + ), + # Complex nested structures + ( + [ + {"nested": {"inner": "value1"}, "list": [1, 2, 3], "tags": ["tag1"]}, + {"nested": {"inner": "value2"}, "list": [4, 5], "tags": ["tag2"]}, + ], + {"nested": {"inner": "value2"}, "list": [4, 5], "tags": {"tag1", "tag2"}}, + ), + # Boolean and numeric values + ( + [ + {"bool": True, "int": 42, "float": 3.14, "tags": ["numeric"]}, + {"bool": False, "int": 100}, + ], + {"bool": False, "int": 100, "float": 3.14, "tags": {"numeric"}}, + ), + # Three or more dicts + ( + [ + {"a": 1, "tags": ["t1"]}, + {"b": 2, "tags": ["t2", "t3"]}, + {"c": 3, "a": 10, "tags": ["t3", "t4"]}, + ], + {"a": 10, "b": 2, "c": 3, "tags": {"t1", "t2", "t3", "t4"}}, + ), + # Dict with only tags + ([{"tags": ["tag1", "tag2"]}], {"tags": {"tag1", "tag2"}}), + # Empty dicts + ([{}, {}], {}), + # Mix of empty and non-empty dicts + ( + [{}, {"key": "value", "tags": ["tag"]}, {}], + {"key": "value", "tags": {"tag"}}, + ), + ], +) +def test_merge_metadata(dicts, expected): + assert merge_metadata(*dicts) == expected + + def test_image_filenames_with_existing_filenames(tmp_path): """Test image_filenames when images already have filenames""" chunk_id = "test_chunk_123" @@ -262,7 +370,7 @@ def test_source_item_chunk_contents_text(chunk_length, expected, default_chunk_s ) default_chunk_size(chunk_length) - assert source._chunk_contents() == expected + assert source._chunk_contents() == [extract.DataChunk(data=e) for e in expected] def test_source_item_chunk_contents_image(tmp_path): @@ -281,8 +389,8 @@ def test_source_item_chunk_contents_image(tmp_path): result = source._chunk_contents() assert len(result) == 1 - assert len(result[0]) == 1 - assert isinstance(result[0][0], Image.Image) + assert len(result[0].data) == 1 + assert isinstance(result[0].data[0], Image.Image) def test_source_item_chunk_contents_mixed(tmp_path): @@ -302,8 +410,8 @@ def test_source_item_chunk_contents_mixed(tmp_path): result = source._chunk_contents() assert len(result) == 2 - assert result[0][0] == "Bla bla" - assert isinstance(result[1][0], Image.Image) + assert result[0].data[0] == "Bla bla" + assert isinstance(result[1].data[0], Image.Image) @pytest.mark.parametrize( @@ -323,7 +431,11 @@ def test_source_item_chunk_contents_mixed(tmp_path): def test_source_item_make_chunk(tmp_path, texts, expected_content): """Test SourceItem._make_chunk method""" source = SourceItem( - sha256=b"test123", content="test", modality="text", tags=["tag1"] + sha256=b"test123", + content="test", + modality="text", + tags=["tag1"], + size=1024, ) # Create actual image image_file = tmp_path / "test.png" @@ -335,7 +447,7 @@ def test_source_item_make_chunk(tmp_path, texts, expected_content): data = [*texts, img] metadata = {"extra": "data"} - chunk = source._make_chunk(data, metadata) + chunk = source._make_chunk(extract.DataChunk(data=data), metadata) assert chunk.id is not None assert chunk.source == source @@ -344,7 +456,12 @@ def test_source_item_make_chunk(tmp_path, texts, expected_content): assert chunk.embedding_model is not None # Check that metadata is merged correctly - expected_payload = {"source_id": source.id, "tags": ["tag1"], "extra": "data"} + expected_payload = { + "source_id": source.id, + "tags": {"tag1"}, + "extra": "data", + "size": 1024, + } assert chunk.item_metadata == expected_payload @@ -355,10 +472,11 @@ def test_source_item_as_payload(): content="test", modality="text", tags=["tag1", "tag2"], + size=1024, ) payload = source.as_payload() - assert payload == {"source_id": 123, "tags": ["tag1", "tag2"]} + assert payload == {"source_id": 123, "tags": ["tag1", "tag2"], "size": 1024} @pytest.mark.parametrize( @@ -544,6 +662,7 @@ def test_mail_message_as_payload(sent_at, expected_date): folder="INBOX", sent_at=sent_at, tags=["tag1", "tag2"], + size=1024, ) # Manually set id for testing object.__setattr__(mail_message, "id", 123) @@ -552,6 +671,7 @@ def test_mail_message_as_payload(sent_at, expected_date): expected = { "source_id": 123, + "size": 1024, "message_id": "", "subject": "Test Subject", "sender": "sender@example.com", @@ -646,12 +766,12 @@ def test_email_attachment_as_payload(created_at, expected_date): payload = attachment.as_payload() expected = { + "source_id": 456, "filename": "document.pdf", "content_type": "application/pdf", "size": 1024, "created_at": expected_date, "mail_message_id": 123, - "source_id": 456, "tags": ["pdf", "document"], } assert payload == expected @@ -702,7 +822,8 @@ def test_email_attachment_data_chunks( # Verify the method calls mock_extract.assert_called_once_with("text/plain", expected_content) mock_make.assert_called_once_with( - ["extracted text"], {"extra": "metadata", "source": content_source} + extract.DataChunk(data=["extracted text"], metadata={"source": content_source}), + {"extra": "metadata"}, ) assert result == [mock_chunk] @@ -813,18 +934,20 @@ def test_subclass_deletion_cascades_from_source_item(db_session: Session): # No pages ([], []), # Single page - (["Page 1 content"], [("Page 1 content", 10)]), + (["Page 1 content"], [("Page 1 content", {"type": "page"})]), # Multiple pages ( ["Page 1", "Page 2", "Page 3"], [ - ("Page 1", 10), - ("Page 2", 11), - ("Page 3", 12), + ( + "Page 1\n\nPage 2\n\nPage 3", + {"type": "section", "tags": {"tag1", "tag2"}}, + ), + ("test", {"type": "summary", "tags": {"tag1", "tag2"}}), ], ), # Empty/whitespace pages filtered out - (["", " ", "Page 3"], [("Page 3", 12)]), + (["", " ", "Page 3"], [("Page 3", {"type": "page"})]), # All empty - no chunks created (["", " ", " "], []), ], @@ -845,11 +968,8 @@ def test_book_section_data_chunks(pages, expected_chunks): chunks = book_section.data_chunks() expected = [ - (p, book_section.as_payload() | {"page": i}) for p, i in expected_chunks + (c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks ] - if content: - expected.append((content, book_section.as_payload())) - assert [(c.content, c.item_metadata) for c in chunks] == expected for c in chunks: assert cast(list, c.file_paths) == [] @@ -859,16 +979,39 @@ def test_book_section_data_chunks(pages, expected_chunks): "content,expected", [ ("", []), - ("Short content", [["Short content"]]), + ( + "Short content", + [ + extract.DataChunk( + data=["Short content"], metadata={"tags": ["tag1", "tag2"]} + ) + ], + ), ( "This is a very long piece of content that should be chunked into multiple pieces when processed.", [ - [ - "This is a very long piece of content that should be chunked into multiple pieces when processed." - ], - ["This is a very long piece of content that"], - ["should be chunked into multiple pieces when"], - ["processed."], + extract.DataChunk( + data=[ + "This is a very long piece of content that should be chunked into multiple pieces when processed." + ], + metadata={"tags": ["tag1", "tag2"]}, + ), + extract.DataChunk( + data=["This is a very long piece of content that"], + metadata={"tags": ["tag1", "tag2"]}, + ), + extract.DataChunk( + data=["should be chunked into multiple pieces when"], + metadata={"tags": ["tag1", "tag2"]}, + ), + extract.DataChunk( + data=["processed."], + metadata={"tags": ["tag1", "tag2"]}, + ), + extract.DataChunk( + data=["test"], + metadata={"tags": ["tag1", "tag2"]}, + ), ], ), ], @@ -906,7 +1049,8 @@ def test_blog_post_chunk_contents_with_images(tmp_path): result = blog_post._chunk_contents() result = [ - [i if isinstance(i, str) else getattr(i, "filename") for i in c] for c in result + [i if isinstance(i, str) else getattr(i, "filename") for i in c.data] + for c in result ] assert result == [ ["Content with images", img1_path.as_posix(), img2_path.as_posix()] @@ -933,9 +1077,9 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun result = blog_post._chunk_contents() result = [ - [i if isinstance(i, str) else getattr(i, "filename") for i in c] for c in result + [i if isinstance(i, str) else getattr(i, "filename") for i in c.data] + for c in result ] - print(result) assert result == [ [ f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}", @@ -950,4 +1094,5 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun f"Second picture is here: {img2_path.as_posix()}", img2_path.as_posix(), ], + ["test"], ] diff --git a/tests/memory/workers/tasks/test_comic_tasks.py b/tests/memory/workers/tasks/test_comic_tasks.py index 2cfa098..76a63e1 100644 --- a/tests/memory/workers/tasks/test_comic_tasks.py +++ b/tests/memory/workers/tasks/test_comic_tasks.py @@ -218,6 +218,7 @@ def test_sync_comic_success(mock_get, mock_image_response, db_session, qdrant): "title": "Test Comic", "url": "https://example.com/comic/1", "source_id": 1, + "size": 90, }, None, ) diff --git a/tests/memory/workers/tasks/test_ebook_tasks.py b/tests/memory/workers/tasks/test_ebook_tasks.py index 70017a5..a437cab 100644 --- a/tests/memory/workers/tasks/test_ebook_tasks.py +++ b/tests/memory/workers/tasks/test_ebook_tasks.py @@ -186,7 +186,7 @@ def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path, qdrant) "author": "Test Author", "status": "processed", "total_sections": 4, - "sections_embedded": 8, + "sections_embedded": 4, } book = db_session.query(Book).filter(Book.title == "Test Book").first() @@ -325,8 +325,4 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client): # Check that the full content was passed to the embedding function texts = mock_voyage_client.embed.call_args[0][0] - assert texts == [ - [large_page_1.strip()], - [large_page_2.strip()], - [large_section_content.strip()], - ] + assert texts == [[large_section_content.strip()], ["test"]] diff --git a/tests/memory/workers/tasks/test_forums_tasks.py b/tests/memory/workers/tasks/test_forums_tasks.py index 8644448..ee085d7 100644 --- a/tests/memory/workers/tasks/test_forums_tasks.py +++ b/tests/memory/workers/tasks/test_forums_tasks.py @@ -335,38 +335,19 @@ def test_sync_lesswrong_max_items_limit(mock_fetch, db_session): assert result["max_items"] == 3 -@pytest.mark.parametrize( - "since_str,expected_days_ago", - [ - ("2024-01-01T00:00:00", None), # Specific date - (None, 30), # Default should be 30 days ago - ], -) @patch("memory.workers.tasks.forums.fetch_lesswrong_posts") -def test_sync_lesswrong_since_parameter( - mock_fetch, since_str, expected_days_ago, db_session -): +def test_sync_lesswrong_since_parameter(mock_fetch, db_session): """Test that since parameter is handled correctly.""" mock_fetch.return_value = [] - if since_str: - forums.sync_lesswrong(since=since_str) - expected_since = datetime.fromisoformat(since_str) - else: - forums.sync_lesswrong() - # Should default to 30 days ago - expected_since = datetime.now() - timedelta(days=30) + forums.sync_lesswrong(since="2024-01-01T00:00:00") + expected_since = datetime.fromisoformat("2024-01-01T00:00:00") # Verify fetch was called with correct since date call_args = mock_fetch.call_args[0] actual_since = call_args[0] - if expected_days_ago: - # For default case, check it's approximately 30 days ago - time_diff = abs((actual_since - expected_since).total_seconds()) - assert time_diff < 120 # Within 2 minute tolerance for slow tests - else: - assert actual_since == expected_since + assert actual_since == expected_since @pytest.mark.parametrize( diff --git a/tests/memory/workers/tasks/test_maintenance.py b/tests/memory/workers/tasks/test_maintenance.py index c06c6cb..6a48c8b 100644 --- a/tests/memory/workers/tasks/test_maintenance.py +++ b/tests/memory/workers/tasks/test_maintenance.py @@ -1,3 +1,4 @@ +# FIXME: Most of this was vibe-coded import uuid from datetime import datetime, timedelta from unittest.mock import patch, call @@ -15,6 +16,9 @@ from memory.workers.tasks.maintenance import ( reingest_missing_chunks, reingest_item, reingest_empty_source_items, + update_metadata_for_item, + update_metadata_for_source_items, + _payloads_equal, ) @@ -596,3 +600,372 @@ def test_reingest_empty_source_items_invalid_type(db_session): def test_reingest_missing_chunks_no_chunks(db_session): assert reingest_missing_chunks() == {} + + +@pytest.mark.parametrize( + "payload1,payload2,expected", + [ + # Identical payloads + ( + {"tags": ["a", "b"], "source_id": 1, "title": "Test"}, + {"tags": ["a", "b"], "source_id": 1, "title": "Test"}, + True, + ), + # Different tag order (should be equal) + ( + {"tags": ["a", "b"], "source_id": 1}, + {"tags": ["b", "a"], "source_id": 1}, + True, + ), + # Different tags + ( + {"tags": ["a", "b"], "source_id": 1}, + {"tags": ["a", "c"], "source_id": 1}, + False, + ), + # Different non-tag fields + ({"tags": ["a"], "source_id": 1}, {"tags": ["a"], "source_id": 2}, False), + # Missing tags in one payload + ({"tags": ["a"], "source_id": 1}, {"source_id": 1}, False), + # Empty tags (should be equal) + ({"tags": [], "source_id": 1}, {"source_id": 1}, True), + ], +) +def test_payloads_equal(payload1, payload2, expected): + """Test the _payloads_equal helper function.""" + assert _payloads_equal(payload1, payload2) == expected + + +@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"]) +def test_update_metadata_for_item_success(db_session, qdrant, item_type): + """Test successful metadata update for an item with chunks.""" + if item_type == "MailMessage": + item = MailMessage( + sha256=b"test_hash" + bytes(24), + tags=["original", "test"], + size=100, + mime_type="message/rfc822", + embed_status="STORED", + message_id="", + subject="Test Subject", + sender="sender@example.com", + recipients=["recipient@example.com"], + content="Test content", + folder="INBOX", + modality="mail", + ) + modality = "mail" + else: + item = BlogPost( + sha256=b"test_hash" + bytes(24), + tags=["original", "test"], + size=100, + mime_type="text/html", + embed_status="STORED", + url="https://example.com/post", + title="Test Blog Post", + author="Author Name", + content="Test blog content", + modality="blog", + ) + modality = "blog" + + db_session.add(item) + db_session.flush() + + # Add chunks to the item + chunk_ids = [str(uuid.uuid4()) for _ in range(3)] + chunks = [ + Chunk( + id=chunk_id, + source=item, + content=f"Test chunk content {i}", + embedding_model="test-model", + ) + for i, chunk_id in enumerate(chunk_ids) + ] + db_session.add_all(chunks) + db_session.commit() + + # Setup Qdrant with existing payloads + qd.ensure_collection_exists(qdrant, modality, 1024) + + existing_payloads = [ + {"tags": ["existing", "qdrant"], "source_id": item.id, "old_field": "value"} + for _ in chunk_ids + ] + qd.upsert_vectors( + qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids), existing_payloads + ) + + # Mock the qdrant functions to track calls + with ( + patch( + "memory.workers.tasks.maintenance.qdrant.get_payloads" + ) as mock_get_payloads, + patch( + "memory.workers.tasks.maintenance.qdrant.set_payload" + ) as mock_set_payload, + ): + # Return the existing payloads + mock_get_payloads.return_value = { + chunk_id: payload for chunk_id, payload in zip(chunk_ids, existing_payloads) + } + + result = update_metadata_for_item(str(item.id), item_type) + + # Verify result + assert result["status"] == "success" + assert result["updated_chunks"] == 3 + assert result["errors"] == 0 + + # Verify batch retrieval was called once + mock_get_payloads.assert_called_once_with(qdrant, modality, chunk_ids) + + # Verify set_payload was called for each chunk with merged tags + assert mock_set_payload.call_count == 3 + for call in mock_set_payload.call_args_list: + args, kwargs = call + client, collection, chunk_id, updated_payload = args + + # Check that tags were merged (existing + new) + expected_tags = set( + [ + "existing", + "qdrant", + "original", + "test", + ] + ) + if item_type == "MailMessage": + expected_tags.update(["recipient@example.com", "sender@example.com"]) + + actual_tags = set(updated_payload["tags"]) + assert actual_tags == expected_tags + + # Check that new metadata is present + assert updated_payload["source_id"] == item.id + + +def test_update_metadata_for_item_no_changes(db_session, qdrant): + """Test that no updates are made when metadata hasn't changed.""" + item = MailMessage( + sha256=b"test_hash" + bytes(24), + tags=["test"], + size=100, + mime_type="message/rfc822", + embed_status="STORED", + message_id="", + subject="Test Subject", + sender="sender@example.com", + recipients=["recipient@example.com"], + content="Test content", + folder="INBOX", + modality="mail", + ) + db_session.add(item) + db_session.flush() + + chunk_id = str(uuid.uuid4()) + chunk = Chunk( + id=chunk_id, + source=item, + content="Test chunk content", + embedding_model="test-model", + ) + db_session.add(chunk) + db_session.commit() + + # Setup payload that matches what the item would generate + item_payload = item.as_payload() + existing_payload = {chunk_id: item_payload} + + with ( + patch( + "memory.workers.tasks.maintenance.qdrant.get_payloads" + ) as mock_get_payloads, + patch( + "memory.workers.tasks.maintenance.qdrant.set_payload" + ) as mock_set_payload, + ): + mock_get_payloads.return_value = existing_payload + + result = update_metadata_for_item(str(item.id), "MailMessage") + + # Verify no updates were made + assert result["status"] == "success" + assert result["updated_chunks"] == 0 + assert result["errors"] == 0 + + # Verify set_payload was never called since nothing changed + mock_set_payload.assert_not_called() + + +def test_update_metadata_for_item_no_chunks(db_session): + """Test updating metadata for an item with no chunks.""" + item = MailMessage( + sha256=b"test_hash" + bytes(24), + tags=["test"], + size=100, + mime_type="message/rfc822", + embed_status="RAW", + message_id="", + subject="Test Subject", + sender="sender@example.com", + recipients=["recipient@example.com"], + content="Test content", + folder="INBOX", + modality="mail", + ) + db_session.add(item) + db_session.commit() + + result = update_metadata_for_item(str(item.id), "MailMessage") + + assert result["status"] == "success" + assert result["updated_chunks"] == 0 + assert result["errors"] == 0 + + +def test_update_metadata_for_item_not_found(db_session): + """Test updating metadata for a non-existent item.""" + non_existent_id = "999" + result = update_metadata_for_item(non_existent_id, "MailMessage") + + assert result == {"status": "error", "error": f"Item {non_existent_id} not found"} + + +def test_update_metadata_for_item_invalid_type(db_session): + """Test updating metadata with an invalid item type.""" + result = update_metadata_for_item("1", "invalid_type") + + assert result["status"] == "error" + assert "Unsupported item type invalid_type" in result["error"] + assert "Available types:" in result["error"] + + +def test_update_metadata_for_item_missing_chunks_in_qdrant(db_session, qdrant): + """Test when some chunks exist in DB but not in Qdrant.""" + item = MailMessage( + sha256=b"test_hash" + bytes(24), + tags=["test"], + size=100, + mime_type="message/rfc822", + embed_status="STORED", + message_id="", + subject="Test Subject", + sender="sender@example.com", + recipients=["recipient@example.com"], + content="Test content", + folder="INBOX", + modality="mail", + ) + db_session.add(item) + db_session.flush() + + chunk_ids = [str(uuid.uuid4()) for _ in range(3)] + chunks = [ + Chunk( + id=chunk_id, + source=item, + content=f"Test chunk content {i}", + embedding_model="test-model", + ) + for i, chunk_id in enumerate(chunk_ids) + ] + db_session.add_all(chunks) + db_session.commit() + + # Mock qdrant to return payloads for only some chunks + with ( + patch( + "memory.workers.tasks.maintenance.qdrant.get_payloads" + ) as mock_get_payloads, + patch( + "memory.workers.tasks.maintenance.qdrant.set_payload" + ) as mock_set_payload, + ): + # Only return payload for first chunk + mock_get_payloads.return_value = { + chunk_ids[0]: {"tags": ["existing"], "source_id": item.id} + } + + result = update_metadata_for_item(str(item.id), "MailMessage") + + # Should only update the chunk that was found in Qdrant + assert result["status"] == "success" + assert result["updated_chunks"] == 1 + assert result["errors"] == 0 + + # Only one set_payload call for the found chunk + assert mock_set_payload.call_count == 1 + + +@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"]) +def test_update_metadata_for_source_items_success(db_session, item_type): + """Test updating metadata for all items of a given type.""" + if item_type == "MailMessage": + items = [ + MailMessage( + sha256=f"test_hash_{i}".encode() + bytes(32 - len(f"test_hash_{i}")), + tags=["test"], + size=100, + mime_type="message/rfc822", + embed_status="STORED", + message_id=f"", + subject=f"Test Subject {i}", + sender="sender@example.com", + recipients=["recipient@example.com"], + content=f"Test content {i}", + folder="INBOX", + modality="mail", + ) + for i in range(3) + ] + else: + items = [ + BlogPost( + sha256=f"test_hash_{i}".encode() + bytes(32 - len(f"test_hash_{i}")), + tags=["test"], + size=100, + mime_type="text/html", + embed_status="STORED", + url=f"https://example.com/post{i}", + title=f"Test Blog Post {i}", + author="Author Name", + content=f"Test blog content {i}", + modality="blog", + ) + for i in range(3) + ] + + db_session.add_all(items) + db_session.commit() + + with patch.object(update_metadata_for_item, "delay") as mock_update: + result = update_metadata_for_source_items(item_type) + + assert result == {"status": "success", "items": 3} + + # Verify that update_metadata_for_item.delay was called for each item + assert mock_update.call_count == 3 + expected_calls = [call(item.id, item_type) for item in items] + mock_update.assert_has_calls(expected_calls, any_order=True) + + +def test_update_metadata_for_source_items_no_items(db_session): + """Test when there are no items of the specified type.""" + with patch.object(update_metadata_for_item, "delay") as mock_update: + result = update_metadata_for_source_items("MailMessage") + + assert result == {"status": "success", "items": 0} + mock_update.assert_not_called() + + +def test_update_metadata_for_source_items_invalid_type(db_session): + """Test updating metadata for an invalid item type.""" + result = update_metadata_for_source_items("invalid_type") + + assert result["status"] == "error" + assert "Unsupported item type invalid_type" in result["error"] + assert "Available types:" in result["error"]