fix linting

This commit is contained in:
Daniel O'Connell 2025-05-27 23:19:28 +02:00
parent 1291ca9d08
commit ab87bced81
12 changed files with 398 additions and 195 deletions

View File

@ -17,10 +17,10 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[all] pip install .[all]
pip install ruff==0.11.10 pylint==1.1.400 pip install ruff==0.11.10 pyright==1.1.327
- name: Run linters - name: Run linters
run: | run: |
ruff check . ruff check .
pylint $(git ls-files '*.py') pyright $(git ls-files '*.py')
- name: Run tests - name: Run tests
run: pytest -vv run: pytest -vv

View File

@ -2,3 +2,4 @@ fastapi==0.112.2
uvicorn==0.29.0 uvicorn==0.29.0
python-jose==3.3.0 python-jose==3.3.0
python-multipart==0.0.9 python-multipart==0.0.9
sqladmin

170
src/memory/api/admin.py Normal file
View File

@ -0,0 +1,170 @@
"""
SQLAdmin views for the knowledge base database models.
"""
from sqladmin import Admin, ModelView
from memory.common.db.models import (
Chunk,
SourceItem,
MailMessage,
EmailAttachment,
Photo,
Comic,
Book,
BookSection,
BlogPost,
MiscDoc,
ArticleFeed,
EmailAccount,
)
DEFAULT_COLUMNS = (
"modality",
"embed_status",
"inserted_at",
"tags",
"size",
"mime_type",
"filename",
"content",
)
def source_columns(model: type[SourceItem], *columns: str):
return [getattr(model, c) for c in columns + DEFAULT_COLUMNS if hasattr(model, c)]
# Create admin views for all models
class SourceItemAdmin(ModelView, model=SourceItem):
column_list = source_columns(SourceItem)
column_searchable_list = [
"modality",
"filename",
"embed_status",
]
class ChunkAdmin(ModelView, model=Chunk):
column_list = ["id", "source_id", "embedding_model", "created_at"]
class MailMessageAdmin(ModelView, model=MailMessage):
column_list = source_columns(
MailMessage,
"subject",
"sender",
"recipients",
"folder",
"message_id",
"tags",
"embed_status",
"inserted_at",
)
column_searchable_list = [
"subject",
"sender",
"recipients",
"folder",
"message_id",
]
class EmailAttachmentAdmin(ModelView, model=EmailAttachment):
column_list = source_columns(EmailAttachment, "filename", "mime_type", "size")
column_searchable_list = [
"filename",
"mime_type",
]
class BlogPostAdmin(ModelView, model=BlogPost):
column_list = source_columns(
BlogPost, "title", "author", "url", "published", "domain"
)
column_searchable_list = ["title", "author", "domain"]
class PhotoAdmin(ModelView, model=Photo):
column_list = source_columns(Photo, "exif_taken_at", "camera")
class ComicAdmin(ModelView, model=Comic):
column_list = source_columns(Comic, "title", "author", "published", "volume")
column_searchable_list = ["title", "author"]
class BookSectionAdmin(ModelView, model=BookSection):
column_list = source_columns(
BookSection,
"section_title",
"section_number",
"section_level",
"start_page",
"end_page",
)
column_searchable_list = ["section_title"]
class MiscDocAdmin(ModelView, model=MiscDoc):
column_list = source_columns(MiscDoc, "path")
column_searchable_list = ["path"]
class BookAdmin(ModelView, model=Book):
column_list = [
"id",
"title",
"author",
"series",
"series_number",
"published",
]
column_searchable_list = ["title", "author"]
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
column_list = [
"id",
"title",
"description",
"url",
"tags",
"active",
"created_at",
"updated_at",
]
column_searchable_list = ["title", "url"]
class EmailAccountAdmin(ModelView, model=EmailAccount):
column_list = [
"id",
"name",
"tags",
"email_address",
"username",
"use_ssl",
"folders",
"active",
"created_at",
"updated_at",
]
column_searchable_list = ["name", "email_address"]
def setup_admin(admin: Admin):
"""Add all admin views to the admin instance."""
admin.add_view(SourceItemAdmin)
admin.add_view(ChunkAdmin)
admin.add_view(EmailAccountAdmin)
admin.add_view(MailMessageAdmin)
admin.add_view(EmailAttachmentAdmin)
admin.add_view(BookAdmin)
admin.add_view(BookSectionAdmin)
admin.add_view(MiscDocAdmin)
admin.add_view(ArticleFeedAdmin)
admin.add_view(BlogPostAdmin)
admin.add_view(ComicAdmin)
admin.add_view(PhotoAdmin)

View File

@ -2,49 +2,28 @@
FastAPI application for the knowledge base. FastAPI application for the knowledge base.
""" """
import base64
import io
from collections import defaultdict
import pathlib import pathlib
from typing import Annotated, List, Optional, Callable
from fastapi import FastAPI, File, UploadFile, Query, HTTPException, Form
from fastapi.responses import FileResponse
import qdrant_client
from qdrant_client.http import models as qdrant_models
from PIL import Image
from pydantic import BaseModel
from memory.common import embedding, qdrant, extract, settings
from memory.common.collections import get_modality, TEXT_COLLECTIONS, ALL_COLLECTIONS
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
import logging import logging
from typing import Annotated, Optional
from fastapi import FastAPI, HTTPException, File, UploadFile, Query, Form
from fastapi.responses import FileResponse
from sqladmin import Admin
from memory.common import settings
from memory.common import extract
from memory.common.db.connection import get_engine
from memory.api.admin import setup_admin
from memory.api.search import search, SearchResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
app = FastAPI(title="Knowledge Base API") app = FastAPI(title="Knowledge Base API")
# SQLAdmin setup
class AnnotatedChunk(BaseModel): engine = get_engine()
id: str admin = Admin(app, engine)
score: float setup_admin(admin)
metadata: dict
preview: Optional[str | None] = None
class SearchResponse(BaseModel):
collection: str
results: List[dict]
class SearchResult(BaseModel):
id: int
size: int
mime_type: str
chunks: list[AnnotatedChunk]
content: Optional[str] = None
filename: Optional[str] = None
@app.get("/health") @app.get("/health")
@ -53,94 +32,6 @@ def health_check():
return {"status": "healthy"} return {"status": "healthy"}
def annotated_chunk(
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
) -> tuple[SourceItem, AnnotatedChunk]:
def serialize_item(item: bytes | str | Image.Image) -> str | None:
if not previews and not isinstance(item, str):
return None
if isinstance(item, Image.Image):
buffer = io.BytesIO()
format = item.format or "PNG"
item.save(buffer, format=format)
mime_type = f"image/{format.lower()}"
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
elif isinstance(item, bytes):
return base64.b64encode(item).decode("utf-8")
elif isinstance(item, str):
return item
else:
raise ValueError(f"Unsupported item type: {type(item)}")
metadata = search_result.payload or {}
metadata = {
k: v
for k, v in metadata.items()
if k not in ["content", "filename", "size", "content_type", "tags"]
}
return chunk.source, AnnotatedChunk(
id=str(chunk.id),
score=search_result.score,
metadata=metadata,
preview=serialize_item(chunk.data[0]) if chunk.data else None,
)
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
items = defaultdict(list)
for source, chunk in chunks:
items[source].append(chunk)
return [
SearchResult(
id=source.id,
size=source.size or len(source.content),
mime_type=source.mime_type or "text/plain",
filename=source.filename
and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
),
content=source.display_contents,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
)
for source, chunks in items.items()
]
def query_chunks(
client: qdrant_client.QdrantClient,
upload_data: list[extract.DataChunk],
allowed_modalities: set[str],
embedder: Callable,
min_score: float = 0.0,
limit: int = 10,
) -> dict[str, list[qdrant_models.ScoredPoint]]:
if not upload_data:
return {}
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
if not chunks:
logger.error(f"No chunks to embed for {allowed_modalities}")
return {}
vector = embedder(chunks, input_type="query")[0]
return {
collection: [
r
for r in qdrant.search_vectors(
client=client,
collection_name=collection,
query_vector=vector,
limit=limit,
)
if r.score >= min_score
]
for collection in allowed_modalities
}
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]: async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
if not item: if not item:
return [] return []
@ -152,7 +43,7 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
@app.post("/search", response_model=list[SearchResult]) @app.post("/search", response_model=list[SearchResult])
async def search( async def search_endpoint(
query: Optional[str] = Form(None), query: Optional[str] = Form(None),
previews: Optional[bool] = Form(False), previews: Optional[bool] = Form(False),
modalities: Annotated[list[str], Query()] = [], modalities: Annotated[list[str], Query()] = [],
@ -161,62 +52,21 @@ async def search(
min_text_score: float = Query(0.3, ge=0.0, le=1.0), min_text_score: float = Query(0.3, ge=0.0, le=1.0),
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0), min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
): ):
""" """Search endpoint - delegates to search module"""
Search across knowledge base using text query and optional files.
Parameters:
- query: Optional text search query
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
- files: Optional files to include in the search context
- limit: Maximum number of results per modality
Returns:
- List of search results sorted by score
"""
upload_data = [ upload_data = [
chunk for item in [query, *files] for chunk in await input_type(item) chunk for item in [query, *files] for chunk in await input_type(item)
] ]
logger.error( logger.error(
f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}" f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}"
) )
return await search(
client = qdrant.get_qdrant_client()
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
text_results = query_chunks(
client,
upload_data, upload_data,
allowed_modalities & TEXT_COLLECTIONS, previews=previews,
embedding.embed_text, modalities=modalities,
min_score=min_text_score,
limit=limit, limit=limit,
min_text_score=min_text_score,
min_multimodal_score=min_multimodal_score,
) )
multimodal_results = query_chunks(
client,
upload_data,
allowed_modalities,
embedding.embed_mixed,
min_score=min_multimodal_score,
limit=limit,
)
search_results = {
k: text_results.get(k, []) + multimodal_results.get(k, [])
for k in allowed_modalities
}
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
logger.error(f"Found chunks: {chunks}")
results = group_chunks(
[
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
)
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
@app.get("/files/{path:path}") @app.get("/files/{path:path}")

189
src/memory/api/search.py Normal file
View File

@ -0,0 +1,189 @@
"""
Search endpoints for the knowledge base API.
"""
import base64
import io
from collections import defaultdict
from typing import Callable, Optional
import logging
from PIL import Image
from pydantic import BaseModel
import qdrant_client
from qdrant_client.http import models as qdrant_models
from memory.common import embedding, qdrant, extract, settings
from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
logger = logging.getLogger(__name__)
class AnnotatedChunk(BaseModel):
id: str
score: float
metadata: dict
preview: Optional[str | None] = None
class SearchResponse(BaseModel):
collection: str
results: list[dict]
class SearchResult(BaseModel):
id: int
size: int
mime_type: str
chunks: list[AnnotatedChunk]
content: Optional[str] = None
filename: Optional[str] = None
def annotated_chunk(
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
) -> tuple[SourceItem, AnnotatedChunk]:
def serialize_item(item: bytes | str | Image.Image) -> str | None:
if not previews and not isinstance(item, str):
return None
if isinstance(item, Image.Image):
buffer = io.BytesIO()
format = item.format or "PNG"
item.save(buffer, format=format)
mime_type = f"image/{format.lower()}"
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
elif isinstance(item, bytes):
return base64.b64encode(item).decode("utf-8")
elif isinstance(item, str):
return item
else:
raise ValueError(f"Unsupported item type: {type(item)}")
metadata = search_result.payload or {}
metadata = {
k: v
for k, v in metadata.items()
if k not in ["content", "filename", "size", "content_type", "tags"]
}
return chunk.source, AnnotatedChunk(
id=str(chunk.id),
score=search_result.score,
metadata=metadata,
preview=serialize_item(chunk.data[0]) if chunk.data else None,
)
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
items = defaultdict(list)
for source, chunk in chunks:
items[source].append(chunk)
return [
SearchResult(
id=source.id,
size=source.size or len(source.content),
mime_type=source.mime_type or "text/plain",
filename=source.filename
and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
),
content=source.display_contents,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
)
for source, chunks in items.items()
]
def query_chunks(
client: qdrant_client.QdrantClient,
upload_data: list[extract.DataChunk],
allowed_modalities: set[str],
embedder: Callable,
min_score: float = 0.0,
limit: int = 10,
) -> dict[str, list[qdrant_models.ScoredPoint]]:
if not upload_data:
return {}
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
if not chunks:
logger.error(f"No chunks to embed for {allowed_modalities}")
return {}
vector = embedder(chunks, input_type="query")[0]
return {
collection: [
r
for r in qdrant.search_vectors(
client=client,
collection_name=collection,
query_vector=vector,
limit=limit,
)
if r.score >= min_score
]
for collection in allowed_modalities
}
async def search(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: list[str] = [],
limit: int = 10,
min_text_score: float = 0.3,
min_multimodal_score: float = 0.3,
) -> list[SearchResult]:
"""
Search across knowledge base using text query and optional files.
Parameters:
- query: Optional text search query
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
- files: Optional files to include in the search context
- limit: Maximum number of results per modality
Returns:
- List of search results sorted by score
"""
client = qdrant.get_qdrant_client()
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
text_results = query_chunks(
client,
data,
allowed_modalities & TEXT_COLLECTIONS,
embedding.embed_text,
min_score=min_text_score,
limit=limit,
)
multimodal_results = query_chunks(
client,
data,
allowed_modalities,
embedding.embed_mixed,
min_score=min_multimodal_score,
limit=limit,
)
search_results = {
k: text_results.get(k, []) + multimodal_results.get(k, [])
for k in allowed_modalities
}
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
logger.error(f"Found chunks: {chunks}")
results = group_chunks(
[
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
)
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)

View File

@ -1,8 +1,13 @@
""" """
Database utilities package. Database utilities package.
""" """
from memory.common.db.models import Base from memory.common.db.models import Base
from memory.common.db.connection import get_engine, get_session_factory, get_scoped_session from memory.common.db.connection import (
get_engine,
get_session_factory,
get_scoped_session,
)
__all__ = [ __all__ = [
"Base", "Base",

View File

@ -2,12 +2,11 @@
Database models for the knowledge base system. Database models for the knowledge base system.
""" """
from dataclasses import dataclass
import pathlib import pathlib
import re import re
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from typing import Any, ClassVar, Iterable, Sequence, cast from typing import Any, Sequence, cast
import uuid import uuid
from PIL import Image from PIL import Image

View File

@ -4,7 +4,7 @@ import pathlib
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any, cast
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import requests import requests
@ -157,7 +157,7 @@ def process_image(url: str, image_dir: pathlib.Path) -> PILImage.Image | None:
# Download if not already cached # Download if not already cached
if not local_path.exists(): if not local_path.exists():
local_path.write_bytes(fetch_html(url, as_bytes=True)) local_path.write_bytes(cast(bytes, fetch_html(url, as_bytes=True)))
try: try:
return PILImage.open(local_path) return PILImage.open(local_path)

View File

@ -128,14 +128,3 @@ def test_docx_to_pdf_default_output():
assert result_path == SAMPLE_DOCX.with_suffix(".pdf") assert result_path == SAMPLE_DOCX.with_suffix(".pdf")
assert result_path.exists() assert result_path.exists()
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
def test_extract_docx():
pages = extract_docx(SAMPLE_DOCX)
assert len(pages) > 0
assert all(isinstance(page, dict) for page in pages)
assert all("contents" in page for page in pages)
assert all("metadata" in page for page in pages)
assert all(isinstance(page["contents"][0], Image.Image) for page in pages)

View File

@ -1,6 +1,6 @@
import pytest import pytest
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock from unittest.mock import Mock, patch
from memory.common.db.models import ArticleFeed, BlogPost from memory.common.db.models import ArticleFeed, BlogPost
from memory.workers.tasks import blogs from memory.workers.tasks import blogs

View File

@ -2,7 +2,7 @@ import pytest
from pathlib import Path from pathlib import Path
from unittest.mock import patch, Mock from unittest.mock import patch, Mock
from memory.common.db.models import Book, BookSection, Chunk from memory.common.db.models import Book, BookSection
from memory.parsers.ebook import Ebook, Section from memory.parsers.ebook import Ebook, Section
from memory.workers.tasks import ebook from memory.workers.tasks import ebook

View File

@ -6,7 +6,7 @@ import pytest
from PIL import Image from PIL import Image
from memory.common import qdrant as qd from memory.common import qdrant as qd
from memory.common import embedding, settings from memory.common import settings
from memory.common.db.models import Chunk, SourceItem from memory.common.db.models import Chunk, SourceItem
from memory.workers.tasks.maintenance import ( from memory.workers.tasks.maintenance import (
clean_collection, clean_collection,