celery beat + image embedding

This commit is contained in:
Daniel O'Connell 2025-04-28 22:10:18 +02:00
parent 869e5ac6b4
commit 453aed7c19
18 changed files with 490 additions and 44 deletions

View File

@ -361,13 +361,13 @@ def downgrade() -> None:
# Drop tables # Drop tables
op.drop_table("photo") op.drop_table("photo")
op.drop_table("misc_doc") op.drop_table("misc_doc")
op.drop_table("mail_message")
op.drop_table("git_commit") op.drop_table("git_commit")
op.drop_table("chat_message") op.drop_table("chat_message")
op.drop_table("book_doc") op.drop_table("book_doc")
op.drop_table("blog_post") op.drop_table("blog_post")
op.drop_table("github_item") op.drop_table("github_item")
op.drop_table("email_attachment") op.drop_table("email_attachment")
op.drop_table("mail_message")
op.drop_table("source_item") op.drop_table("source_item")
op.drop_table("rss_feeds") op.drop_table("rss_feeds")
op.drop_table("email_accounts") op.drop_table("email_accounts")

View File

@ -21,6 +21,8 @@ volumes:
# ------------------------------ X-templates ---------------------------- # ------------------------------ X-templates ----------------------------
x-common-env: &env x-common-env: &env
RABBITMQ_USER: kb RABBITMQ_USER: kb
QDRANT_HOST: qdrant
DB_HOST: postgres
FILE_STORAGE_DIR: /app/memory_files FILE_STORAGE_DIR: /app/memory_files
TZ: "Etc/UTC" TZ: "Etc/UTC"
@ -165,6 +167,13 @@ services:
# - ./acme.json:/acme.json:rw # - ./acme.json:/acme.json:rw
# ------------------------------------------------------------ Celery workers # ------------------------------------------------------------ Celery workers
worker-email:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "email"
deploy: {resources: {limits: {cpus: "2", memory: 3g}}}
worker-text: worker-text:
<<: *worker-base <<: *worker-base
environment: environment:
@ -207,6 +216,22 @@ services:
QUEUES: "docs" QUEUES: "docs"
deploy: {resources: {limits: {cpus: "1", memory: 1g}}} deploy: {resources: {limits: {cpus: "1", memory: 1g}}}
ingest-hub:
<<: *worker-base
build:
context: .
dockerfile: docker/ingest_hub/Dockerfile
environment:
<<: *worker-env
volumes:
- file_storage:/app/memory_files:rw
tmpfs:
- /tmp
- /var/tmp
- /var/log/supervisor
- /var/run/supervisor
deploy: {resources: {limits: {cpus: "0.5", memory: 512m}}}
# ------------------------------------------------------------ watchtower (auto-update) # ------------------------------------------------------------ watchtower (auto-update)
watchtower: watchtower:
image: containrrr/watchtower image: containrrr/watchtower

View File

@ -0,0 +1,32 @@
FROM python:3.11-slim
WORKDIR /app
# Copy requirements files and setup
COPY requirements-*.txt ./
COPY setup.py ./
COPY src/ ./src/
# Install dependencies
RUN apt-get update && apt-get install -y \
libpq-dev gcc supervisor && \
pip install -e ".[workers]" && \
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
# Create and copy entrypoint script
COPY docker/ingest_hub/supervisor.conf /etc/supervisor/conf.d/supervisor.conf
COPY docker/workers/entry.sh ./entry.sh
RUN chmod +x entry.sh
# Create required tmpfs directories for supervisor
RUN mkdir -p /var/log/supervisor /var/run/supervisor
# Create user and set permissions
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor
USER kb
# Default queues to process
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email"
ENV PYTHONPATH="/app"
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]

View File

@ -0,0 +1,16 @@
[supervisord]
nodaemon=true
loglevel=info
logfile=/dev/stdout
logfile_maxbytes=0
user=kb
pidfile=/dev/null
[program:celery-beat]
command=celery -A memory.workers.ingest beat --pidfile= --schedule=/tmp/celerybeat-schedule
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
stderr_logfile=/dev/stderr
stderr_logfile_maxbytes=0
autorestart=true
startsecs=10

View File

@ -22,7 +22,7 @@ RUN useradd -m kb && chown -R kb /app
USER kb USER kb
# Default queues to process # Default queues to process
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs" ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email"
ENV PYTHONPATH="/app" ENV PYTHONPATH="/app"
ENTRYPOINT ["./entry.sh"] ENTRYPOINT ["./entry.sh"]

View File

@ -2,6 +2,7 @@ sqlalchemy==2.0.30
psycopg2-binary==2.9.9 psycopg2-binary==2.9.9
pydantic==2.7.1 pydantic==2.7.1
alembic==1.13.1 alembic==1.13.1
dotenv==1.1.0 dotenv==0.9.9
voyageai==0.3.2 voyageai==0.3.2
qdrant-client==1.9.0 qdrant-client==1.9.0
PyMuPDF==1.25.5

View File

@ -1,7 +1,9 @@
import pathlib import pathlib
from typing import Literal, TypedDict, Iterable from typing import Literal, TypedDict, Iterable, Any
import voyageai import voyageai
import re import re
import uuid
from memory.common import extract, settings
# Chunking configuration # Chunking configuration
MAX_TOKENS = 32000 # VoyageAI max context window MAX_TOKENS = 32000 # VoyageAI max context window
@ -11,6 +13,8 @@ CHARS_PER_TOKEN = 4
DistanceType = Literal["Cosine", "Dot", "Euclidean"] DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float] Vector = list[float]
Embedding = tuple[str, Vector, dict[str, Any]]
class Collection(TypedDict): class Collection(TypedDict):
dimension: int dimension: int
@ -163,17 +167,48 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
yield current.strip() yield current.strip()
def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]: def embed_chunks(chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
vo = voyageai.Client() vo = voyageai.Client()
return vo.embed(chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS), model=model) return vo.embed(chunks, model=model).embeddings
def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]: def embed_text(texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
return embed_text(file_path.read_text(), model, n_dimensions) chunks = [c for text in texts for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
return embed_chunks(chunks, model)
def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> tuple[str, list[Vector]]: def embed_file(file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
if isinstance(content, bytes): return embed_text([file_path.read_text()], model)
content = content.decode("utf-8")
return get_modality(mime_type), embed_text(content, model, n_dimensions)
def embed_mixed(items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL) -> list[Vector]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
if isinstance(item, str):
return [c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
return [item]
chunks = [c for item in items for c in to_chunks(item)]
return embed_chunks(chunks, model)
def embed_page(page: dict[str, Any]) -> list[Vector]:
contents = page["contents"]
if all(isinstance(c, str) for c in contents):
return embed_text(contents, model=settings.TEXT_EMBEDDING_MODEL)
return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL)
def embed(
mime_type: str,
content: bytes | str | pathlib.Path,
metadata: dict[str, Any] = {},
) -> tuple[str, list[Embedding]]:
modality = get_modality(mime_type)
pages = extract.extract_content(mime_type, content)
vectors = [
(str(uuid.uuid4()), vector, page.get("metadata", {}) | metadata)
for page in pages
for vector in embed_page(page)
]
return modality, vectors

View File

@ -0,0 +1,74 @@
from contextlib import contextmanager
import io
import pathlib
import tempfile
import pymupdf # PyMuPDF
from PIL import Image
from typing import Any, TypedDict, Generator
MulitmodalChunk = Image.Image | str
class Page(TypedDict):
contents: list[MulitmodalChunk]
metadata: dict[str, Any]
@contextmanager
def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]:
if isinstance(content, pathlib.Path):
yield content
else:
mode = "w" if isinstance(content, str) else "wb"
with tempfile.NamedTemporaryFile(mode=mode) as f:
f.write(content)
f.flush()
yield pathlib.Path(f.name)
def page_to_image(page: pymupdf.Page) -> Image.Image:
pix = page.get_pixmap()
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]:
with as_file(content) as file_path:
with pymupdf.open(file_path) as pdf:
return [{
"contents": page_to_image(page),
"metadata": {
"page": page.number,
"width": page.rect.width,
"height": page.rect.height,
}
} for page in pdf.pages()]
def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
if isinstance(content, pathlib.Path):
image = Image.open(content)
elif isinstance(content, bytes):
image = Image.open(io.BytesIO(content))
else:
raise ValueError(f"Unsupported content type: {type(content)}")
return [{"contents": image, "metadata": {}}]
def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
if isinstance(content, pathlib.Path):
content = content.read_text()
if isinstance(content, bytes):
content = content.decode("utf-8")
return [{"contents": [content], "metadata": {}}]
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
if mime_type == "application/pdf":
return doc_to_images(content)
if mime_type.startswith("text/"):
return extract_text(content)
if mime_type.startswith("image/"):
return extract_image(content)
# Return empty list for unknown mime types
return []

View File

@ -7,16 +7,19 @@ load_dotenv()
def boolean_env(key: str, default: bool = False) -> bool: def boolean_env(key: str, default: bool = False) -> bool:
return os.getenv(key, "0").lower() in ("1", "true", "yes") return os.getenv(key, "0").lower() in ("1", "true", "yes")
# Database settings
DB_USER = os.getenv("DB_USER", "kb") DB_USER = os.getenv("DB_USER", "kb")
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb") if password_file := os.getenv("POSTGRES_PASSWORD_FILE"):
DB_PASSWORD = pathlib.Path(password_file).read_text().strip()
else:
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
DB_HOST = os.getenv("DB_HOST", "postgres") DB_HOST = os.getenv("DB_HOST", "postgres")
DB_PORT = os.getenv("DB_PORT", "5432") DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME", "kb") DB_NAME = os.getenv("DB_NAME", "kb")
def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, db=DB_NAME): def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, db=DB_NAME):
return f"postgresql://{user}:{password}@{host}:{port}/{db}" return f"postgresql://{user}:{password}@{host}:{port}/{db}"
DB_URL = os.getenv("DATABASE_URL", make_db_url()) DB_URL = os.getenv("DATABASE_URL", make_db_url())
@ -33,3 +36,13 @@ QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334"))
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None) QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False) QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False)
QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60")) QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
# Worker settings
BEAT_LOOP_INTERVAL = int(os.getenv("BEAT_LOOP_INTERVAL", 3600))
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 600))
# Embedding settings
TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large")
MIXED_EMBEDDING_MODEL = os.getenv("MIXED_EMBEDDING_MODEL", "voyage-multimodal-3")

View File

@ -1,6 +1,6 @@
import os import os
from celery import Celery from celery import Celery
from memory.common import settings
def rabbit_url() -> str: def rabbit_url() -> str:
user = os.getenv("RABBITMQ_USER", "guest") user = os.getenv("RABBITMQ_USER", "guest")
@ -8,10 +8,11 @@ def rabbit_url() -> str:
return f"amqp://{user}:{password}@rabbitmq:5672//" return f"amqp://{user}:{password}@rabbitmq:5672//"
app = Celery("memory", app = Celery(
broker=rabbit_url(), "memory",
backend=os.getenv("CELERY_RESULT_BACKEND", broker=rabbit_url(),
"db+postgresql://kb:kb@postgres/kb")) backend=os.getenv("CELERY_RESULT_BACKEND", f"db+{settings.DB_URL}")
)
app.autodiscover_tasks(["memory.workers.tasks"]) app.autodiscover_tasks(["memory.workers.tasks"])
@ -22,8 +23,8 @@ app.conf.update(
task_reject_on_worker_lost=True, task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1, worker_prefetch_multiplier=1,
task_routes={ task_routes={
# Task routing configuration
"memory.workers.tasks.text.*": {"queue": "medium_embed"}, "memory.workers.tasks.text.*": {"queue": "medium_embed"},
"memory.workers.tasks.email.*": {"queue": "email"},
"memory.workers.tasks.photo.*": {"queue": "photo_embed"}, "memory.workers.tasks.photo.*": {"queue": "photo_embed"},
"memory.workers.tasks.ocr.*": {"queue": "low_ocr"}, "memory.workers.tasks.ocr.*": {"queue": "low_ocr"},
"memory.workers.tasks.git.*": {"queue": "git_summary"}, "memory.workers.tasks.git.*": {"queue": "git_summary"},

View File

@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
from collections import defaultdict from collections import defaultdict
from memory.common import settings, embedding from memory.common import settings, embedding
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common.qdrant import get_qdrant_client, upsert_vectors from memory.common import qdrant
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -471,17 +471,18 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
def vectorize_email(email: MailMessage) -> list[float]: def vectorize_email(email: MailMessage) -> list[float]:
qdrant_client = get_qdrant_client() qdrant_client = qdrant.get_qdrant_client()
chunks = embedding.embed_text(email.body_raw) _, chunks = embedding.embed(
payloads = [email.as_payload()] * len(chunks) "text/plain", email.body_raw, metadata=email.as_payload(),
vector_ids = [str(uuid.uuid4()) for _ in chunks] )
upsert_vectors( vector_ids, vectors, metadata = zip(*chunks)
qdrant.upsert_vectors(
client=qdrant_client, client=qdrant_client,
collection_name="mail", collection_name="mail",
ids=vector_ids, ids=vector_ids,
vectors=chunks, vectors=vectors,
payloads=payloads, payloads=metadata,
) )
vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids] vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids]
@ -491,15 +492,14 @@ def vectorize_email(email: MailMessage) -> list[float]:
content = pathlib.Path(attachment.file_path).read_bytes() content = pathlib.Path(attachment.file_path).read_bytes()
else: else:
content = attachment.content content = attachment.content
collection, vectors = embedding.embed(attachment.content_type, content) collection, chunks = embedding.embed(attachment.content_type, content, metadata=attachment.as_payload())
attachment.source.vector_ids = vector_ids ids, vectors, metadata = zip(*chunks)
embeds[collection].extend( attachment.source.vector_ids = ids
(str(uuid.uuid4()), vector, attachment.as_payload()) for vector in vectors embeds[collection].extend(chunks)
)
for collection, embeds in embeds.items(): for collection, chunks in embeds.items():
ids, vectors, payloads = zip(*embeds) ids, vectors, payloads = zip(*chunks)
upsert_vectors( qdrant.upsert_vectors(
client=qdrant_client, client=qdrant_client,
collection_name=collection, collection_name=collection,
ids=ids, ids=ids,

View File

@ -0,0 +1,14 @@
from celery.schedules import schedule
from memory.workers.tasks.email import SYNC_ALL_ACCOUNTS
from memory.common import settings
from memory.workers.celery_app import app
@app.on_after_configure.connect
def register_mail_schedules(sender, **_):
sender.add_periodic_task(
schedule=schedule(settings.EMAIL_SYNC_INTERVAL),
sig=app.signature(SYNC_ALL_ACCOUNTS),
name="sync-mail-all",
options={"queue": "email"},
)

View File

@ -2,3 +2,7 @@
Import sub-modules so Celery can register their @app.task decorators. Import sub-modules so Celery can register their @app.task decorators.
""" """
from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
__all__ = ["text", "photo", "ocr", "git", "rss", "docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"]

View File

@ -18,8 +18,12 @@ from memory.workers.email import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PROCESS_EMAIL = "memory.workers.tasks.email.process_message"
SYNC_ACCOUNT = "memory.workers.tasks.email.sync_account"
SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
@app.task(name="memory.email.process_message")
@app.task(name=PROCESS_EMAIL)
def process_message( def process_message(
account_id: int, message_id: str, folder: str, raw_email: str, account_id: int, message_id: str, folder: str, raw_email: str,
) -> int | None: ) -> int | None:
@ -35,6 +39,7 @@ def process_message(
Returns: Returns:
source_id if successful, None otherwise source_id if successful, None otherwise
""" """
logger.info(f"Processing message {message_id} for account {account_id}")
if not raw_email.strip(): if not raw_email.strip():
logger.warning(f"Empty email message received for account {account_id}") logger.warning(f"Empty email message received for account {account_id}")
return None return None
@ -76,7 +81,7 @@ def process_message(
return source_item.id return source_item.id
@app.task(name="memory.email.sync_account") @app.task(name=SYNC_ACCOUNT)
def sync_account(account_id: int) -> dict: def sync_account(account_id: int) -> dict:
""" """
Synchronize emails from a specific account. Synchronize emails from a specific account.
@ -87,6 +92,7 @@ def sync_account(account_id: int) -> dict:
Returns: Returns:
dict with stats about the sync operation dict with stats about the sync operation
""" """
logger.info(f"Syncing account {account_id}")
with make_session() as db: with make_session() as db:
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first() account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
if not account or not account.active: if not account or not account.active:
@ -124,7 +130,7 @@ def sync_account(account_id: int) -> dict:
} }
@app.task(name="memory.email.sync_all_accounts") @app.task(name=SYNC_ALL_ACCOUNTS)
def sync_all_accounts() -> list[dict]: def sync_all_accounts() -> list[dict]:
""" """
Synchronize all active email accounts. Synchronize all active email accounts.

View File

@ -7,6 +7,7 @@ from pathlib import Path
import pytest import pytest
import qdrant_client import qdrant_client
import voyageai
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from testcontainers.qdrant import QdrantContainer from testcontainers.qdrant import QdrantContainer
@ -210,3 +211,8 @@ def qdrant():
initialize_collections(client) initialize_collections(client)
yield client yield client
@pytest.fixture(autouse=True)
def mock_voyage_client():
with patch.object(voyageai, "Client", autospec=True) as mock_client:
yield mock_client()

BIN
tests/data/regulamin.pdf Normal file

Binary file not shown.

View File

@ -1,5 +1,19 @@
import uuid
import pytest import pytest
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN from unittest.mock import Mock, patch
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count, get_modality, embed_text, embed_file, embed_mixed, embed_page, embed
@pytest.fixture
def mock_embed(mock_voyage_client):
vectors = ([i] for i in range(1000))
def embed(texts, model):
return Mock(embeddings=[next(vectors) for _ in texts])
mock_voyage_client.embed.side_effect = embed
return mock_voyage_client
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -270,3 +284,77 @@ def test_chunk_text_very_long_sentences():
'chunks by the function.', 'chunks by the function.',
] ]
@pytest.mark.parametrize(
"string, expected_count",
[
("", 0),
("a" * CHARS_PER_TOKEN, 1),
("a" * (CHARS_PER_TOKEN * 2), 2),
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
]
)
def test_approx_token_count(string, expected_count):
assert approx_token_count(string) == expected_count
@pytest.mark.parametrize(
"mime_type, expected_modality",
[
("text/plain", "doc"),
("text/html", "doc"),
("image/jpeg", "photo"),
("image/png", "photo"),
("application/pdf", "book"),
("application/epub+zip", "book"),
("application/mobi", "book"),
("application/x-mobipocket-ebook", "book"),
("audio/mp3", "unknown"),
("video/mp4", "unknown"),
("text/something-new", "doc"), # Should match by 'text/' stem
("image/something-new", "photo"), # Should match by 'image/' stem
("custom/format", "unknown"), # No matching stem
]
)
def test_get_modality(mime_type, expected_modality):
assert get_modality(mime_type) == expected_modality
def test_embed_text(mock_embed):
texts = ["text1 with words", "text2"]
assert embed_text(texts) == [[0], [1]]
def test_embed_file(mock_embed, tmp_path):
mock_file = tmp_path / "test.txt"
mock_file.write_text("file content")
assert embed_file(mock_file) == [[0]]
def test_embed_mixed(mock_embed):
items = ["text", {"type": "image", "data": "base64"}]
assert embed_mixed(items) == [[0], [1]]
def test_embed_page_text_only(mock_embed):
page = {"contents": ["text1", "text2"]}
assert embed_page(page) == [[0], [1]]
def test_embed_page_mixed_content(mock_embed):
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
assert embed_page(page) == [[0], [1]]
def test_embed(mock_embed):
mime_type = "text/plain"
content = "sample content"
metadata = {"source": "test"}
with patch.object(uuid, "uuid4", return_value="id1"):
modality, vectors = embed(mime_type, content, metadata)
assert modality == "doc"
assert vectors == [('id1', [0], {'source': 'test'})]

View File

@ -0,0 +1,131 @@
import pathlib
import pytest
import pymupdf
from PIL import Image
import io
from memory.common.extract import as_file, extract_text, extract_content, Page, doc_to_images, extract_image
REGULAMIN = pathlib.Path(__file__).parent.parent.parent / "data" / "regulamin.pdf"
def test_as_file_with_path(tmp_path):
test_path = tmp_path / "test.txt"
test_path.write_text("test content")
with as_file(test_path) as path:
assert path == test_path
assert path.read_text() == "test content"
def test_as_file_with_bytes():
content = b"test content"
with as_file(content) as path:
assert pathlib.Path(path).read_bytes() == content
def test_as_file_with_str():
content = "test content"
with as_file(content) as path:
assert pathlib.Path(path).read_text() == content
@pytest.mark.parametrize(
"input_content,expected",
[
("simple text", [{"contents": ["simple text"], "metadata": {}}]),
(b"bytes text", [{"contents": ["bytes text"], "metadata": {}}]),
]
)
def test_extract_text(input_content, expected):
assert extract_text(input_content) == expected
def test_extract_text_with_path(tmp_path):
test_file = tmp_path / "test.txt"
test_file.write_text("file text content")
assert extract_text(test_file) == [{"contents": ["file text content"], "metadata": {}}]
def test_doc_to_images():
result = doc_to_images(REGULAMIN)
assert len(result) == 2
with pymupdf.open(REGULAMIN) as pdf:
for page, pdf_page in zip(result, pdf.pages()):
pix = pdf_page.get_pixmap()
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
assert page["contents"] == img
assert page["metadata"] == {
"page": pdf_page.number,
"width": pdf_page.rect.width,
"height": pdf_page.rect.height,
}
def test_extract_image_with_path(tmp_path):
img = Image.new('RGB', (100, 100), color='red')
img_path = tmp_path / "test.png"
img.save(img_path)
page, = extract_image(img_path)
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
assert page["metadata"] == {}
def test_extract_image_with_bytes():
img = Image.new('RGB', (100, 100), color='blue')
buffer = io.BytesIO()
img.save(buffer, format='PNG')
img_bytes = buffer.getvalue()
page, = extract_image(img_bytes)
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
assert page["metadata"] == {}
def test_extract_image_with_str():
with pytest.raises(ValueError):
extract_image("test")
@pytest.mark.parametrize(
"mime_type,content",
[
("text/plain", "Text content"),
("text/html", "<html>content</html>"),
("text/markdown", "# Heading"),
("text/csv", "a,b,c"),
]
)
def test_extract_content_different_text_types(mime_type, content):
assert extract_content(mime_type, content) == [{"contents": [content], "metadata": {}}]
def test_extract_content_pdf():
result = extract_content("application/pdf", REGULAMIN)
assert len(result) == 2
assert all(isinstance(page["contents"], Image.Image) for page in result)
assert all("page" in page["metadata"] for page in result)
assert all("width" in page["metadata"] for page in result)
assert all("height" in page["metadata"] for page in result)
def test_extract_content_image(tmp_path):
# Create a test image
img = Image.new('RGB', (100, 100), color='red')
img_path = tmp_path / "test_img.png"
img.save(img_path)
result = extract_content("image/png", img_path)
assert len(result) == 1
assert isinstance(result[0]["contents"], Image.Image)
assert result[0]["contents"].size == (100, 100)
assert result[0]["metadata"] == {}
def test_extract_content_unsupported_type():
assert extract_content("unsupported/type", "content") == []