mirror of
https://github.com/mruwnik/memory.git
synced 2025-08-01 15:36:55 +02:00
celery beat + image embedding
This commit is contained in:
parent
869e5ac6b4
commit
453aed7c19
@ -361,13 +361,13 @@ def downgrade() -> None:
|
|||||||
# Drop tables
|
# Drop tables
|
||||||
op.drop_table("photo")
|
op.drop_table("photo")
|
||||||
op.drop_table("misc_doc")
|
op.drop_table("misc_doc")
|
||||||
op.drop_table("mail_message")
|
|
||||||
op.drop_table("git_commit")
|
op.drop_table("git_commit")
|
||||||
op.drop_table("chat_message")
|
op.drop_table("chat_message")
|
||||||
op.drop_table("book_doc")
|
op.drop_table("book_doc")
|
||||||
op.drop_table("blog_post")
|
op.drop_table("blog_post")
|
||||||
op.drop_table("github_item")
|
op.drop_table("github_item")
|
||||||
op.drop_table("email_attachment")
|
op.drop_table("email_attachment")
|
||||||
|
op.drop_table("mail_message")
|
||||||
op.drop_table("source_item")
|
op.drop_table("source_item")
|
||||||
op.drop_table("rss_feeds")
|
op.drop_table("rss_feeds")
|
||||||
op.drop_table("email_accounts")
|
op.drop_table("email_accounts")
|
||||||
|
@ -21,6 +21,8 @@ volumes:
|
|||||||
# ------------------------------ X-templates ----------------------------
|
# ------------------------------ X-templates ----------------------------
|
||||||
x-common-env: &env
|
x-common-env: &env
|
||||||
RABBITMQ_USER: kb
|
RABBITMQ_USER: kb
|
||||||
|
QDRANT_HOST: qdrant
|
||||||
|
DB_HOST: postgres
|
||||||
FILE_STORAGE_DIR: /app/memory_files
|
FILE_STORAGE_DIR: /app/memory_files
|
||||||
TZ: "Etc/UTC"
|
TZ: "Etc/UTC"
|
||||||
|
|
||||||
@ -165,6 +167,13 @@ services:
|
|||||||
# - ./acme.json:/acme.json:rw
|
# - ./acme.json:/acme.json:rw
|
||||||
|
|
||||||
# ------------------------------------------------------------ Celery workers
|
# ------------------------------------------------------------ Celery workers
|
||||||
|
worker-email:
|
||||||
|
<<: *worker-base
|
||||||
|
environment:
|
||||||
|
<<: *worker-env
|
||||||
|
QUEUES: "email"
|
||||||
|
deploy: {resources: {limits: {cpus: "2", memory: 3g}}}
|
||||||
|
|
||||||
worker-text:
|
worker-text:
|
||||||
<<: *worker-base
|
<<: *worker-base
|
||||||
environment:
|
environment:
|
||||||
@ -207,6 +216,22 @@ services:
|
|||||||
QUEUES: "docs"
|
QUEUES: "docs"
|
||||||
deploy: {resources: {limits: {cpus: "1", memory: 1g}}}
|
deploy: {resources: {limits: {cpus: "1", memory: 1g}}}
|
||||||
|
|
||||||
|
ingest-hub:
|
||||||
|
<<: *worker-base
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: docker/ingest_hub/Dockerfile
|
||||||
|
environment:
|
||||||
|
<<: *worker-env
|
||||||
|
volumes:
|
||||||
|
- file_storage:/app/memory_files:rw
|
||||||
|
tmpfs:
|
||||||
|
- /tmp
|
||||||
|
- /var/tmp
|
||||||
|
- /var/log/supervisor
|
||||||
|
- /var/run/supervisor
|
||||||
|
deploy: {resources: {limits: {cpus: "0.5", memory: 512m}}}
|
||||||
|
|
||||||
# ------------------------------------------------------------ watchtower (auto-update)
|
# ------------------------------------------------------------ watchtower (auto-update)
|
||||||
watchtower:
|
watchtower:
|
||||||
image: containrrr/watchtower
|
image: containrrr/watchtower
|
||||||
|
32
docker/ingest_hub/Dockerfile
Normal file
32
docker/ingest_hub/Dockerfile
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy requirements files and setup
|
||||||
|
COPY requirements-*.txt ./
|
||||||
|
COPY setup.py ./
|
||||||
|
COPY src/ ./src/
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
libpq-dev gcc supervisor && \
|
||||||
|
pip install -e ".[workers]" && \
|
||||||
|
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Create and copy entrypoint script
|
||||||
|
COPY docker/ingest_hub/supervisor.conf /etc/supervisor/conf.d/supervisor.conf
|
||||||
|
COPY docker/workers/entry.sh ./entry.sh
|
||||||
|
RUN chmod +x entry.sh
|
||||||
|
|
||||||
|
# Create required tmpfs directories for supervisor
|
||||||
|
RUN mkdir -p /var/log/supervisor /var/run/supervisor
|
||||||
|
|
||||||
|
# Create user and set permissions
|
||||||
|
RUN useradd -m kb && chown -R kb /app /var/log/supervisor /var/run/supervisor
|
||||||
|
USER kb
|
||||||
|
|
||||||
|
# Default queues to process
|
||||||
|
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email"
|
||||||
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
|
ENTRYPOINT ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisor.conf"]
|
16
docker/ingest_hub/supervisor.conf
Normal file
16
docker/ingest_hub/supervisor.conf
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
[supervisord]
|
||||||
|
nodaemon=true
|
||||||
|
loglevel=info
|
||||||
|
logfile=/dev/stdout
|
||||||
|
logfile_maxbytes=0
|
||||||
|
user=kb
|
||||||
|
pidfile=/dev/null
|
||||||
|
|
||||||
|
[program:celery-beat]
|
||||||
|
command=celery -A memory.workers.ingest beat --pidfile= --schedule=/tmp/celerybeat-schedule
|
||||||
|
stdout_logfile=/dev/stdout
|
||||||
|
stdout_logfile_maxbytes=0
|
||||||
|
stderr_logfile=/dev/stderr
|
||||||
|
stderr_logfile_maxbytes=0
|
||||||
|
autorestart=true
|
||||||
|
startsecs=10
|
@ -22,7 +22,7 @@ RUN useradd -m kb && chown -R kb /app
|
|||||||
USER kb
|
USER kb
|
||||||
|
|
||||||
# Default queues to process
|
# Default queues to process
|
||||||
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs"
|
ENV QUEUES="medium_embed,photo_embed,low_ocr,git_summary,rss,docs,email"
|
||||||
ENV PYTHONPATH="/app"
|
ENV PYTHONPATH="/app"
|
||||||
|
|
||||||
ENTRYPOINT ["./entry.sh"]
|
ENTRYPOINT ["./entry.sh"]
|
@ -2,6 +2,7 @@ sqlalchemy==2.0.30
|
|||||||
psycopg2-binary==2.9.9
|
psycopg2-binary==2.9.9
|
||||||
pydantic==2.7.1
|
pydantic==2.7.1
|
||||||
alembic==1.13.1
|
alembic==1.13.1
|
||||||
dotenv==1.1.0
|
dotenv==0.9.9
|
||||||
voyageai==0.3.2
|
voyageai==0.3.2
|
||||||
qdrant-client==1.9.0
|
qdrant-client==1.9.0
|
||||||
|
PyMuPDF==1.25.5
|
@ -1,7 +1,9 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
from typing import Literal, TypedDict, Iterable
|
from typing import Literal, TypedDict, Iterable, Any
|
||||||
import voyageai
|
import voyageai
|
||||||
import re
|
import re
|
||||||
|
import uuid
|
||||||
|
from memory.common import extract, settings
|
||||||
|
|
||||||
# Chunking configuration
|
# Chunking configuration
|
||||||
MAX_TOKENS = 32000 # VoyageAI max context window
|
MAX_TOKENS = 32000 # VoyageAI max context window
|
||||||
@ -11,6 +13,8 @@ CHARS_PER_TOKEN = 4
|
|||||||
|
|
||||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||||
Vector = list[float]
|
Vector = list[float]
|
||||||
|
Embedding = tuple[str, Vector, dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class Collection(TypedDict):
|
class Collection(TypedDict):
|
||||||
dimension: int
|
dimension: int
|
||||||
@ -163,17 +167,48 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
|
|||||||
yield current.strip()
|
yield current.strip()
|
||||||
|
|
||||||
|
|
||||||
def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
|
def embed_chunks(chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
|
||||||
vo = voyageai.Client()
|
vo = voyageai.Client()
|
||||||
return vo.embed(chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS), model=model)
|
return vo.embed(chunks, model=model).embeddings
|
||||||
|
|
||||||
|
|
||||||
def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
|
def embed_text(texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
|
||||||
return embed_text(file_path.read_text(), model, n_dimensions)
|
chunks = [c for text in texts for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
|
||||||
|
return embed_chunks(chunks, model)
|
||||||
|
|
||||||
|
|
||||||
def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> tuple[str, list[Vector]]:
|
def embed_file(file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
|
||||||
if isinstance(content, bytes):
|
return embed_text([file_path.read_text()], model)
|
||||||
content = content.decode("utf-8")
|
|
||||||
|
|
||||||
return get_modality(mime_type), embed_text(content, model, n_dimensions)
|
|
||||||
|
def embed_mixed(items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL) -> list[Vector]:
|
||||||
|
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
|
||||||
|
if isinstance(item, str):
|
||||||
|
return [c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
|
||||||
|
return [item]
|
||||||
|
|
||||||
|
chunks = [c for item in items for c in to_chunks(item)]
|
||||||
|
return embed_chunks(chunks, model)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_page(page: dict[str, Any]) -> list[Vector]:
|
||||||
|
contents = page["contents"]
|
||||||
|
if all(isinstance(c, str) for c in contents):
|
||||||
|
return embed_text(contents, model=settings.TEXT_EMBEDDING_MODEL)
|
||||||
|
return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL)
|
||||||
|
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
mime_type: str,
|
||||||
|
content: bytes | str | pathlib.Path,
|
||||||
|
metadata: dict[str, Any] = {},
|
||||||
|
) -> tuple[str, list[Embedding]]:
|
||||||
|
modality = get_modality(mime_type)
|
||||||
|
|
||||||
|
pages = extract.extract_content(mime_type, content)
|
||||||
|
vectors = [
|
||||||
|
(str(uuid.uuid4()), vector, page.get("metadata", {}) | metadata)
|
||||||
|
for page in pages
|
||||||
|
for vector in embed_page(page)
|
||||||
|
]
|
||||||
|
return modality, vectors
|
||||||
|
74
src/memory/common/extract.py
Normal file
74
src/memory/common/extract.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
import io
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
import pymupdf # PyMuPDF
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Any, TypedDict, Generator
|
||||||
|
|
||||||
|
|
||||||
|
MulitmodalChunk = Image.Image | str
|
||||||
|
class Page(TypedDict):
|
||||||
|
contents: list[MulitmodalChunk]
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None, None]:
|
||||||
|
if isinstance(content, pathlib.Path):
|
||||||
|
yield content
|
||||||
|
else:
|
||||||
|
mode = "w" if isinstance(content, str) else "wb"
|
||||||
|
with tempfile.NamedTemporaryFile(mode=mode) as f:
|
||||||
|
f.write(content)
|
||||||
|
f.flush()
|
||||||
|
yield pathlib.Path(f.name)
|
||||||
|
|
||||||
|
|
||||||
|
def page_to_image(page: pymupdf.Page) -> Image.Image:
|
||||||
|
pix = page.get_pixmap()
|
||||||
|
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||||
|
|
||||||
|
|
||||||
|
def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||||
|
with as_file(content) as file_path:
|
||||||
|
with pymupdf.open(file_path) as pdf:
|
||||||
|
return [{
|
||||||
|
"contents": page_to_image(page),
|
||||||
|
"metadata": {
|
||||||
|
"page": page.number,
|
||||||
|
"width": page.rect.width,
|
||||||
|
"height": page.rect.height,
|
||||||
|
}
|
||||||
|
} for page in pdf.pages()]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||||
|
if isinstance(content, pathlib.Path):
|
||||||
|
image = Image.open(content)
|
||||||
|
elif isinstance(content, bytes):
|
||||||
|
image = Image.open(io.BytesIO(content))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||||
|
return [{"contents": image, "metadata": {}}]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||||
|
if isinstance(content, pathlib.Path):
|
||||||
|
content = content.read_text()
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
content = content.decode("utf-8")
|
||||||
|
|
||||||
|
return [{"contents": [content], "metadata": {}}]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
||||||
|
if mime_type == "application/pdf":
|
||||||
|
return doc_to_images(content)
|
||||||
|
if mime_type.startswith("text/"):
|
||||||
|
return extract_text(content)
|
||||||
|
if mime_type.startswith("image/"):
|
||||||
|
return extract_image(content)
|
||||||
|
|
||||||
|
# Return empty list for unknown mime types
|
||||||
|
return []
|
@ -7,16 +7,19 @@ load_dotenv()
|
|||||||
def boolean_env(key: str, default: bool = False) -> bool:
|
def boolean_env(key: str, default: bool = False) -> bool:
|
||||||
return os.getenv(key, "0").lower() in ("1", "true", "yes")
|
return os.getenv(key, "0").lower() in ("1", "true", "yes")
|
||||||
|
|
||||||
|
# Database settings
|
||||||
DB_USER = os.getenv("DB_USER", "kb")
|
DB_USER = os.getenv("DB_USER", "kb")
|
||||||
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
|
if password_file := os.getenv("POSTGRES_PASSWORD_FILE"):
|
||||||
|
DB_PASSWORD = pathlib.Path(password_file).read_text().strip()
|
||||||
|
else:
|
||||||
|
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
|
||||||
|
|
||||||
DB_HOST = os.getenv("DB_HOST", "postgres")
|
DB_HOST = os.getenv("DB_HOST", "postgres")
|
||||||
DB_PORT = os.getenv("DB_PORT", "5432")
|
DB_PORT = os.getenv("DB_PORT", "5432")
|
||||||
DB_NAME = os.getenv("DB_NAME", "kb")
|
DB_NAME = os.getenv("DB_NAME", "kb")
|
||||||
|
|
||||||
def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, db=DB_NAME):
|
def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, db=DB_NAME):
|
||||||
return f"postgresql://{user}:{password}@{host}:{port}/{db}"
|
return f"postgresql://{user}:{password}@{host}:{port}/{db}"
|
||||||
|
|
||||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
||||||
|
|
||||||
|
|
||||||
@ -33,3 +36,13 @@ QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334"))
|
|||||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
|
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
|
||||||
QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False)
|
QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False)
|
||||||
QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
||||||
|
|
||||||
|
|
||||||
|
# Worker settings
|
||||||
|
BEAT_LOOP_INTERVAL = int(os.getenv("BEAT_LOOP_INTERVAL", 3600))
|
||||||
|
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 600))
|
||||||
|
|
||||||
|
|
||||||
|
# Embedding settings
|
||||||
|
TEXT_EMBEDDING_MODEL = os.getenv("TEXT_EMBEDDING_MODEL", "voyage-3-large")
|
||||||
|
MIXED_EMBEDDING_MODEL = os.getenv("MIXED_EMBEDDING_MODEL", "voyage-multimodal-3")
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
def rabbit_url() -> str:
|
def rabbit_url() -> str:
|
||||||
user = os.getenv("RABBITMQ_USER", "guest")
|
user = os.getenv("RABBITMQ_USER", "guest")
|
||||||
@ -8,10 +8,11 @@ def rabbit_url() -> str:
|
|||||||
return f"amqp://{user}:{password}@rabbitmq:5672//"
|
return f"amqp://{user}:{password}@rabbitmq:5672//"
|
||||||
|
|
||||||
|
|
||||||
app = Celery("memory",
|
app = Celery(
|
||||||
broker=rabbit_url(),
|
"memory",
|
||||||
backend=os.getenv("CELERY_RESULT_BACKEND",
|
broker=rabbit_url(),
|
||||||
"db+postgresql://kb:kb@postgres/kb"))
|
backend=os.getenv("CELERY_RESULT_BACKEND", f"db+{settings.DB_URL}")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
app.autodiscover_tasks(["memory.workers.tasks"])
|
app.autodiscover_tasks(["memory.workers.tasks"])
|
||||||
@ -22,8 +23,8 @@ app.conf.update(
|
|||||||
task_reject_on_worker_lost=True,
|
task_reject_on_worker_lost=True,
|
||||||
worker_prefetch_multiplier=1,
|
worker_prefetch_multiplier=1,
|
||||||
task_routes={
|
task_routes={
|
||||||
# Task routing configuration
|
|
||||||
"memory.workers.tasks.text.*": {"queue": "medium_embed"},
|
"memory.workers.tasks.text.*": {"queue": "medium_embed"},
|
||||||
|
"memory.workers.tasks.email.*": {"queue": "email"},
|
||||||
"memory.workers.tasks.photo.*": {"queue": "photo_embed"},
|
"memory.workers.tasks.photo.*": {"queue": "photo_embed"},
|
||||||
"memory.workers.tasks.ocr.*": {"queue": "low_ocr"},
|
"memory.workers.tasks.ocr.*": {"queue": "low_ocr"},
|
||||||
"memory.workers.tasks.git.*": {"queue": "git_summary"},
|
"memory.workers.tasks.git.*": {"queue": "git_summary"},
|
||||||
|
@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from memory.common import settings, embedding
|
from memory.common import settings, embedding
|
||||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||||
from memory.common.qdrant import get_qdrant_client, upsert_vectors
|
from memory.common import qdrant
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -471,17 +471,18 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
|||||||
|
|
||||||
|
|
||||||
def vectorize_email(email: MailMessage) -> list[float]:
|
def vectorize_email(email: MailMessage) -> list[float]:
|
||||||
qdrant_client = get_qdrant_client()
|
qdrant_client = qdrant.get_qdrant_client()
|
||||||
|
|
||||||
chunks = embedding.embed_text(email.body_raw)
|
_, chunks = embedding.embed(
|
||||||
payloads = [email.as_payload()] * len(chunks)
|
"text/plain", email.body_raw, metadata=email.as_payload(),
|
||||||
vector_ids = [str(uuid.uuid4()) for _ in chunks]
|
)
|
||||||
upsert_vectors(
|
vector_ids, vectors, metadata = zip(*chunks)
|
||||||
|
qdrant.upsert_vectors(
|
||||||
client=qdrant_client,
|
client=qdrant_client,
|
||||||
collection_name="mail",
|
collection_name="mail",
|
||||||
ids=vector_ids,
|
ids=vector_ids,
|
||||||
vectors=chunks,
|
vectors=vectors,
|
||||||
payloads=payloads,
|
payloads=metadata,
|
||||||
)
|
)
|
||||||
vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids]
|
vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids]
|
||||||
|
|
||||||
@ -491,15 +492,14 @@ def vectorize_email(email: MailMessage) -> list[float]:
|
|||||||
content = pathlib.Path(attachment.file_path).read_bytes()
|
content = pathlib.Path(attachment.file_path).read_bytes()
|
||||||
else:
|
else:
|
||||||
content = attachment.content
|
content = attachment.content
|
||||||
collection, vectors = embedding.embed(attachment.content_type, content)
|
collection, chunks = embedding.embed(attachment.content_type, content, metadata=attachment.as_payload())
|
||||||
attachment.source.vector_ids = vector_ids
|
ids, vectors, metadata = zip(*chunks)
|
||||||
embeds[collection].extend(
|
attachment.source.vector_ids = ids
|
||||||
(str(uuid.uuid4()), vector, attachment.as_payload()) for vector in vectors
|
embeds[collection].extend(chunks)
|
||||||
)
|
|
||||||
|
|
||||||
for collection, embeds in embeds.items():
|
for collection, chunks in embeds.items():
|
||||||
ids, vectors, payloads = zip(*embeds)
|
ids, vectors, payloads = zip(*chunks)
|
||||||
upsert_vectors(
|
qdrant.upsert_vectors(
|
||||||
client=qdrant_client,
|
client=qdrant_client,
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
ids=ids,
|
ids=ids,
|
||||||
|
14
src/memory/workers/ingest.py
Normal file
14
src/memory/workers/ingest.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from celery.schedules import schedule
|
||||||
|
from memory.workers.tasks.email import SYNC_ALL_ACCOUNTS
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.workers.celery_app import app
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_after_configure.connect
|
||||||
|
def register_mail_schedules(sender, **_):
|
||||||
|
sender.add_periodic_task(
|
||||||
|
schedule=schedule(settings.EMAIL_SYNC_INTERVAL),
|
||||||
|
sig=app.signature(SYNC_ALL_ACCOUNTS),
|
||||||
|
name="sync-mail-all",
|
||||||
|
options={"queue": "email"},
|
||||||
|
)
|
@ -2,3 +2,7 @@
|
|||||||
Import sub-modules so Celery can register their @app.task decorators.
|
Import sub-modules so Celery can register their @app.task decorators.
|
||||||
"""
|
"""
|
||||||
from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa
|
from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa
|
||||||
|
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["text", "photo", "ocr", "git", "rss", "docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"]
|
@ -18,8 +18,12 @@ from memory.workers.email import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PROCESS_EMAIL = "memory.workers.tasks.email.process_message"
|
||||||
|
SYNC_ACCOUNT = "memory.workers.tasks.email.sync_account"
|
||||||
|
SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
|
||||||
|
|
||||||
@app.task(name="memory.email.process_message")
|
|
||||||
|
@app.task(name=PROCESS_EMAIL)
|
||||||
def process_message(
|
def process_message(
|
||||||
account_id: int, message_id: str, folder: str, raw_email: str,
|
account_id: int, message_id: str, folder: str, raw_email: str,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
@ -35,6 +39,7 @@ def process_message(
|
|||||||
Returns:
|
Returns:
|
||||||
source_id if successful, None otherwise
|
source_id if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
|
logger.info(f"Processing message {message_id} for account {account_id}")
|
||||||
if not raw_email.strip():
|
if not raw_email.strip():
|
||||||
logger.warning(f"Empty email message received for account {account_id}")
|
logger.warning(f"Empty email message received for account {account_id}")
|
||||||
return None
|
return None
|
||||||
@ -76,7 +81,7 @@ def process_message(
|
|||||||
return source_item.id
|
return source_item.id
|
||||||
|
|
||||||
|
|
||||||
@app.task(name="memory.email.sync_account")
|
@app.task(name=SYNC_ACCOUNT)
|
||||||
def sync_account(account_id: int) -> dict:
|
def sync_account(account_id: int) -> dict:
|
||||||
"""
|
"""
|
||||||
Synchronize emails from a specific account.
|
Synchronize emails from a specific account.
|
||||||
@ -87,6 +92,7 @@ def sync_account(account_id: int) -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
dict with stats about the sync operation
|
dict with stats about the sync operation
|
||||||
"""
|
"""
|
||||||
|
logger.info(f"Syncing account {account_id}")
|
||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||||
if not account or not account.active:
|
if not account or not account.active:
|
||||||
@ -124,7 +130,7 @@ def sync_account(account_id: int) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.task(name="memory.email.sync_all_accounts")
|
@app.task(name=SYNC_ALL_ACCOUNTS)
|
||||||
def sync_all_accounts() -> list[dict]:
|
def sync_all_accounts() -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Synchronize all active email accounts.
|
Synchronize all active email accounts.
|
||||||
|
@ -7,6 +7,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
|
import voyageai
|
||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine, text
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from testcontainers.qdrant import QdrantContainer
|
from testcontainers.qdrant import QdrantContainer
|
||||||
@ -210,3 +211,8 @@ def qdrant():
|
|||||||
initialize_collections(client)
|
initialize_collections(client)
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_voyage_client():
|
||||||
|
with patch.object(voyageai, "Client", autospec=True) as mock_client:
|
||||||
|
yield mock_client()
|
BIN
tests/data/regulamin.pdf
Normal file
BIN
tests/data/regulamin.pdf
Normal file
Binary file not shown.
@ -1,5 +1,19 @@
|
|||||||
|
import uuid
|
||||||
import pytest
|
import pytest
|
||||||
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN
|
from unittest.mock import Mock, patch
|
||||||
|
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count, get_modality, embed_text, embed_file, embed_mixed, embed_page, embed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embed(mock_voyage_client):
|
||||||
|
vectors = ([i] for i in range(1000))
|
||||||
|
|
||||||
|
def embed(texts, model):
|
||||||
|
return Mock(embeddings=[next(vectors) for _ in texts])
|
||||||
|
|
||||||
|
mock_voyage_client.embed.side_effect = embed
|
||||||
|
|
||||||
|
return mock_voyage_client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -270,3 +284,77 @@ def test_chunk_text_very_long_sentences():
|
|||||||
'chunks by the function.',
|
'chunks by the function.',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"string, expected_count",
|
||||||
|
[
|
||||||
|
("", 0),
|
||||||
|
("a" * CHARS_PER_TOKEN, 1),
|
||||||
|
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||||
|
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||||
|
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_approx_token_count(string, expected_count):
|
||||||
|
assert approx_token_count(string) == expected_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mime_type, expected_modality",
|
||||||
|
[
|
||||||
|
("text/plain", "doc"),
|
||||||
|
("text/html", "doc"),
|
||||||
|
("image/jpeg", "photo"),
|
||||||
|
("image/png", "photo"),
|
||||||
|
("application/pdf", "book"),
|
||||||
|
("application/epub+zip", "book"),
|
||||||
|
("application/mobi", "book"),
|
||||||
|
("application/x-mobipocket-ebook", "book"),
|
||||||
|
("audio/mp3", "unknown"),
|
||||||
|
("video/mp4", "unknown"),
|
||||||
|
("text/something-new", "doc"), # Should match by 'text/' stem
|
||||||
|
("image/something-new", "photo"), # Should match by 'image/' stem
|
||||||
|
("custom/format", "unknown"), # No matching stem
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_get_modality(mime_type, expected_modality):
|
||||||
|
assert get_modality(mime_type) == expected_modality
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_text(mock_embed):
|
||||||
|
texts = ["text1 with words", "text2"]
|
||||||
|
assert embed_text(texts) == [[0], [1]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_file(mock_embed, tmp_path):
|
||||||
|
mock_file = tmp_path / "test.txt"
|
||||||
|
mock_file.write_text("file content")
|
||||||
|
|
||||||
|
assert embed_file(mock_file) == [[0]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_mixed(mock_embed):
|
||||||
|
items = ["text", {"type": "image", "data": "base64"}]
|
||||||
|
assert embed_mixed(items) == [[0], [1]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_page_text_only(mock_embed):
|
||||||
|
page = {"contents": ["text1", "text2"]}
|
||||||
|
assert embed_page(page) == [[0], [1]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_page_mixed_content(mock_embed):
|
||||||
|
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
|
||||||
|
assert embed_page(page) == [[0], [1]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed(mock_embed):
|
||||||
|
mime_type = "text/plain"
|
||||||
|
content = "sample content"
|
||||||
|
metadata = {"source": "test"}
|
||||||
|
|
||||||
|
with patch.object(uuid, "uuid4", return_value="id1"):
|
||||||
|
modality, vectors = embed(mime_type, content, metadata)
|
||||||
|
|
||||||
|
assert modality == "doc"
|
||||||
|
assert vectors == [('id1', [0], {'source': 'test'})]
|
||||||
|
131
tests/memory/common/test_extract.py
Normal file
131
tests/memory/common/test_extract.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import pathlib
|
||||||
|
import pytest
|
||||||
|
import pymupdf
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
from memory.common.extract import as_file, extract_text, extract_content, Page, doc_to_images, extract_image
|
||||||
|
|
||||||
|
|
||||||
|
REGULAMIN = pathlib.Path(__file__).parent.parent.parent / "data" / "regulamin.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
def test_as_file_with_path(tmp_path):
|
||||||
|
test_path = tmp_path / "test.txt"
|
||||||
|
test_path.write_text("test content")
|
||||||
|
|
||||||
|
with as_file(test_path) as path:
|
||||||
|
assert path == test_path
|
||||||
|
assert path.read_text() == "test content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_as_file_with_bytes():
|
||||||
|
content = b"test content"
|
||||||
|
with as_file(content) as path:
|
||||||
|
assert pathlib.Path(path).read_bytes() == content
|
||||||
|
|
||||||
|
|
||||||
|
def test_as_file_with_str():
|
||||||
|
content = "test content"
|
||||||
|
with as_file(content) as path:
|
||||||
|
assert pathlib.Path(path).read_text() == content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_content,expected",
|
||||||
|
[
|
||||||
|
("simple text", [{"contents": ["simple text"], "metadata": {}}]),
|
||||||
|
(b"bytes text", [{"contents": ["bytes text"], "metadata": {}}]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_extract_text(input_content, expected):
|
||||||
|
assert extract_text(input_content) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_text_with_path(tmp_path):
|
||||||
|
test_file = tmp_path / "test.txt"
|
||||||
|
test_file.write_text("file text content")
|
||||||
|
|
||||||
|
assert extract_text(test_file) == [{"contents": ["file text content"], "metadata": {}}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_to_images():
|
||||||
|
result = doc_to_images(REGULAMIN)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
with pymupdf.open(REGULAMIN) as pdf:
|
||||||
|
for page, pdf_page in zip(result, pdf.pages()):
|
||||||
|
pix = pdf_page.get_pixmap()
|
||||||
|
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||||
|
assert page["contents"] == img
|
||||||
|
assert page["metadata"] == {
|
||||||
|
"page": pdf_page.number,
|
||||||
|
"width": pdf_page.rect.width,
|
||||||
|
"height": pdf_page.rect.height,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_image_with_path(tmp_path):
|
||||||
|
img = Image.new('RGB', (100, 100), color='red')
|
||||||
|
img_path = tmp_path / "test.png"
|
||||||
|
img.save(img_path)
|
||||||
|
|
||||||
|
page, = extract_image(img_path)
|
||||||
|
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
|
||||||
|
assert page["metadata"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_image_with_bytes():
|
||||||
|
img = Image.new('RGB', (100, 100), color='blue')
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
img.save(buffer, format='PNG')
|
||||||
|
img_bytes = buffer.getvalue()
|
||||||
|
|
||||||
|
page, = extract_image(img_bytes)
|
||||||
|
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
|
||||||
|
assert page["metadata"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_image_with_str():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
extract_image("test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mime_type,content",
|
||||||
|
[
|
||||||
|
("text/plain", "Text content"),
|
||||||
|
("text/html", "<html>content</html>"),
|
||||||
|
("text/markdown", "# Heading"),
|
||||||
|
("text/csv", "a,b,c"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_extract_content_different_text_types(mime_type, content):
|
||||||
|
assert extract_content(mime_type, content) == [{"contents": [content], "metadata": {}}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_pdf():
|
||||||
|
result = extract_content("application/pdf", REGULAMIN)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(isinstance(page["contents"], Image.Image) for page in result)
|
||||||
|
assert all("page" in page["metadata"] for page in result)
|
||||||
|
assert all("width" in page["metadata"] for page in result)
|
||||||
|
assert all("height" in page["metadata"] for page in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_image(tmp_path):
|
||||||
|
# Create a test image
|
||||||
|
img = Image.new('RGB', (100, 100), color='red')
|
||||||
|
img_path = tmp_path / "test_img.png"
|
||||||
|
img.save(img_path)
|
||||||
|
|
||||||
|
result = extract_content("image/png", img_path)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(result[0]["contents"], Image.Image)
|
||||||
|
assert result[0]["contents"].size == (100, 100)
|
||||||
|
assert result[0]["metadata"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_unsupported_type():
|
||||||
|
assert extract_content("unsupported/type", "content") == []
|
Loading…
x
Reference in New Issue
Block a user