mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
fix linting
This commit is contained in:
parent
1291ca9d08
commit
ab87bced81
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@ -17,10 +17,10 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
pip install ruff==0.11.10 pylint==1.1.400
|
||||
pip install ruff==0.11.10 pyright==1.1.327
|
||||
- name: Run linters
|
||||
run: |
|
||||
ruff check .
|
||||
pylint $(git ls-files '*.py')
|
||||
pyright $(git ls-files '*.py')
|
||||
- name: Run tests
|
||||
run: pytest -vv
|
||||
|
@ -1,4 +1,5 @@
|
||||
fastapi==0.112.2
|
||||
uvicorn==0.29.0
|
||||
python-jose==3.3.0
|
||||
python-multipart==0.0.9
|
||||
python-multipart==0.0.9
|
||||
sqladmin
|
||||
|
170
src/memory/api/admin.py
Normal file
170
src/memory/api/admin.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""
|
||||
SQLAdmin views for the knowledge base database models.
|
||||
"""
|
||||
|
||||
from sqladmin import Admin, ModelView
|
||||
|
||||
from memory.common.db.models import (
|
||||
Chunk,
|
||||
SourceItem,
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
Photo,
|
||||
Comic,
|
||||
Book,
|
||||
BookSection,
|
||||
BlogPost,
|
||||
MiscDoc,
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_COLUMNS = (
|
||||
"modality",
|
||||
"embed_status",
|
||||
"inserted_at",
|
||||
"tags",
|
||||
"size",
|
||||
"mime_type",
|
||||
"filename",
|
||||
"content",
|
||||
)
|
||||
|
||||
|
||||
def source_columns(model: type[SourceItem], *columns: str):
|
||||
return [getattr(model, c) for c in columns + DEFAULT_COLUMNS if hasattr(model, c)]
|
||||
|
||||
|
||||
# Create admin views for all models
|
||||
class SourceItemAdmin(ModelView, model=SourceItem):
|
||||
column_list = source_columns(SourceItem)
|
||||
column_searchable_list = [
|
||||
"modality",
|
||||
"filename",
|
||||
"embed_status",
|
||||
]
|
||||
|
||||
|
||||
class ChunkAdmin(ModelView, model=Chunk):
|
||||
column_list = ["id", "source_id", "embedding_model", "created_at"]
|
||||
|
||||
|
||||
class MailMessageAdmin(ModelView, model=MailMessage):
|
||||
column_list = source_columns(
|
||||
MailMessage,
|
||||
"subject",
|
||||
"sender",
|
||||
"recipients",
|
||||
"folder",
|
||||
"message_id",
|
||||
"tags",
|
||||
"embed_status",
|
||||
"inserted_at",
|
||||
)
|
||||
column_searchable_list = [
|
||||
"subject",
|
||||
"sender",
|
||||
"recipients",
|
||||
"folder",
|
||||
"message_id",
|
||||
]
|
||||
|
||||
|
||||
class EmailAttachmentAdmin(ModelView, model=EmailAttachment):
|
||||
column_list = source_columns(EmailAttachment, "filename", "mime_type", "size")
|
||||
column_searchable_list = [
|
||||
"filename",
|
||||
"mime_type",
|
||||
]
|
||||
|
||||
|
||||
class BlogPostAdmin(ModelView, model=BlogPost):
|
||||
column_list = source_columns(
|
||||
BlogPost, "title", "author", "url", "published", "domain"
|
||||
)
|
||||
column_searchable_list = ["title", "author", "domain"]
|
||||
|
||||
|
||||
class PhotoAdmin(ModelView, model=Photo):
|
||||
column_list = source_columns(Photo, "exif_taken_at", "camera")
|
||||
|
||||
|
||||
class ComicAdmin(ModelView, model=Comic):
|
||||
column_list = source_columns(Comic, "title", "author", "published", "volume")
|
||||
column_searchable_list = ["title", "author"]
|
||||
|
||||
|
||||
class BookSectionAdmin(ModelView, model=BookSection):
|
||||
column_list = source_columns(
|
||||
BookSection,
|
||||
"section_title",
|
||||
"section_number",
|
||||
"section_level",
|
||||
"start_page",
|
||||
"end_page",
|
||||
)
|
||||
column_searchable_list = ["section_title"]
|
||||
|
||||
|
||||
class MiscDocAdmin(ModelView, model=MiscDoc):
|
||||
column_list = source_columns(MiscDoc, "path")
|
||||
column_searchable_list = ["path"]
|
||||
|
||||
|
||||
class BookAdmin(ModelView, model=Book):
|
||||
column_list = [
|
||||
"id",
|
||||
"title",
|
||||
"author",
|
||||
"series",
|
||||
"series_number",
|
||||
"published",
|
||||
]
|
||||
column_searchable_list = ["title", "author"]
|
||||
|
||||
|
||||
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
|
||||
column_list = [
|
||||
"id",
|
||||
"title",
|
||||
"description",
|
||||
"url",
|
||||
"tags",
|
||||
"active",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
column_searchable_list = ["title", "url"]
|
||||
|
||||
|
||||
class EmailAccountAdmin(ModelView, model=EmailAccount):
|
||||
column_list = [
|
||||
"id",
|
||||
"name",
|
||||
"tags",
|
||||
"email_address",
|
||||
"username",
|
||||
"use_ssl",
|
||||
"folders",
|
||||
"active",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
column_searchable_list = ["name", "email_address"]
|
||||
|
||||
|
||||
def setup_admin(admin: Admin):
|
||||
"""Add all admin views to the admin instance."""
|
||||
admin.add_view(SourceItemAdmin)
|
||||
admin.add_view(ChunkAdmin)
|
||||
admin.add_view(EmailAccountAdmin)
|
||||
admin.add_view(MailMessageAdmin)
|
||||
admin.add_view(EmailAttachmentAdmin)
|
||||
admin.add_view(BookAdmin)
|
||||
admin.add_view(BookSectionAdmin)
|
||||
admin.add_view(MiscDocAdmin)
|
||||
admin.add_view(ArticleFeedAdmin)
|
||||
admin.add_view(BlogPostAdmin)
|
||||
admin.add_view(ComicAdmin)
|
||||
admin.add_view(PhotoAdmin)
|
@ -2,49 +2,28 @@
|
||||
FastAPI application for the knowledge base.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
from collections import defaultdict
|
||||
import pathlib
|
||||
from typing import Annotated, List, Optional, Callable
|
||||
from fastapi import FastAPI, File, UploadFile, Query, HTTPException, Form
|
||||
from fastapi.responses import FileResponse
|
||||
import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common import embedding, qdrant, extract, settings
|
||||
from memory.common.collections import get_modality, TEXT_COLLECTIONS, ALL_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Query, Form
|
||||
from fastapi.responses import FileResponse
|
||||
from sqladmin import Admin
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common import extract
|
||||
from memory.common.db.connection import get_engine
|
||||
from memory.api.admin import setup_admin
|
||||
from memory.api.search import search, SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Knowledge Base API")
|
||||
|
||||
|
||||
class AnnotatedChunk(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
metadata: dict
|
||||
preview: Optional[str | None] = None
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
collection: str
|
||||
results: List[dict]
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
id: int
|
||||
size: int
|
||||
mime_type: str
|
||||
chunks: list[AnnotatedChunk]
|
||||
content: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
# SQLAdmin setup
|
||||
engine = get_engine()
|
||||
admin = Admin(app, engine)
|
||||
setup_admin(admin)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@ -53,94 +32,6 @@ def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
def annotated_chunk(
|
||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||
) -> tuple[SourceItem, AnnotatedChunk]:
|
||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||
if not previews and not isinstance(item, str):
|
||||
return None
|
||||
|
||||
if isinstance(item, Image.Image):
|
||||
buffer = io.BytesIO()
|
||||
format = item.format or "PNG"
|
||||
item.save(buffer, format=format)
|
||||
mime_type = f"image/{format.lower()}"
|
||||
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
||||
elif isinstance(item, bytes):
|
||||
return base64.b64encode(item).decode("utf-8")
|
||||
elif isinstance(item, str):
|
||||
return item
|
||||
else:
|
||||
raise ValueError(f"Unsupported item type: {type(item)}")
|
||||
|
||||
metadata = search_result.payload or {}
|
||||
metadata = {
|
||||
k: v
|
||||
for k, v in metadata.items()
|
||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||
}
|
||||
return chunk.source, AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=search_result.score,
|
||||
metadata=metadata,
|
||||
preview=serialize_item(chunk.data[0]) if chunk.data else None,
|
||||
)
|
||||
|
||||
|
||||
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
||||
items = defaultdict(list)
|
||||
for source, chunk in chunks:
|
||||
items[source].append(chunk)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
id=source.id,
|
||||
size=source.size or len(source.content),
|
||||
mime_type=source.mime_type or "text/plain",
|
||||
filename=source.filename
|
||||
and source.filename.replace(
|
||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||
),
|
||||
content=source.display_contents,
|
||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||
)
|
||||
for source, chunks in items.items()
|
||||
]
|
||||
|
||||
|
||||
def query_chunks(
|
||||
client: qdrant_client.QdrantClient,
|
||||
upload_data: list[extract.DataChunk],
|
||||
allowed_modalities: set[str],
|
||||
embedder: Callable,
|
||||
min_score: float = 0.0,
|
||||
limit: int = 10,
|
||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||
if not upload_data:
|
||||
return {}
|
||||
|
||||
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
|
||||
if not chunks:
|
||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||
return {}
|
||||
|
||||
vector = embedder(chunks, input_type="query")[0]
|
||||
|
||||
return {
|
||||
collection: [
|
||||
r
|
||||
for r in qdrant.search_vectors(
|
||||
client=client,
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
for collection in allowed_modalities
|
||||
}
|
||||
|
||||
|
||||
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||
if not item:
|
||||
return []
|
||||
@ -152,7 +43,7 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||
|
||||
|
||||
@app.post("/search", response_model=list[SearchResult])
|
||||
async def search(
|
||||
async def search_endpoint(
|
||||
query: Optional[str] = Form(None),
|
||||
previews: Optional[bool] = Form(False),
|
||||
modalities: Annotated[list[str], Query()] = [],
|
||||
@ -161,62 +52,21 @@ async def search(
|
||||
min_text_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||
):
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
|
||||
Parameters:
|
||||
- query: Optional text search query
|
||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||
- files: Optional files to include in the search context
|
||||
- limit: Maximum number of results per modality
|
||||
|
||||
Returns:
|
||||
- List of search results sorted by score
|
||||
"""
|
||||
"""Search endpoint - delegates to search module"""
|
||||
upload_data = [
|
||||
chunk for item in [query, *files] for chunk in await input_type(item)
|
||||
]
|
||||
logger.error(
|
||||
f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}"
|
||||
)
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
||||
text_results = query_chunks(
|
||||
client,
|
||||
return await search(
|
||||
upload_data,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
embedding.embed_text,
|
||||
min_score=min_text_score,
|
||||
previews=previews,
|
||||
modalities=modalities,
|
||||
limit=limit,
|
||||
min_text_score=min_text_score,
|
||||
min_multimodal_score=min_multimodal_score,
|
||||
)
|
||||
multimodal_results = query_chunks(
|
||||
client,
|
||||
upload_data,
|
||||
allowed_modalities,
|
||||
embedding.embed_mixed,
|
||||
min_score=min_multimodal_score,
|
||||
limit=limit,
|
||||
)
|
||||
search_results = {
|
||||
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
||||
for k in allowed_modalities
|
||||
}
|
||||
|
||||
found_chunks = {
|
||||
str(r.id): r for results in search_results.values() for r in results
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
logger.error(f"Found chunks: {chunks}")
|
||||
|
||||
results = group_chunks(
|
||||
[
|
||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||
for chunk in chunks
|
||||
]
|
||||
)
|
||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||
|
||||
|
||||
@app.get("/files/{path:path}")
|
||||
|
189
src/memory/api/search.py
Normal file
189
src/memory/api/search.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""
|
||||
Search endpoints for the knowledge base API.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
|
||||
from memory.common import embedding, qdrant, extract, settings
|
||||
from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotatedChunk(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
metadata: dict
|
||||
preview: Optional[str | None] = None
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
collection: str
|
||||
results: list[dict]
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
id: int
|
||||
size: int
|
||||
mime_type: str
|
||||
chunks: list[AnnotatedChunk]
|
||||
content: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
def annotated_chunk(
|
||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||
) -> tuple[SourceItem, AnnotatedChunk]:
|
||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||
if not previews and not isinstance(item, str):
|
||||
return None
|
||||
|
||||
if isinstance(item, Image.Image):
|
||||
buffer = io.BytesIO()
|
||||
format = item.format or "PNG"
|
||||
item.save(buffer, format=format)
|
||||
mime_type = f"image/{format.lower()}"
|
||||
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
||||
elif isinstance(item, bytes):
|
||||
return base64.b64encode(item).decode("utf-8")
|
||||
elif isinstance(item, str):
|
||||
return item
|
||||
else:
|
||||
raise ValueError(f"Unsupported item type: {type(item)}")
|
||||
|
||||
metadata = search_result.payload or {}
|
||||
metadata = {
|
||||
k: v
|
||||
for k, v in metadata.items()
|
||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||
}
|
||||
return chunk.source, AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=search_result.score,
|
||||
metadata=metadata,
|
||||
preview=serialize_item(chunk.data[0]) if chunk.data else None,
|
||||
)
|
||||
|
||||
|
||||
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
||||
items = defaultdict(list)
|
||||
for source, chunk in chunks:
|
||||
items[source].append(chunk)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
id=source.id,
|
||||
size=source.size or len(source.content),
|
||||
mime_type=source.mime_type or "text/plain",
|
||||
filename=source.filename
|
||||
and source.filename.replace(
|
||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||
),
|
||||
content=source.display_contents,
|
||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||
)
|
||||
for source, chunks in items.items()
|
||||
]
|
||||
|
||||
|
||||
def query_chunks(
|
||||
client: qdrant_client.QdrantClient,
|
||||
upload_data: list[extract.DataChunk],
|
||||
allowed_modalities: set[str],
|
||||
embedder: Callable,
|
||||
min_score: float = 0.0,
|
||||
limit: int = 10,
|
||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||
if not upload_data:
|
||||
return {}
|
||||
|
||||
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
|
||||
if not chunks:
|
||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||
return {}
|
||||
|
||||
vector = embedder(chunks, input_type="query")[0]
|
||||
|
||||
return {
|
||||
collection: [
|
||||
r
|
||||
for r in qdrant.search_vectors(
|
||||
client=client,
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
for collection in allowed_modalities
|
||||
}
|
||||
|
||||
|
||||
async def search(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: list[str] = [],
|
||||
limit: int = 10,
|
||||
min_text_score: float = 0.3,
|
||||
min_multimodal_score: float = 0.3,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
|
||||
Parameters:
|
||||
- query: Optional text search query
|
||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||
- files: Optional files to include in the search context
|
||||
- limit: Maximum number of results per modality
|
||||
|
||||
Returns:
|
||||
- List of search results sorted by score
|
||||
"""
|
||||
client = qdrant.get_qdrant_client()
|
||||
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
||||
text_results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
embedding.embed_text,
|
||||
min_score=min_text_score,
|
||||
limit=limit,
|
||||
)
|
||||
multimodal_results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
allowed_modalities,
|
||||
embedding.embed_mixed,
|
||||
min_score=min_multimodal_score,
|
||||
limit=limit,
|
||||
)
|
||||
search_results = {
|
||||
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
||||
for k in allowed_modalities
|
||||
}
|
||||
|
||||
found_chunks = {
|
||||
str(r.id): r for results in search_results.values() for r in results
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
logger.error(f"Found chunks: {chunks}")
|
||||
|
||||
results = group_chunks(
|
||||
[
|
||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||
for chunk in chunks
|
||||
]
|
||||
)
|
||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
@ -1,12 +1,17 @@
|
||||
"""
|
||||
Database utilities package.
|
||||
"""
|
||||
|
||||
from memory.common.db.models import Base
|
||||
from memory.common.db.connection import get_engine, get_session_factory, get_scoped_session
|
||||
from memory.common.db.connection import (
|
||||
get_engine,
|
||||
get_session_factory,
|
||||
get_scoped_session,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"get_engine",
|
||||
"get_session_factory",
|
||||
"get_scoped_session",
|
||||
]
|
||||
]
|
||||
|
@ -2,12 +2,11 @@
|
||||
Database models for the knowledge base system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import pathlib
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, ClassVar, Iterable, Sequence, cast
|
||||
from typing import Any, Sequence, cast
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
|
@ -4,7 +4,7 @@ import pathlib
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import requests
|
||||
@ -157,7 +157,7 @@ def process_image(url: str, image_dir: pathlib.Path) -> PILImage.Image | None:
|
||||
|
||||
# Download if not already cached
|
||||
if not local_path.exists():
|
||||
local_path.write_bytes(fetch_html(url, as_bytes=True))
|
||||
local_path.write_bytes(cast(bytes, fetch_html(url, as_bytes=True)))
|
||||
|
||||
try:
|
||||
return PILImage.open(local_path)
|
||||
|
@ -128,14 +128,3 @@ def test_docx_to_pdf_default_output():
|
||||
|
||||
assert result_path == SAMPLE_DOCX.with_suffix(".pdf")
|
||||
assert result_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
|
||||
def test_extract_docx():
|
||||
pages = extract_docx(SAMPLE_DOCX)
|
||||
|
||||
assert len(pages) > 0
|
||||
assert all(isinstance(page, dict) for page in pages)
|
||||
assert all("contents" in page for page in pages)
|
||||
assert all("metadata" in page for page in pages)
|
||||
assert all(isinstance(page["contents"][0], Image.Image) for page in pages)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from memory.common.db.models import ArticleFeed, BlogPost
|
||||
from memory.workers.tasks import blogs
|
||||
|
@ -2,7 +2,7 @@ import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from memory.common.db.models import Book, BookSection, Chunk
|
||||
from memory.common.db.models import Book, BookSection
|
||||
from memory.parsers.ebook import Ebook, Section
|
||||
from memory.workers.tasks import ebook
|
||||
|
||||
|
@ -6,7 +6,7 @@ import pytest
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import qdrant as qd
|
||||
from memory.common import embedding, settings
|
||||
from memory.common import settings
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.workers.tasks.maintenance import (
|
||||
clean_collection,
|
||||
|
Loading…
x
Reference in New Issue
Block a user