mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 05:14:43 +02:00
summarize before chunking
This commit is contained in:
parent
ed8033bdd3
commit
e505f9b53c
@ -11,6 +11,7 @@ secrets:
|
||||
postgres_password: { file: ./secrets/postgres_password.txt }
|
||||
jwt_secret: { file: ./secrets/jwt_secret.txt }
|
||||
openai_key: { file: ./secrets/openai_key.txt }
|
||||
anthropic_key: { file: ./secrets/anthropic_key.txt }
|
||||
|
||||
# --------------------------------------------------------------------- volumes
|
||||
volumes:
|
||||
@ -42,7 +43,8 @@ x-worker-base: &worker-base
|
||||
# DSNs are built in worker entrypoint from user + pw files
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||
secrets: [ postgres_password, openai_key ]
|
||||
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
||||
secrets: [ postgres_password, openai_key, anthropic_key ]
|
||||
read_only: true
|
||||
tmpfs: [ /tmp, /var/tmp ]
|
||||
cap_drop: [ ALL ]
|
||||
@ -76,9 +78,6 @@ services:
|
||||
mem_limit: 4g
|
||||
cpus: "1.5"
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
ports:
|
||||
# PostgreSQL port for local Celery result backend
|
||||
- "15432:5432"
|
||||
|
||||
rabbitmq:
|
||||
image: rabbitmq:3.13-management
|
||||
@ -98,11 +97,6 @@ services:
|
||||
mem_limit: 512m
|
||||
cpus: "0.5"
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
ports:
|
||||
# UI only on localhost
|
||||
- "15672:15672"
|
||||
# AMQP port for local Celery clients
|
||||
- "15673:5672"
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant:v1.14.0
|
||||
@ -119,8 +113,6 @@ services:
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
ports:
|
||||
- "6333:6333"
|
||||
mem_limit: 4g
|
||||
cpus: "2"
|
||||
security_opt: [ "no-new-privileges=true" ]
|
||||
|
@ -4,4 +4,5 @@ pydantic==2.7.1
|
||||
alembic==1.13.1
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.18.1
|
449
run_celery_task.py
Normal file
449
run_celery_task.py
Normal file
@ -0,0 +1,449 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to run Celery tasks on the Docker Compose setup from your local machine.
|
||||
|
||||
This script connects to the RabbitMQ broker running in Docker and sends tasks
|
||||
to the workers. It requires the same dependencies as the workers to import
|
||||
the task definitions.
|
||||
|
||||
Usage:
|
||||
python run_celery_task.py --help
|
||||
python run_celery_task.py email sync-all-accounts
|
||||
python run_celery_task.py email sync-account --account-id 1
|
||||
python run_celery_task.py ebook sync-book --file-path "/path/to/book.epub" --tags "fiction,scifi"
|
||||
python run_celery_task.py maintenance clean-all-collections
|
||||
python run_celery_task.py blogs sync-webpage --url "https://example.com"
|
||||
python run_celery_task.py comic sync-all-comics
|
||||
python run_celery_task.py forums sync-lesswrong --since-date "2025-01-01" --min-karma 10 --limit 50 --cooldown 0.5 --max-items 1000
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from memory.workers.tasks.blogs import (
|
||||
SYNC_ALL_ARTICLE_FEEDS,
|
||||
SYNC_ARTICLE_FEED,
|
||||
SYNC_WEBPAGE,
|
||||
SYNC_WEBSITE_ARCHIVE,
|
||||
)
|
||||
from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_COMIC, SYNC_SMBC, SYNC_XKCD
|
||||
from memory.workers.tasks.ebook import SYNC_BOOK
|
||||
from memory.workers.tasks.email import PROCESS_EMAIL, SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS
|
||||
from memory.workers.tasks.forums import SYNC_LESSWRONG, SYNC_LESSWRONG_POST
|
||||
from memory.workers.tasks.maintenance import (
|
||||
CLEAN_ALL_COLLECTIONS,
|
||||
CLEAN_COLLECTION,
|
||||
REINGEST_CHUNK,
|
||||
REINGEST_EMPTY_SOURCE_ITEMS,
|
||||
REINGEST_ITEM,
|
||||
REINGEST_MISSING_CHUNKS,
|
||||
UPDATE_METADATA_FOR_ITEM,
|
||||
UPDATE_METADATA_FOR_SOURCE_ITEMS,
|
||||
)
|
||||
|
||||
# Add the src directory to Python path so we can import memory modules
|
||||
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
|
||||
from celery import Celery
|
||||
from memory.common import settings
|
||||
|
||||
|
||||
TASK_MAPPINGS = {
|
||||
"email": {
|
||||
"sync_all_accounts": SYNC_ALL_ACCOUNTS,
|
||||
"sync_account": SYNC_ACCOUNT,
|
||||
"process_message": PROCESS_EMAIL,
|
||||
},
|
||||
"ebook": {
|
||||
"sync_book": SYNC_BOOK,
|
||||
},
|
||||
"maintenance": {
|
||||
"clean_all_collections": CLEAN_ALL_COLLECTIONS,
|
||||
"clean_collection": CLEAN_COLLECTION,
|
||||
"reingest_missing_chunks": REINGEST_MISSING_CHUNKS,
|
||||
"reingest_chunk": REINGEST_CHUNK,
|
||||
"reingest_item": REINGEST_ITEM,
|
||||
"reingest_empty_source_items": REINGEST_EMPTY_SOURCE_ITEMS,
|
||||
"update_metadata_for_item": UPDATE_METADATA_FOR_ITEM,
|
||||
"update_metadata_for_source_items": UPDATE_METADATA_FOR_SOURCE_ITEMS,
|
||||
},
|
||||
"blogs": {
|
||||
"sync_webpage": SYNC_WEBPAGE,
|
||||
"sync_article_feed": SYNC_ARTICLE_FEED,
|
||||
"sync_all_article_feeds": SYNC_ALL_ARTICLE_FEEDS,
|
||||
"sync_website_archive": SYNC_WEBSITE_ARCHIVE,
|
||||
},
|
||||
"comic": {
|
||||
"sync_all_comics": SYNC_ALL_COMICS,
|
||||
"sync_smbc": SYNC_SMBC,
|
||||
"sync_xkcd": SYNC_XKCD,
|
||||
"sync_comic": SYNC_COMIC,
|
||||
},
|
||||
"forums": {
|
||||
"sync_lesswrong": SYNC_LESSWRONG,
|
||||
"sync_lesswrong_post": SYNC_LESSWRONG_POST,
|
||||
},
|
||||
}
|
||||
QUEUE_MAPPINGS = {
|
||||
"email": "email",
|
||||
"ebook": "ebooks",
|
||||
"photo": "photo_embed",
|
||||
}
|
||||
|
||||
|
||||
def create_local_celery_app() -> Celery:
|
||||
"""Create a Celery app configured to connect to the Docker RabbitMQ."""
|
||||
# Override settings for local connection to Docker services
|
||||
rabbitmq_url = f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@localhost:15673//"
|
||||
|
||||
app = Celery(
|
||||
"memory-local",
|
||||
broker=rabbitmq_url,
|
||||
backend=settings.CELERY_RESULT_BACKEND.replace(
|
||||
"postgres:5432", "localhost:15432"
|
||||
),
|
||||
)
|
||||
|
||||
# Import task modules so they're registered
|
||||
app.autodiscover_tasks(["memory.workers.tasks"])
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run_task(app: Celery, category: str, task_name: str, **kwargs) -> str:
|
||||
"""Run a task using the task mappings."""
|
||||
if category not in TASK_MAPPINGS:
|
||||
raise ValueError(f"Unknown category: {category}")
|
||||
|
||||
if task_name not in TASK_MAPPINGS[category]:
|
||||
raise ValueError(f"Unknown {category} task: {task_name}")
|
||||
|
||||
task_path = TASK_MAPPINGS[category][task_name]
|
||||
queue_name = QUEUE_MAPPINGS.get(category) or category
|
||||
|
||||
result = app.send_task(task_path, kwargs=kwargs, queue=queue_name)
|
||||
return result.id
|
||||
|
||||
|
||||
def get_task_result(app: Celery, task_id: str, timeout: int = 300) -> Any:
|
||||
"""Get the result of a task."""
|
||||
result = app.AsyncResult(task_id)
|
||||
try:
|
||||
return result.get(timeout=timeout)
|
||||
except Exception as e:
|
||||
return {"error": str(e), "status": result.status}
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option("--wait", is_flag=True, help="Wait for task completion and show result")
|
||||
@click.option(
|
||||
"--timeout", default=300, help="Timeout in seconds when waiting for result"
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx, wait, timeout):
|
||||
"""Run Celery tasks on Docker Compose setup."""
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["wait"] = wait
|
||||
ctx.obj["timeout"] = timeout
|
||||
|
||||
try:
|
||||
ctx.obj["app"] = create_local_celery_app()
|
||||
except Exception as e:
|
||||
click.echo(f"Error connecting to Celery broker: {e}")
|
||||
click.echo(
|
||||
"Make sure Docker Compose is running and RabbitMQ is accessible on localhost:15673"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def execute_task(ctx, category: str, task_name: str, **kwargs):
|
||||
"""Helper to execute a task and handle results."""
|
||||
app = ctx.obj["app"]
|
||||
wait = ctx.obj["wait"]
|
||||
timeout = ctx.obj["timeout"]
|
||||
|
||||
# Filter out None values
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
try:
|
||||
task_id = run_task(app, category, task_name, **kwargs)
|
||||
click.echo("Task submitted successfully!")
|
||||
click.echo(f"Task ID: {task_id}")
|
||||
|
||||
if wait:
|
||||
click.echo(f"Waiting for task completion (timeout: {timeout}s)...")
|
||||
result = get_task_result(app, task_id, timeout)
|
||||
click.echo("Task result:")
|
||||
click.echo(json.dumps(result, indent=2, default=str))
|
||||
except Exception as e:
|
||||
click.echo(f"Error running task: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def email(ctx):
|
||||
"""Email-related tasks."""
|
||||
pass
|
||||
|
||||
|
||||
@email.command("sync-all-accounts")
|
||||
@click.option("--since-date", help="Sync items since this date (ISO format)")
|
||||
@click.pass_context
|
||||
def email_sync_all_accounts(ctx, since_date):
|
||||
"""Sync all email accounts."""
|
||||
execute_task(ctx, "email", "sync_all_accounts", since_date=since_date)
|
||||
|
||||
|
||||
@email.command("sync-account")
|
||||
@click.option("--account-id", type=int, required=True, help="Email account ID")
|
||||
@click.option("--since-date", help="Sync items since this date (ISO format)")
|
||||
@click.pass_context
|
||||
def email_sync_account(ctx, account_id, since_date):
|
||||
"""Sync a specific email account."""
|
||||
execute_task(
|
||||
ctx, "email", "sync_account", account_id=account_id, since_date=since_date
|
||||
)
|
||||
|
||||
|
||||
@email.command("process-message")
|
||||
@click.option("--message-id", required=True, help="Email message ID")
|
||||
@click.option("--folder", help="Email folder name")
|
||||
@click.option("--raw-email", help="Raw email content")
|
||||
@click.pass_context
|
||||
def email_process_message(ctx, message_id, folder, raw_email):
|
||||
"""Process a specific email message."""
|
||||
execute_task(
|
||||
ctx,
|
||||
"email",
|
||||
"process_message",
|
||||
message_id=message_id,
|
||||
folder=folder,
|
||||
raw_email=raw_email,
|
||||
)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def ebook(ctx):
|
||||
"""Ebook-related tasks."""
|
||||
pass
|
||||
|
||||
|
||||
@ebook.command("sync-book")
|
||||
@click.option("--file-path", required=True, help="Path to ebook file")
|
||||
@click.option("--tags", help="Comma-separated tags")
|
||||
@click.pass_context
|
||||
def ebook_sync_book(ctx, file_path, tags):
|
||||
"""Sync an ebook."""
|
||||
execute_task(ctx, "ebook", "sync_book", file_path=file_path, tags=tags)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def maintenance(ctx):
|
||||
"""Maintenance tasks."""
|
||||
pass
|
||||
|
||||
|
||||
@maintenance.command("clean-all-collections")
|
||||
@click.pass_context
|
||||
def maintenance_clean_all_collections(ctx):
|
||||
"""Clean all collections."""
|
||||
execute_task(ctx, "maintenance", "clean_all_collections")
|
||||
|
||||
|
||||
@maintenance.command("clean-collection")
|
||||
@click.option("--collection", required=True, help="Collection name to clean")
|
||||
@click.pass_context
|
||||
def maintenance_clean_collection(ctx, collection):
|
||||
"""Clean a specific collection."""
|
||||
execute_task(ctx, "maintenance", "clean_collection", collection=collection)
|
||||
|
||||
|
||||
@maintenance.command("reingest-missing-chunks")
|
||||
@click.option("--minutes-ago", type=int, help="Minutes ago to reingest chunks")
|
||||
@click.pass_context
|
||||
def maintenance_reingest_missing_chunks(ctx, minutes_ago):
|
||||
"""Reingest missing chunks."""
|
||||
execute_task(ctx, "maintenance", "reingest_missing_chunks", minutes_ago=minutes_ago)
|
||||
|
||||
|
||||
@maintenance.command("reingest-item")
|
||||
@click.option("--item-id", required=True, help="Item ID to reingest")
|
||||
@click.option("--item-type", required=True, help="Item type to reingest")
|
||||
@click.pass_context
|
||||
def maintenance_reingest_item(ctx, item_id, item_type):
|
||||
"""Reingest a specific item."""
|
||||
execute_task(
|
||||
ctx, "maintenance", "reingest_item", item_id=item_id, item_type=item_type
|
||||
)
|
||||
|
||||
|
||||
@maintenance.command("update-metadata-for-item")
|
||||
@click.option("--item-id", required=True, help="Item ID to update metadata for")
|
||||
@click.option("--item-type", required=True, help="Item type to update metadata for")
|
||||
@click.pass_context
|
||||
def maintenance_update_metadata_for_item(ctx, item_id, item_type):
|
||||
"""Update metadata for a specific item."""
|
||||
execute_task(
|
||||
ctx,
|
||||
"maintenance",
|
||||
"update_metadata_for_item",
|
||||
item_id=item_id,
|
||||
item_type=item_type,
|
||||
)
|
||||
|
||||
|
||||
@maintenance.command("update-metadata-for-source-items")
|
||||
@click.option("--item-type", required=True, help="Item type to update metadata for")
|
||||
@click.pass_context
|
||||
def maintenance_update_metadata_for_source_items(ctx, item_type):
|
||||
"""Update metadata for all items of a specific type."""
|
||||
execute_task(
|
||||
ctx, "maintenance", "update_metadata_for_source_items", item_type=item_type
|
||||
)
|
||||
|
||||
|
||||
@maintenance.command("reingest-empty-source-items")
|
||||
@click.option("--item-type", required=True, help="Item type to reingest")
|
||||
@click.pass_context
|
||||
def maintenance_reingest_empty_source_items(ctx, item_type):
|
||||
"""Reingest empty source items."""
|
||||
execute_task(ctx, "maintenance", "reingest_empty_source_items", item_type=item_type)
|
||||
|
||||
|
||||
@maintenance.command("reingest-chunk")
|
||||
@click.option("--chunk-id", required=True, help="Chunk ID to reingest")
|
||||
@click.pass_context
|
||||
def maintenance_reingest_chunk(ctx, chunk_id):
|
||||
"""Reingest a specific chunk."""
|
||||
execute_task(ctx, "maintenance", "reingest_chunk", chunk_id=chunk_id)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def blogs(ctx):
|
||||
"""Blog-related tasks."""
|
||||
pass
|
||||
|
||||
|
||||
@blogs.command("sync-webpage")
|
||||
@click.option("--url", required=True, help="URL to sync")
|
||||
@click.pass_context
|
||||
def blogs_sync_webpage(ctx, url):
|
||||
"""Sync a webpage."""
|
||||
execute_task(ctx, "blogs", "sync_webpage", url=url)
|
||||
|
||||
|
||||
@blogs.command("sync-article-feed")
|
||||
@click.option("--feed-id", type=int, required=True, help="Feed ID to sync")
|
||||
@click.pass_context
|
||||
def blogs_sync_article_feed(ctx, feed_id):
|
||||
"""Sync an article feed."""
|
||||
execute_task(ctx, "blogs", "sync_article_feed", feed_id=feed_id)
|
||||
|
||||
|
||||
@blogs.command("sync-all-article-feeds")
|
||||
@click.pass_context
|
||||
def blogs_sync_all_article_feeds(ctx):
|
||||
"""Sync all article feeds."""
|
||||
execute_task(ctx, "blogs", "sync_all_article_feeds")
|
||||
|
||||
|
||||
@blogs.command("sync-website-archive")
|
||||
@click.option("--url", required=True, help="URL to sync")
|
||||
@click.pass_context
|
||||
def blogs_sync_website_archive(ctx, url):
|
||||
"""Sync a website archive."""
|
||||
execute_task(ctx, "blogs", "sync_website_archive", url=url)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def comic(ctx):
|
||||
"""Comic-related tasks."""
|
||||
pass
|
||||
|
||||
|
||||
@comic.command("sync-all-comics")
|
||||
@click.pass_context
|
||||
def comic_sync_all_comics(ctx):
|
||||
"""Sync all comics."""
|
||||
execute_task(ctx, "comic", "sync_all_comics")
|
||||
|
||||
|
||||
@comic.command("sync-smbc")
|
||||
@click.pass_context
|
||||
def comic_sync_smbc(ctx):
|
||||
"""Sync SMBC comics."""
|
||||
execute_task(ctx, "comic", "sync_smbc")
|
||||
|
||||
|
||||
@comic.command("sync-xkcd")
|
||||
@click.pass_context
|
||||
def comic_sync_xkcd(ctx):
|
||||
"""Sync XKCD comics."""
|
||||
execute_task(ctx, "comic", "sync_xkcd")
|
||||
|
||||
|
||||
@comic.command("sync-comic")
|
||||
@click.option("--image-url", required=True, help="Image URL to sync")
|
||||
@click.option("--title", help="Comic title")
|
||||
@click.option("--author", help="Comic author")
|
||||
@click.option("--published-date", help="Comic published date")
|
||||
@click.pass_context
|
||||
def comic_sync_comic(ctx, image_url, title, author, published_date):
|
||||
"""Sync a specific comic."""
|
||||
execute_task(
|
||||
ctx,
|
||||
"comic",
|
||||
"sync_comic",
|
||||
image_url=image_url,
|
||||
title=title,
|
||||
author=author,
|
||||
published_date=published_date,
|
||||
)
|
||||
|
||||
|
||||
@cli.group()
|
||||
@click.pass_context
|
||||
def forums(ctx):
|
||||
"""Forum-related tasks."""
|
||||
pass
|
||||
|
||||
|
||||
@forums.command("sync-lesswrong")
|
||||
@click.option("--since-date", help="Sync items since this date (ISO format)")
|
||||
@click.option("--min-karma", type=int, help="Minimum karma to sync")
|
||||
@click.option("--limit", type=int, help="Limit the number of posts to sync")
|
||||
@click.option("--cooldown", type=float, help="Cooldown between posts")
|
||||
@click.option("--max-items", type=int, help="Maximum number of posts to sync")
|
||||
@click.pass_context
|
||||
def forums_sync_lesswrong(ctx, since_date, min_karma, limit, cooldown, max_items):
|
||||
"""Sync LessWrong posts."""
|
||||
execute_task(
|
||||
ctx,
|
||||
"forums",
|
||||
"sync_lesswrong",
|
||||
since_date=since_date,
|
||||
min_karma=min_karma,
|
||||
limit=limit,
|
||||
cooldown=cooldown,
|
||||
max_items=max_items,
|
||||
)
|
||||
|
||||
|
||||
@forums.command("sync-lesswrong-post")
|
||||
@click.option("--url", required=True, help="LessWrong post URL")
|
||||
@click.pass_context
|
||||
def forums_sync_lesswrong_post(ctx, url):
|
||||
"""Sync a specific LessWrong post."""
|
||||
execute_task(ctx, "forums", "sync_lesswrong_post", url=url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
@ -2,13 +2,15 @@ import logging
|
||||
from typing import Iterable, Any
|
||||
import re
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Chunking configuration
|
||||
EMBEDDING_MAX_TOKENS = 32000 # VoyageAI max context window
|
||||
DEFAULT_CHUNK_TOKENS = 512 # Optimal chunk size for semantic search
|
||||
OVERLAP_TOKENS = 50 # Default overlap between chunks (10% of chunk size)
|
||||
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
||||
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
||||
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
|
@ -35,6 +35,7 @@ from memory.common import settings
|
||||
import memory.common.extract as extract
|
||||
import memory.common.collections as collections
|
||||
import memory.common.chunker as chunker
|
||||
import memory.common.summarizer as summarizer
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@ -105,19 +106,36 @@ def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalCh
|
||||
]
|
||||
|
||||
|
||||
def chunk_mixed(
|
||||
content: str, image_paths: Sequence[str]
|
||||
) -> list[list[extract.MulitmodalChunk]]:
|
||||
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
||||
full_text: list[extract.MulitmodalChunk] = [content.strip(), *images]
|
||||
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
final = {}
|
||||
for m in metadata:
|
||||
if tags := set(m.pop("tags", [])):
|
||||
final["tags"] = tags | final.get("tags", set())
|
||||
final |= m
|
||||
return final
|
||||
|
||||
chunks = []
|
||||
|
||||
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
|
||||
if not content.strip():
|
||||
return []
|
||||
|
||||
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
||||
|
||||
summary, tags = summarizer.summarize(content)
|
||||
full_text: extract.DataChunk = extract.DataChunk(
|
||||
data=[content.strip(), *images], metadata={"tags": tags}
|
||||
)
|
||||
|
||||
chunks: list[extract.DataChunk] = [full_text]
|
||||
tokens = chunker.approx_token_count(content)
|
||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
|
||||
chunks += [
|
||||
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
|
||||
for c in chunker.chunk_text(content)
|
||||
]
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
all_chunks = [full_text] + chunks
|
||||
return [c for c in all_chunks if c and all(i for i in c)]
|
||||
return [c for c in chunks if c.data]
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
@ -224,22 +242,27 @@ class SourceItem(Base):
|
||||
"""Get vector IDs from associated chunks."""
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
chunks: list[list[extract.MulitmodalChunk]] = []
|
||||
if cast(str | None, self.content):
|
||||
chunks = [[c] for c in chunker.chunk_text(cast(str, self.content))]
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
chunks: list[extract.DataChunk] = []
|
||||
content = cast(str | None, self.content)
|
||||
if content:
|
||||
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
|
||||
|
||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
summary, tags = summarizer.summarize(content)
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
mime_type = cast(str | None, self.mime_type)
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
chunks.append([Image.open(self.filename)])
|
||||
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
|
||||
return chunks
|
||||
|
||||
def _make_chunk(
|
||||
self, data: Sequence[extract.MulitmodalChunk], metadata: dict[str, Any] = {}
|
||||
):
|
||||
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
|
||||
) -> Chunk:
|
||||
chunk_id = str(uuid.uuid4())
|
||||
text = "\n\n".join(c for c in data if isinstance(c, str) and c.strip())
|
||||
images = [c for c in data if isinstance(c, Image.Image)]
|
||||
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
|
||||
images = [c for c in data.data if isinstance(c, Image.Image)]
|
||||
image_names = image_filenames(chunk_id, images)
|
||||
|
||||
chunk = Chunk(
|
||||
@ -249,17 +272,18 @@ class SourceItem(Base):
|
||||
images=images,
|
||||
file_paths=image_names,
|
||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
||||
item_metadata=self.as_payload() | metadata,
|
||||
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
||||
)
|
||||
return chunk
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
|
||||
return [self._make_chunk(data) for data in self._chunk_contents()]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
}
|
||||
|
||||
@property
|
||||
@ -312,7 +336,7 @@ class MailMessage(SourceItem):
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
**super().as_payload(),
|
||||
"message_id": self.message_id,
|
||||
"subject": self.subject,
|
||||
"sender": self.sender,
|
||||
@ -381,13 +405,12 @@ class EmailAttachment(SourceItem):
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
**super().as_payload(),
|
||||
"filename": self.filename,
|
||||
"content_type": self.mime_type,
|
||||
"size": self.size,
|
||||
"created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
|
||||
"mail_message_id": self.mail_message_id,
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
@ -397,7 +420,7 @@ class EmailAttachment(SourceItem):
|
||||
contents = cast(str, self.content)
|
||||
|
||||
chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents)
|
||||
return [self._make_chunk(c.data, metadata | c.metadata) for c in chunks]
|
||||
return [self._make_chunk(c, metadata) for c in chunks]
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
||||
@ -488,8 +511,7 @@ class Comic(SourceItem):
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
payload = {
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
**super().as_payload(),
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
"published": self.published,
|
||||
@ -500,10 +522,10 @@ class Comic(SourceItem):
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
image = Image.open(pathlib.Path(cast(str, self.filename)))
|
||||
description = f"{self.title} by {self.author}"
|
||||
return [[image, description]]
|
||||
return [extract.DataChunk(data=[image, description])]
|
||||
|
||||
|
||||
class Book(Base):
|
||||
@ -538,7 +560,7 @@ class Book(Base):
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
**super().as_payload(),
|
||||
"isbn": self.isbn,
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
@ -548,7 +570,6 @@ class Book(Base):
|
||||
"edition": self.edition,
|
||||
"series": self.series,
|
||||
"series_number": self.series_number,
|
||||
"tags": self.tags,
|
||||
} | (cast(dict, self.book_metadata) or {})
|
||||
|
||||
|
||||
@ -591,29 +612,41 @@ class BookSection(SourceItem):
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
vals = {
|
||||
**super().as_payload(),
|
||||
"title": self.book.title,
|
||||
"author": self.book.author,
|
||||
"source_id": self.id,
|
||||
"book_id": self.book_id,
|
||||
"section_title": self.section_title,
|
||||
"section_number": self.section_number,
|
||||
"section_level": self.section_level,
|
||||
"start_page": self.start_page,
|
||||
"end_page": self.end_page,
|
||||
"tags": self.tags,
|
||||
}
|
||||
return {k: v for k, v in vals.items() if v}
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
if not cast(str, self.content.strip()):
|
||||
content = cast(str, self.content.strip())
|
||||
if not content:
|
||||
return []
|
||||
|
||||
texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)]
|
||||
if len([p for p in self.pages if p.strip()]) == 1:
|
||||
return [
|
||||
self._make_chunk(
|
||||
extract.DataChunk(data=[content]), metadata | {"type": "page"}
|
||||
)
|
||||
]
|
||||
|
||||
summary, tags = summarizer.summarize(content)
|
||||
return [
|
||||
self._make_chunk([text.strip()], metadata | {"page": page_number})
|
||||
for text, page_number in texts
|
||||
if text and text.strip()
|
||||
] + [self._make_chunk([cast(str, self.content.strip())], metadata)]
|
||||
self._make_chunk(
|
||||
extract.DataChunk(data=[content]),
|
||||
merge_metadata(metadata, {"type": "section", "tags": tags}),
|
||||
),
|
||||
self._make_chunk(
|
||||
extract.DataChunk(data=[summary]),
|
||||
merge_metadata(metadata, {"type": "summary", "tags": tags}),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class BlogPost(SourceItem):
|
||||
@ -652,7 +685,7 @@ class BlogPost(SourceItem):
|
||||
metadata = cast(dict | None, self.webpage_metadata) or {}
|
||||
|
||||
payload = {
|
||||
"source_id": self.id,
|
||||
**super().as_payload(),
|
||||
"url": self.url,
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
@ -660,12 +693,11 @@ class BlogPost(SourceItem):
|
||||
"description": self.description,
|
||||
"domain": self.domain,
|
||||
"word_count": self.word_count,
|
||||
"tags": self.tags,
|
||||
**metadata,
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v}
|
||||
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||
|
||||
|
||||
@ -701,7 +733,7 @@ class ForumPost(SourceItem):
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
**super().as_payload(),
|
||||
"url": self.url,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
@ -712,10 +744,9 @@ class ForumPost(SourceItem):
|
||||
"votes": self.votes,
|
||||
"score": self.score,
|
||||
"comments": self.comments,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||
|
||||
|
||||
|
@ -245,3 +245,46 @@ def find_missing_points(
|
||||
collection_name, ids=ids, with_payload=False, with_vectors=False
|
||||
)
|
||||
return set(ids) - {str(r.id) for r in found}
|
||||
|
||||
|
||||
def set_payload(
|
||||
client: qdrant_client.QdrantClient,
|
||||
collection_name: str,
|
||||
point_id: str,
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
"""Set payload for a single point without modifying its vector.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
point_id: Vector ID (as string)
|
||||
payload: New payload to set
|
||||
"""
|
||||
client.set_payload(
|
||||
collection_name=collection_name,
|
||||
payload=payload,
|
||||
points=[point_id],
|
||||
)
|
||||
|
||||
logger.debug(f"Set payload for point {point_id} in {collection_name}")
|
||||
|
||||
|
||||
def get_payloads(
|
||||
client: qdrant_client.QdrantClient, collection_name: str, ids: list[str]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retrieve payloads for multiple points.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
ids: List of vector IDs (as strings)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping point IDs to their payloads
|
||||
"""
|
||||
points = client.retrieve(
|
||||
collection_name=collection_name, ids=ids, with_payload=True, with_vectors=False
|
||||
)
|
||||
|
||||
return {str(point.id): point.payload or {} for point in points}
|
||||
|
@ -98,3 +98,22 @@ CHUNK_REINGEST_SINCE_MINUTES = int(os.getenv("CHUNK_REINGEST_SINCE_MINUTES", 60
|
||||
TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large")
|
||||
MIXED_EMBEDDING_MODEL = os.getenv("MIXED_EMBEDDING_MODEL", "voyage-multimodal-3")
|
||||
EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", 50))
|
||||
|
||||
# VoyageAI max context window
|
||||
EMBEDDING_MAX_TOKENS = int(os.getenv("EMBEDDING_MAX_TOKENS", 32000))
|
||||
# Optimal chunk size for semantic search
|
||||
DEFAULT_CHUNK_TOKENS = int(os.getenv("DEFAULT_CHUNK_TOKENS", 512))
|
||||
OVERLAP_TOKENS = int(os.getenv("OVERLAP_TOKENS", 50))
|
||||
|
||||
|
||||
# LLM settings
|
||||
if openai_key_file := os.getenv("OPENAI_API_KEY_FILE"):
|
||||
OPENAI_API_KEY = pathlib.Path(openai_key_file).read_text().strip()
|
||||
else:
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
|
||||
if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
||||
ANTHROPIC_API_KEY = pathlib.Path(anthropic_key_file).read_text().strip()
|
||||
else:
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
|
137
src/memory/common/summarizer.py
Normal file
137
src/memory/common/summarizer.py
Normal file
@ -0,0 +1,137 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from memory.common import settings, chunker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TAGS_PROMPT = """
|
||||
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
|
||||
|
||||
Return your response as JSON with this format:
|
||||
{{
|
||||
"summary": "{summary}",
|
||||
"tags": ["tag1", "tag2", "tag3"]
|
||||
}}
|
||||
|
||||
Text:
|
||||
{content}
|
||||
"""
|
||||
|
||||
SUMMARY_PROMPT = """
|
||||
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
|
||||
Also provide 3-5 relevant tags that capture the main topics or themes.
|
||||
|
||||
Return your response as JSON with this format:
|
||||
{{
|
||||
"summary": "your summary here",
|
||||
"tags": ["tag1", "tag2", "tag3"]
|
||||
}}
|
||||
|
||||
Text to summarize:
|
||||
{content}
|
||||
"""
|
||||
|
||||
|
||||
def _call_openai(prompt: str) -> dict[str, Any]:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return json.loads(response.choices[0].message.content or "{}")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _call_anthropic(prompt: str) -> dict[str, Any]:
|
||||
"""Call Anthropic API for summarization."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid JSON.",
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return json.loads(response.content[0].text)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
||||
|
||||
|
||||
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Summarize content to approximately target_tokens length and generate tags.
|
||||
|
||||
Args:
|
||||
content: Text to summarize
|
||||
target_tokens: Target length in tokens (defaults to DEFAULT_CHUNK_TOKENS)
|
||||
|
||||
Returns:
|
||||
Tuple of (summary, tags)
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return "", []
|
||||
|
||||
if target_tokens is None:
|
||||
target_tokens = settings.DEFAULT_CHUNK_TOKENS
|
||||
|
||||
summary, tags = content, []
|
||||
|
||||
# If content is already short enough, just extract tags
|
||||
current_tokens = chunker.approx_token_count(content)
|
||||
if current_tokens <= target_tokens:
|
||||
logger.info(
|
||||
f"Content already under {target_tokens} tokens, extracting tags only"
|
||||
)
|
||||
prompt = TAGS_PROMPT.format(content=content, summary=summary[:1000])
|
||||
else:
|
||||
prompt = SUMMARY_PROMPT.format(
|
||||
target_tokens=target_tokens,
|
||||
target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
|
||||
content=content,
|
||||
)
|
||||
|
||||
try:
|
||||
if settings.SUMMARIZER_MODEL.startswith("anthropic"):
|
||||
result = _call_anthropic(prompt)
|
||||
else:
|
||||
result = _call_openai(prompt)
|
||||
|
||||
summary = result.get("summary", "")
|
||||
tags = result.get("tags", [])
|
||||
except Exception as e:
|
||||
logger.error(f"Summarization failed: {e}")
|
||||
|
||||
tokens = chunker.approx_token_count(summary)
|
||||
if tokens > target_tokens * 1.5:
|
||||
logger.warning(f"Summary too long ({tokens} tokens), truncating")
|
||||
summary = truncate(content, target_tokens)
|
||||
|
||||
return summary, tags
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Any
|
||||
|
||||
from memory.workers.tasks.content_processing import process_content_item
|
||||
from sqlalchemy import select
|
||||
@ -21,6 +21,10 @@ REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks"
|
||||
REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk"
|
||||
REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item"
|
||||
REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items"
|
||||
UPDATE_METADATA_FOR_SOURCE_ITEMS = (
|
||||
f"{MAINTENANCE_ROOT}.update_metadata_for_source_items"
|
||||
)
|
||||
UPDATE_METADATA_FOR_ITEM = f"{MAINTENANCE_ROOT}.update_metadata_for_item"
|
||||
|
||||
|
||||
@app.task(name=CLEAN_COLLECTION)
|
||||
@ -97,6 +101,8 @@ def get_item_class(item_type: str):
|
||||
raise ValueError(
|
||||
f"Unsupported item type {item_type}. Available types: {available_types}"
|
||||
)
|
||||
if not hasattr(class_, "chunks"):
|
||||
raise ValueError(f"Item type {item_type} does not have chunks")
|
||||
return class_
|
||||
|
||||
|
||||
@ -226,3 +232,99 @@ def reingest_missing_chunks(
|
||||
total_stats[collection]["total"] += stats["total"]
|
||||
|
||||
return dict(total_stats)
|
||||
|
||||
|
||||
def _payloads_equal(current: dict[str, Any], new: dict[str, Any]) -> bool:
|
||||
"""Compare two payloads to see if they're effectively equal."""
|
||||
# Handle tags specially - compare as sets since order doesn't matter
|
||||
current_tags = set(current.get("tags", []))
|
||||
new_tags = set(new.get("tags", []))
|
||||
|
||||
if current_tags != new_tags:
|
||||
return False
|
||||
|
||||
# Compare all other fields
|
||||
current_without_tags = {k: v for k, v in current.items() if k != "tags"}
|
||||
new_without_tags = {k: v for k, v in new.items() if k != "tags"}
|
||||
|
||||
return current_without_tags == new_without_tags
|
||||
|
||||
|
||||
@app.task(name=UPDATE_METADATA_FOR_ITEM)
|
||||
def update_metadata_for_item(item_id: str, item_type: str):
|
||||
"""Update metadata in Qdrant for all chunks of a single source item, merging tags."""
|
||||
logger.info(f"Updating metadata for {item_type} {item_id}")
|
||||
try:
|
||||
class_ = get_item_class(item_type)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error getting item class: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
updated_chunks = 0
|
||||
errors = 0
|
||||
|
||||
with make_session() as session:
|
||||
item = session.query(class_).get(item_id)
|
||||
if not item:
|
||||
return {"status": "error", "error": f"Item {item_id} not found"}
|
||||
|
||||
chunk_ids = [str(chunk.id) for chunk in item.chunks if chunk.id]
|
||||
if not chunk_ids:
|
||||
return {"status": "success", "updated_chunks": 0, "errors": 0}
|
||||
|
||||
collection = item.modality
|
||||
|
||||
try:
|
||||
current_payloads = qdrant.get_payloads(client, collection, chunk_ids)
|
||||
|
||||
# Get new metadata from source item
|
||||
new_metadata = item.as_payload()
|
||||
new_tags = set(new_metadata.get("tags", []))
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
if chunk_id not in current_payloads:
|
||||
logger.warning(
|
||||
f"Chunk {chunk_id} not found in Qdrant collection {collection}"
|
||||
)
|
||||
continue
|
||||
|
||||
current_payload = current_payloads[chunk_id]
|
||||
current_tags = set(current_payload.get("tags", []))
|
||||
|
||||
# Merge tags (combine existing and new tags)
|
||||
merged_tags = list(current_tags | new_tags)
|
||||
updated_metadata = new_metadata.copy()
|
||||
updated_metadata["tags"] = merged_tags
|
||||
|
||||
if _payloads_equal(current_payload, updated_metadata):
|
||||
continue
|
||||
|
||||
qdrant.set_payload(client, collection, chunk_id, updated_metadata)
|
||||
updated_chunks += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating metadata for item {item.id}: {e}")
|
||||
errors += 1
|
||||
|
||||
return {"status": "success", "updated_chunks": updated_chunks, "errors": errors}
|
||||
|
||||
|
||||
@app.task(name=UPDATE_METADATA_FOR_SOURCE_ITEMS)
|
||||
def update_metadata_for_source_items(item_type: str):
|
||||
"""Update metadata in Qdrant for all chunks of all items of a given source type."""
|
||||
logger.info(f"Updating metadata for all {item_type} source items")
|
||||
try:
|
||||
class_ = get_item_class(item_type)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error getting item class: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
with make_session() as session:
|
||||
item_ids = session.query(class_.id).all()
|
||||
logger.info(f"Found {len(item_ids)} items to update metadata for")
|
||||
|
||||
for item_id in item_ids:
|
||||
update_metadata_for_item.delay(item_id.id, item_type) # type: ignore
|
||||
|
||||
return {"status": "success", "items": len(item_ids)}
|
||||
|
@ -5,6 +5,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import anthropic
|
||||
import openai
|
||||
import pytest
|
||||
import qdrant_client
|
||||
import voyageai
|
||||
@ -237,3 +239,35 @@ def mock_voyage_client():
|
||||
client.embed = embeder
|
||||
client.multimodal_embed = embeder
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_openai_client():
|
||||
with patch.object(openai, "OpenAI", autospec=True) as mock_client:
|
||||
client = mock_client()
|
||||
client.chat = Mock()
|
||||
client.chat.completions.create = Mock(
|
||||
return_value=Mock(
|
||||
choices=[
|
||||
Mock(
|
||||
message=Mock(
|
||||
content='{"summary": "test", "tags": ["tag1", "tag2"]}'
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_anthropic_client():
|
||||
with patch.object(anthropic, "Anthropic", autospec=True) as mock_client:
|
||||
client = mock_client()
|
||||
client.messages = Mock()
|
||||
client.messages.create = Mock(
|
||||
return_value=Mock(
|
||||
content=[Mock(text='{"summary": "test", "tags": ["tag1", "tag2"]}')]
|
||||
)
|
||||
)
|
||||
yield client
|
||||
|
@ -5,8 +5,7 @@ from typing import cast
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
from memory.common import settings
|
||||
from memory.common import chunker
|
||||
from memory.common import settings, chunker, extract
|
||||
from memory.common.db.models import (
|
||||
Chunk,
|
||||
clean_filename,
|
||||
@ -17,6 +16,7 @@ from memory.common.db.models import (
|
||||
BookSection,
|
||||
BlogPost,
|
||||
Book,
|
||||
merge_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -54,6 +54,114 @@ def test_clean_filename(input_filename, expected):
|
||||
assert clean_filename(input_filename) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dicts,expected",
|
||||
[
|
||||
# Empty input
|
||||
([], {}),
|
||||
# Single dict without tags
|
||||
([{"key": "value"}], {"key": "value"}),
|
||||
# Single dict with tags as list
|
||||
(
|
||||
[{"key": "value", "tags": ["tag1", "tag2"]}],
|
||||
{"key": "value", "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
# Single dict with tags as set
|
||||
(
|
||||
[{"key": "value", "tags": {"tag1", "tag2"}}],
|
||||
{"key": "value", "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
# Multiple dicts without tags
|
||||
(
|
||||
[{"key1": "value1"}, {"key2": "value2"}],
|
||||
{"key1": "value1", "key2": "value2"},
|
||||
),
|
||||
# Multiple dicts with non-overlapping tags
|
||||
(
|
||||
[
|
||||
{"key1": "value1", "tags": ["tag1"]},
|
||||
{"key2": "value2", "tags": ["tag2"]},
|
||||
],
|
||||
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
# Multiple dicts with overlapping tags
|
||||
(
|
||||
[
|
||||
{"key1": "value1", "tags": ["tag1", "tag2"]},
|
||||
{"key2": "value2", "tags": ["tag2", "tag3"]},
|
||||
],
|
||||
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2", "tag3"}},
|
||||
),
|
||||
# Overlapping keys - later dict wins
|
||||
(
|
||||
[
|
||||
{"key": "value1", "other": "data1"},
|
||||
{"key": "value2", "another": "data2"},
|
||||
],
|
||||
{"key": "value2", "other": "data1", "another": "data2"},
|
||||
),
|
||||
# Mixed tags types (list and set)
|
||||
(
|
||||
[
|
||||
{"key1": "value1", "tags": ["tag1", "tag2"]},
|
||||
{"key2": "value2", "tags": {"tag3", "tag4"}},
|
||||
],
|
||||
{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
"tags": {"tag1", "tag2", "tag3", "tag4"},
|
||||
},
|
||||
),
|
||||
# Empty tags
|
||||
(
|
||||
[{"key": "value", "tags": []}, {"key2": "value2", "tags": []}],
|
||||
{"key": "value", "key2": "value2"},
|
||||
),
|
||||
# None values
|
||||
(
|
||||
[{"key1": None, "key2": "value"}, {"key3": None}],
|
||||
{"key1": None, "key2": "value", "key3": None},
|
||||
),
|
||||
# Complex nested structures
|
||||
(
|
||||
[
|
||||
{"nested": {"inner": "value1"}, "list": [1, 2, 3], "tags": ["tag1"]},
|
||||
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": ["tag2"]},
|
||||
],
|
||||
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
# Boolean and numeric values
|
||||
(
|
||||
[
|
||||
{"bool": True, "int": 42, "float": 3.14, "tags": ["numeric"]},
|
||||
{"bool": False, "int": 100},
|
||||
],
|
||||
{"bool": False, "int": 100, "float": 3.14, "tags": {"numeric"}},
|
||||
),
|
||||
# Three or more dicts
|
||||
(
|
||||
[
|
||||
{"a": 1, "tags": ["t1"]},
|
||||
{"b": 2, "tags": ["t2", "t3"]},
|
||||
{"c": 3, "a": 10, "tags": ["t3", "t4"]},
|
||||
],
|
||||
{"a": 10, "b": 2, "c": 3, "tags": {"t1", "t2", "t3", "t4"}},
|
||||
),
|
||||
# Dict with only tags
|
||||
([{"tags": ["tag1", "tag2"]}], {"tags": {"tag1", "tag2"}}),
|
||||
# Empty dicts
|
||||
([{}, {}], {}),
|
||||
# Mix of empty and non-empty dicts
|
||||
(
|
||||
[{}, {"key": "value", "tags": ["tag"]}, {}],
|
||||
{"key": "value", "tags": {"tag"}},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_merge_metadata(dicts, expected):
|
||||
assert merge_metadata(*dicts) == expected
|
||||
|
||||
|
||||
def test_image_filenames_with_existing_filenames(tmp_path):
|
||||
"""Test image_filenames when images already have filenames"""
|
||||
chunk_id = "test_chunk_123"
|
||||
@ -262,7 +370,7 @@ def test_source_item_chunk_contents_text(chunk_length, expected, default_chunk_s
|
||||
)
|
||||
|
||||
default_chunk_size(chunk_length)
|
||||
assert source._chunk_contents() == expected
|
||||
assert source._chunk_contents() == [extract.DataChunk(data=e) for e in expected]
|
||||
|
||||
|
||||
def test_source_item_chunk_contents_image(tmp_path):
|
||||
@ -281,8 +389,8 @@ def test_source_item_chunk_contents_image(tmp_path):
|
||||
result = source._chunk_contents()
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 1
|
||||
assert isinstance(result[0][0], Image.Image)
|
||||
assert len(result[0].data) == 1
|
||||
assert isinstance(result[0].data[0], Image.Image)
|
||||
|
||||
|
||||
def test_source_item_chunk_contents_mixed(tmp_path):
|
||||
@ -302,8 +410,8 @@ def test_source_item_chunk_contents_mixed(tmp_path):
|
||||
result = source._chunk_contents()
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0][0] == "Bla bla"
|
||||
assert isinstance(result[1][0], Image.Image)
|
||||
assert result[0].data[0] == "Bla bla"
|
||||
assert isinstance(result[1].data[0], Image.Image)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -323,7 +431,11 @@ def test_source_item_chunk_contents_mixed(tmp_path):
|
||||
def test_source_item_make_chunk(tmp_path, texts, expected_content):
|
||||
"""Test SourceItem._make_chunk method"""
|
||||
source = SourceItem(
|
||||
sha256=b"test123", content="test", modality="text", tags=["tag1"]
|
||||
sha256=b"test123",
|
||||
content="test",
|
||||
modality="text",
|
||||
tags=["tag1"],
|
||||
size=1024,
|
||||
)
|
||||
# Create actual image
|
||||
image_file = tmp_path / "test.png"
|
||||
@ -335,7 +447,7 @@ def test_source_item_make_chunk(tmp_path, texts, expected_content):
|
||||
data = [*texts, img]
|
||||
metadata = {"extra": "data"}
|
||||
|
||||
chunk = source._make_chunk(data, metadata)
|
||||
chunk = source._make_chunk(extract.DataChunk(data=data), metadata)
|
||||
|
||||
assert chunk.id is not None
|
||||
assert chunk.source == source
|
||||
@ -344,7 +456,12 @@ def test_source_item_make_chunk(tmp_path, texts, expected_content):
|
||||
assert chunk.embedding_model is not None
|
||||
|
||||
# Check that metadata is merged correctly
|
||||
expected_payload = {"source_id": source.id, "tags": ["tag1"], "extra": "data"}
|
||||
expected_payload = {
|
||||
"source_id": source.id,
|
||||
"tags": {"tag1"},
|
||||
"extra": "data",
|
||||
"size": 1024,
|
||||
}
|
||||
assert chunk.item_metadata == expected_payload
|
||||
|
||||
|
||||
@ -355,10 +472,11 @@ def test_source_item_as_payload():
|
||||
content="test",
|
||||
modality="text",
|
||||
tags=["tag1", "tag2"],
|
||||
size=1024,
|
||||
)
|
||||
|
||||
payload = source.as_payload()
|
||||
assert payload == {"source_id": 123, "tags": ["tag1", "tag2"]}
|
||||
assert payload == {"source_id": 123, "tags": ["tag1", "tag2"], "size": 1024}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -544,6 +662,7 @@ def test_mail_message_as_payload(sent_at, expected_date):
|
||||
folder="INBOX",
|
||||
sent_at=sent_at,
|
||||
tags=["tag1", "tag2"],
|
||||
size=1024,
|
||||
)
|
||||
# Manually set id for testing
|
||||
object.__setattr__(mail_message, "id", 123)
|
||||
@ -552,6 +671,7 @@ def test_mail_message_as_payload(sent_at, expected_date):
|
||||
|
||||
expected = {
|
||||
"source_id": 123,
|
||||
"size": 1024,
|
||||
"message_id": "<test@example.com>",
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
@ -646,12 +766,12 @@ def test_email_attachment_as_payload(created_at, expected_date):
|
||||
payload = attachment.as_payload()
|
||||
|
||||
expected = {
|
||||
"source_id": 456,
|
||||
"filename": "document.pdf",
|
||||
"content_type": "application/pdf",
|
||||
"size": 1024,
|
||||
"created_at": expected_date,
|
||||
"mail_message_id": 123,
|
||||
"source_id": 456,
|
||||
"tags": ["pdf", "document"],
|
||||
}
|
||||
assert payload == expected
|
||||
@ -702,7 +822,8 @@ def test_email_attachment_data_chunks(
|
||||
# Verify the method calls
|
||||
mock_extract.assert_called_once_with("text/plain", expected_content)
|
||||
mock_make.assert_called_once_with(
|
||||
["extracted text"], {"extra": "metadata", "source": content_source}
|
||||
extract.DataChunk(data=["extracted text"], metadata={"source": content_source}),
|
||||
{"extra": "metadata"},
|
||||
)
|
||||
assert result == [mock_chunk]
|
||||
|
||||
@ -813,18 +934,20 @@ def test_subclass_deletion_cascades_from_source_item(db_session: Session):
|
||||
# No pages
|
||||
([], []),
|
||||
# Single page
|
||||
(["Page 1 content"], [("Page 1 content", 10)]),
|
||||
(["Page 1 content"], [("Page 1 content", {"type": "page"})]),
|
||||
# Multiple pages
|
||||
(
|
||||
["Page 1", "Page 2", "Page 3"],
|
||||
[
|
||||
("Page 1", 10),
|
||||
("Page 2", 11),
|
||||
("Page 3", 12),
|
||||
(
|
||||
"Page 1\n\nPage 2\n\nPage 3",
|
||||
{"type": "section", "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
("test", {"type": "summary", "tags": {"tag1", "tag2"}}),
|
||||
],
|
||||
),
|
||||
# Empty/whitespace pages filtered out
|
||||
(["", " ", "Page 3"], [("Page 3", 12)]),
|
||||
(["", " ", "Page 3"], [("Page 3", {"type": "page"})]),
|
||||
# All empty - no chunks created
|
||||
(["", " ", " "], []),
|
||||
],
|
||||
@ -845,11 +968,8 @@ def test_book_section_data_chunks(pages, expected_chunks):
|
||||
|
||||
chunks = book_section.data_chunks()
|
||||
expected = [
|
||||
(p, book_section.as_payload() | {"page": i}) for p, i in expected_chunks
|
||||
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks
|
||||
]
|
||||
if content:
|
||||
expected.append((content, book_section.as_payload()))
|
||||
|
||||
assert [(c.content, c.item_metadata) for c in chunks] == expected
|
||||
for c in chunks:
|
||||
assert cast(list, c.file_paths) == []
|
||||
@ -859,16 +979,39 @@ def test_book_section_data_chunks(pages, expected_chunks):
|
||||
"content,expected",
|
||||
[
|
||||
("", []),
|
||||
("Short content", [["Short content"]]),
|
||||
(
|
||||
"Short content",
|
||||
[
|
||||
extract.DataChunk(
|
||||
data=["Short content"], metadata={"tags": ["tag1", "tag2"]}
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
"This is a very long piece of content that should be chunked into multiple pieces when processed.",
|
||||
[
|
||||
[
|
||||
"This is a very long piece of content that should be chunked into multiple pieces when processed."
|
||||
],
|
||||
["This is a very long piece of content that"],
|
||||
["should be chunked into multiple pieces when"],
|
||||
["processed."],
|
||||
extract.DataChunk(
|
||||
data=[
|
||||
"This is a very long piece of content that should be chunked into multiple pieces when processed."
|
||||
],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["This is a very long piece of content that"],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["should be chunked into multiple pieces when"],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["processed."],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["test"],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
@ -906,7 +1049,8 @@ def test_blog_post_chunk_contents_with_images(tmp_path):
|
||||
|
||||
result = blog_post._chunk_contents()
|
||||
result = [
|
||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c] for c in result
|
||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
||||
for c in result
|
||||
]
|
||||
assert result == [
|
||||
["Content with images", img1_path.as_posix(), img2_path.as_posix()]
|
||||
@ -933,9 +1077,9 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
||||
result = blog_post._chunk_contents()
|
||||
|
||||
result = [
|
||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c] for c in result
|
||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
||||
for c in result
|
||||
]
|
||||
print(result)
|
||||
assert result == [
|
||||
[
|
||||
f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
|
||||
@ -950,4 +1094,5 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
||||
f"Second picture is here: {img2_path.as_posix()}",
|
||||
img2_path.as_posix(),
|
||||
],
|
||||
["test"],
|
||||
]
|
||||
|
@ -218,6 +218,7 @@ def test_sync_comic_success(mock_get, mock_image_response, db_session, qdrant):
|
||||
"title": "Test Comic",
|
||||
"url": "https://example.com/comic/1",
|
||||
"source_id": 1,
|
||||
"size": 90,
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
@ -186,7 +186,7 @@ def test_sync_book_success(mock_parse, mock_ebook, db_session, tmp_path, qdrant)
|
||||
"author": "Test Author",
|
||||
"status": "processed",
|
||||
"total_sections": 4,
|
||||
"sections_embedded": 8,
|
||||
"sections_embedded": 4,
|
||||
}
|
||||
|
||||
book = db_session.query(Book).filter(Book.title == "Test Book").first()
|
||||
@ -325,8 +325,4 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
||||
|
||||
# Check that the full content was passed to the embedding function
|
||||
texts = mock_voyage_client.embed.call_args[0][0]
|
||||
assert texts == [
|
||||
[large_page_1.strip()],
|
||||
[large_page_2.strip()],
|
||||
[large_section_content.strip()],
|
||||
]
|
||||
assert texts == [[large_section_content.strip()], ["test"]]
|
||||
|
@ -335,38 +335,19 @@ def test_sync_lesswrong_max_items_limit(mock_fetch, db_session):
|
||||
assert result["max_items"] == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"since_str,expected_days_ago",
|
||||
[
|
||||
("2024-01-01T00:00:00", None), # Specific date
|
||||
(None, 30), # Default should be 30 days ago
|
||||
],
|
||||
)
|
||||
@patch("memory.workers.tasks.forums.fetch_lesswrong_posts")
|
||||
def test_sync_lesswrong_since_parameter(
|
||||
mock_fetch, since_str, expected_days_ago, db_session
|
||||
):
|
||||
def test_sync_lesswrong_since_parameter(mock_fetch, db_session):
|
||||
"""Test that since parameter is handled correctly."""
|
||||
mock_fetch.return_value = []
|
||||
|
||||
if since_str:
|
||||
forums.sync_lesswrong(since=since_str)
|
||||
expected_since = datetime.fromisoformat(since_str)
|
||||
else:
|
||||
forums.sync_lesswrong()
|
||||
# Should default to 30 days ago
|
||||
expected_since = datetime.now() - timedelta(days=30)
|
||||
forums.sync_lesswrong(since="2024-01-01T00:00:00")
|
||||
expected_since = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
|
||||
# Verify fetch was called with correct since date
|
||||
call_args = mock_fetch.call_args[0]
|
||||
actual_since = call_args[0]
|
||||
|
||||
if expected_days_ago:
|
||||
# For default case, check it's approximately 30 days ago
|
||||
time_diff = abs((actual_since - expected_since).total_seconds())
|
||||
assert time_diff < 120 # Within 2 minute tolerance for slow tests
|
||||
else:
|
||||
assert actual_since == expected_since
|
||||
assert actual_since == expected_since
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1,3 +1,4 @@
|
||||
# FIXME: Most of this was vibe-coded
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch, call
|
||||
@ -15,6 +16,9 @@ from memory.workers.tasks.maintenance import (
|
||||
reingest_missing_chunks,
|
||||
reingest_item,
|
||||
reingest_empty_source_items,
|
||||
update_metadata_for_item,
|
||||
update_metadata_for_source_items,
|
||||
_payloads_equal,
|
||||
)
|
||||
|
||||
|
||||
@ -596,3 +600,372 @@ def test_reingest_empty_source_items_invalid_type(db_session):
|
||||
|
||||
def test_reingest_missing_chunks_no_chunks(db_session):
|
||||
assert reingest_missing_chunks() == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload1,payload2,expected",
|
||||
[
|
||||
# Identical payloads
|
||||
(
|
||||
{"tags": ["a", "b"], "source_id": 1, "title": "Test"},
|
||||
{"tags": ["a", "b"], "source_id": 1, "title": "Test"},
|
||||
True,
|
||||
),
|
||||
# Different tag order (should be equal)
|
||||
(
|
||||
{"tags": ["a", "b"], "source_id": 1},
|
||||
{"tags": ["b", "a"], "source_id": 1},
|
||||
True,
|
||||
),
|
||||
# Different tags
|
||||
(
|
||||
{"tags": ["a", "b"], "source_id": 1},
|
||||
{"tags": ["a", "c"], "source_id": 1},
|
||||
False,
|
||||
),
|
||||
# Different non-tag fields
|
||||
({"tags": ["a"], "source_id": 1}, {"tags": ["a"], "source_id": 2}, False),
|
||||
# Missing tags in one payload
|
||||
({"tags": ["a"], "source_id": 1}, {"source_id": 1}, False),
|
||||
# Empty tags (should be equal)
|
||||
({"tags": [], "source_id": 1}, {"source_id": 1}, True),
|
||||
],
|
||||
)
|
||||
def test_payloads_equal(payload1, payload2, expected):
|
||||
"""Test the _payloads_equal helper function."""
|
||||
assert _payloads_equal(payload1, payload2) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
|
||||
def test_update_metadata_for_item_success(db_session, qdrant, item_type):
|
||||
"""Test successful metadata update for an item with chunks."""
|
||||
if item_type == "MailMessage":
|
||||
item = MailMessage(
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["original", "test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="STORED",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content="Test content",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
modality = "mail"
|
||||
else:
|
||||
item = BlogPost(
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["original", "test"],
|
||||
size=100,
|
||||
mime_type="text/html",
|
||||
embed_status="STORED",
|
||||
url="https://example.com/post",
|
||||
title="Test Blog Post",
|
||||
author="Author Name",
|
||||
content="Test blog content",
|
||||
modality="blog",
|
||||
)
|
||||
modality = "blog"
|
||||
|
||||
db_session.add(item)
|
||||
db_session.flush()
|
||||
|
||||
# Add chunks to the item
|
||||
chunk_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
chunks = [
|
||||
Chunk(
|
||||
id=chunk_id,
|
||||
source=item,
|
||||
content=f"Test chunk content {i}",
|
||||
embedding_model="test-model",
|
||||
)
|
||||
for i, chunk_id in enumerate(chunk_ids)
|
||||
]
|
||||
db_session.add_all(chunks)
|
||||
db_session.commit()
|
||||
|
||||
# Setup Qdrant with existing payloads
|
||||
qd.ensure_collection_exists(qdrant, modality, 1024)
|
||||
|
||||
existing_payloads = [
|
||||
{"tags": ["existing", "qdrant"], "source_id": item.id, "old_field": "value"}
|
||||
for _ in chunk_ids
|
||||
]
|
||||
qd.upsert_vectors(
|
||||
qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids), existing_payloads
|
||||
)
|
||||
|
||||
# Mock the qdrant functions to track calls
|
||||
with (
|
||||
patch(
|
||||
"memory.workers.tasks.maintenance.qdrant.get_payloads"
|
||||
) as mock_get_payloads,
|
||||
patch(
|
||||
"memory.workers.tasks.maintenance.qdrant.set_payload"
|
||||
) as mock_set_payload,
|
||||
):
|
||||
# Return the existing payloads
|
||||
mock_get_payloads.return_value = {
|
||||
chunk_id: payload for chunk_id, payload in zip(chunk_ids, existing_payloads)
|
||||
}
|
||||
|
||||
result = update_metadata_for_item(str(item.id), item_type)
|
||||
|
||||
# Verify result
|
||||
assert result["status"] == "success"
|
||||
assert result["updated_chunks"] == 3
|
||||
assert result["errors"] == 0
|
||||
|
||||
# Verify batch retrieval was called once
|
||||
mock_get_payloads.assert_called_once_with(qdrant, modality, chunk_ids)
|
||||
|
||||
# Verify set_payload was called for each chunk with merged tags
|
||||
assert mock_set_payload.call_count == 3
|
||||
for call in mock_set_payload.call_args_list:
|
||||
args, kwargs = call
|
||||
client, collection, chunk_id, updated_payload = args
|
||||
|
||||
# Check that tags were merged (existing + new)
|
||||
expected_tags = set(
|
||||
[
|
||||
"existing",
|
||||
"qdrant",
|
||||
"original",
|
||||
"test",
|
||||
]
|
||||
)
|
||||
if item_type == "MailMessage":
|
||||
expected_tags.update(["recipient@example.com", "sender@example.com"])
|
||||
|
||||
actual_tags = set(updated_payload["tags"])
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
# Check that new metadata is present
|
||||
assert updated_payload["source_id"] == item.id
|
||||
|
||||
|
||||
def test_update_metadata_for_item_no_changes(db_session, qdrant):
|
||||
"""Test that no updates are made when metadata hasn't changed."""
|
||||
item = MailMessage(
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="STORED",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content="Test content",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
db_session.add(item)
|
||||
db_session.flush()
|
||||
|
||||
chunk_id = str(uuid.uuid4())
|
||||
chunk = Chunk(
|
||||
id=chunk_id,
|
||||
source=item,
|
||||
content="Test chunk content",
|
||||
embedding_model="test-model",
|
||||
)
|
||||
db_session.add(chunk)
|
||||
db_session.commit()
|
||||
|
||||
# Setup payload that matches what the item would generate
|
||||
item_payload = item.as_payload()
|
||||
existing_payload = {chunk_id: item_payload}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"memory.workers.tasks.maintenance.qdrant.get_payloads"
|
||||
) as mock_get_payloads,
|
||||
patch(
|
||||
"memory.workers.tasks.maintenance.qdrant.set_payload"
|
||||
) as mock_set_payload,
|
||||
):
|
||||
mock_get_payloads.return_value = existing_payload
|
||||
|
||||
result = update_metadata_for_item(str(item.id), "MailMessage")
|
||||
|
||||
# Verify no updates were made
|
||||
assert result["status"] == "success"
|
||||
assert result["updated_chunks"] == 0
|
||||
assert result["errors"] == 0
|
||||
|
||||
# Verify set_payload was never called since nothing changed
|
||||
mock_set_payload.assert_not_called()
|
||||
|
||||
|
||||
def test_update_metadata_for_item_no_chunks(db_session):
|
||||
"""Test updating metadata for an item with no chunks."""
|
||||
item = MailMessage(
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content="Test content",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
db_session.add(item)
|
||||
db_session.commit()
|
||||
|
||||
result = update_metadata_for_item(str(item.id), "MailMessage")
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["updated_chunks"] == 0
|
||||
assert result["errors"] == 0
|
||||
|
||||
|
||||
def test_update_metadata_for_item_not_found(db_session):
|
||||
"""Test updating metadata for a non-existent item."""
|
||||
non_existent_id = "999"
|
||||
result = update_metadata_for_item(non_existent_id, "MailMessage")
|
||||
|
||||
assert result == {"status": "error", "error": f"Item {non_existent_id} not found"}
|
||||
|
||||
|
||||
def test_update_metadata_for_item_invalid_type(db_session):
|
||||
"""Test updating metadata with an invalid item type."""
|
||||
result = update_metadata_for_item("1", "invalid_type")
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "Unsupported item type invalid_type" in result["error"]
|
||||
assert "Available types:" in result["error"]
|
||||
|
||||
|
||||
def test_update_metadata_for_item_missing_chunks_in_qdrant(db_session, qdrant):
|
||||
"""Test when some chunks exist in DB but not in Qdrant."""
|
||||
item = MailMessage(
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="STORED",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content="Test content",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
db_session.add(item)
|
||||
db_session.flush()
|
||||
|
||||
chunk_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
chunks = [
|
||||
Chunk(
|
||||
id=chunk_id,
|
||||
source=item,
|
||||
content=f"Test chunk content {i}",
|
||||
embedding_model="test-model",
|
||||
)
|
||||
for i, chunk_id in enumerate(chunk_ids)
|
||||
]
|
||||
db_session.add_all(chunks)
|
||||
db_session.commit()
|
||||
|
||||
# Mock qdrant to return payloads for only some chunks
|
||||
with (
|
||||
patch(
|
||||
"memory.workers.tasks.maintenance.qdrant.get_payloads"
|
||||
) as mock_get_payloads,
|
||||
patch(
|
||||
"memory.workers.tasks.maintenance.qdrant.set_payload"
|
||||
) as mock_set_payload,
|
||||
):
|
||||
# Only return payload for first chunk
|
||||
mock_get_payloads.return_value = {
|
||||
chunk_ids[0]: {"tags": ["existing"], "source_id": item.id}
|
||||
}
|
||||
|
||||
result = update_metadata_for_item(str(item.id), "MailMessage")
|
||||
|
||||
# Should only update the chunk that was found in Qdrant
|
||||
assert result["status"] == "success"
|
||||
assert result["updated_chunks"] == 1
|
||||
assert result["errors"] == 0
|
||||
|
||||
# Only one set_payload call for the found chunk
|
||||
assert mock_set_payload.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("item_type", ["MailMessage", "BlogPost"])
|
||||
def test_update_metadata_for_source_items_success(db_session, item_type):
|
||||
"""Test updating metadata for all items of a given type."""
|
||||
if item_type == "MailMessage":
|
||||
items = [
|
||||
MailMessage(
|
||||
sha256=f"test_hash_{i}".encode() + bytes(32 - len(f"test_hash_{i}")),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="STORED",
|
||||
message_id=f"<test{i}@example.com>",
|
||||
subject=f"Test Subject {i}",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
content=f"Test content {i}",
|
||||
folder="INBOX",
|
||||
modality="mail",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
else:
|
||||
items = [
|
||||
BlogPost(
|
||||
sha256=f"test_hash_{i}".encode() + bytes(32 - len(f"test_hash_{i}")),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="text/html",
|
||||
embed_status="STORED",
|
||||
url=f"https://example.com/post{i}",
|
||||
title=f"Test Blog Post {i}",
|
||||
author="Author Name",
|
||||
content=f"Test blog content {i}",
|
||||
modality="blog",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
db_session.add_all(items)
|
||||
db_session.commit()
|
||||
|
||||
with patch.object(update_metadata_for_item, "delay") as mock_update:
|
||||
result = update_metadata_for_source_items(item_type)
|
||||
|
||||
assert result == {"status": "success", "items": 3}
|
||||
|
||||
# Verify that update_metadata_for_item.delay was called for each item
|
||||
assert mock_update.call_count == 3
|
||||
expected_calls = [call(item.id, item_type) for item in items]
|
||||
mock_update.assert_has_calls(expected_calls, any_order=True)
|
||||
|
||||
|
||||
def test_update_metadata_for_source_items_no_items(db_session):
|
||||
"""Test when there are no items of the specified type."""
|
||||
with patch.object(update_metadata_for_item, "delay") as mock_update:
|
||||
result = update_metadata_for_source_items("MailMessage")
|
||||
|
||||
assert result == {"status": "success", "items": 0}
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
def test_update_metadata_for_source_items_invalid_type(db_session):
|
||||
"""Test updating metadata for an invalid item type."""
|
||||
result = update_metadata_for_source_items("invalid_type")
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "Unsupported item type invalid_type" in result["error"]
|
||||
assert "Available types:" in result["error"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user