mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
fix linting
This commit is contained in:
parent
1291ca9d08
commit
ab87bced81
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@ -17,10 +17,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .[all]
|
pip install .[all]
|
||||||
pip install ruff==0.11.10 pylint==1.1.400
|
pip install ruff==0.11.10 pyright==1.1.327
|
||||||
- name: Run linters
|
- name: Run linters
|
||||||
run: |
|
run: |
|
||||||
ruff check .
|
ruff check .
|
||||||
pylint $(git ls-files '*.py')
|
pyright $(git ls-files '*.py')
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: pytest -vv
|
run: pytest -vv
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
fastapi==0.112.2
|
fastapi==0.112.2
|
||||||
uvicorn==0.29.0
|
uvicorn==0.29.0
|
||||||
python-jose==3.3.0
|
python-jose==3.3.0
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
|
sqladmin
|
||||||
|
170
src/memory/api/admin.py
Normal file
170
src/memory/api/admin.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
SQLAdmin views for the knowledge base database models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqladmin import Admin, ModelView
|
||||||
|
|
||||||
|
from memory.common.db.models import (
|
||||||
|
Chunk,
|
||||||
|
SourceItem,
|
||||||
|
MailMessage,
|
||||||
|
EmailAttachment,
|
||||||
|
Photo,
|
||||||
|
Comic,
|
||||||
|
Book,
|
||||||
|
BookSection,
|
||||||
|
BlogPost,
|
||||||
|
MiscDoc,
|
||||||
|
ArticleFeed,
|
||||||
|
EmailAccount,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_COLUMNS = (
|
||||||
|
"modality",
|
||||||
|
"embed_status",
|
||||||
|
"inserted_at",
|
||||||
|
"tags",
|
||||||
|
"size",
|
||||||
|
"mime_type",
|
||||||
|
"filename",
|
||||||
|
"content",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def source_columns(model: type[SourceItem], *columns: str):
|
||||||
|
return [getattr(model, c) for c in columns + DEFAULT_COLUMNS if hasattr(model, c)]
|
||||||
|
|
||||||
|
|
||||||
|
# Create admin views for all models
|
||||||
|
class SourceItemAdmin(ModelView, model=SourceItem):
|
||||||
|
column_list = source_columns(SourceItem)
|
||||||
|
column_searchable_list = [
|
||||||
|
"modality",
|
||||||
|
"filename",
|
||||||
|
"embed_status",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkAdmin(ModelView, model=Chunk):
|
||||||
|
column_list = ["id", "source_id", "embedding_model", "created_at"]
|
||||||
|
|
||||||
|
|
||||||
|
class MailMessageAdmin(ModelView, model=MailMessage):
|
||||||
|
column_list = source_columns(
|
||||||
|
MailMessage,
|
||||||
|
"subject",
|
||||||
|
"sender",
|
||||||
|
"recipients",
|
||||||
|
"folder",
|
||||||
|
"message_id",
|
||||||
|
"tags",
|
||||||
|
"embed_status",
|
||||||
|
"inserted_at",
|
||||||
|
)
|
||||||
|
column_searchable_list = [
|
||||||
|
"subject",
|
||||||
|
"sender",
|
||||||
|
"recipients",
|
||||||
|
"folder",
|
||||||
|
"message_id",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class EmailAttachmentAdmin(ModelView, model=EmailAttachment):
|
||||||
|
column_list = source_columns(EmailAttachment, "filename", "mime_type", "size")
|
||||||
|
column_searchable_list = [
|
||||||
|
"filename",
|
||||||
|
"mime_type",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BlogPostAdmin(ModelView, model=BlogPost):
|
||||||
|
column_list = source_columns(
|
||||||
|
BlogPost, "title", "author", "url", "published", "domain"
|
||||||
|
)
|
||||||
|
column_searchable_list = ["title", "author", "domain"]
|
||||||
|
|
||||||
|
|
||||||
|
class PhotoAdmin(ModelView, model=Photo):
|
||||||
|
column_list = source_columns(Photo, "exif_taken_at", "camera")
|
||||||
|
|
||||||
|
|
||||||
|
class ComicAdmin(ModelView, model=Comic):
|
||||||
|
column_list = source_columns(Comic, "title", "author", "published", "volume")
|
||||||
|
column_searchable_list = ["title", "author"]
|
||||||
|
|
||||||
|
|
||||||
|
class BookSectionAdmin(ModelView, model=BookSection):
|
||||||
|
column_list = source_columns(
|
||||||
|
BookSection,
|
||||||
|
"section_title",
|
||||||
|
"section_number",
|
||||||
|
"section_level",
|
||||||
|
"start_page",
|
||||||
|
"end_page",
|
||||||
|
)
|
||||||
|
column_searchable_list = ["section_title"]
|
||||||
|
|
||||||
|
|
||||||
|
class MiscDocAdmin(ModelView, model=MiscDoc):
|
||||||
|
column_list = source_columns(MiscDoc, "path")
|
||||||
|
column_searchable_list = ["path"]
|
||||||
|
|
||||||
|
|
||||||
|
class BookAdmin(ModelView, model=Book):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"title",
|
||||||
|
"author",
|
||||||
|
"series",
|
||||||
|
"series_number",
|
||||||
|
"published",
|
||||||
|
]
|
||||||
|
column_searchable_list = ["title", "author"]
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleFeedAdmin(ModelView, model=ArticleFeed):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"title",
|
||||||
|
"description",
|
||||||
|
"url",
|
||||||
|
"tags",
|
||||||
|
"active",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
column_searchable_list = ["title", "url"]
|
||||||
|
|
||||||
|
|
||||||
|
class EmailAccountAdmin(ModelView, model=EmailAccount):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"tags",
|
||||||
|
"email_address",
|
||||||
|
"username",
|
||||||
|
"use_ssl",
|
||||||
|
"folders",
|
||||||
|
"active",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
column_searchable_list = ["name", "email_address"]
|
||||||
|
|
||||||
|
|
||||||
|
def setup_admin(admin: Admin):
|
||||||
|
"""Add all admin views to the admin instance."""
|
||||||
|
admin.add_view(SourceItemAdmin)
|
||||||
|
admin.add_view(ChunkAdmin)
|
||||||
|
admin.add_view(EmailAccountAdmin)
|
||||||
|
admin.add_view(MailMessageAdmin)
|
||||||
|
admin.add_view(EmailAttachmentAdmin)
|
||||||
|
admin.add_view(BookAdmin)
|
||||||
|
admin.add_view(BookSectionAdmin)
|
||||||
|
admin.add_view(MiscDocAdmin)
|
||||||
|
admin.add_view(ArticleFeedAdmin)
|
||||||
|
admin.add_view(BlogPostAdmin)
|
||||||
|
admin.add_view(ComicAdmin)
|
||||||
|
admin.add_view(PhotoAdmin)
|
@ -2,49 +2,28 @@
|
|||||||
FastAPI application for the knowledge base.
|
FastAPI application for the knowledge base.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
from collections import defaultdict
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Annotated, List, Optional, Callable
|
|
||||||
from fastapi import FastAPI, File, UploadFile, Query, HTTPException, Form
|
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
import qdrant_client
|
|
||||||
from qdrant_client.http import models as qdrant_models
|
|
||||||
from PIL import Image
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, extract, settings
|
|
||||||
from memory.common.collections import get_modality, TEXT_COLLECTIONS, ALL_COLLECTIONS
|
|
||||||
from memory.common.db.connection import make_session
|
|
||||||
from memory.common.db.models import Chunk, SourceItem
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, File, UploadFile, Query, Form
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from sqladmin import Admin
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common import extract
|
||||||
|
from memory.common.db.connection import get_engine
|
||||||
|
from memory.api.admin import setup_admin
|
||||||
|
from memory.api.search import search, SearchResult
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API")
|
app = FastAPI(title="Knowledge Base API")
|
||||||
|
|
||||||
|
# SQLAdmin setup
|
||||||
class AnnotatedChunk(BaseModel):
|
engine = get_engine()
|
||||||
id: str
|
admin = Admin(app, engine)
|
||||||
score: float
|
setup_admin(admin)
|
||||||
metadata: dict
|
|
||||||
preview: Optional[str | None] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
|
||||||
collection: str
|
|
||||||
results: List[dict]
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResult(BaseModel):
|
|
||||||
id: int
|
|
||||||
size: int
|
|
||||||
mime_type: str
|
|
||||||
chunks: list[AnnotatedChunk]
|
|
||||||
content: Optional[str] = None
|
|
||||||
filename: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
@ -53,94 +32,6 @@ def health_check():
|
|||||||
return {"status": "healthy"}
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
def annotated_chunk(
|
|
||||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
|
||||||
) -> tuple[SourceItem, AnnotatedChunk]:
|
|
||||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
|
||||||
if not previews and not isinstance(item, str):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if isinstance(item, Image.Image):
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
format = item.format or "PNG"
|
|
||||||
item.save(buffer, format=format)
|
|
||||||
mime_type = f"image/{format.lower()}"
|
|
||||||
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
|
||||||
elif isinstance(item, bytes):
|
|
||||||
return base64.b64encode(item).decode("utf-8")
|
|
||||||
elif isinstance(item, str):
|
|
||||||
return item
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported item type: {type(item)}")
|
|
||||||
|
|
||||||
metadata = search_result.payload or {}
|
|
||||||
metadata = {
|
|
||||||
k: v
|
|
||||||
for k, v in metadata.items()
|
|
||||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
|
||||||
}
|
|
||||||
return chunk.source, AnnotatedChunk(
|
|
||||||
id=str(chunk.id),
|
|
||||||
score=search_result.score,
|
|
||||||
metadata=metadata,
|
|
||||||
preview=serialize_item(chunk.data[0]) if chunk.data else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
|
||||||
items = defaultdict(list)
|
|
||||||
for source, chunk in chunks:
|
|
||||||
items[source].append(chunk)
|
|
||||||
|
|
||||||
return [
|
|
||||||
SearchResult(
|
|
||||||
id=source.id,
|
|
||||||
size=source.size or len(source.content),
|
|
||||||
mime_type=source.mime_type or "text/plain",
|
|
||||||
filename=source.filename
|
|
||||||
and source.filename.replace(
|
|
||||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
|
||||||
),
|
|
||||||
content=source.display_contents,
|
|
||||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
|
||||||
)
|
|
||||||
for source, chunks in items.items()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def query_chunks(
|
|
||||||
client: qdrant_client.QdrantClient,
|
|
||||||
upload_data: list[extract.DataChunk],
|
|
||||||
allowed_modalities: set[str],
|
|
||||||
embedder: Callable,
|
|
||||||
min_score: float = 0.0,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
|
||||||
if not upload_data:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
|
|
||||||
if not chunks:
|
|
||||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
vector = embedder(chunks, input_type="query")[0]
|
|
||||||
|
|
||||||
return {
|
|
||||||
collection: [
|
|
||||||
r
|
|
||||||
for r in qdrant.search_vectors(
|
|
||||||
client=client,
|
|
||||||
collection_name=collection,
|
|
||||||
query_vector=vector,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
if r.score >= min_score
|
|
||||||
]
|
|
||||||
for collection in allowed_modalities
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||||
if not item:
|
if not item:
|
||||||
return []
|
return []
|
||||||
@ -152,7 +43,7 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/search", response_model=list[SearchResult])
|
@app.post("/search", response_model=list[SearchResult])
|
||||||
async def search(
|
async def search_endpoint(
|
||||||
query: Optional[str] = Form(None),
|
query: Optional[str] = Form(None),
|
||||||
previews: Optional[bool] = Form(False),
|
previews: Optional[bool] = Form(False),
|
||||||
modalities: Annotated[list[str], Query()] = [],
|
modalities: Annotated[list[str], Query()] = [],
|
||||||
@ -161,62 +52,21 @@ async def search(
|
|||||||
min_text_score: float = Query(0.3, ge=0.0, le=1.0),
|
min_text_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||||
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
min_multimodal_score: float = Query(0.3, ge=0.0, le=1.0),
|
||||||
):
|
):
|
||||||
"""
|
"""Search endpoint - delegates to search module"""
|
||||||
Search across knowledge base using text query and optional files.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- query: Optional text search query
|
|
||||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
|
||||||
- files: Optional files to include in the search context
|
|
||||||
- limit: Maximum number of results per modality
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- List of search results sorted by score
|
|
||||||
"""
|
|
||||||
upload_data = [
|
upload_data = [
|
||||||
chunk for item in [query, *files] for chunk in await input_type(item)
|
chunk for item in [query, *files] for chunk in await input_type(item)
|
||||||
]
|
]
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}"
|
f"Querying chunks for {modalities}, query: {query}, previews: {previews}, upload_data: {upload_data}"
|
||||||
)
|
)
|
||||||
|
return await search(
|
||||||
client = qdrant.get_qdrant_client()
|
|
||||||
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
|
||||||
text_results = query_chunks(
|
|
||||||
client,
|
|
||||||
upload_data,
|
upload_data,
|
||||||
allowed_modalities & TEXT_COLLECTIONS,
|
previews=previews,
|
||||||
embedding.embed_text,
|
modalities=modalities,
|
||||||
min_score=min_text_score,
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
min_text_score=min_text_score,
|
||||||
|
min_multimodal_score=min_multimodal_score,
|
||||||
)
|
)
|
||||||
multimodal_results = query_chunks(
|
|
||||||
client,
|
|
||||||
upload_data,
|
|
||||||
allowed_modalities,
|
|
||||||
embedding.embed_mixed,
|
|
||||||
min_score=min_multimodal_score,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
search_results = {
|
|
||||||
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
|
||||||
for k in allowed_modalities
|
|
||||||
}
|
|
||||||
|
|
||||||
found_chunks = {
|
|
||||||
str(r.id): r for results in search_results.values() for r in results
|
|
||||||
}
|
|
||||||
with make_session() as db:
|
|
||||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
|
||||||
logger.error(f"Found chunks: {chunks}")
|
|
||||||
|
|
||||||
results = group_chunks(
|
|
||||||
[
|
|
||||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
|
||||||
for chunk in chunks
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/files/{path:path}")
|
@app.get("/files/{path:path}")
|
||||||
|
189
src/memory/api/search.py
Normal file
189
src/memory/api/search.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
Search endpoints for the knowledge base API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Callable, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import qdrant_client
|
||||||
|
from qdrant_client.http import models as qdrant_models
|
||||||
|
|
||||||
|
from memory.common import embedding, qdrant, extract, settings
|
||||||
|
from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotatedChunk(BaseModel):
|
||||||
|
id: str
|
||||||
|
score: float
|
||||||
|
metadata: dict
|
||||||
|
preview: Optional[str | None] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResponse(BaseModel):
|
||||||
|
collection: str
|
||||||
|
results: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
id: int
|
||||||
|
size: int
|
||||||
|
mime_type: str
|
||||||
|
chunks: list[AnnotatedChunk]
|
||||||
|
content: Optional[str] = None
|
||||||
|
filename: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def annotated_chunk(
|
||||||
|
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||||
|
) -> tuple[SourceItem, AnnotatedChunk]:
|
||||||
|
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||||
|
if not previews and not isinstance(item, str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(item, Image.Image):
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
format = item.format or "PNG"
|
||||||
|
item.save(buffer, format=format)
|
||||||
|
mime_type = f"image/{format.lower()}"
|
||||||
|
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
||||||
|
elif isinstance(item, bytes):
|
||||||
|
return base64.b64encode(item).decode("utf-8")
|
||||||
|
elif isinstance(item, str):
|
||||||
|
return item
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported item type: {type(item)}")
|
||||||
|
|
||||||
|
metadata = search_result.payload or {}
|
||||||
|
metadata = {
|
||||||
|
k: v
|
||||||
|
for k, v in metadata.items()
|
||||||
|
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||||
|
}
|
||||||
|
return chunk.source, AnnotatedChunk(
|
||||||
|
id=str(chunk.id),
|
||||||
|
score=search_result.score,
|
||||||
|
metadata=metadata,
|
||||||
|
preview=serialize_item(chunk.data[0]) if chunk.data else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
||||||
|
items = defaultdict(list)
|
||||||
|
for source, chunk in chunks:
|
||||||
|
items[source].append(chunk)
|
||||||
|
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
id=source.id,
|
||||||
|
size=source.size or len(source.content),
|
||||||
|
mime_type=source.mime_type or "text/plain",
|
||||||
|
filename=source.filename
|
||||||
|
and source.filename.replace(
|
||||||
|
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||||
|
),
|
||||||
|
content=source.display_contents,
|
||||||
|
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||||
|
)
|
||||||
|
for source, chunks in items.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def query_chunks(
|
||||||
|
client: qdrant_client.QdrantClient,
|
||||||
|
upload_data: list[extract.DataChunk],
|
||||||
|
allowed_modalities: set[str],
|
||||||
|
embedder: Callable,
|
||||||
|
min_score: float = 0.0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||||
|
if not upload_data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
|
||||||
|
if not chunks:
|
||||||
|
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
vector = embedder(chunks, input_type="query")[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
collection: [
|
||||||
|
r
|
||||||
|
for r in qdrant.search_vectors(
|
||||||
|
client=client,
|
||||||
|
collection_name=collection,
|
||||||
|
query_vector=vector,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
if r.score >= min_score
|
||||||
|
]
|
||||||
|
for collection in allowed_modalities
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
data: list[extract.DataChunk],
|
||||||
|
previews: Optional[bool] = False,
|
||||||
|
modalities: list[str] = [],
|
||||||
|
limit: int = 10,
|
||||||
|
min_text_score: float = 0.3,
|
||||||
|
min_multimodal_score: float = 0.3,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search across knowledge base using text query and optional files.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- query: Optional text search query
|
||||||
|
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||||
|
- files: Optional files to include in the search context
|
||||||
|
- limit: Maximum number of results per modality
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- List of search results sorted by score
|
||||||
|
"""
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
||||||
|
text_results = query_chunks(
|
||||||
|
client,
|
||||||
|
data,
|
||||||
|
allowed_modalities & TEXT_COLLECTIONS,
|
||||||
|
embedding.embed_text,
|
||||||
|
min_score=min_text_score,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
multimodal_results = query_chunks(
|
||||||
|
client,
|
||||||
|
data,
|
||||||
|
allowed_modalities,
|
||||||
|
embedding.embed_mixed,
|
||||||
|
min_score=min_multimodal_score,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
search_results = {
|
||||||
|
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
||||||
|
for k in allowed_modalities
|
||||||
|
}
|
||||||
|
|
||||||
|
found_chunks = {
|
||||||
|
str(r.id): r for results in search_results.values() for r in results
|
||||||
|
}
|
||||||
|
with make_session() as db:
|
||||||
|
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||||
|
logger.error(f"Found chunks: {chunks}")
|
||||||
|
|
||||||
|
results = group_chunks(
|
||||||
|
[
|
||||||
|
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
@ -1,12 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
Database utilities package.
|
Database utilities package.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from memory.common.db.models import Base
|
from memory.common.db.models import Base
|
||||||
from memory.common.db.connection import get_engine, get_session_factory, get_scoped_session
|
from memory.common.db.connection import (
|
||||||
|
get_engine,
|
||||||
|
get_session_factory,
|
||||||
|
get_scoped_session,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Base",
|
"Base",
|
||||||
"get_engine",
|
"get_engine",
|
||||||
"get_session_factory",
|
"get_session_factory",
|
||||||
"get_scoped_session",
|
"get_scoped_session",
|
||||||
]
|
]
|
||||||
|
@ -2,12 +2,11 @@
|
|||||||
Database models for the knowledge base system.
|
Database models for the knowledge base system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, ClassVar, Iterable, Sequence, cast
|
from typing import Any, Sequence, cast
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -4,7 +4,7 @@ import pathlib
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -157,7 +157,7 @@ def process_image(url: str, image_dir: pathlib.Path) -> PILImage.Image | None:
|
|||||||
|
|
||||||
# Download if not already cached
|
# Download if not already cached
|
||||||
if not local_path.exists():
|
if not local_path.exists():
|
||||||
local_path.write_bytes(fetch_html(url, as_bytes=True))
|
local_path.write_bytes(cast(bytes, fetch_html(url, as_bytes=True)))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return PILImage.open(local_path)
|
return PILImage.open(local_path)
|
||||||
|
@ -128,14 +128,3 @@ def test_docx_to_pdf_default_output():
|
|||||||
|
|
||||||
assert result_path == SAMPLE_DOCX.with_suffix(".pdf")
|
assert result_path == SAMPLE_DOCX.with_suffix(".pdf")
|
||||||
assert result_path.exists()
|
assert result_path.exists()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_pdflatex_available(), reason="pdflatex not installed")
|
|
||||||
def test_extract_docx():
|
|
||||||
pages = extract_docx(SAMPLE_DOCX)
|
|
||||||
|
|
||||||
assert len(pages) > 0
|
|
||||||
assert all(isinstance(page, dict) for page in pages)
|
|
||||||
assert all("contents" in page for page in pages)
|
|
||||||
assert all("metadata" in page for page in pages)
|
|
||||||
assert all(isinstance(page["contents"][0], Image.Image) for page in pages)
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from unittest.mock import Mock, patch, MagicMock
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from memory.common.db.models import ArticleFeed, BlogPost
|
from memory.common.db.models import ArticleFeed, BlogPost
|
||||||
from memory.workers.tasks import blogs
|
from memory.workers.tasks import blogs
|
||||||
|
@ -2,7 +2,7 @@ import pytest
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch, Mock
|
from unittest.mock import patch, Mock
|
||||||
|
|
||||||
from memory.common.db.models import Book, BookSection, Chunk
|
from memory.common.db.models import Book, BookSection
|
||||||
from memory.parsers.ebook import Ebook, Section
|
from memory.parsers.ebook import Ebook, Section
|
||||||
from memory.workers.tasks import ebook
|
from memory.workers.tasks import ebook
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import pytest
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common import qdrant as qd
|
from memory.common import qdrant as qd
|
||||||
from memory.common import embedding, settings
|
from memory.common import settings
|
||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
from memory.workers.tasks.maintenance import (
|
from memory.workers.tasks.maintenance import (
|
||||||
clean_collection,
|
clean_collection,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user