mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 23:24:43 +02:00
handle duplicates and docx
This commit is contained in:
parent
c6cd809eb7
commit
fe15442a6d
@ -9,7 +9,11 @@ COPY src/ ./src/
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libpq-dev gcc && \
|
||||
libpq-dev gcc pandoc \
|
||||
texlive-full texlive-fonts-recommended texlive-plain-generic \
|
||||
# For optional LibreOffice support (uncomment if needed)
|
||||
# libreoffice-writer \
|
||||
&& \
|
||||
pip install -e ".[workers]" && \
|
||||
apt-get purge -y gcc && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
celery==5.3.6
|
||||
openai==1.25.0
|
||||
pillow==10.3.0
|
||||
pypandoc==1.15.0
|
@ -22,11 +22,12 @@ from sqlalchemy import (
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
event,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, Session
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.parsers.email import parse_email_message
|
||||
@ -34,6 +35,50 @@ from memory.common.parsers.email import parse_email_message
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def handle_duplicate_sha256(session, flush_context, instances):
|
||||
"""
|
||||
Event listener that efficiently checks for duplicate sha256 values before flush
|
||||
and removes items with duplicate sha256 from the session.
|
||||
|
||||
Uses a single query to identify all duplicates rather than querying for each item.
|
||||
"""
|
||||
# Find all SourceItem objects being added
|
||||
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
items = {}
|
||||
for item in new_items:
|
||||
try:
|
||||
if (sha256 := item.sha256) is None:
|
||||
continue
|
||||
|
||||
if sha256 in items:
|
||||
session.expunge(item)
|
||||
continue
|
||||
|
||||
items[sha256] = item
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
# Query database for existing items with these sha256 values in a single query
|
||||
existing_sha256s = set(
|
||||
row[0]
|
||||
for row in session.query(SourceItem.sha256).filter(
|
||||
SourceItem.sha256.in_(items.keys())
|
||||
)
|
||||
)
|
||||
|
||||
# Remove objects with duplicate sha256 values from the session
|
||||
for sha256 in existing_sha256s:
|
||||
if sha256 in items:
|
||||
session.expunge(items[sha256])
|
||||
|
||||
|
||||
def clean_filename(filename: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
|
||||
|
||||
@ -65,7 +110,7 @@ class Chunk(Base):
|
||||
|
||||
@property
|
||||
def data(self) -> list[bytes | str | Image.Image]:
|
||||
if not self.file_path:
|
||||
if self.file_path is None:
|
||||
return [self.content]
|
||||
|
||||
path = pathlib.Path(self.file_path)
|
||||
@ -178,7 +223,7 @@ class MailMessage(SourceItem):
|
||||
"sender": self.sender,
|
||||
"recipients": self.recipients,
|
||||
"folder": self.folder,
|
||||
"tags": self.tags,
|
||||
"tags": self.tags + [self.sender] + self.recipients,
|
||||
"date": self.sent_at and self.sent_at.isoformat() or None,
|
||||
}
|
||||
|
||||
|
@ -2,15 +2,17 @@ from contextlib import contextmanager
|
||||
import io
|
||||
import pathlib
|
||||
import tempfile
|
||||
import pypandoc
|
||||
import pymupdf # PyMuPDF
|
||||
from PIL import Image
|
||||
from typing import Any, TypedDict, Generator
|
||||
|
||||
from typing import Any, TypedDict, Generator, Sequence
|
||||
|
||||
|
||||
MulitmodalChunk = Image.Image | str
|
||||
|
||||
|
||||
class Page(TypedDict):
|
||||
contents: list[MulitmodalChunk]
|
||||
contents: Sequence[MulitmodalChunk]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@ -27,7 +29,7 @@ def as_file(content: bytes | str | pathlib.Path) -> Generator[pathlib.Path, None
|
||||
|
||||
|
||||
def page_to_image(page: pymupdf.Page) -> Image.Image:
|
||||
pix = page.get_pixmap()
|
||||
pix = page.get_pixmap() # type: ignore
|
||||
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
|
||||
|
||||
@ -36,16 +38,37 @@ def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
with pymupdf.open(file_path) as pdf:
|
||||
return [
|
||||
{
|
||||
"contents": page_to_image(page),
|
||||
"contents": [page_to_image(page)],
|
||||
"metadata": {
|
||||
"page": page.number,
|
||||
"width": page.rect.width,
|
||||
"height": page.rect.height,
|
||||
},
|
||||
}
|
||||
} for page in pdf.pages()
|
||||
for page in pdf.pages()
|
||||
]
|
||||
|
||||
|
||||
def docx_to_pdf(
|
||||
docx_path: pathlib.Path,
|
||||
output_path: pathlib.Path | None = None,
|
||||
) -> pathlib.Path:
|
||||
"""Convert DOCX to PDF using pypandoc"""
|
||||
if output_path is None:
|
||||
output_path = docx_path.with_suffix(".pdf")
|
||||
|
||||
pypandoc.convert_file(str(docx_path), "pdf", outputfile=str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def extract_docx(docx_path: pathlib.Path) -> list[Page]:
|
||||
"""Extract content from DOCX by converting to PDF first, then processing"""
|
||||
with as_file(docx_path) as file_path:
|
||||
pdf_path = docx_to_pdf(file_path)
|
||||
return doc_to_images(pdf_path)
|
||||
|
||||
|
||||
def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
image = Image.open(content)
|
||||
@ -68,6 +91,11 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
if mime_type == "application/pdf":
|
||||
return doc_to_images(content)
|
||||
if isinstance(content, pathlib.Path) and mime_type in [
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/msword",
|
||||
]:
|
||||
return extract_docx(content)
|
||||
if mime_type.startswith("text/"):
|
||||
return extract_text(content)
|
||||
if mime_type.startswith("image/"):
|
||||
|
BIN
tests/data/sample.docx
Normal file
BIN
tests/data/sample.docx
Normal file
Binary file not shown.
45
tests/memory/common/db/test_models.py
Normal file
45
tests/memory/common/db/test_models.py
Normal file
@ -0,0 +1,45 @@
|
||||
from memory.common.db.models import SourceItem
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def test_unique_source_items_same_commit(db_session: Session):
|
||||
source_item1 = SourceItem(sha256=b"1234567890", content="test1", modality="email")
|
||||
source_item2 = SourceItem(sha256=b"1234567890", content="test2", modality="email")
|
||||
source_item3 = SourceItem(sha256=b"1234567891", content="test3", modality="email")
|
||||
db_session.add(source_item1)
|
||||
db_session.add(source_item2)
|
||||
db_session.add(source_item3)
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.query(SourceItem.sha256, SourceItem.content).all() == [
|
||||
(b"1234567890", "test1"),
|
||||
(b"1234567891", "test3"),
|
||||
]
|
||||
|
||||
|
||||
def test_unique_source_items_previous_commit(db_session: Session):
|
||||
db_session.add_all(
|
||||
[
|
||||
SourceItem(sha256=b"1234567890", content="test1", modality="email"),
|
||||
SourceItem(sha256=b"1234567891", content="test2", modality="email"),
|
||||
SourceItem(sha256=b"1234567892", content="test3", modality="email"),
|
||||
]
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
db_session.add_all(
|
||||
[
|
||||
SourceItem(sha256=b"1234567890", content="test4", modality="email"),
|
||||
SourceItem(sha256=b"1234567893", content="test5", modality="email"),
|
||||
SourceItem(sha256=b"1234567894", content="test6", modality="email"),
|
||||
]
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.query(SourceItem.sha256, SourceItem.content).all() == [
|
||||
(b"1234567890", "test1"),
|
||||
(b"1234567891", "test2"),
|
||||
(b"1234567892", "test3"),
|
||||
(b"1234567893", "test5"),
|
||||
(b"1234567894", "test6"),
|
||||
]
|
@ -3,10 +3,26 @@ import pytest
|
||||
import pymupdf
|
||||
from PIL import Image
|
||||
import io
|
||||
from memory.common.extract import as_file, extract_text, extract_content, Page, doc_to_images, extract_image
|
||||
import shutil
|
||||
from memory.common.extract import (
|
||||
as_file,
|
||||
extract_text,
|
||||
extract_content,
|
||||
Page,
|
||||
doc_to_images,
|
||||
extract_image,
|
||||
docx_to_pdf,
|
||||
extract_docx,
|
||||
)
|
||||
|
||||
|
||||
REGULAMIN = pathlib.Path(__file__).parent.parent.parent / "data" / "regulamin.pdf"
|
||||
SAMPLE_DOCX = pathlib.Path(__file__).parent.parent.parent / "data" / "sample.docx"
|
||||
|
||||
|
||||
# Helper to check if pdflatex is available
|
||||
def is_pdflatex_available():
|
||||
return shutil.which("pdflatex") is not None
|
||||
|
||||
|
||||
def test_as_file_with_path(tmp_path):
|
||||
@ -35,7 +51,7 @@ def test_as_file_with_str():
|
||||
[
|
||||
("simple text", [{"contents": ["simple text"], "metadata": {}}]),
|
||||
(b"bytes text", [{"contents": ["bytes text"], "metadata": {}}]),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_extract_text(input_content, expected):
|
||||
assert extract_text(input_content) == expected
|
||||
@ -45,7 +61,9 @@ def test_extract_text_with_path(tmp_path):
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("file text content")
|
||||
|
||||
assert extract_text(test_file) == [{"contents": ["file text content"], "metadata": {}}]
|
||||
assert extract_text(test_file) == [
|
||||
{"contents": ["file text content"], "metadata": {}}
|
||||
]
|
||||
|
||||
|
||||
def test_doc_to_images():
|
||||
@ -56,7 +74,7 @@ def test_doc_to_images():
|
||||
for page, pdf_page in zip(result, pdf.pages()):
|
||||
pix = pdf_page.get_pixmap()
|
||||
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
assert page["contents"] == img
|
||||
assert page["contents"] == [img]
|
||||
assert page["metadata"] == {
|
||||
"page": pdf_page.number,
|
||||
"width": pdf_page.rect.width,
|
||||
@ -65,22 +83,22 @@ def test_doc_to_images():
|
||||
|
||||
|
||||
def test_extract_image_with_path(tmp_path):
|
||||
img = Image.new('RGB', (100, 100), color='red')
|
||||
img = Image.new("RGB", (100, 100), color="red")
|
||||
img_path = tmp_path / "test.png"
|
||||
img.save(img_path)
|
||||
|
||||
page, = extract_image(img_path)
|
||||
(page,) = extract_image(img_path)
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["metadata"] == {}
|
||||
|
||||
|
||||
def test_extract_image_with_bytes():
|
||||
img = Image.new('RGB', (100, 100), color='blue')
|
||||
img = Image.new("RGB", (100, 100), color="blue")
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
img.save(buffer, format="PNG")
|
||||
img_bytes = buffer.getvalue()
|
||||
|
||||
page, = extract_image(img_bytes)
|
||||
(page,) = extract_image(img_bytes)
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["metadata"] == {}
|
||||
|
||||
@ -97,17 +115,23 @@ def test_extract_image_with_str():
|
||||
("text/html", "<html>content</html>"),
|
||||
("text/markdown", "# Heading"),
|
||||
("text/csv", "a,b,c"),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_extract_content_different_text_types(mime_type, content):
|
||||
assert extract_content(mime_type, content) == [{"contents": [content], "metadata": {}}]
|
||||
assert extract_content(mime_type, content) == [
|
||||
{"contents": [content], "metadata": {}}
|
||||
]
|
||||
|
||||
|
||||
def test_extract_content_pdf():
|
||||
result = extract_content("application/pdf", REGULAMIN)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(page["contents"], Image.Image) for page in result)
|
||||
assert all(
|
||||
isinstance(page["contents"], list)
|
||||
and all(isinstance(c, Image.Image) for c in page["contents"])
|
||||
for page in result
|
||||
)
|
||||
assert all("page" in page["metadata"] for page in result)
|
||||
assert all("width" in page["metadata"] for page in result)
|
||||
assert all("height" in page["metadata"] for page in result)
|
||||
@ -115,11 +139,11 @@ def test_extract_content_pdf():
|
||||
|
||||
def test_extract_content_image(tmp_path):
|
||||
# Create a test image
|
||||
img = Image.new('RGB', (100, 100), color='red')
|
||||
img = Image.new("RGB", (100, 100), color="red")
|
||||
img_path = tmp_path / "test_img.png"
|
||||
img.save(img_path)
|
||||
|
||||
result, = extract_content("image/png", img_path)
|
||||
(result,) = extract_content("image/png", img_path)
|
||||
|
||||
assert isinstance(result["contents"][0], Image.Image)
|
||||
assert result["contents"][0].size == (100, 100)
|
||||
@ -128,3 +152,37 @@ def test_extract_content_image(tmp_path):
|
||||
|
||||
def test_extract_content_unsupported_type():
|
||||
assert extract_content("unsupported/type", "content") == []
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
|
||||
def test_docx_to_pdf(tmp_path):
|
||||
output_path = tmp_path / "output.pdf"
|
||||
result_path = docx_to_pdf(SAMPLE_DOCX, output_path)
|
||||
|
||||
assert result_path == output_path
|
||||
assert result_path.exists()
|
||||
assert result_path.suffix == ".pdf"
|
||||
|
||||
# Verify the PDF is valid by opening it
|
||||
with pymupdf.open(result_path) as pdf:
|
||||
assert pdf.page_count > 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
|
||||
def test_docx_to_pdf_default_output():
|
||||
# Test with default output path
|
||||
result_path = docx_to_pdf(SAMPLE_DOCX)
|
||||
|
||||
assert result_path == SAMPLE_DOCX.with_suffix(".pdf")
|
||||
assert result_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
|
||||
def test_extract_docx():
|
||||
pages = extract_docx(SAMPLE_DOCX)
|
||||
|
||||
assert len(pages) > 0
|
||||
assert all(isinstance(page, dict) for page in pages)
|
||||
assert all("contents" in page for page in pages)
|
||||
assert all("metadata" in page for page in pages)
|
||||
assert all(isinstance(page["contents"][0], Image.Image) for page in pages)
|
||||
|
Loading…
x
Reference in New Issue
Block a user