mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
celery beat + image embedding
This commit is contained in:
parent
869e5ac6b4
commit
453aed7c19
@ -361,13 +361,13 @@ def downgrade() -> None:
|
||||
# Drop tables
|
||||
op.drop_table("photo")
|
||||
op.drop_table("misc_doc")
|
||||
op.drop_table("mail_message")
|
||||
op.drop_table("git_commit")
|
||||
op.drop_table("chat_message")
|
||||
op.drop_table("book_doc")
|
||||
op.drop_table("blog_post")
|
||||
op.drop_table("github_item")
|
||||
op.drop_table("email_attachment")
|
||||
op.drop_table("mail_message")
|
||||
op.drop_table("source_item")
|
||||
op.drop_table("rss_feeds")
|
||||
op.drop_table("email_accounts")
|
||||
|
@ -21,6 +21,8 @@ volumes:
|
||||
# ------------------------------ X-templates ----------------------------
|
||||
x-common-env: &env
|
||||
RABBITMQ_USER: kb
|
||||
QDRANT_HOST: qdrant
|
||||
DB_HOST: postgres
|
||||
FILE_STORAGE_DIR: /app/memory_files
|
||||
TZ: "Etc/UTC"
|
||||
|
||||
@ -165,6 +167,13 @@ services:
|
||||
# - ./acme.json:/acme.json:rw
|
||||
|
||||
# ------------------------------------------------------------ Celery workers
|
||||
worker-email:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
<<: *worker-env
|
||||
QUEUES: "email"
|
||||
deploy: {resources: {limits: {cpus: "2", memory: 3g}}}
|
||||
|
||||
worker-text:
|
||||
<<: *worker-base
|
||||
environment:
|
||||
@ -207,6 +216,22 @@ services:
|
||||
QUEUES: "docs"
|
||||
deploy: {resources: {limits: {cpus: "1", memory: 1g}}}
|
||||
|
||||
ingest-hub:
|
||||
<<: *worker-base
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/ingest_hub/Dockerfile
|
||||
environment:
|
||||
<<: *worker-env
|
||||
volumes:
|
||||
- file_storage:/app/memory_files:rw
|
||||
tmpfs:
|
||||
- /tmp
|
||||
- /var/tmp
|
||||
- /var/log/supervisor
|
||||
- /var/run/supervisor
|
||||
deploy: {resources: {limits: {cpus: "0.5", memory: 512m}}}
|
||||
|
||||
# ------------------------------------------------------------ watchtower (auto-update)
|
||||
watchtower:
|
||||
image: containrrr/watchtower
|
||||
|
32
docker/ingest_hub/Dockerfile
Normal file
32
docker/ingest_hub/Dockerfile
Normal file
@ -0,0 +1,32 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements files and setup
|
||||
COPY requirements-*.txt ./
|
||||
COPY setup.py ./
|
||||
COPY src/ ./src/
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libpq-dev gcc supervisor && \
|
||||
pip install -e ".[workers]" && \
|
||||
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create and copy entrypoint script
|
||||
COPY docker/ingest_hub/supervisor.conf /etc/supervisor/conf.d/supervisor.conf
|
||||
COPY docker/workers/entry.sh ./entry.sh
|
||||
RUN chmod +x entry.sh
|
||||
|
||||
# Create required tmpfs directories for supervisor
|
||||
RUN mkdir -p /var/log/supervisor /var/run/supervisor
|
||||
|
||||
# Create user and set permissions
|
||||
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor
|
||||
USER kb
|
||||
|
||||
# Default queues to process
|
||||
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email"
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]
|
16
docker/ingest_hub/supervisor.conf
Normal file
16
docker/ingest_hub/supervisor.conf
Normal file
@ -0,0 +1,16 @@
|
||||
[supervisord]
|
||||
nodaemon=true
|
||||
loglevel=info
|
||||
logfile=/dev/stdout
|
||||
logfile_maxbytes=0
|
||||
user=kb
|
||||
pidfile=/dev/null
|
||||
|
||||
[program:celery-beat]
|
||||
command=celery -A memory.workers.ingest beat --pidfile= --schedule=/tmp/celerybeat-schedule
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile=/dev/stderr
|
||||
stderr_logfile_maxbytes=0
|
||||
autorestart=true
|
||||
startsecs=10
|
@ -22,7 +22,7 @@ RUN useradd -m kb && chown -R kb /app
|
||||
USER kb
|
||||
|
||||
# Default queues to process
|
||||
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs"
|
||||
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email"
|
||||
ENV PYTHONPATH="/app"
|
||||
|
||||
ENTRYPOINT ["./entry.sh"]
|
@ -2,6 +2,7 @@ sqlalchemy==2.0.30
|
||||
psycopg2-binary==2.9.9
|
||||
pydantic==2.7.1
|
||||
alembic==1.13.1
|
||||
dotenv==1.1.0
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
qdrant-client==1.9.0
|
||||
PyMuPDF==1.25.5
|
@ -1,7 +1,9 @@
|
||||
import pathlib
|
||||
from typing import Literal, TypedDict, Iterable
|
||||
from typing import Literal, TypedDict, Iterable, Any
|
||||
import voyageai
|
||||
import re
|
||||
import uuid
|
||||
from memory.common import extract, settings
|
||||
|
||||
# Chunking configuration
|
||||
MAX_TOKENS = 32000 # VoyageAI max context window
|
||||
@ -11,6 +13,8 @@ CHARS_PER_TOKEN = 4
|
||||
|
||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||
Vector = list[float]
|
||||
Embedding = tuple[str, Vector, dict[str, Any]]
|
||||
|
||||
|
||||
class Collection(TypedDict):
|
||||
dimension: int
|
||||
@ -163,17 +167,48 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
|
||||
yield current.strip()
|
||||
|
||||
|
||||
def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
|
||||
def embed_chunks(chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
|
||||
vo = voyageai.Client()
|
||||
return vo.embed(chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS), model=model)
|
||||
return vo.embed(chunks, model=model).embeddings
|
||||
|
||||
|
||||
def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
|
||||
return embed_text(file_path.read_text(), model, n_dimensions)
|
||||
def embed_text(texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
|
||||
chunks = [c for text in texts for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
|
||||
return embed_chunks(chunks, model)
|
||||
|
||||
|
||||
def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> tuple[str, list[Vector]]:
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
def embed_file(file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
|
||||
return embed_text([file_path.read_text()], model)
|
||||
|
||||
return get_modality(mime_type), embed_text(content, model, n_dimensions)
|
||||
|
||||
def embed_mixed(items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL) -> list[Vector]:
|
||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
|
||||
if isinstance(item, str):
|
||||
return [c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
|
||||
return [item]
|
||||
|
||||
chunks = [c for item in items for c in to_chunks(item)]
|
||||
return embed_chunks(chunks, model)
|
||||
|
||||
|
||||
def embed_page(page: dict[str, Any]) -> list[Vector]:
|
||||
contents = page["contents"]
|
||||
if all(isinstance(c, str) for c in contents):
|
||||
return embed_text(contents, model=settings.TEXT_EMBEDDING_MODEL)
|
||||
return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
def embed(
|
||||
mime_type: str,
|
||||
content: bytes | str | pathlib.Path,
|
||||
metadata: dict[str, Any] = {},
|
||||
) -> tuple[str, list[Embedding]]:
|
||||
modality = get_modality(mime_type)
|
||||
|
||||
pages = extract.extract_content(mime_type, content)
|
||||
vectors = [
|
||||
(str(uuid.uuid4()), vector, page.get("metadata", {}) | metadata)
|
||||
for page in pages
|
||||
for vector in embed_page(page)
|
||||
]
|
||||
return modality, vectors
|
||||
|
74
src/memory/common/extract.py
Normal file
74
src/memory/common/extract.py
Normal file
@ -0,0 +1,74 @@
|
||||
from contextlib import contextmanager
|
||||
import io
|
||||
import pathlib
|
||||
import tempfile
|
||||
import pymupdf # PyMuPDF
|
||||
from PIL import Image
|
||||
from typing import Any, TypedDict, Generator
|
||||
|
||||
|
||||
MulitmodalChunk = Image.Image | str
|
||||
class Page(TypedDict):
|
||||
contents: list[MulitmodalChunk]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
yield content
|
||||
else:
|
||||
mode = "w" if isinstance(content, str) else "wb"
|
||||
with tempfile.NamedTemporaryFile(mode=mode) as f:
|
||||
f.write(content)
|
||||
f.flush()
|
||||
yield pathlib.Path(f.name)
|
||||
|
||||
|
||||
def page_to_image(page: pymupdf.Page) -> Image.Image:
|
||||
pix = page.get_pixmap()
|
||||
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
|
||||
|
||||
def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
with as_file(content) as file_path:
|
||||
with pymupdf.open(file_path) as pdf:
|
||||
return [{
|
||||
"contents": page_to_image(page),
|
||||
"metadata": {
|
||||
"page": page.number,
|
||||
"width": page.rect.width,
|
||||
"height": page.rect.height,
|
||||
}
|
||||
} for page in pdf.pages()]
|
||||
|
||||
|
||||
def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
image = Image.open(content)
|
||||
elif isinstance(content, bytes):
|
||||
image = Image.open(io.BytesIO(content))
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
return [{"contents": image, "metadata": {}}]
|
||||
|
||||
|
||||
def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
content = content.read_text()
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return [{"contents": [content], "metadata": {}}]
|
||||
|
||||
|
||||
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
if mime_type == "application/pdf":
|
||||
return doc_to_images(content)
|
||||
if mime_type.startswith("text/"):
|
||||
return extract_text(content)
|
||||
if mime_type.startswith("image/"):
|
||||
return extract_image(content)
|
||||
|
||||
# Return empty list for unknown mime types
|
||||
return []
|
@ -7,16 +7,19 @@ load_dotenv()
|
||||
def boolean_env(key: str, default: bool = False) -> bool:
|
||||
return os.getenv(key, "0").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
# Database settings
|
||||
DB_USER = os.getenv("DB_USER", "kb")
|
||||
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
|
||||
if password_file := os.getenv("POSTGRES_PASSWORD_FILE"):
|
||||
DB_PASSWORD = pathlib.Path(password_file).read_text().strip()
|
||||
else:
|
||||
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
|
||||
|
||||
DB_HOST = os.getenv("DB_HOST", "postgres")
|
||||
DB_PORT = os.getenv("DB_PORT", "5432")
|
||||
DB_NAME = os.getenv("DB_NAME", "kb")
|
||||
|
||||
def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, db=DB_NAME):
|
||||
return f"postgresql://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||
|
||||
|
||||
@ -32,4 +35,14 @@ QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||||
QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334"))
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
|
||||
QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False)
|
||||
QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
||||
QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
||||
|
||||
|
||||
# Worker settings
|
||||
BEAT_LOOP_INTERVAL = int(os.getenv("BEAT_LOOP_INTERVAL", 3600))
|
||||
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 600))
|
||||
|
||||
|
||||
# Embedding settings
|
||||
TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large")
|
||||
MIXED_EMBEDDING_MODEL = os.getenv("MIXED_EMBEDDING_MODEL", "voyage-multimodal-3")
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from celery import Celery
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
def rabbit_url() -> str:
|
||||
user = os.getenv("RABBITMQ_USER", "guest")
|
||||
@ -8,10 +8,11 @@ def rabbit_url() -> str:
|
||||
return f"amqp://{user}:{password}@rabbitmq:5672//"
|
||||
|
||||
|
||||
app = Celery("memory",
|
||||
broker=rabbit_url(),
|
||||
backend=os.getenv("CELERY_RESULT_BACKEND",
|
||||
"db+postgresql://kb:kb@postgres/kb"))
|
||||
app = Celery(
|
||||
"memory",
|
||||
broker=rabbit_url(),
|
||||
backend=os.getenv("CELERY_RESULT_BACKEND", f"db+{settings.DB_URL}")
|
||||
)
|
||||
|
||||
|
||||
app.autodiscover_tasks(["memory.workers.tasks"])
|
||||
@ -22,8 +23,8 @@ app.conf.update(
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
task_routes={
|
||||
# Task routing configuration
|
||||
"memory.workers.tasks.text.*": {"queue": "medium_embed"},
|
||||
"memory.workers.tasks.email.*": {"queue": "email"},
|
||||
"memory.workers.tasks.photo.*": {"queue": "photo_embed"},
|
||||
"memory.workers.tasks.ocr.*": {"queue": "low_ocr"},
|
||||
"memory.workers.tasks.git.*": {"queue": "git_summary"},
|
||||
|
@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
|
||||
from collections import defaultdict
|
||||
from memory.common import settings, embedding
|
||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||
from memory.common.qdrant import get_qdrant_client, upsert_vectors
|
||||
from memory.common import qdrant
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -471,17 +471,18 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
||||
|
||||
|
||||
def vectorize_email(email: MailMessage) -> list[float]:
|
||||
qdrant_client = get_qdrant_client()
|
||||
qdrant_client = qdrant.get_qdrant_client()
|
||||
|
||||
chunks = embedding.embed_text(email.body_raw)
|
||||
payloads = [email.as_payload()] * len(chunks)
|
||||
vector_ids = [str(uuid.uuid4()) for _ in chunks]
|
||||
upsert_vectors(
|
||||
_, chunks = embedding.embed(
|
||||
"text/plain", email.body_raw, metadata=email.as_payload(),
|
||||
)
|
||||
vector_ids, vectors, metadata = zip(*chunks)
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant_client,
|
||||
collection_name="mail",
|
||||
ids=vector_ids,
|
||||
vectors=chunks,
|
||||
payloads=payloads,
|
||||
vectors=vectors,
|
||||
payloads=metadata,
|
||||
)
|
||||
vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids]
|
||||
|
||||
@ -491,15 +492,14 @@ def vectorize_email(email: MailMessage) -> list[float]:
|
||||
content = pathlib.Path(attachment.file_path).read_bytes()
|
||||
else:
|
||||
content = attachment.content
|
||||
collection, vectors = embedding.embed(attachment.content_type, content)
|
||||
attachment.source.vector_ids = vector_ids
|
||||
embeds[collection].extend(
|
||||
(str(uuid.uuid4()), vector, attachment.as_payload()) for vector in vectors
|
||||
)
|
||||
collection, chunks = embedding.embed(attachment.content_type, content, metadata=attachment.as_payload())
|
||||
ids, vectors, metadata = zip(*chunks)
|
||||
attachment.source.vector_ids = ids
|
||||
embeds[collection].extend(chunks)
|
||||
|
||||
for collection, embeds in embeds.items():
|
||||
ids, vectors, payloads = zip(*embeds)
|
||||
upsert_vectors(
|
||||
for collection, chunks in embeds.items():
|
||||
ids, vectors, payloads = zip(*chunks)
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant_client,
|
||||
collection_name=collection,
|
||||
ids=ids,
|
||||
|
14
src/memory/workers/ingest.py
Normal file
14
src/memory/workers/ingest.py
Normal file
@ -0,0 +1,14 @@
|
||||
from celery.schedules import schedule
|
||||
from memory.workers.tasks.email import SYNC_ALL_ACCOUNTS
|
||||
from memory.common import settings
|
||||
from memory.workers.celery_app import app
|
||||
|
||||
|
||||
@app.on_after_configure.connect
|
||||
def register_mail_schedules(sender, **_):
|
||||
sender.add_periodic_task(
|
||||
schedule=schedule(settings.EMAIL_SYNC_INTERVAL),
|
||||
sig=app.signature(SYNC_ALL_ACCOUNTS),
|
||||
name="sync-mail-all",
|
||||
options={"queue": "email"},
|
||||
)
|
@ -1,4 +1,8 @@
|
||||
"""
|
||||
Import sub-modules so Celery can register their @app.task decorators.
|
||||
"""
|
||||
from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa
|
||||
from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa
|
||||
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
|
||||
|
||||
|
||||
__all__ = ["text", "photo", "ocr", "git", "rss", "docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"]
|
@ -18,8 +18,12 @@ from memory.workers.email import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROCESS_EMAIL = "memory.workers.tasks.email.process_message"
|
||||
SYNC_ACCOUNT = "memory.workers.tasks.email.sync_account"
|
||||
SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
|
||||
|
||||
@app.task(name="memory.email.process_message")
|
||||
|
||||
@app.task(name=PROCESS_EMAIL)
|
||||
def process_message(
|
||||
account_id: int, message_id: str, folder: str, raw_email: str,
|
||||
) -> int | None:
|
||||
@ -35,6 +39,7 @@ def process_message(
|
||||
Returns:
|
||||
source_id if successful, None otherwise
|
||||
"""
|
||||
logger.info(f"Processing message {message_id} for account {account_id}")
|
||||
if not raw_email.strip():
|
||||
logger.warning(f"Empty email message received for account {account_id}")
|
||||
return None
|
||||
@ -76,7 +81,7 @@ def process_message(
|
||||
return source_item.id
|
||||
|
||||
|
||||
@app.task(name="memory.email.sync_account")
|
||||
@app.task(name=SYNC_ACCOUNT)
|
||||
def sync_account(account_id: int) -> dict:
|
||||
"""
|
||||
Synchronize emails from a specific account.
|
||||
@ -87,6 +92,7 @@ def sync_account(account_id: int) -> dict:
|
||||
Returns:
|
||||
dict with stats about the sync operation
|
||||
"""
|
||||
logger.info(f"Syncing account {account_id}")
|
||||
with make_session() as db:
|
||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||
if not account or not account.active:
|
||||
@ -124,7 +130,7 @@ def sync_account(account_id: int) -> dict:
|
||||
}
|
||||
|
||||
|
||||
@app.task(name="memory.email.sync_all_accounts")
|
||||
@app.task(name=SYNC_ALL_ACCOUNTS)
|
||||
def sync_all_accounts() -> list[dict]:
|
||||
"""
|
||||
Synchronize all active email accounts.
|
||||
|
@ -7,6 +7,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import qdrant_client
|
||||
import voyageai
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from testcontainers.qdrant import QdrantContainer
|
||||
@ -210,3 +211,8 @@ def qdrant():
|
||||
initialize_collections(client)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_voyage_client():
|
||||
with patch.object(voyageai, "Client", autospec=True) as mock_client:
|
||||
yield mock_client()
|
BIN
tests/data/regulamin.pdf
Normal file
BIN
tests/data/regulamin.pdf
Normal file
Binary file not shown.
@ -1,5 +1,19 @@
|
||||
import uuid
|
||||
import pytest
|
||||
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN
|
||||
from unittest.mock import Mock, patch
|
||||
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count, get_modality, embed_text, embed_file, embed_mixed, embed_page, embed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed(mock_voyage_client):
|
||||
vectors = ([i] for i in range(1000))
|
||||
|
||||
def embed(texts, model):
|
||||
return Mock(embeddings=[next(vectors) for _ in texts])
|
||||
|
||||
mock_voyage_client.embed.side_effect = embed
|
||||
|
||||
return mock_voyage_client
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -270,3 +284,77 @@ def test_chunk_text_very_long_sentences():
|
||||
'chunks by the function.',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string, expected_count",
|
||||
[
|
||||
("", 0),
|
||||
("a" * CHARS_PER_TOKEN, 1),
|
||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||
]
|
||||
)
|
||||
def test_approx_token_count(string, expected_count):
|
||||
assert approx_token_count(string) == expected_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mime_type, expected_modality",
|
||||
[
|
||||
("text/plain", "doc"),
|
||||
("text/html", "doc"),
|
||||
("image/jpeg", "photo"),
|
||||
("image/png", "photo"),
|
||||
("application/pdf", "book"),
|
||||
("application/epub+zip", "book"),
|
||||
("application/mobi", "book"),
|
||||
("application/x-mobipocket-ebook", "book"),
|
||||
("audio/mp3", "unknown"),
|
||||
("video/mp4", "unknown"),
|
||||
("text/something-new", "doc"), # Should match by 'text/' stem
|
||||
("image/something-new", "photo"), # Should match by 'image/' stem
|
||||
("custom/format", "unknown"), # No matching stem
|
||||
]
|
||||
)
|
||||
def test_get_modality(mime_type, expected_modality):
|
||||
assert get_modality(mime_type) == expected_modality
|
||||
|
||||
|
||||
def test_embed_text(mock_embed):
|
||||
texts = ["text1 with words", "text2"]
|
||||
assert embed_text(texts) == [[0], [1]]
|
||||
|
||||
|
||||
def test_embed_file(mock_embed, tmp_path):
|
||||
mock_file = tmp_path / "test.txt"
|
||||
mock_file.write_text("file content")
|
||||
|
||||
assert embed_file(mock_file) == [[0]]
|
||||
|
||||
|
||||
def test_embed_mixed(mock_embed):
|
||||
items = ["text", {"type": "image", "data": "base64"}]
|
||||
assert embed_mixed(items) == [[0], [1]]
|
||||
|
||||
|
||||
def test_embed_page_text_only(mock_embed):
|
||||
page = {"contents": ["text1", "text2"]}
|
||||
assert embed_page(page) == [[0], [1]]
|
||||
|
||||
|
||||
def test_embed_page_mixed_content(mock_embed):
|
||||
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
|
||||
assert embed_page(page) == [[0], [1]]
|
||||
|
||||
|
||||
def test_embed(mock_embed):
|
||||
mime_type = "text/plain"
|
||||
content = "sample content"
|
||||
metadata = {"source": "test"}
|
||||
|
||||
with patch.object(uuid, "uuid4", return_value="id1"):
|
||||
modality, vectors = embed(mime_type, content, metadata)
|
||||
|
||||
assert modality == "doc"
|
||||
assert vectors == [('id1', [0], {'source': 'test'})]
|
||||
|
131
tests/memory/common/test_extract.py
Normal file
131
tests/memory/common/test_extract.py
Normal file
@ -0,0 +1,131 @@
|
||||
import pathlib
|
||||
import pytest
|
||||
import pymupdf
|
||||
from PIL import Image
|
||||
import io
|
||||
from memory.common.extract import as_file, extract_text, extract_content, Page, doc_to_images, extract_image
|
||||
|
||||
|
||||
REGULAMIN = pathlib.Path(__file__).parent.parent.parent / "data" / "regulamin.pdf"
|
||||
|
||||
|
||||
def test_as_file_with_path(tmp_path):
|
||||
test_path = tmp_path / "test.txt"
|
||||
test_path.write_text("test content")
|
||||
|
||||
with as_file(test_path) as path:
|
||||
assert path == test_path
|
||||
assert path.read_text() == "test content"
|
||||
|
||||
|
||||
def test_as_file_with_bytes():
|
||||
content = b"test content"
|
||||
with as_file(content) as path:
|
||||
assert pathlib.Path(path).read_bytes() == content
|
||||
|
||||
|
||||
def test_as_file_with_str():
|
||||
content = "test content"
|
||||
with as_file(content) as path:
|
||||
assert pathlib.Path(path).read_text() == content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_content,expected",
|
||||
[
|
||||
("simple text", [{"contents": ["simple text"], "metadata": {}}]),
|
||||
(b"bytes text", [{"contents": ["bytes text"], "metadata": {}}]),
|
||||
]
|
||||
)
|
||||
def test_extract_text(input_content, expected):
|
||||
assert extract_text(input_content) == expected
|
||||
|
||||
|
||||
def test_extract_text_with_path(tmp_path):
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("file text content")
|
||||
|
||||
assert extract_text(test_file) == [{"contents": ["file text content"], "metadata": {}}]
|
||||
|
||||
|
||||
def test_doc_to_images():
|
||||
result = doc_to_images(REGULAMIN)
|
||||
|
||||
assert len(result) == 2
|
||||
with pymupdf.open(REGULAMIN) as pdf:
|
||||
for page, pdf_page in zip(result, pdf.pages()):
|
||||
pix = pdf_page.get_pixmap()
|
||||
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
assert page["contents"] == img
|
||||
assert page["metadata"] == {
|
||||
"page": pdf_page.number,
|
||||
"width": pdf_page.rect.width,
|
||||
"height": pdf_page.rect.height,
|
||||
}
|
||||
|
||||
|
||||
def test_extract_image_with_path(tmp_path):
|
||||
img = Image.new('RGB', (100, 100), color='red')
|
||||
img_path = tmp_path / "test.png"
|
||||
img.save(img_path)
|
||||
|
||||
page, = extract_image(img_path)
|
||||
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["metadata"] == {}
|
||||
|
||||
|
||||
def test_extract_image_with_bytes():
|
||||
img = Image.new('RGB', (100, 100), color='blue')
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
img_bytes = buffer.getvalue()
|
||||
|
||||
page, = extract_image(img_bytes)
|
||||
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["metadata"] == {}
|
||||
|
||||
|
||||
def test_extract_image_with_str():
|
||||
with pytest.raises(ValueError):
|
||||
extract_image("test")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mime_type,content",
|
||||
[
|
||||
("text/plain", "Text content"),
|
||||
("text/html", "<html>content</html>"),
|
||||
("text/markdown", "# Heading"),
|
||||
("text/csv", "a,b,c"),
|
||||
]
|
||||
)
|
||||
def test_extract_content_different_text_types(mime_type, content):
|
||||
assert extract_content(mime_type, content) == [{"contents": [content], "metadata": {}}]
|
||||
|
||||
|
||||
def test_extract_content_pdf():
|
||||
result = extract_content("application/pdf", REGULAMIN)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(page["contents"], Image.Image) for page in result)
|
||||
assert all("page" in page["metadata"] for page in result)
|
||||
assert all("width" in page["metadata"] for page in result)
|
||||
assert all("height" in page["metadata"] for page in result)
|
||||
|
||||
|
||||
def test_extract_content_image(tmp_path):
|
||||
# Create a test image
|
||||
img = Image.new('RGB', (100, 100), color='red')
|
||||
img_path = tmp_path / "test_img.png"
|
||||
img.save(img_path)
|
||||
|
||||
result = extract_content("image/png", img_path)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0]["contents"], Image.Image)
|
||||
assert result[0]["contents"].size == (100, 100)
|
||||
assert result[0]["metadata"] == {}
|
||||
|
||||
|
||||
def test_extract_content_unsupported_type():
|
||||
assert extract_content("unsupported/type", "content") == []
|
Loading…
x
Reference in New Issue
Block a user