diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e30a598..e6cd9a3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,10 +17,10 @@ jobs: run: | python -m pip install --upgrade pip pip install .[all] - pip install ruff==0.11.10 pylint==1.1.400 + pip install ruff==0.11.10 pyright==1.1.327 - name: Run linters run: | ruff check . - pylint $(git ls-files '*.py') + pyright $(git ls-files '*.py') - name: Run tests run: pytest -vv diff --git a/requirements-api.txt b/requirements-api.txt index 7d03ed0..5e44c4f 100644 --- a/requirements-api.txt +++ b/requirements-api.txt @@ -1,4 +1,5 @@ fastapi==0.112.2 uvicorn==0.29.0 python-jose==3.3.0 -python-multipart==0.0.9 +python-multipart==0.0.9 +sqladmin diff --git a/src/memory/api/admin.py b/src/memory/api/admin.py new file mode 100644 index 0000000..5066323 --- /dev/null +++ b/src/memory/api/admin.py @@ -0,0 +1,170 @@ +""" +SQLAdmin views for the knowledge base database models. +""" + +from sqladmin import Admin, ModelView + +from memory.common.db.models import ( + Chunk, + SourceItem, + MailMessage, + EmailAttachment, + Photo, + Comic, + Book, + BookSection, + BlogPost, + MiscDoc, + ArticleFeed, + EmailAccount, +) + + +DEFAULT_COLUMNS = ( + "modality", + "embed_status", + "inserted_at", + "tags", + "size", + "mime_type", + "filename", + "content", +) + + +def source_columns(model: type[SourceItem], *columns: str): + return [getattr(model, c) for c in columns + DEFAULT_COLUMNS if hasattr(model, c)] + + +# Create admin views for all models +class SourceItemAdmin(ModelView, model=SourceItem): + column_list = source_columns(SourceItem) + column_searchable_list = [ + "modality", + "filename", + "embed_status", + ] + + +class ChunkAdmin(ModelView, model=Chunk): + column_list = ["id", "source_id", "embedding_model", "created_at"] + + +class MailMessageAdmin(ModelView, model=MailMessage): + column_list = source_columns( + MailMessage, + "subject", + "sender", + "recipients", + "folder", + "message_id", + "tags", + "embed_status", + "inserted_at", + ) + column_searchable_list = [ + "subject", + "sender", + "recipients", + "folder", + "message_id", + ] + + +class EmailAttachmentAdmin(ModelView, model=EmailAttachment): + column_list = source_columns(EmailAttachment, "filename", "mime_type", "size") + column_searchable_list = [ + "filename", + "mime_type", + ] + + +class BlogPostAdmin(ModelView, model=BlogPost): + column_list = source_columns( + BlogPost, "title", "author", "url", "published", "domain" + ) + column_searchable_list = ["title", "author", "domain"] + + +class PhotoAdmin(ModelView, model=Photo): + column_list = source_columns(Photo, "exif_taken_at", "camera") + + +class ComicAdmin(ModelView, model=Comic): + column_list = source_columns(Comic, "title", "author", "published", "volume") + column_searchable_list = ["title", "author"] + + +class BookSectionAdmin(ModelView, model=BookSection): + column_list = source_columns( + BookSection, + "section_title", + "section_number", + "section_level", + "start_page", + "end_page", + ) + column_searchable_list = ["section_title"] + + +class MiscDocAdmin(ModelView, model=MiscDoc): + column_list = source_columns(MiscDoc, "path") + column_searchable_list = ["path"] + + +class BookAdmin(ModelView, model=Book): + column_list = [ + "id", + "title", + "author", + "series", + "series_number", + "published", + ] + column_searchable_list = ["title", "author"] + + +class ArticleFeedAdmin(ModelView, model=ArticleFeed): + column_list = [ + "id", + "title", + "description", + "url", + "tags", + "active", + "created_at", + "updated_at", + ] + column_searchable_list = ["title", "url"] + + +class EmailAccountAdmin(ModelView, model=EmailAccount): + column_list = [ + "id", + "name", + "tags", + "email_address", + "username", + "use_ssl", + "folders", + "active", + "created_at", + "updated_at", + ] + column_searchable_list = ["name", "email_address"] + + +def setup_admin(admin: Admin): + """Add all admin views to the admin instance.""" + admin.add_view(SourceItemAdmin) + admin.add_view(ChunkAdmin) + admin.add_view(EmailAccountAdmin) + admin.add_view(MailMessageAdmin) + admin.add_view(EmailAttachmentAdmin) + admin.add_view(BookAdmin) + admin.add_view(BookSectionAdmin) + admin.add_view(MiscDocAdmin) + admin.add_view(ArticleFeedAdmin) + admin.add_view(BlogPostAdmin) + admin.add_view(ComicAdmin) + admin.add_view(PhotoAdmin) diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 172e788..1b81edf 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -2,49 +2,28 @@ FastAPI application for the knowledge base. """ -import base64 -import io -from collections import defaultdict import pathlib -from typing import Annotated, List, Optional, Callable -from fastapi import FastAPI, File, UploadFile, Query, HTTPException, Form -from fastapi.responses import FileResponse -import qdrant_client -from qdrant_client.http import models as qdrant_models -from PIL import Image -from pydantic import BaseModel - -from memory.common import embedding, qdrant, extract, settings -from memory.common.collections import get_modality, TEXT_COLLECTIONS, ALL_COLLECTIONS -from memory.common.db.connection import make_session -from memory.common.db.models import Chunk, SourceItem - import logging +from typing import Annotated, Optional + +from fastapi import FastAPI, HTTPException, File, UploadFile, Query, Form +from fastapi.responses import FileResponse +from sqladmin import Admin + +from memory.common import settings +from memory.common import extract +from memory.common.db.connection import get_engine +from memory.api.admin import setup_admin +from memory.api.search import search, SearchResult logger = logging.getLogger(__name__) app = FastAPI(title="Knowledge Base API") - -class AnnotatedChunk(BaseModel): - id: str - score: float - metadata: dict - preview: Optional[str | None] = None - - -class SearchResponse(BaseModel): - collection: str - results: List[dict] - - -class SearchResult(BaseModel): - id: int - size: int - mime_type: str - chunks: list[AnnotatedChunk] - content: Optional[str] = None - filename: Optional[str] = None +# SQLAdmin setup +engine = get_engine() +admin = Admin(app, engine) +setup_admin(admin) @app.get("/health") @@ -53,94 +32,6 @@ def health_check(): return {"status": "healthy"} -def annotated_chunk( - chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool -) -> tuple[SourceItem, AnnotatedChunk]: - def serialize_item(item: bytes | str | Image.Image) -> str | None: - if not previews and not isinstance(item, str): - return None - - if isinstance(item, Image.Image): - buffer = io.BytesIO() - format = item.format or "PNG" - item.save(buffer, format=format) - mime_type = f"image/{format.lower()}" - return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" - elif isinstance(item, bytes): - return base64.b64encode(item).decode("utf-8") - elif isinstance(item, str): - return item - else: - raise ValueError(f"Unsupported item type: {type(item)}") - - metadata = search_result.payload or {} - metadata = { - k: v - for k, v in metadata.items() - if k not in ["content", "filename", "size", "content_type", "tags"] - } - return chunk.source, AnnotatedChunk( - id=str(chunk.id), - score=search_result.score, - metadata=metadata, - preview=serialize_item(chunk.data[0]) if chunk.data else None, - ) - - -def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]: - items = defaultdict(list) - for source, chunk in chunks: - items[source].append(chunk) - - return [ - SearchResult( - id=source.id, - size=source.size or len(source.content), - mime_type=source.mime_type or "text/plain", - filename=source.filename - and source.filename.replace( - str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files" - ), - content=source.display_contents, - chunks=sorted(chunks, key=lambda x: x.score, reverse=True), - ) - for source, chunks in items.items() - ] - - -def query_chunks( - client: qdrant_client.QdrantClient, - upload_data: list[extract.DataChunk], - allowed_modalities: set[str], - embedder: Callable, - min_score: float = 0.0, - limit: int = 10, -) -> dict[str, list[qdrant_models.ScoredPoint]]: - if not upload_data: - return {} - - chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data] - if not chunks: - logger.error(f"No chunks to embed for {allowed_modalities}") - return {} - - vector = embedder(chunks, input_type="query")[0] - - return { - collection: [ - r - for r in qdrant.search_vectors( - client=client, - collection_name=collection, - query_vector=vector, - limit=limit, - ) - if r.score >= min_score - ] - for collection in allowed_modalities - } - - async def input_type(item: str | UploadFile) -> list[extract.DataChunk]: if not item: return [] @@ -152,7 +43,7 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]: @app.post("/search", response_model=list[SearchResult]) -async def search( +async def search_endpoint( query: Optional[str] = Form(None), previews: Optional[bool] = Form(False), modalities: Annotated[list[str], Query()] = [], @@ -161,62 +52,21 @@ async def search( min_text_score: float = Query(0.3, ge=0.0, le=1.0), min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0), ): - """ - Search across knowledge base using text query and optional files. - - Parameters: - - query: Optional text search query - - modalities: List of modalities to search in (e.g., "text", "photo", "doc") - - files: Optional files to include in the search context - - limit: Maximum number of results per modality - - Returns: - - List of search results sorted by score - """ + """Search endpoint - delegates to search module""" upload_data = [ chunk for item in [query, *files] for chunk in await input_type(item) ] logger.error( f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}" ) - - client = qdrant.get_qdrant_client() - allowed_modalities = set(modalities or ALL_COLLECTIONS.keys()) - text_results = query_chunks( - client, + return await search( upload_data, - allowed_modalities & TEXT_COLLECTIONS, - embedding.embed_text, - min_score=min_text_score, + previews=previews, + modalities=modalities, limit=limit, + min_text_score=min_text_score, + min_multimodal_score=min_multimodal_score, ) - multimodal_results = query_chunks( - client, - upload_data, - allowed_modalities, - embedding.embed_mixed, - min_score=min_multimodal_score, - limit=limit, - ) - search_results = { - k: text_results.get(k, []) + multimodal_results.get(k, []) - for k in allowed_modalities - } - - found_chunks = { - str(r.id): r for results in search_results.values() for r in results - } - with make_session() as db: - chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all() - logger.error(f"Found chunks: {chunks}") - - results = group_chunks( - [ - annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False) - for chunk in chunks - ] - ) - return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True) @app.get("/files/{path:path}") diff --git a/src/memory/api/search.py b/src/memory/api/search.py new file mode 100644 index 0000000..9195ff2 --- /dev/null +++ b/src/memory/api/search.py @@ -0,0 +1,189 @@ +""" +Search endpoints for the knowledge base API. +""" + +import base64 +import io +from collections import defaultdict +from typing import Callable, Optional +import logging + +from PIL import Image +from pydantic import BaseModel +import qdrant_client +from qdrant_client.http import models as qdrant_models + +from memory.common import embedding, qdrant, extract, settings +from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS +from memory.common.db.connection import make_session +from memory.common.db.models import Chunk, SourceItem + +logger = logging.getLogger(__name__) + + +class AnnotatedChunk(BaseModel): + id: str + score: float + metadata: dict + preview: Optional[str | None] = None + + +class SearchResponse(BaseModel): + collection: str + results: list[dict] + + +class SearchResult(BaseModel): + id: int + size: int + mime_type: str + chunks: list[AnnotatedChunk] + content: Optional[str] = None + filename: Optional[str] = None + + +def annotated_chunk( + chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool +) -> tuple[SourceItem, AnnotatedChunk]: + def serialize_item(item: bytes | str | Image.Image) -> str | None: + if not previews and not isinstance(item, str): + return None + + if isinstance(item, Image.Image): + buffer = io.BytesIO() + format = item.format or "PNG" + item.save(buffer, format=format) + mime_type = f"image/{format.lower()}" + return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" + elif isinstance(item, bytes): + return base64.b64encode(item).decode("utf-8") + elif isinstance(item, str): + return item + else: + raise ValueError(f"Unsupported item type: {type(item)}") + + metadata = search_result.payload or {} + metadata = { + k: v + for k, v in metadata.items() + if k not in ["content", "filename", "size", "content_type", "tags"] + } + return chunk.source, AnnotatedChunk( + id=str(chunk.id), + score=search_result.score, + metadata=metadata, + preview=serialize_item(chunk.data[0]) if chunk.data else None, + ) + + +def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]: + items = defaultdict(list) + for source, chunk in chunks: + items[source].append(chunk) + + return [ + SearchResult( + id=source.id, + size=source.size or len(source.content), + mime_type=source.mime_type or "text/plain", + filename=source.filename + and source.filename.replace( + str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files" + ), + content=source.display_contents, + chunks=sorted(chunks, key=lambda x: x.score, reverse=True), + ) + for source, chunks in items.items() + ] + + +def query_chunks( + client: qdrant_client.QdrantClient, + upload_data: list[extract.DataChunk], + allowed_modalities: set[str], + embedder: Callable, + min_score: float = 0.0, + limit: int = 10, +) -> dict[str, list[qdrant_models.ScoredPoint]]: + if not upload_data: + return {} + + chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data] + if not chunks: + logger.error(f"No chunks to embed for {allowed_modalities}") + return {} + + vector = embedder(chunks, input_type="query")[0] + + return { + collection: [ + r + for r in qdrant.search_vectors( + client=client, + collection_name=collection, + query_vector=vector, + limit=limit, + ) + if r.score >= min_score + ] + for collection in allowed_modalities + } + + +async def search( + data: list[extract.DataChunk], + previews: Optional[bool] = False, + modalities: list[str] = [], + limit: int = 10, + min_text_score: float = 0.3, + min_multimodal_score: float = 0.3, +) -> list[SearchResult]: + """ + Search across knowledge base using text query and optional files. + + Parameters: + - query: Optional text search query + - modalities: List of modalities to search in (e.g., "text", "photo", "doc") + - files: Optional files to include in the search context + - limit: Maximum number of results per modality + + Returns: + - List of search results sorted by score + """ + client = qdrant.get_qdrant_client() + allowed_modalities = set(modalities or ALL_COLLECTIONS.keys()) + text_results = query_chunks( + client, + data, + allowed_modalities & TEXT_COLLECTIONS, + embedding.embed_text, + min_score=min_text_score, + limit=limit, + ) + multimodal_results = query_chunks( + client, + data, + allowed_modalities, + embedding.embed_mixed, + min_score=min_multimodal_score, + limit=limit, + ) + search_results = { + k: text_results.get(k, []) + multimodal_results.get(k, []) + for k in allowed_modalities + } + + found_chunks = { + str(r.id): r for results in search_results.values() for r in results + } + with make_session() as db: + chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all() + logger.error(f"Found chunks: {chunks}") + + results = group_chunks( + [ + annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False) + for chunk in chunks + ] + ) + return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True) diff --git a/src/memory/common/db/__init__.py b/src/memory/common/db/__init__.py index 7d49f32..91a2a75 100644 --- a/src/memory/common/db/__init__.py +++ b/src/memory/common/db/__init__.py @@ -1,12 +1,17 @@ """ Database utilities package. """ + from memory.common.db.models import Base -from memory.common.db.connection import get_engine, get_session_factory, get_scoped_session +from memory.common.db.connection import ( + get_engine, + get_session_factory, + get_scoped_session, +) __all__ = [ "Base", "get_engine", "get_session_factory", "get_scoped_session", -] \ No newline at end of file +] diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index 3ceda92..9c7a0db 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -2,12 +2,11 @@ Database models for the knowledge base system. """ -from dataclasses import dataclass import pathlib import re import textwrap from datetime import datetime -from typing import Any, ClassVar, Iterable, Sequence, cast +from typing import Any, Sequence, cast import uuid from PIL import Image diff --git a/src/memory/parsers/html.py b/src/memory/parsers/html.py index 0d8ca54..a6169a1 100644 --- a/src/memory/parsers/html.py +++ b/src/memory/parsers/html.py @@ -4,7 +4,7 @@ import pathlib import re from dataclasses import dataclass, field from datetime import datetime -from typing import Any +from typing import Any, cast from urllib.parse import urljoin, urlparse import requests @@ -157,7 +157,7 @@ def process_image(url: str, image_dir: pathlib.Path) -> PILImage.Image | None: # Download if not already cached if not local_path.exists(): - local_path.write_bytes(fetch_html(url, as_bytes=True)) + local_path.write_bytes(cast(bytes, fetch_html(url, as_bytes=True))) try: return PILImage.open(local_path) diff --git a/tests/memory/common/test_extract.py b/tests/memory/common/test_extract.py index bf03725..5e391e7 100644 --- a/tests/memory/common/test_extract.py +++ b/tests/memory/common/test_extract.py @@ -128,14 +128,3 @@ def test_docx_to_pdf_default_output(): assert result_path == SAMPLE_DOCX.with_suffix(".pdf") assert result_path.exists() - - -@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed") -def test_extract_docx(): - pages = extract_docx(SAMPLE_DOCX) - - assert len(pages) > 0 - assert all(isinstance(page, dict) for page in pages) - assert all("contents" in page for page in pages) - assert all("metadata" in page for page in pages) - assert all(isinstance(page["contents"][0], Image.Image) for page in pages) diff --git a/tests/memory/workers/tasks/test_blogs_tasks.py b/tests/memory/workers/tasks/test_blogs_tasks.py index 4980e4d..14f1e78 100644 --- a/tests/memory/workers/tasks/test_blogs_tasks.py +++ b/tests/memory/workers/tasks/test_blogs_tasks.py @@ -1,6 +1,6 @@ import pytest from datetime import datetime, timedelta -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, patch from memory.common.db.models import ArticleFeed, BlogPost from memory.workers.tasks import blogs diff --git a/tests/memory/workers/tasks/test_ebook_tasks.py b/tests/memory/workers/tasks/test_ebook_tasks.py index 96213f5..1191fc3 100644 --- a/tests/memory/workers/tasks/test_ebook_tasks.py +++ b/tests/memory/workers/tasks/test_ebook_tasks.py @@ -2,7 +2,7 @@ import pytest from pathlib import Path from unittest.mock import patch, Mock -from memory.common.db.models import Book, BookSection, Chunk +from memory.common.db.models import Book, BookSection from memory.parsers.ebook import Ebook, Section from memory.workers.tasks import ebook diff --git a/tests/memory/workers/tasks/test_maintenance.py b/tests/memory/workers/tasks/test_maintenance.py index 62611c0..96d7e5d 100644 --- a/tests/memory/workers/tasks/test_maintenance.py +++ b/tests/memory/workers/tasks/test_maintenance.py @@ -6,7 +6,7 @@ import pytest from PIL import Image from memory.common import qdrant as qd -from memory.common import embedding, settings +from memory.common import settings from memory.common.db.models import Chunk, SourceItem from memory.workers.tasks.maintenance import ( clean_collection,