diff --git a/db/migrations/versions/20250427_171537_initial_structure.py b/db/migrations/versions/20250427_171537_initial_structure.py index d1a7a43..56f3b38 100644 --- a/db/migrations/versions/20250427_171537_initial_structure.py +++ b/db/migrations/versions/20250427_171537_initial_structure.py @@ -361,13 +361,13 @@ def downgrade() -> None: # Drop tables op.drop_table("photo") op.drop_table("misc_doc") - op.drop_table("mail_message") op.drop_table("git_commit") op.drop_table("chat_message") op.drop_table("book_doc") op.drop_table("blog_post") op.drop_table("github_item") op.drop_table("email_attachment") + op.drop_table("mail_message") op.drop_table("source_item") op.drop_table("rss_feeds") op.drop_table("email_accounts") diff --git a/docker-compose.yaml b/docker-compose.yaml index 9fb8a91..dd1f07a 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -21,6 +21,8 @@ volumes: # ------------------------------ X-templates ---------------------------- x-common-env: &env RABBITMQ_USER: kb + QDRANT_HOST: qdrant + DB_HOST: postgres FILE_STORAGE_DIR: /app/memory_files TZ: "Etc/UTC" @@ -165,6 +167,13 @@ services: # - ./acme.json:/acme.json:rw # ------------------------------------------------------------ Celery workers + worker-email: + <<: *worker-base + environment: + <<: *worker-env + QUEUES: "email" + deploy: {resources: {limits: {cpus: "2", memory: 3g}}} + worker-text: <<: *worker-base environment: @@ -207,6 +216,22 @@ services: QUEUES: "docs" deploy: {resources: {limits: {cpus: "1", memory: 1g}}} + ingest-hub: + <<: *worker-base + build: + context: . + dockerfile: docker/ingest_hub/Dockerfile + environment: + <<: *worker-env + volumes: + - file_storage:/app/memory_files:rw + tmpfs: + - /tmp + - /var/tmp + - /var/log/supervisor + - /var/run/supervisor + deploy: {resources: {limits: {cpus: "0.5", memory: 512m}}} + # ------------------------------------------------------------ watchtower (auto-update) watchtower: image: containrrr/watchtower diff --git a/docker/ingest_hub/Dockerfile b/docker/ingest_hub/Dockerfile new file mode 100644 index 0000000..f2ef5e6 --- /dev/null +++ b/docker/ingest_hub/Dockerfile @@ -0,0 +1,32 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Copy requirements files and setup +COPY requirements-*.txt ./ +COPY setup.py ./ +COPY src/ ./src/ + +# Install dependencies +RUN apt-get update && apt-get install -y \ + libpq-dev gcc supervisor && \ + pip install -e ".[workers]" && \ + apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* + +# Create and copy entrypoint script +COPY docker/ingest_hub/supervisor.conf /etc/supervisor/conf.d/supervisor.conf +COPY docker/workers/entry.sh ./entry.sh +RUN chmod +x entry.sh + +# Create required tmpfs directories for supervisor +RUN mkdir -p /var/log/supervisor /var/run/supervisor + +# Create user and set permissions +RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor +USER kb + +# Default queues to process +ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email" +ENV PYTHONPATH="/app" + +ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"] \ No newline at end of file diff --git a/docker/ingest_hub/supervisor.conf b/docker/ingest_hub/supervisor.conf new file mode 100644 index 0000000..b6e9872 --- /dev/null +++ b/docker/ingest_hub/supervisor.conf @@ -0,0 +1,16 @@ +[supervisord] +nodaemon=true +loglevel=info +logfile=/dev/stdout +logfile_maxbytes=0 +user=kb +pidfile=/dev/null + +[program:celery-beat] +command=celery -A memory.workers.ingest beat --pidfile= --schedule=/tmp/celerybeat-schedule +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes=0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 +autorestart=true +startsecs=10 diff --git a/docker/workers/Dockerfile b/docker/workers/Dockerfile index b0e9cae..c88d6eb 100644 --- a/docker/workers/Dockerfile +++ b/docker/workers/Dockerfile @@ -22,7 +22,7 @@ RUN useradd -m kb && chown -R kb /app USER kb # Default queues to process -ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs" +ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email" ENV PYTHONPATH="/app" ENTRYPOINT ["./entry.sh"] \ No newline at end of file diff --git a/requirements-common.txt b/requirements-common.txt index 92052c5..bf32273 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -2,6 +2,7 @@ sqlalchemy==2.0.30 psycopg2-binary==2.9.9 pydantic==2.7.1 alembic==1.13.1 -dotenv==1.1.0 +dotenv==0.9.9 voyageai==0.3.2 -qdrant-client==1.9.0 \ No newline at end of file +qdrant-client==1.9.0 +PyMuPDF==1.25.5 \ No newline at end of file diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index bedc110..3353ad5 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -1,7 +1,9 @@ import pathlib -from typing import Literal, TypedDict, Iterable +from typing import Literal, TypedDict, Iterable, Any import voyageai import re +import uuid +from memory.common import extract, settings # Chunking configuration MAX_TOKENS = 32000 # VoyageAI max context window @@ -11,6 +13,8 @@ CHARS_PER_TOKEN = 4 DistanceType = Literal["Cosine", "Dot", "Euclidean"] Vector = list[float] +Embedding = tuple[str, Vector, dict[str, Any]] + class Collection(TypedDict): dimension: int @@ -163,17 +167,48 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T yield current.strip() -def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]: +def embed_chunks(chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]: vo = voyageai.Client() - return vo.embed(chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS), model=model) + return vo.embed(chunks, model=model).embeddings -def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]: - return embed_text(file_path.read_text(), model, n_dimensions) +def embed_text(texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]: + chunks = [c for text in texts for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()] + return embed_chunks(chunks, model) -def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> tuple[str, list[Vector]]: - if isinstance(content, bytes): - content = content.decode("utf-8") +def embed_file(file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]: + return embed_text([file_path.read_text()], model) - return get_modality(mime_type), embed_text(content, model, n_dimensions) + +def embed_mixed(items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL) -> list[Vector]: + def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]: + if isinstance(item, str): + return [c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()] + return [item] + + chunks = [c for item in items for c in to_chunks(item)] + return embed_chunks(chunks, model) + + +def embed_page(page: dict[str, Any]) -> list[Vector]: + contents = page["contents"] + if all(isinstance(c, str) for c in contents): + return embed_text(contents, model=settings.TEXT_EMBEDDING_MODEL) + return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL) + + +def embed( + mime_type: str, + content: bytes | str | pathlib.Path, + metadata: dict[str, Any] = {}, +) -> tuple[str, list[Embedding]]: + modality = get_modality(mime_type) + + pages = extract.extract_content(mime_type, content) + vectors = [ + (str(uuid.uuid4()), vector, page.get("metadata", {}) | metadata) + for page in pages + for vector in embed_page(page) + ] + return modality, vectors diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py new file mode 100644 index 0000000..3fd221d --- /dev/null +++ b/src/memory/common/extract.py @@ -0,0 +1,74 @@ +from contextlib import contextmanager +import io +import pathlib +import tempfile +import pymupdf # PyMuPDF +from PIL import Image +from typing import Any, TypedDict, Generator + + +MulitmodalChunk = Image.Image | str +class Page(TypedDict): + contents: list[MulitmodalChunk] + metadata: dict[str, Any] + + +@contextmanager +def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]: + if isinstance(content, pathlib.Path): + yield content + else: + mode = "w" if isinstance(content, str) else "wb" + with tempfile.NamedTemporaryFile(mode=mode) as f: + f.write(content) + f.flush() + yield pathlib.Path(f.name) + + +def page_to_image(page: pymupdf.Page) -> Image.Image: + pix = page.get_pixmap() + return Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + + +def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]: + with as_file(content) as file_path: + with pymupdf.open(file_path) as pdf: + return [{ + "contents": page_to_image(page), + "metadata": { + "page": page.number, + "width": page.rect.width, + "height": page.rect.height, + } + } for page in pdf.pages()] + + +def extract_image(content: bytes | str | pathlib.Path) -> list[Page]: + if isinstance(content, pathlib.Path): + image = Image.open(content) + elif isinstance(content, bytes): + image = Image.open(io.BytesIO(content)) + else: + raise ValueError(f"Unsupported content type: {type(content)}") + return [{"contents": image, "metadata": {}}] + + +def extract_text(content: bytes | str | pathlib.Path) -> list[Page]: + if isinstance(content, pathlib.Path): + content = content.read_text() + if isinstance(content, bytes): + content = content.decode("utf-8") + + return [{"contents": [content], "metadata": {}}] + + +def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]: + if mime_type == "application/pdf": + return doc_to_images(content) + if mime_type.startswith("text/"): + return extract_text(content) + if mime_type.startswith("image/"): + return extract_image(content) + + # Return empty list for unknown mime types + return [] diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index a130a74..ff82dbc 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -7,16 +7,19 @@ load_dotenv() def boolean_env(key: str, default: bool = False) -> bool: return os.getenv(key, "0").lower() in ("1", "true", "yes") - +# Database settings DB_USER = os.getenv("DB_USER", "kb") -DB_PASSWORD = os.getenv("DB_PASSWORD", "kb") +if password_file := os.getenv("POSTGRES_PASSWORD_FILE"): + DB_PASSWORD = pathlib.Path(password_file).read_text().strip() +else: + DB_PASSWORD = os.getenv("DB_PASSWORD", "kb") + DB_HOST = os.getenv("DB_HOST", "postgres") DB_PORT = os.getenv("DB_PORT", "5432") DB_NAME = os.getenv("DB_NAME", "kb") def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, db=DB_NAME): return f"postgresql://{user}:{password}@{host}:{port}/{db}" - DB_URL = os.getenv("DATABASE_URL", make_db_url()) @@ -32,4 +35,14 @@ QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334")) QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None) QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False) -QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60")) \ No newline at end of file +QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60")) + + +# Worker settings +BEAT_LOOP_INTERVAL = int(os.getenv("BEAT_LOOP_INTERVAL", 3600)) +EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 600)) + + +# Embedding settings +TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large") +MIXED_EMBEDDING_MODEL = os.getenv("MIXED_EMBEDDING_MODEL", "voyage-multimodal-3") diff --git a/src/memory/workers/celery_app.py b/src/memory/workers/celery_app.py index 901f298..ca81a76 100644 --- a/src/memory/workers/celery_app.py +++ b/src/memory/workers/celery_app.py @@ -1,6 +1,6 @@ import os from celery import Celery - +from memory.common import settings def rabbit_url() -> str: user = os.getenv("RABBITMQ_USER", "guest") @@ -8,10 +8,11 @@ def rabbit_url() -> str: return f"amqp://{user}:{password}@rabbitmq:5672//" -app = Celery("memory", - broker=rabbit_url(), - backend=os.getenv("CELERY_RESULT_BACKEND", - "db+postgresql://kb:kb@postgres/kb")) +app = Celery( + "memory", + broker=rabbit_url(), + backend=os.getenv("CELERY_RESULT_BACKEND", f"db+{settings.DB_URL}") +) app.autodiscover_tasks(["memory.workers.tasks"]) @@ -22,8 +23,8 @@ app.conf.update( task_reject_on_worker_lost=True, worker_prefetch_multiplier=1, task_routes={ - # Task routing configuration "memory.workers.tasks.text.*": {"queue": "medium_embed"}, + "memory.workers.tasks.email.*": {"queue": "email"}, "memory.workers.tasks.photo.*": {"queue": "photo_embed"}, "memory.workers.tasks.ocr.*": {"queue": "low_ocr"}, "memory.workers.tasks.git.*": {"queue": "git_summary"}, diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index faae8d3..0b9819d 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -14,7 +14,7 @@ from sqlalchemy.orm import Session from collections import defaultdict from memory.common import settings, embedding from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment -from memory.common.qdrant import get_qdrant_client, upsert_vectors +from memory.common import qdrant logger = logging.getLogger(__name__) @@ -471,17 +471,18 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, def vectorize_email(email: MailMessage) -> list[float]: - qdrant_client = get_qdrant_client() + qdrant_client = qdrant.get_qdrant_client() - chunks = embedding.embed_text(email.body_raw) - payloads = [email.as_payload()] * len(chunks) - vector_ids = [str(uuid.uuid4()) for _ in chunks] - upsert_vectors( + _, chunks = embedding.embed( + "text/plain", email.body_raw, metadata=email.as_payload(), + ) + vector_ids, vectors, metadata = zip(*chunks) + qdrant.upsert_vectors( client=qdrant_client, collection_name="mail", ids=vector_ids, - vectors=chunks, - payloads=payloads, + vectors=vectors, + payloads=metadata, ) vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids] @@ -491,15 +492,14 @@ def vectorize_email(email: MailMessage) -> list[float]: content = pathlib.Path(attachment.file_path).read_bytes() else: content = attachment.content - collection, vectors = embedding.embed(attachment.content_type, content) - attachment.source.vector_ids = vector_ids - embeds[collection].extend( - (str(uuid.uuid4()), vector, attachment.as_payload()) for vector in vectors - ) + collection, chunks = embedding.embed(attachment.content_type, content, metadata=attachment.as_payload()) + ids, vectors, metadata = zip(*chunks) + attachment.source.vector_ids = ids + embeds[collection].extend(chunks) - for collection, embeds in embeds.items(): - ids, vectors, payloads = zip(*embeds) - upsert_vectors( + for collection, chunks in embeds.items(): + ids, vectors, payloads = zip(*chunks) + qdrant.upsert_vectors( client=qdrant_client, collection_name=collection, ids=ids, diff --git a/src/memory/workers/ingest.py b/src/memory/workers/ingest.py new file mode 100644 index 0000000..2923016 --- /dev/null +++ b/src/memory/workers/ingest.py @@ -0,0 +1,14 @@ +from celery.schedules import schedule +from memory.workers.tasks.email import SYNC_ALL_ACCOUNTS +from memory.common import settings +from memory.workers.celery_app import app + + +@app.on_after_configure.connect +def register_mail_schedules(sender, **_): + sender.add_periodic_task( + schedule=schedule(settings.EMAIL_SYNC_INTERVAL), + sig=app.signature(SYNC_ALL_ACCOUNTS), + name="sync-mail-all", + options={"queue": "email"}, + ) diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py index 1b1f3e8..32c9d92 100644 --- a/src/memory/workers/tasks/__init__.py +++ b/src/memory/workers/tasks/__init__.py @@ -1,4 +1,8 @@ """ Import sub-modules so Celery can register their @app.task decorators. """ -from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa \ No newline at end of file +from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa +from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL + + +__all__ = ["text", "photo", "ocr", "git", "rss", "docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"] \ No newline at end of file diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py index d6a9995..7577998 100644 --- a/src/memory/workers/tasks/email.py +++ b/src/memory/workers/tasks/email.py @@ -18,8 +18,12 @@ from memory.workers.email import ( logger = logging.getLogger(__name__) +PROCESS_EMAIL = "memory.workers.tasks.email.process_message" +SYNC_ACCOUNT = "memory.workers.tasks.email.sync_account" +SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts" -@app.task(name="memory.email.process_message") + +@app.task(name=PROCESS_EMAIL) def process_message( account_id: int, message_id: str, folder: str, raw_email: str, ) -> int | None: @@ -35,6 +39,7 @@ def process_message( Returns: source_id if successful, None otherwise """ + logger.info(f"Processing message {message_id} for account {account_id}") if not raw_email.strip(): logger.warning(f"Empty email message received for account {account_id}") return None @@ -76,7 +81,7 @@ def process_message( return source_item.id -@app.task(name="memory.email.sync_account") +@app.task(name=SYNC_ACCOUNT) def sync_account(account_id: int) -> dict: """ Synchronize emails from a specific account. @@ -87,6 +92,7 @@ def sync_account(account_id: int) -> dict: Returns: dict with stats about the sync operation """ + logger.info(f"Syncing account {account_id}") with make_session() as db: account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first() if not account or not account.active: @@ -124,7 +130,7 @@ def sync_account(account_id: int) -> dict: } -@app.task(name="memory.email.sync_all_accounts") +@app.task(name=SYNC_ALL_ACCOUNTS) def sync_all_accounts() -> list[dict]: """ Synchronize all active email accounts. diff --git a/tests/conftest.py b/tests/conftest.py index b61916c..e59e5b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from pathlib import Path import pytest import qdrant_client +import voyageai from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from testcontainers.qdrant import QdrantContainer @@ -210,3 +211,8 @@ def qdrant(): initialize_collections(client) yield client + +@pytest.fixture(autouse=True) +def mock_voyage_client(): + with patch.object(voyageai, "Client", autospec=True) as mock_client: + yield mock_client() \ No newline at end of file diff --git a/tests/data/regulamin.pdf b/tests/data/regulamin.pdf new file mode 100644 index 0000000..44e9f58 Binary files /dev/null and b/tests/data/regulamin.pdf differ diff --git a/tests/memory/common/test_embedding.py b/tests/memory/common/test_embedding.py index fa159d3..4960dd6 100644 --- a/tests/memory/common/test_embedding.py +++ b/tests/memory/common/test_embedding.py @@ -1,5 +1,19 @@ +import uuid import pytest -from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN +from unittest.mock import Mock, patch +from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count, get_modality, embed_text, embed_file, embed_mixed, embed_page, embed + + +@pytest.fixture +def mock_embed(mock_voyage_client): + vectors = ([i] for i in range(1000)) + + def embed(texts, model): + return Mock(embeddings=[next(vectors) for _ in texts]) + + mock_voyage_client.embed.side_effect = embed + + return mock_voyage_client @pytest.mark.parametrize( @@ -270,3 +284,77 @@ def test_chunk_text_very_long_sentences(): 'chunks by the function.', ] + +@pytest.mark.parametrize( + "string, expected_count", + [ + ("", 0), + ("a" * CHARS_PER_TOKEN, 1), + ("a" * (CHARS_PER_TOKEN * 2), 2), + ("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation + ("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation + ] +) +def test_approx_token_count(string, expected_count): + assert approx_token_count(string) == expected_count + + +@pytest.mark.parametrize( + "mime_type, expected_modality", + [ + ("text/plain", "doc"), + ("text/html", "doc"), + ("image/jpeg", "photo"), + ("image/png", "photo"), + ("application/pdf", "book"), + ("application/epub+zip", "book"), + ("application/mobi", "book"), + ("application/x-mobipocket-ebook", "book"), + ("audio/mp3", "unknown"), + ("video/mp4", "unknown"), + ("text/something-new", "doc"), # Should match by 'text/' stem + ("image/something-new", "photo"), # Should match by 'image/' stem + ("custom/format", "unknown"), # No matching stem + ] +) +def test_get_modality(mime_type, expected_modality): + assert get_modality(mime_type) == expected_modality + + +def test_embed_text(mock_embed): + texts = ["text1 with words", "text2"] + assert embed_text(texts) == [[0], [1]] + + +def test_embed_file(mock_embed, tmp_path): + mock_file = tmp_path / "test.txt" + mock_file.write_text("file content") + + assert embed_file(mock_file) == [[0]] + + +def test_embed_mixed(mock_embed): + items = ["text", {"type": "image", "data": "base64"}] + assert embed_mixed(items) == [[0], [1]] + + +def test_embed_page_text_only(mock_embed): + page = {"contents": ["text1", "text2"]} + assert embed_page(page) == [[0], [1]] + + +def test_embed_page_mixed_content(mock_embed): + page = {"contents": ["text", {"type": "image", "data": "base64"}]} + assert embed_page(page) == [[0], [1]] + + +def test_embed(mock_embed): + mime_type = "text/plain" + content = "sample content" + metadata = {"source": "test"} + + with patch.object(uuid, "uuid4", return_value="id1"): + modality, vectors = embed(mime_type, content, metadata) + + assert modality == "doc" + assert vectors == [('id1', [0], {'source': 'test'})] diff --git a/tests/memory/common/test_extract.py b/tests/memory/common/test_extract.py new file mode 100644 index 0000000..0cd1422 --- /dev/null +++ b/tests/memory/common/test_extract.py @@ -0,0 +1,131 @@ +import pathlib +import pytest +import pymupdf +from PIL import Image +import io +from memory.common.extract import as_file, extract_text, extract_content, Page, doc_to_images, extract_image + + +REGULAMIN = pathlib.Path(__file__).parent.parent.parent / "data" / "regulamin.pdf" + + +def test_as_file_with_path(tmp_path): + test_path = tmp_path / "test.txt" + test_path.write_text("test content") + + with as_file(test_path) as path: + assert path == test_path + assert path.read_text() == "test content" + + +def test_as_file_with_bytes(): + content = b"test content" + with as_file(content) as path: + assert pathlib.Path(path).read_bytes() == content + + +def test_as_file_with_str(): + content = "test content" + with as_file(content) as path: + assert pathlib.Path(path).read_text() == content + + +@pytest.mark.parametrize( + "input_content,expected", + [ + ("simple text", [{"contents": ["simple text"], "metadata": {}}]), + (b"bytes text", [{"contents": ["bytes text"], "metadata": {}}]), + ] +) +def test_extract_text(input_content, expected): + assert extract_text(input_content) == expected + + +def test_extract_text_with_path(tmp_path): + test_file = tmp_path / "test.txt" + test_file.write_text("file text content") + + assert extract_text(test_file) == [{"contents": ["file text content"], "metadata": {}}] + + +def test_doc_to_images(): + result = doc_to_images(REGULAMIN) + + assert len(result) == 2 + with pymupdf.open(REGULAMIN) as pdf: + for page, pdf_page in zip(result, pdf.pages()): + pix = pdf_page.get_pixmap() + img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + assert page["contents"] == img + assert page["metadata"] == { + "page": pdf_page.number, + "width": pdf_page.rect.width, + "height": pdf_page.rect.height, + } + + +def test_extract_image_with_path(tmp_path): + img = Image.new('RGB', (100, 100), color='red') + img_path = tmp_path / "test.png" + img.save(img_path) + + page, = extract_image(img_path) + assert page["contents"].tobytes() == img.convert("RGB").tobytes() + assert page["metadata"] == {} + + +def test_extract_image_with_bytes(): + img = Image.new('RGB', (100, 100), color='blue') + buffer = io.BytesIO() + img.save(buffer, format='PNG') + img_bytes = buffer.getvalue() + + page, = extract_image(img_bytes) + assert page["contents"].tobytes() == img.convert("RGB").tobytes() + assert page["metadata"] == {} + + +def test_extract_image_with_str(): + with pytest.raises(ValueError): + extract_image("test") + + +@pytest.mark.parametrize( + "mime_type,content", + [ + ("text/plain", "Text content"), + ("text/html", "content"), + ("text/markdown", "# Heading"), + ("text/csv", "a,b,c"), + ] +) +def test_extract_content_different_text_types(mime_type, content): + assert extract_content(mime_type, content) == [{"contents": [content], "metadata": {}}] + + +def test_extract_content_pdf(): + result = extract_content("application/pdf", REGULAMIN) + + assert len(result) == 2 + assert all(isinstance(page["contents"], Image.Image) for page in result) + assert all("page" in page["metadata"] for page in result) + assert all("width" in page["metadata"] for page in result) + assert all("height" in page["metadata"] for page in result) + + +def test_extract_content_image(tmp_path): + # Create a test image + img = Image.new('RGB', (100, 100), color='red') + img_path = tmp_path / "test_img.png" + img.save(img_path) + + result = extract_content("image/png", img_path) + + assert len(result) == 1 + assert isinstance(result[0]["contents"], Image.Image) + assert result[0]["contents"].size == (100, 100) + assert result[0]["metadata"] == {} + + +def test_extract_content_unsupported_type(): + assert extract_content("unsupported/type", "content") == []