summarize before chunking

This commit is contained in:
Daniel O'Connell 2025-05-29 01:21:24 +02:00
parent ed8033bdd3
commit e505f9b53c
15 changed files with 1427 additions and 121 deletions

View File

@ -11,6 +11,7 @@ secrets:
postgres_password: { file: ./secrets/postgres_password.txt }
jwt_secret: { file: ./secrets/jwt_secret.txt }
openai_key: { file: ./secrets/openai_key.txt }
anthropic_key: { file: ./secrets/anthropic_key.txt }
# --------------------------------------------------------------------- volumes
volumes:
@ -42,7 +43,8 @@ x-worker-base: &worker-base
# DSNs are built in worker entrypoint from user + pw files
QDRANT_URL: http://qdrant:6333
OPENAI_API_KEY_FILE: /run/secrets/openai_key
secrets: [ postgres_password, openai_key ]
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
secrets: [ postgres_password, openai_key, anthropic_key ]
read_only: true
tmpfs: [ /tmp, /var/tmp ]
cap_drop: [ ALL ]
@ -76,9 +78,6 @@ services:
mem_limit: 4g
cpus: "1.5"
security_opt: [ "no-new-privileges=true" ]
ports:
# PostgreSQL port for local Celery result backend
- "15432:5432"
rabbitmq:
image: rabbitmq:3.13-management
@ -98,11 +97,6 @@ services:
mem_limit: 512m
cpus: "0.5"
security_opt: [ "no-new-privileges=true" ]
ports:
# UI only on localhost
- "15672:15672"
# AMQP port for local Celery clients
- "15673:5672"
qdrant:
image: qdrant/qdrant:v1.14.0
@ -119,8 +113,6 @@ services:
interval: 15s
timeout: 5s
retries: 5
ports:
- "6333:6333"
mem_limit: 4g
cpus: "2"
security_opt: [ "no-new-privileges=true" ]

View File

@ -4,4 +4,5 @@ pydantic==2.7.1
alembic==1.13.1
dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
qdrant-client==1.9.0
anthropic==0.18.1

449
run_celery_task.py Normal file
View File

@ -0,0 +1,449 @@
#!/usr/bin/env python3
"""
Script to run Celery tasks on the Docker Compose setup from your local machine.
This script connects to the RabbitMQ broker running in Docker and sends tasks
to the workers. It requires the same dependencies as the workers to import
the task definitions.
Usage:
python run_celery_task.py --help
python run_celery_task.py email sync-all-accounts
python run_celery_task.py email sync-account --account-id 1
python run_celery_task.py ebook sync-book --file-path "/path/to/book.epub" --tags "fiction,scifi"
python run_celery_task.py maintenance clean-all-collections
python run_celery_task.py blogs sync-webpage --url "https://example.com"
python run_celery_task.py comic sync-all-comics
python run_celery_task.py forums sync-lesswrong --since-date "2025-01-01" --min-karma 10 --limit 50 --cooldown 0.5 --max-items 1000
"""
import json
import sys
from pathlib import Path
from typing import Any
import click
from memory.workers.tasks.blogs import (
SYNC_ALL_ARTICLE_FEEDS,
SYNC_ARTICLE_FEED,
SYNC_WEBPAGE,
SYNC_WEBSITE_ARCHIVE,
)
from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_COMIC, SYNC_SMBC, SYNC_XKCD
from memory.workers.tasks.ebook import SYNC_BOOK
from memory.workers.tasks.email import PROCESS_EMAIL, SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS
from memory.workers.tasks.forums import SYNC_LESSWRONG, SYNC_LESSWRONG_POST
from memory.workers.tasks.maintenance import (
CLEAN_ALL_COLLECTIONS,
CLEAN_COLLECTION,
REINGEST_CHUNK,
REINGEST_EMPTY_SOURCE_ITEMS,
REINGEST_ITEM,
REINGEST_MISSING_CHUNKS,
UPDATE_METADATA_FOR_ITEM,
UPDATE_METADATA_FOR_SOURCE_ITEMS,
)
# Add the src directory to Python path so we can import memory modules
sys.path.insert(0, str(Path(__file__).parent / "src"))
from celery import Celery
from memory.common import settings
TASK_MAPPINGS = {
"email": {
"sync_all_accounts": SYNC_ALL_ACCOUNTS,
"sync_account": SYNC_ACCOUNT,
"process_message": PROCESS_EMAIL,
},
"ebook": {
"sync_book": SYNC_BOOK,
},
"maintenance": {
"clean_all_collections": CLEAN_ALL_COLLECTIONS,
"clean_collection": CLEAN_COLLECTION,
"reingest_missing_chunks": REINGEST_MISSING_CHUNKS,
"reingest_chunk": REINGEST_CHUNK,
"reingest_item": REINGEST_ITEM,
"reingest_empty_source_items": REINGEST_EMPTY_SOURCE_ITEMS,
"update_metadata_for_item": UPDATE_METADATA_FOR_ITEM,
"update_metadata_for_source_items": UPDATE_METADATA_FOR_SOURCE_ITEMS,
},
"blogs": {
"sync_webpage": SYNC_WEBPAGE,
"sync_article_feed": SYNC_ARTICLE_FEED,
"sync_all_article_feeds": SYNC_ALL_ARTICLE_FEEDS,
"sync_website_archive": SYNC_WEBSITE_ARCHIVE,
},
"comic": {
"sync_all_comics": SYNC_ALL_COMICS,
"sync_smbc": SYNC_SMBC,
"sync_xkcd": SYNC_XKCD,
"sync_comic": SYNC_COMIC,
},
"forums": {
"sync_lesswrong": SYNC_LESSWRONG,
"sync_lesswrong_post": SYNC_LESSWRONG_POST,
},
}
QUEUE_MAPPINGS = {
"email": "email",
"ebook": "ebooks",
"photo": "photo_embed",
}
def create_local_celery_app() -> Celery:
"""Create a Celery app configured to connect to the Docker RabbitMQ."""
# Override settings for local connection to Docker services
rabbitmq_url = f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@localhost:15673//"
app = Celery(
"memory-local",
broker=rabbitmq_url,
backend=settings.CELERY_RESULT_BACKEND.replace(
"postgres:5432", "localhost:15432"
),
)
# Import task modules so they're registered
app.autodiscover_tasks(["memory.workers.tasks"])
return app
def run_task(app: Celery, category: str, task_name: str, **kwargs) -> str:
"""Run a task using the task mappings."""
if category not in TASK_MAPPINGS:
raise ValueError(f"Unknown category: {category}")
if task_name not in TASK_MAPPINGS[category]:
raise ValueError(f"Unknown {category} task: {task_name}")
task_path = TASK_MAPPINGS[category][task_name]
queue_name = QUEUE_MAPPINGS.get(category) or category
result = app.send_task(task_path, kwargs=kwargs, queue=queue_name)
return result.id
def get_task_result(app: Celery, task_id: str, timeout: int = 300) -> Any:
"""Get the result of a task."""
result = app.AsyncResult(task_id)
try:
return result.get(timeout=timeout)
except Exception as e:
return {"error": str(e), "status": result.status}
@click.group()
@click.option("--wait", is_flag=True, help="Wait for task completion and show result")
@click.option(
"--timeout", default=300, help="Timeout in seconds when waiting for result"
)
@click.pass_context
def cli(ctx, wait, timeout):
"""Run Celery tasks on Docker Compose setup."""
ctx.ensure_object(dict)
ctx.obj["wait"] = wait
ctx.obj["timeout"] = timeout
try:
ctx.obj["app"] = create_local_celery_app()
except Exception as e:
click.echo(f"Error connecting to Celery broker: {e}")
click.echo(
"Make sure Docker Compose is running and RabbitMQ is accessible on localhost:15673"
)
sys.exit(1)
def execute_task(ctx, category: str, task_name: str, **kwargs):
"""Helper to execute a task and handle results."""
app = ctx.obj["app"]
wait = ctx.obj["wait"]
timeout = ctx.obj["timeout"]
# Filter out None values
kwargs = {k: v for k, v in kwargs.items() if v is not None}
try:
task_id = run_task(app, category, task_name, **kwargs)
click.echo("Task submitted successfully!")
click.echo(f"Task ID: {task_id}")
if wait:
click.echo(f"Waiting for task completion (timeout: {timeout}s)...")
result = get_task_result(app, task_id, timeout)
click.echo("Task result:")
click.echo(json.dumps(result, indent=2, default=str))
except Exception as e:
click.echo(f"Error running task: {e}")
sys.exit(1)
@cli.group()
@click.pass_context
def email(ctx):
"""Email-related tasks."""
pass
@email.command("sync-all-accounts")
@click.option("--since-date", help="Sync items since this date (ISO format)")
@click.pass_context
def email_sync_all_accounts(ctx, since_date):
"""Sync all email accounts."""
execute_task(ctx, "email", "sync_all_accounts", since_date=since_date)
@email.command("sync-account")
@click.option("--account-id", type=int, required=True, help="Email account ID")
@click.option("--since-date", help="Sync items since this date (ISO format)")
@click.pass_context
def email_sync_account(ctx, account_id, since_date):
"""Sync a specific email account."""
execute_task(
ctx, "email", "sync_account", account_id=account_id, since_date=since_date
)
@email.command("process-message")
@click.option("--message-id", required=True, help="Email message ID")
@click.option("--folder", help="Email folder name")
@click.option("--raw-email", help="Raw email content")
@click.pass_context
def email_process_message(ctx, message_id, folder, raw_email):
"""Process a specific email message."""
execute_task(
ctx,
"email",
"process_message",
message_id=message_id,
folder=folder,
raw_email=raw_email,
)
@cli.group()
@click.pass_context
def ebook(ctx):
"""Ebook-related tasks."""
pass
@ebook.command("sync-book")
@click.option("--file-path", required=True, help="Path to ebook file")
@click.option("--tags", help="Comma-separated tags")
@click.pass_context
def ebook_sync_book(ctx, file_path, tags):
"""Sync an ebook."""
execute_task(ctx, "ebook", "sync_book", file_path=file_path, tags=tags)
@cli.group()
@click.pass_context
def maintenance(ctx):
"""Maintenance tasks."""
pass
@maintenance.command("clean-all-collections")
@click.pass_context
def maintenance_clean_all_collections(ctx):
"""Clean all collections."""
execute_task(ctx, "maintenance", "clean_all_collections")
@maintenance.command("clean-collection")
@click.option("--collection", required=True, help="Collection name to clean")
@click.pass_context
def maintenance_clean_collection(ctx, collection):
"""Clean a specific collection."""
execute_task(ctx, "maintenance", "clean_collection", collection=collection)
@maintenance.command("reingest-missing-chunks")
@click.option("--minutes-ago", type=int, help="Minutes ago to reingest chunks")
@click.pass_context
def maintenance_reingest_missing_chunks(ctx, minutes_ago):
"""Reingest missing chunks."""
execute_task(ctx, "maintenance", "reingest_missing_chunks", minutes_ago=minutes_ago)
@maintenance.command("reingest-item")
@click.option("--item-id", required=True, help="Item ID to reingest")
@click.option("--item-type", required=True, help="Item type to reingest")
@click.pass_context
def maintenance_reingest_item(ctx, item_id, item_type):
"""Reingest a specific item."""
execute_task(
ctx, "maintenance", "reingest_item", item_id=item_id, item_type=item_type
)
@maintenance.command("update-metadata-for-item")
@click.option("--item-id", required=True, help="Item ID to update metadata for")
@click.option("--item-type", required=True, help="Item type to update metadata for")
@click.pass_context
def maintenance_update_metadata_for_item(ctx, item_id, item_type):
"""Update metadata for a specific item."""
execute_task(
ctx,
"maintenance",
"update_metadata_for_item",
item_id=item_id,
item_type=item_type,
)
@maintenance.command("update-metadata-for-source-items")
@click.option("--item-type", required=True, help="Item type to update metadata for")
@click.pass_context
def maintenance_update_metadata_for_source_items(ctx, item_type):
"""Update metadata for all items of a specific type."""
execute_task(
ctx, "maintenance", "update_metadata_for_source_items", item_type=item_type
)
@maintenance.command("reingest-empty-source-items")
@click.option("--item-type", required=True, help="Item type to reingest")
@click.pass_context
def maintenance_reingest_empty_source_items(ctx, item_type):
"""Reingest empty source items."""
execute_task(ctx, "maintenance", "reingest_empty_source_items", item_type=item_type)
@maintenance.command("reingest-chunk")
@click.option("--chunk-id", required=True, help="Chunk ID to reingest")
@click.pass_context
def maintenance_reingest_chunk(ctx, chunk_id):
"""Reingest a specific chunk."""
execute_task(ctx, "maintenance", "reingest_chunk", chunk_id=chunk_id)
@cli.group()
@click.pass_context
def blogs(ctx):
"""Blog-related tasks."""
pass
@blogs.command("sync-webpage")
@click.option("--url", required=True, help="URL to sync")
@click.pass_context
def blogs_sync_webpage(ctx, url):
"""Sync a webpage."""
execute_task(ctx, "blogs", "sync_webpage", url=url)
@blogs.command("sync-article-feed")
@click.option("--feed-id", type=int, required=True, help="Feed ID to sync")
@click.pass_context
def blogs_sync_article_feed(ctx, feed_id):
"""Sync an article feed."""
execute_task(ctx, "blogs", "sync_article_feed", feed_id=feed_id)
@blogs.command("sync-all-article-feeds")
@click.pass_context
def blogs_sync_all_article_feeds(ctx):
"""Sync all article feeds."""
execute_task(ctx, "blogs", "sync_all_article_feeds")
@blogs.command("sync-website-archive")
@click.option("--url", required=True, help="URL to sync")
@click.pass_context
def blogs_sync_website_archive(ctx, url):
"""Sync a website archive."""
execute_task(ctx, "blogs", "sync_website_archive", url=url)
@cli.group()
@click.pass_context
def comic(ctx):
"""Comic-related tasks."""
pass
@comic.command("sync-all-comics")
@click.pass_context
def comic_sync_all_comics(ctx):
"""Sync all comics."""
execute_task(ctx, "comic", "sync_all_comics")
@comic.command("sync-smbc")
@click.pass_context
def comic_sync_smbc(ctx):
"""Sync SMBC comics."""
execute_task(ctx, "comic", "sync_smbc")
@comic.command("sync-xkcd")
@click.pass_context
def comic_sync_xkcd(ctx):
"""Sync XKCD comics."""
execute_task(ctx, "comic", "sync_xkcd")
@comic.command("sync-comic")
@click.option("--image-url", required=True, help="Image URL to sync")
@click.option("--title", help="Comic title")
@click.option("--author", help="Comic author")
@click.option("--published-date", help="Comic published date")
@click.pass_context
def comic_sync_comic(ctx, image_url, title, author, published_date):
"""Sync a specific comic."""
execute_task(
ctx,
"comic",
"sync_comic",
image_url=image_url,
title=title,
author=author,
published_date=published_date,
)
@cli.group()
@click.pass_context
def forums(ctx):
"""Forum-related tasks."""
pass
@forums.command("sync-lesswrong")
@click.option("--since-date", help="Sync items since this date (ISO format)")
@click.option("--min-karma", type=int, help="Minimum karma to sync")
@click.option("--limit", type=int, help="Limit the number of posts to sync")
@click.option("--cooldown", type=float, help="Cooldown between posts")
@click.option("--max-items", type=int, help="Maximum number of posts to sync")
@click.pass_context
def forums_sync_lesswrong(ctx, since_date, min_karma, limit, cooldown, max_items):
"""Sync LessWrong posts."""
execute_task(
ctx,
"forums",
"sync_lesswrong",
since_date=since_date,
min_karma=min_karma,
limit=limit,
cooldown=cooldown,
max_items=max_items,
)
@forums.command("sync-lesswrong-post")
@click.option("--url", required=True, help="LessWrong post URL")
@click.pass_context
def forums_sync_lesswrong_post(ctx, url):
"""Sync a specific LessWrong post."""
execute_task(ctx, "forums", "sync_lesswrong_post", url=url)
if __name__ == "__main__":
cli()

View File

@ -2,13 +2,15 @@ import logging
from typing import Iterable, Any
import re
from memory.common import settings
logger = logging.getLogger(__name__)
# Chunking configuration
EMBEDDING_MAX_TOKENS = 32000 # VoyageAI max context window
DEFAULT_CHUNK_TOKENS = 512 # Optimal chunk size for semantic search
OVERLAP_TOKENS = 50 # Default overlap between chunks (10% of chunk size)
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
CHARS_PER_TOKEN = 4

View File

@ -35,6 +35,7 @@ from memory.common import settings
import memory.common.extract as extract
import memory.common.collections as collections
import memory.common.chunker as chunker
import memory.common.summarizer as summarizer
Base = declarative_base()
@ -105,19 +106,36 @@ def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalCh
]
def chunk_mixed(
content: str, image_paths: Sequence[str]
) -> list[list[extract.MulitmodalChunk]]:
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
full_text: list[extract.MulitmodalChunk] = [content.strip(), *images]
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
final = {}
for m in metadata:
if tags := set(m.pop("tags", [])):
final["tags"] = tags | final.get("tags", set())
final |= m
return final
chunks = []
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
if not content.strip():
return []
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
summary, tags = summarizer.summarize(content)
full_text: extract.DataChunk = extract.DataChunk(
data=[content.strip(), *images], metadata={"tags": tags}
)
chunks: list[extract.DataChunk] = [full_text]
tokens = chunker.approx_token_count(content)
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
chunks += [
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
for c in chunker.chunk_text(content)
]
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
all_chunks = [full_text] + chunks
return [c for c in all_chunks if c and all(i for i in c)]
return [c for c in chunks if c.data]
class Chunk(Base):
@ -224,22 +242,27 @@ class SourceItem(Base):
"""Get vector IDs from associated chunks."""
return [chunk.id for chunk in self.chunks]
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
chunks: list[list[extract.MulitmodalChunk]] = []
if cast(str | None, self.content):
chunks = [[c] for c in chunker.chunk_text(cast(str, self.content))]
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
chunks: list[extract.DataChunk] = []
content = cast(str | None, self.content)
if content:
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
summary, tags = summarizer.summarize(content)
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
mime_type = cast(str | None, self.mime_type)
if mime_type and mime_type.startswith("image/"):
chunks.append([Image.open(self.filename)])
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
return chunks
def _make_chunk(
self, data: Sequence[extract.MulitmodalChunk], metadata: dict[str, Any] = {}
):
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
) -> Chunk:
chunk_id = str(uuid.uuid4())
text = "\n\n".join(c for c in data if isinstance(c, str) and c.strip())
images = [c for c in data if isinstance(c, Image.Image)]
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
images = [c for c in data.data if isinstance(c, Image.Image)]
image_names = image_filenames(chunk_id, images)
chunk = Chunk(
@ -249,17 +272,18 @@ class SourceItem(Base):
images=images,
file_paths=image_names,
embedding_model=collections.collection_model(cast(str, self.modality)),
item_metadata=self.as_payload() | metadata,
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
)
return chunk
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
return [self._make_chunk(data) for data in self._chunk_contents()]
def as_payload(self) -> dict:
return {
"source_id": self.id,
"tags": self.tags,
"size": self.size,
}
@property
@ -312,7 +336,7 @@ class MailMessage(SourceItem):
def as_payload(self) -> dict:
return {
"source_id": self.id,
**super().as_payload(),
"message_id": self.message_id,
"subject": self.subject,
"sender": self.sender,
@ -381,13 +405,12 @@ class EmailAttachment(SourceItem):
def as_payload(self) -> dict:
return {
**super().as_payload(),
"filename": self.filename,
"content_type": self.mime_type,
"size": self.size,
"created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
"mail_message_id": self.mail_message_id,
"source_id": self.id,
"tags": self.tags,
}
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
@ -397,7 +420,7 @@ class EmailAttachment(SourceItem):
contents = cast(str, self.content)
chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents)
return [self._make_chunk(c.data, metadata | c.metadata) for c in chunks]
return [self._make_chunk(c, metadata) for c in chunks]
# Add indexes
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
@ -488,8 +511,7 @@ class Comic(SourceItem):
def as_payload(self) -> dict:
payload = {
"source_id": self.id,
"tags": self.tags,
**super().as_payload(),
"title": self.title,
"author": self.author,
"published": self.published,
@ -500,10 +522,10 @@ class Comic(SourceItem):
}
return {k: v for k, v in payload.items() if v is not None}
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
image = Image.open(pathlib.Path(cast(str, self.filename)))
description = f"{self.title} by {self.author}"
return [[image, description]]
return [extract.DataChunk(data=[image, description])]
class Book(Base):
@ -538,7 +560,7 @@ class Book(Base):
def as_payload(self) -> dict:
return {
"source_id": self.id,
**super().as_payload(),
"isbn": self.isbn,
"title": self.title,
"author": self.author,
@ -548,7 +570,6 @@ class Book(Base):
"edition": self.edition,
"series": self.series,
"series_number": self.series_number,
"tags": self.tags,
} | (cast(dict, self.book_metadata) or {})
@ -591,29 +612,41 @@ class BookSection(SourceItem):
def as_payload(self) -> dict:
vals = {
**super().as_payload(),
"title": self.book.title,
"author": self.book.author,
"source_id": self.id,
"book_id": self.book_id,
"section_title": self.section_title,
"section_number": self.section_number,
"section_level": self.section_level,
"start_page": self.start_page,
"end_page": self.end_page,
"tags": self.tags,
}
return {k: v for k, v in vals.items() if v}
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
if not cast(str, self.content.strip()):
content = cast(str, self.content.strip())
if not content:
return []
texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)]
if len([p for p in self.pages if p.strip()]) == 1:
return [
self._make_chunk(
extract.DataChunk(data=[content]), metadata | {"type": "page"}
)
]
summary, tags = summarizer.summarize(content)
return [
self._make_chunk([text.strip()], metadata | {"page": page_number})
for text, page_number in texts
if text and text.strip()
] + [self._make_chunk([cast(str, self.content.strip())], metadata)]
self._make_chunk(
extract.DataChunk(data=[content]),
merge_metadata(metadata, {"type": "section", "tags": tags}),
),
self._make_chunk(
extract.DataChunk(data=[summary]),
merge_metadata(metadata, {"type": "summary", "tags": tags}),
),
]
class BlogPost(SourceItem):
@ -652,7 +685,7 @@ class BlogPost(SourceItem):
metadata = cast(dict | None, self.webpage_metadata) or {}
payload = {
"source_id": self.id,
**super().as_payload(),
"url": self.url,
"title": self.title,
"author": self.author,
@ -660,12 +693,11 @@ class BlogPost(SourceItem):
"description": self.description,
"domain": self.domain,
"word_count": self.word_count,
"tags": self.tags,
**metadata,
}
return {k: v for k, v in payload.items() if v}
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
@ -701,7 +733,7 @@ class ForumPost(SourceItem):
def as_payload(self) -> dict:
return {
"source_id": self.id,
**super().as_payload(),
"url": self.url,
"title": self.title,
"description": self.description,
@ -712,10 +744,9 @@ class ForumPost(SourceItem):
"votes": self.votes,
"score": self.score,
"comments": self.comments,
"tags": self.tags,
}
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))

View File

@ -245,3 +245,46 @@ def find_missing_points(
collection_name, ids=ids, with_payload=False, with_vectors=False
)
return set(ids) - {str(r.id) for r in found}
def set_payload(
client: qdrant_client.QdrantClient,
collection_name: str,
point_id: str,
payload: dict[str, Any],
) -> None:
"""Set payload for a single point without modifying its vector.
Args:
client: Qdrant client
collection_name: Name of the collection
point_id: Vector ID (as string)
payload: New payload to set
"""
client.set_payload(
collection_name=collection_name,
payload=payload,
points=[point_id],
)
logger.debug(f"Set payload for point {point_id} in {collection_name}")
def get_payloads(
client: qdrant_client.QdrantClient, collection_name: str, ids: list[str]
) -> dict[str, dict[str, Any]]:
"""Retrieve payloads for multiple points.
Args:
client: Qdrant client
collection_name: Name of the collection
ids: List of vector IDs (as strings)
Returns:
Dictionary mapping point IDs to their payloads
"""
points = client.retrieve(
collection_name=collection_name, ids=ids, with_payload=True, with_vectors=False
)
return {str(point.id): point.payload or {} for point in points}

View File

@ -98,3 +98,22 @@ CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60
TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large")
MIXED_EMBEDDING_MODEL = os.getenv("MIXED_EMBEDDING_MODEL", "voyage-multimodal-3")
EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", 50))
# VoyageAI max context window
EMBEDDING_MAX_TOKENS = int(os.getenv("EMBEDDING_MAX_TOKENS", 32000))
# Optimal chunk size for semantic search
DEFAULT_CHUNK_TOKENS = int(os.getenv("DEFAULT_CHUNK_TOKENS", 512))
OVERLAP_TOKENS = int(os.getenv("OVERLAP_TOKENS", 50))
# LLM settings
if openai_key_file := os.getenv("OPENAI_API_KEY_FILE"):
OPENAI_API_KEY = pathlib.Path(openai_key_file).read_text().strip()
else:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
ANTHROPIC_API_KEY = pathlib.Path(anthropic_key_file).read_text().strip()
else:
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")

View File

@ -0,0 +1,137 @@
import json
import logging
from typing import Any
from memory.common import settings, chunker
logger = logging.getLogger(__name__)
TAGS_PROMPT = """
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
Return your response as JSON with this format:
{{
"summary": "{summary}",
"tags": ["tag1", "tag2", "tag3"]
}}
Text:
{content}
"""
SUMMARY_PROMPT = """
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
Also provide 3-5 relevant tags that capture the main topics or themes.
Return your response as JSON with this format:
{{
"summary": "your summary here",
"tags": ["tag1", "tag2", "tag3"]
}}
Text to summarize:
{content}
"""
def _call_openai(prompt: str) -> dict[str, Any]:
"""Call OpenAI API for summarization."""
import openai
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
try:
response = client.chat.completions.create(
model=settings.SUMMARIZER_MODEL.split("/")[1],
messages=[
{
"role": "system",
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
temperature=0.3,
max_tokens=2048,
)
return json.loads(response.choices[0].message.content or "{}")
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
def _call_anthropic(prompt: str) -> dict[str, Any]:
"""Call Anthropic API for summarization."""
import anthropic
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
try:
response = client.messages.create(
model=settings.SUMMARIZER_MODEL.split("/")[1],
messages=[{"role": "user", "content": prompt}],
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid JSON.",
temperature=0.3,
max_tokens=2048,
)
return json.loads(response.content[0].text)
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
def truncate(content: str, target_tokens: int) -> str:
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
"""
Summarize content to approximately target_tokens length and generate tags.
Args:
content: Text to summarize
target_tokens: Target length in tokens (defaults to DEFAULT_CHUNK_TOKENS)
Returns:
Tuple of (summary, tags)
"""
if not content or not content.strip():
return "", []
if target_tokens is None:
target_tokens = settings.DEFAULT_CHUNK_TOKENS
summary, tags = content, []
# If content is already short enough, just extract tags
current_tokens = chunker.approx_token_count(content)
if current_tokens <= target_tokens:
logger.info(
f"Content already under {target_tokens} tokens, extracting tags only"
)
prompt = TAGS_PROMPT.format(content=content, summary=summary[:1000])
else:
prompt = SUMMARY_PROMPT.format(
target_tokens=target_tokens,
target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
content=content,
)
try:
if settings.SUMMARIZER_MODEL.startswith("anthropic"):
result = _call_anthropic(prompt)
else:
result = _call_openai(prompt)
summary = result.get("summary", "")
tags = result.get("tags", [])
except Exception as e:
logger.error(f"Summarization failed: {e}")
tokens = chunker.approx_token_count(summary)
if tokens > target_tokens * 1.5:
logger.warning(f"Summary too long ({tokens} tokens), truncating")
summary = truncate(content, target_tokens)
return summary, tags

View File

@ -1,7 +1,7 @@
import logging
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Sequence
from typing import Sequence, Any
from memory.workers.tasks.content_processing import process_content_item
from sqlalchemy import select
@ -21,6 +21,10 @@ REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks"
REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk"
REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item"
REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items"
UPDATE_METADATA_FOR_SOURCE_ITEMS = (
f"{MAINTENANCE_ROOT}.update_metadata_for_source_items"
)
UPDATE_METADATA_FOR_ITEM = f"{MAINTENANCE_ROOT}.update_metadata_for_item"
@app.task(name=CLEAN_COLLECTION)
@ -97,6 +101,8 @@ def get_item_class(item_type: str):
raise ValueError(
f"Unsupported item type {item_type}. Available types: {available_types}"
)
if not hasattr(class_, "chunks"):
raise ValueError(f"Item type {item_type} does not have chunks")
return class_
@ -226,3 +232,99 @@ def reingest_missing_chunks(
total_stats[collection]["total"] += stats["total"]
return dict(total_stats)
def _payloads_equal(current: dict[str, Any], new: dict[str, Any]) -> bool:
"""Compare two payloads to see if they're effectively equal."""
# Handle tags specially - compare as sets since order doesn't matter
current_tags = set(current.get("tags", []))
new_tags = set(new.get("tags", []))
if current_tags != new_tags:
return False
# Compare all other fields
current_without_tags = {k: v for k, v in current.items() if k != "tags"}
new_without_tags = {k: v for k, v in new.items() if k != "tags"}
return current_without_tags == new_without_tags
@app.task(name=UPDATE_METADATA_FOR_ITEM)
def update_metadata_for_item(item_id: str, item_type: str):
"""Update metadata in Qdrant for all chunks of a single source item, merging tags."""
logger.info(f"Updating metadata for {item_type} {item_id}")
try:
class_ = get_item_class(item_type)
except ValueError as e:
logger.error(f"Error getting item class: {e}")
return {"status": "error", "error": str(e)}
client = qdrant.get_qdrant_client()
updated_chunks = 0
errors = 0
with make_session() as session:
item = session.query(class_).get(item_id)
if not item:
return {"status": "error", "error": f"Item {item_id} not found"}
chunk_ids = [str(chunk.id) for chunk in item.chunks if chunk.id]
if not chunk_ids:
return {"status": "success", "updated_chunks": 0, "errors": 0}
collection = item.modality
try:
current_payloads = qdrant.get_payloads(client, collection, chunk_ids)
# Get new metadata from source item
new_metadata = item.as_payload()
new_tags = set(new_metadata.get("tags", []))
for chunk_id in chunk_ids:
if chunk_id not in current_payloads:
logger.warning(
f"Chunk {chunk_id} not found in Qdrant collection {collection}"
)
continue
current_payload = current_payloads[chunk_id]
current_tags = set(current_payload.get("tags", []))
# Merge tags (combine existing and new tags)
merged_tags = list(current_tags | new_tags)
updated_metadata = new_metadata.copy()
updated_metadata["tags"] = merged_tags
if _payloads_equal(current_payload, updated_metadata):
continue
qdrant.set_payload(client, collection, chunk_id, updated_metadata)
updated_chunks += 1
except Exception as e:
logger.error(f"Error updating metadata for item {item.id}: {e}")
errors += 1
return {"status": "success", "updated_chunks": updated_chunks, "errors": errors}
@app.task(name=UPDATE_METADATA_FOR_SOURCE_ITEMS)
def update_metadata_for_source_items(item_type: str):
"""Update metadata in Qdrant for all chunks of all items of a given source type."""
logger.info(f"Updating metadata for all {item_type} source items")
try:
class_ = get_item_class(item_type)
except ValueError as e:
logger.error(f"Error getting item class: {e}")
return {"status": "error", "error": str(e)}
with make_session() as session:
item_ids = session.query(class_.id).all()
logger.info(f"Found {len(item_ids)} items to update metadata for")
for item_id in item_ids:
update_metadata_for_item.delay(item_id.id, item_type) # type: ignore
return {"status": "success", "items": len(item_ids)}

View File

@ -5,6 +5,8 @@ from datetime import datetime
from pathlib import Path
from unittest.mock import Mock, patch
import anthropic
import openai
import pytest
import qdrant_client
import voyageai
@ -237,3 +239,35 @@ def mock_voyage_client():
client.embed = embeder
client.multimodal_embed = embeder
yield client
@pytest.fixture(autouse=True)
def mock_openai_client():
with patch.object(openai, "OpenAI", autospec=True) as mock_client:
client = mock_client()
client.chat = Mock()
client.chat.completions.create = Mock(
return_value=Mock(
choices=[
Mock(
message=Mock(
content='{"summary": "test", "tags": ["tag1", "tag2"]}'
)
)
]
)
)
yield client
@pytest.fixture(autouse=True)
def mock_anthropic_client():
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
client = mock_client()
client.messages = Mock()
client.messages.create = Mock(
return_value=Mock(
content=[Mock(text='{"summary": "test", "tags": ["tag1", "tag2"]}')]
)
)
yield client

View File

@ -5,8 +5,7 @@ from typing import cast
import pytest
from PIL import Image
from datetime import datetime
from memory.common import settings
from memory.common import chunker
from memory.common import settings, chunker, extract
from memory.common.db.models import (
Chunk,
clean_filename,
@ -17,6 +16,7 @@ from memory.common.db.models import (
BookSection,
BlogPost,
Book,
merge_metadata,
)
@ -54,6 +54,114 @@ def test_clean_filename(input_filename, expected):
assert clean_filename(input_filename) == expected
@pytest.mark.parametrize(
"dicts,expected",
[
# Empty input
([], {}),
# Single dict without tags
([{"key": "value"}], {"key": "value"}),
# Single dict with tags as list
(
[{"key": "value", "tags": ["tag1", "tag2"]}],
{"key": "value", "tags": {"tag1", "tag2"}},
),
# Single dict with tags as set
(
[{"key": "value", "tags": {"tag1", "tag2"}}],
{"key": "value", "tags": {"tag1", "tag2"}},
),
# Multiple dicts without tags
(
[{"key1": "value1"}, {"key2": "value2"}],
{"key1": "value1", "key2": "value2"},
),
# Multiple dicts with non-overlapping tags
(
[
{"key1": "value1", "tags": ["tag1"]},
{"key2": "value2", "tags": ["tag2"]},
],
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2"}},
),
# Multiple dicts with overlapping tags
(
[
{"key1": "value1", "tags": ["tag1", "tag2"]},
{"key2": "value2", "tags": ["tag2", "tag3"]},
],
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2", "tag3"}},
),
# Overlapping keys - later dict wins
(
[
{"key": "value1", "other": "data1"},
{"key": "value2", "another": "data2"},
],
{"key": "value2", "other": "data1", "another": "data2"},
),
# Mixed tags types (list and set)
(
[
{"key1": "value1", "tags": ["tag1", "tag2"]},
{"key2": "value2", "tags": {"tag3", "tag4"}},
],
{
"key1": "value1",
"key2": "value2",
"tags": {"tag1", "tag2", "tag3", "tag4"},
},
),
# Empty tags
(
[{"key": "value", "tags": []}, {"key2": "value2", "tags": []}],
{"key": "value", "key2": "value2"},
),
# None values
(
[{"key1": None, "key2": "value"}, {"key3": None}],
{"key1": None, "key2": "value", "key3": None},
),
# Complex nested structures
(
[
{"nested": {"inner": "value1"}, "list": [1, 2, 3], "tags": ["tag1"]},
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": ["tag2"]},
],
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": {"tag1", "tag2"}},
),
# Boolean and numeric values
(
[
{"bool": True, "int": 42, "float": 3.14, "tags": ["numeric"]},
{"bool": False, "int": 100},
],
{"bool": False, "int": 100, "float": 3.14, "tags": {"numeric"}},
),
# Three or more dicts
(
[
{"a": 1, "tags": ["t1"]},
{"b": 2, "tags": ["t2", "t3"]},
{"c": 3, "a": 10, "tags": ["t3", "t4"]},
],
{"a": 10, "b": 2, "c": 3, "tags": {"t1", "t2", "t3", "t4"}},
),
# Dict with only tags
([{"tags": ["tag1", "tag2"]}], {"tags": {"tag1", "tag2"}}),
# Empty dicts
([{}, {}], {}),
# Mix of empty and non-empty dicts
(
[{}, {"key": "value", "tags": ["tag"]}, {}],
{"key": "value", "tags": {"tag"}},
),
],
)
def test_merge_metadata(dicts, expected):
assert merge_metadata(*dicts) == expected
def test_image_filenames_with_existing_filenames(tmp_path):
"""Test image_filenames when images already have filenames"""
chunk_id = "test_chunk_123"
@ -262,7 +370,7 @@ def test_source_item_chunk_contents_text(chunk_length, expected, default_chunk_s
)
default_chunk_size(chunk_length)
assert source._chunk_contents() == expected
assert source._chunk_contents() == [extract.DataChunk(data=e) for e in expected]
def test_source_item_chunk_contents_image(tmp_path):
@ -281,8 +389,8 @@ def test_source_item_chunk_contents_image(tmp_path):
result = source._chunk_contents()
assert len(result) == 1
assert len(result[0]) == 1
assert isinstance(result[0][0], Image.Image)
assert len(result[0].data) == 1
assert isinstance(result[0].data[0], Image.Image)
def test_source_item_chunk_contents_mixed(tmp_path):
@ -302,8 +410,8 @@ def test_source_item_chunk_contents_mixed(tmp_path):
result = source._chunk_contents()
assert len(result) == 2
assert result[0][0] == "Bla bla"
assert isinstance(result[1][0], Image.Image)
assert result[0].data[0] == "Bla bla"
assert isinstance(result[1].data[0], Image.Image)
@pytest.mark.parametrize(
@ -323,7 +431,11 @@ def test_source_item_chunk_contents_mixed(tmp_path):
def test_source_item_make_chunk(tmp_path, texts, expected_content):
"""Test SourceItem._make_chunk method"""
source = SourceItem(
sha256=b"test123", content="test", modality="text", tags=["tag1"]
sha256=b"test123",
content="test",
modality="text",
tags=["tag1"],
size=1024,
)
# Create actual image
image_file = tmp_path / "test.png"
@ -335,7 +447,7 @@ def test_source_item_make_chunk(tmp_path, texts, expected_content):
data = [*texts, img]
metadata = {"extra": "data"}
chunk = source._make_chunk(data, metadata)
chunk = source._make_chunk(extract.DataChunk(data=data), metadata)
assert chunk.id is not None
assert chunk.source == source
@ -344,7 +456,12 @@ def test_source_item_make_chunk(tmp_path, texts, expected_content):
assert chunk.embedding_model is not None
# Check that metadata is merged correctly
expected_payload = {"source_id": source.id, "tags": ["tag1"], "extra": "data"}
expected_payload = {
"source_id": source.id,
"tags": {"tag1"},
"extra": "data",
"size": 1024,
}
assert chunk.item_metadata == expected_payload
@ -355,10 +472,11 @@ def test_source_item_as_payload():
content="test",
modality="text",
tags=["tag1", "tag2"],
size=1024,
)
payload = source.as_payload()
assert payload == {"source_id": 123, "tags": ["tag1", "tag2"]}
assert payload == {"source_id": 123, "tags": ["tag1", "tag2"], "size": 1024}
@pytest.mark.parametrize(
@ -544,6 +662,7 @@ def test_mail_message_as_payload(sent_at, expected_date):
folder="INBOX",
sent_at=sent_at,
tags=["tag1", "tag2"],
size=1024,
)
# Manually set id for testing
object.__setattr__(mail_message, "id", 123)
@ -552,6 +671,7 @@ def test_mail_message_as_payload(sent_at, expected_date):
expected = {
"source_id": 123,
"size": 1024,
"message_id": "<test@example.com>",
"subject": "Test Subject",
"sender": "sender@example.com",
@ -646,12 +766,12 @@ def test_email_attachment_as_payload(created_at, expected_date):
payload = attachment.as_payload()
expected = {
"source_id": 456,
"filename": "document.pdf",
"content_type": "application/pdf",
"size": 1024,
"created_at": expected_date,
"mail_message_id": 123,
"source_id": 456,
"tags": ["pdf", "document"],
}
assert payload == expected
@ -702,7 +822,8 @@ def test_email_attachment_data_chunks(
# Verify the method calls
mock_extract.assert_called_once_with("text/plain", expected_content)
mock_make.assert_called_once_with(
["extracted text"], {"extra": "metadata", "source": content_source}
extract.DataChunk(data=["extracted text"], metadata={"source": content_source}),
{"extra": "metadata"},
)
assert result == [mock_chunk]
@ -813,18 +934,20 @@ def test_subclass_deletion_cascades_from_source_item(db_session: Session):
# No pages
([], []),
# Single page
(["Page 1 content"], [("Page 1 content", 10)]),
(["Page 1 content"], [("Page 1 content", {"type": "page"})]),
# Multiple pages
(
["Page 1", "Page 2", "Page 3"],
[
("Page 1", 10),
("Page 2", 11),
("Page 3", 12),
(
"Page 1\n\nPage 2\n\nPage 3",
{"type": "section", "tags": {"tag1", "tag2"}},
),
("test", {"type": "summary", "tags": {"tag1", "tag2"}}),
],
),
# Empty/whitespace pages filtered out
(["", " ", "Page 3"], [("Page 3", 12)]),
(["", " ", "Page 3"], [("Page 3", {"type": "page"})]),
# All empty - no chunks created
(["", " ", " "], []),
],
@ -845,11 +968,8 @@ def test_book_section_data_chunks(pages, expected_chunks):
chunks = book_section.data_chunks()
expected = [
(p, book_section.as_payload() | {"page": i}) for p, i in expected_chunks
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks
]
if content:
expected.append((content, book_section.as_payload()))
assert [(c.content, c.item_metadata) for c in chunks] == expected
for c in chunks:
assert cast(list, c.file_paths) == []
@ -859,16 +979,39 @@ def test_book_section_data_chunks(pages, expected_chunks):
"content,expected",
[
("", []),
("Short content", [["Short content"]]),
(
"Short content",
[
extract.DataChunk(
data=["Short content"], metadata={"tags": ["tag1", "tag2"]}
)
],
),
(
"This is a very long piece of content that should be chunked into multiple pieces when processed.",
[
[
"This is a very long piece of content that should be chunked into multiple pieces when processed."
],
["This is a very long piece of content that"],
["should be chunked into multiple pieces when"],
["processed."],
extract.DataChunk(
data=[
"This is a very long piece of content that should be chunked into multiple pieces when processed."
],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["This is a very long piece of content that"],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["should be chunked into multiple pieces when"],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["processed."],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["test"],
metadata={"tags": ["tag1", "tag2"]},
),
],
),
],
@ -906,7 +1049,8 @@ def test_blog_post_chunk_contents_with_images(tmp_path):
result = blog_post._chunk_contents()
result = [
[i if isinstance(i, str) else getattr(i, "filename") for i in c] for c in result
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
for c in result
]
assert result == [
["Content with images", img1_path.as_posix(), img2_path.as_posix()]
@ -933,9 +1077,9 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
result = blog_post._chunk_contents()
result = [
[i if isinstance(i, str) else getattr(i, "filename") for i in c] for c in result
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
for c in result
]
print(result)
assert result == [
[
f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
@ -950,4 +1094,5 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
f"Second picture is here: {img2_path.as_posix()}",
img2_path.as_posix(),
],
["test"],
]

View File

@ -218,6 +218,7 @@ def test_sync_comic_success(mock_get, mock_image_response, db_session, qdrant):
"title": "Test Comic",
"url": "https://example.com/comic/1",
"source_id": 1,
"size": 90,
},
None,
)

View File

@ -186,7 +186,7 @@ def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path, qdrant)
"author": "Test Author",
"status": "processed",
"total_sections": 4,
"sections_embedded": 8,
"sections_embedded": 4,
}
book = db_session.query(Book).filter(Book.title == "Test Book").first()
@ -325,8 +325,4 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
# Check that the full content was passed to the embedding function
texts = mock_voyage_client.embed.call_args[0][0]
assert texts == [
[large_page_1.strip()],
[large_page_2.strip()],
[large_section_content.strip()],
]
assert texts == [[large_section_content.strip()], ["test"]]

View File

@ -335,38 +335,19 @@ def test_sync_lesswrong_max_items_limit(mock_fetch, db_session):
assert result["max_items"] == 3
@pytest.mark.parametrize(
"since_str,expected_days_ago",
[
("2024-01-01T00:00:00", None), # Specific date
(None, 30), # Default should be 30 days ago
],
)
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
def test_sync_lesswrong_since_parameter(
mock_fetch, since_str, expected_days_ago, db_session
):
def test_sync_lesswrong_since_parameter(mock_fetch, db_session):
"""Test that since parameter is handled correctly."""
mock_fetch.return_value = []
if since_str:
forums.sync_lesswrong(since=since_str)
expected_since = datetime.fromisoformat(since_str)
else:
forums.sync_lesswrong()
# Should default to 30 days ago
expected_since = datetime.now() - timedelta(days=30)
forums.sync_lesswrong(since="2024-01-01T00:00:00")
expected_since = datetime.fromisoformat("2024-01-01T00:00:00")
# Verify fetch was called with correct since date
call_args = mock_fetch.call_args[0]
actual_since = call_args[0]
if expected_days_ago:
# For default case, check it's approximately 30 days ago
time_diff = abs((actual_since - expected_since).total_seconds())
assert time_diff < 120 # Within 2 minute tolerance for slow tests
else:
assert actual_since == expected_since
assert actual_since == expected_since
@pytest.mark.parametrize(

View File

@ -1,3 +1,4 @@
# FIXME: Most of this was vibe-coded
import uuid
from datetime import datetime, timedelta
from unittest.mock import patch, call
@ -15,6 +16,9 @@ from memory.workers.tasks.maintenance import (
reingest_missing_chunks,
reingest_item,
reingest_empty_source_items,
update_metadata_for_item,
update_metadata_for_source_items,
_payloads_equal,
)
@ -596,3 +600,372 @@ def test_reingest_empty_source_items_invalid_type(db_session):
def test_reingest_missing_chunks_no_chunks(db_session):
assert reingest_missing_chunks() == {}
@pytest.mark.parametrize(
"payload1,payload2,expected",
[
# Identical payloads
(
{"tags": ["a", "b"], "source_id": 1, "title": "Test"},
{"tags": ["a", "b"], "source_id": 1, "title": "Test"},
True,
),
# Different tag order (should be equal)
(
{"tags": ["a", "b"], "source_id": 1},
{"tags": ["b", "a"], "source_id": 1},
True,
),
# Different tags
(
{"tags": ["a", "b"], "source_id": 1},
{"tags": ["a", "c"], "source_id": 1},
False,
),
# Different non-tag fields
({"tags": ["a"], "source_id": 1}, {"tags": ["a"], "source_id": 2}, False),
# Missing tags in one payload
({"tags": ["a"], "source_id": 1}, {"source_id": 1}, False),
# Empty tags (should be equal)
({"tags": [], "source_id": 1}, {"source_id": 1}, True),
],
)
def test_payloads_equal(payload1, payload2, expected):
"""Test the _payloads_equal helper function."""
assert _payloads_equal(payload1, payload2) == expected
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
def test_update_metadata_for_item_success(db_session, qdrant, item_type):
"""Test successful metadata update for an item with chunks."""
if item_type == "MailMessage":
item = MailMessage(
sha256=b"test_hash" + bytes(24),
tags=["original", "test"],
size=100,
mime_type="message/rfc822",
embed_status="STORED",
message_id="<test@example.com>",
subject="Test Subject",
sender="sender@example.com",
recipients=["recipient@example.com"],
content="Test content",
folder="INBOX",
modality="mail",
)
modality = "mail"
else:
item = BlogPost(
sha256=b"test_hash" + bytes(24),
tags=["original", "test"],
size=100,
mime_type="text/html",
embed_status="STORED",
url="https://example.com/post",
title="Test Blog Post",
author="Author Name",
content="Test blog content",
modality="blog",
)
modality = "blog"
db_session.add(item)
db_session.flush()
# Add chunks to the item
chunk_ids = [str(uuid.uuid4()) for _ in range(3)]
chunks = [
Chunk(
id=chunk_id,
source=item,
content=f"Test chunk content {i}",
embedding_model="test-model",
)
for i, chunk_id in enumerate(chunk_ids)
]
db_session.add_all(chunks)
db_session.commit()
# Setup Qdrant with existing payloads
qd.ensure_collection_exists(qdrant, modality, 1024)
existing_payloads = [
{"tags": ["existing", "qdrant"], "source_id": item.id, "old_field": "value"}
for _ in chunk_ids
]
qd.upsert_vectors(
qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids), existing_payloads
)
# Mock the qdrant functions to track calls
with (
patch(
"memory.workers.tasks.maintenance.qdrant.get_payloads"
) as mock_get_payloads,
patch(
"memory.workers.tasks.maintenance.qdrant.set_payload"
) as mock_set_payload,
):
# Return the existing payloads
mock_get_payloads.return_value = {
chunk_id: payload for chunk_id, payload in zip(chunk_ids, existing_payloads)
}
result = update_metadata_for_item(str(item.id), item_type)
# Verify result
assert result["status"] == "success"
assert result["updated_chunks"] == 3
assert result["errors"] == 0
# Verify batch retrieval was called once
mock_get_payloads.assert_called_once_with(qdrant, modality, chunk_ids)
# Verify set_payload was called for each chunk with merged tags
assert mock_set_payload.call_count == 3
for call in mock_set_payload.call_args_list:
args, kwargs = call
client, collection, chunk_id, updated_payload = args
# Check that tags were merged (existing + new)
expected_tags = set(
[
"existing",
"qdrant",
"original",
"test",
]
)
if item_type == "MailMessage":
expected_tags.update(["recipient@example.com", "sender@example.com"])
actual_tags = set(updated_payload["tags"])
assert actual_tags == expected_tags
# Check that new metadata is present
assert updated_payload["source_id"] == item.id
def test_update_metadata_for_item_no_changes(db_session, qdrant):
"""Test that no updates are made when metadata hasn't changed."""
item = MailMessage(
sha256=b"test_hash" + bytes(24),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="STORED",
message_id="<test@example.com>",
subject="Test Subject",
sender="sender@example.com",
recipients=["recipient@example.com"],
content="Test content",
folder="INBOX",
modality="mail",
)
db_session.add(item)
db_session.flush()
chunk_id = str(uuid.uuid4())
chunk = Chunk(
id=chunk_id,
source=item,
content="Test chunk content",
embedding_model="test-model",
)
db_session.add(chunk)
db_session.commit()
# Setup payload that matches what the item would generate
item_payload = item.as_payload()
existing_payload = {chunk_id: item_payload}
with (
patch(
"memory.workers.tasks.maintenance.qdrant.get_payloads"
) as mock_get_payloads,
patch(
"memory.workers.tasks.maintenance.qdrant.set_payload"
) as mock_set_payload,
):
mock_get_payloads.return_value = existing_payload
result = update_metadata_for_item(str(item.id), "MailMessage")
# Verify no updates were made
assert result["status"] == "success"
assert result["updated_chunks"] == 0
assert result["errors"] == 0
# Verify set_payload was never called since nothing changed
mock_set_payload.assert_not_called()
def test_update_metadata_for_item_no_chunks(db_session):
"""Test updating metadata for an item with no chunks."""
item = MailMessage(
sha256=b"test_hash" + bytes(24),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="RAW",
message_id="<test@example.com>",
subject="Test Subject",
sender="sender@example.com",
recipients=["recipient@example.com"],
content="Test content",
folder="INBOX",
modality="mail",
)
db_session.add(item)
db_session.commit()
result = update_metadata_for_item(str(item.id), "MailMessage")
assert result["status"] == "success"
assert result["updated_chunks"] == 0
assert result["errors"] == 0
def test_update_metadata_for_item_not_found(db_session):
"""Test updating metadata for a non-existent item."""
non_existent_id = "999"
result = update_metadata_for_item(non_existent_id, "MailMessage")
assert result == {"status": "error", "error": f"Item {non_existent_id} not found"}
def test_update_metadata_for_item_invalid_type(db_session):
"""Test updating metadata with an invalid item type."""
result = update_metadata_for_item("1", "invalid_type")
assert result["status"] == "error"
assert "Unsupported item type invalid_type" in result["error"]
assert "Available types:" in result["error"]
def test_update_metadata_for_item_missing_chunks_in_qdrant(db_session, qdrant):
"""Test when some chunks exist in DB but not in Qdrant."""
item = MailMessage(
sha256=b"test_hash" + bytes(24),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="STORED",
message_id="<test@example.com>",
subject="Test Subject",
sender="sender@example.com",
recipients=["recipient@example.com"],
content="Test content",
folder="INBOX",
modality="mail",
)
db_session.add(item)
db_session.flush()
chunk_ids = [str(uuid.uuid4()) for _ in range(3)]
chunks = [
Chunk(
id=chunk_id,
source=item,
content=f"Test chunk content {i}",
embedding_model="test-model",
)
for i, chunk_id in enumerate(chunk_ids)
]
db_session.add_all(chunks)
db_session.commit()
# Mock qdrant to return payloads for only some chunks
with (
patch(
"memory.workers.tasks.maintenance.qdrant.get_payloads"
) as mock_get_payloads,
patch(
"memory.workers.tasks.maintenance.qdrant.set_payload"
) as mock_set_payload,
):
# Only return payload for first chunk
mock_get_payloads.return_value = {
chunk_ids[0]: {"tags": ["existing"], "source_id": item.id}
}
result = update_metadata_for_item(str(item.id), "MailMessage")
# Should only update the chunk that was found in Qdrant
assert result["status"] == "success"
assert result["updated_chunks"] == 1
assert result["errors"] == 0
# Only one set_payload call for the found chunk
assert mock_set_payload.call_count == 1
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
def test_update_metadata_for_source_items_success(db_session, item_type):
"""Test updating metadata for all items of a given type."""
if item_type == "MailMessage":
items = [
MailMessage(
sha256=f"test_hash_{i}".encode() + bytes(32 - len(f"test_hash_{i}")),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="STORED",
message_id=f"<test{i}@example.com>",
subject=f"Test Subject {i}",
sender="sender@example.com",
recipients=["recipient@example.com"],
content=f"Test content {i}",
folder="INBOX",
modality="mail",
)
for i in range(3)
]
else:
items = [
BlogPost(
sha256=f"test_hash_{i}".encode() + bytes(32 - len(f"test_hash_{i}")),
tags=["test"],
size=100,
mime_type="text/html",
embed_status="STORED",
url=f"https://example.com/post{i}",
title=f"Test Blog Post {i}",
author="Author Name",
content=f"Test blog content {i}",
modality="blog",
)
for i in range(3)
]
db_session.add_all(items)
db_session.commit()
with patch.object(update_metadata_for_item, "delay") as mock_update:
result = update_metadata_for_source_items(item_type)
assert result == {"status": "success", "items": 3}
# Verify that update_metadata_for_item.delay was called for each item
assert mock_update.call_count == 3
expected_calls = [call(item.id, item_type) for item in items]
mock_update.assert_has_calls(expected_calls, any_order=True)
def test_update_metadata_for_source_items_no_items(db_session):
"""Test when there are no items of the specified type."""
with patch.object(update_metadata_for_item, "delay") as mock_update:
result = update_metadata_for_source_items("MailMessage")
assert result == {"status": "success", "items": 0}
mock_update.assert_not_called()
def test_update_metadata_for_source_items_invalid_type(db_session):
"""Test updating metadata for an invalid item type."""
result = update_metadata_for_source_items("invalid_type")
assert result["status"] == "error"
assert "Unsupported item type invalid_type" in result["error"]
assert "Available types:" in result["error"]