Compare commits

...

3 Commits

Author SHA1 Message Date
Daniel O'Connell
beb94375da fix tests 2025-07-24 23:34:10 +02:00
EC2 Default User
cf456c04d6 handle books 2025-07-24 21:33:15 +00:00
EC2 Default User
907375eee5 fix approx tokens call 2025-07-24 17:57:39 +00:00
8 changed files with 115 additions and 22 deletions

View File

@ -0,0 +1,61 @@
import logging
from sqlalchemy.orm import joinedload
from memory.api.MCP.tools import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import Book, BookSection, BookSectionPayload
logger = logging.getLogger(__name__)
@mcp.tool()
async def all_books(sections: bool = False) -> list[dict]:
"""
Get all books in the database.
If sections is True, the response will include the sections for each book.
Args:
sections: Whether to include sections in the response. Defaults to False.
Returns:
List of books in the database.
"""
options = []
if sections:
options = [joinedload(Book.sections)]
with make_session() as session:
books = session.query(Book).options(*options).all()
return [book.as_payload(sections=sections) for book in books]
@mcp.tool()
def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]:
"""
Read a book from the database.
If sections is provided, only the sections with the given IDs will be returned.
Args:
book_id: The ID of the book to read.
sections: The IDs of the sections to read. Defaults to all sections.
Returns:
List of sections in the book, with contents. In the case of nested sections, only the top-level sections are returned.
"""
with make_session() as session:
book_sections = session.query(BookSection).filter(
BookSection.book_id == book_id
)
if sections:
book_sections = book_sections.filter(BookSection.id.in_(sections))
all_sections = book_sections.all()
parents = [section.parent_section_id for section in all_sections]
return [
section.as_payload()
for section in all_sections
if section.id not in parents
]

View File

@ -180,6 +180,7 @@ async def observe(
session_id: UUID to group observations from same conversation
agent_model: AI model making observations (for quality tracking)
"""
logger.info("MCP: Observing")
tasks = [
(
observation,
@ -237,6 +238,7 @@ async def search_observations(
Returns: List with content, tags, created_at, metadata
Results sorted by relevance to your query.
"""
logger.info("MCP: Searching observations for %s", query)
semantic_text = observation.generate_semantic_text(
subject=subject or "",
observation_type="".join(observation_types or []),
@ -297,6 +299,7 @@ async def create_note(
confidences: Dict of scores (0.0-1.0), e.g. {"observation_accuracy": 0.9}
tags: Organization tags for filtering and discovery
"""
logger.info("MCP: creating note: %s", subject)
if filename:
path = pathlib.Path(filename)
if not path.is_absolute():

View File

@ -108,10 +108,3 @@ async def get_authenticated_user() -> dict:
"client_id": access_token.client_id,
"user": user_info,
}
@mcp.tool()
async def send_response(response: str) -> dict:
"""Send a response to the user."""
logger.info(f"Sending response: {response}")
return {"response": response}

View File

@ -28,7 +28,7 @@ from sqlalchemy.dialects.postgresql import BYTEA
from sqlalchemy.orm import Session, relationship
from sqlalchemy.types import Numeric
from memory.common import settings
from memory.common import settings, tokens
import memory.common.extract as extract
import memory.common.collections as collections
import memory.common.chunker as chunker
@ -125,8 +125,7 @@ def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataCh
)
chunks: list[extract.DataChunk] = [full_text]
tokens = chunker.approx_token_count(content)
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
if tokens.approx_token_count(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks += [
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
for c in chunker.chunk_text(content)

View File

@ -50,9 +50,9 @@ class Book(Base):
Index("book_title_idx", "title"),
)
def as_payload(self) -> dict:
return {
**super().as_payload(),
def as_payload(self, sections: bool = False) -> dict:
data = {
"id": self.id,
"isbn": self.isbn,
"title": self.title,
"author": self.author,
@ -63,6 +63,9 @@ class Book(Base):
"series": self.series,
"series_number": self.series_number,
} | (cast(dict, self.book_metadata) or {})
if sections:
data["sections"] = [section.as_payload() for section in self.sections]
return data
class ArticleFeed(Base):

View File

@ -3,7 +3,8 @@ from dataclasses import dataclass, field
from typing import Any, cast
from pathlib import Path
import fitz # PyMuPDF
import fitz
from memory.common import settings # PyMuPDF
logger = logging.getLogger(__name__)
@ -27,6 +28,7 @@ class Ebook:
title: str
author: str
file_path: Path
relative_path: Path
metadata: dict[str, Any] = field(default_factory=dict)
sections: list[Section] = field(default_factory=list)
full_content: str = ""
@ -180,6 +182,7 @@ def parse_ebook(file_path: str | Path) -> Ebook:
sections=sections,
full_content=full_content,
file_path=path,
relative_path=path.relative_to(settings.FILE_STORAGE_DIR),
file_type=path.suffix.lower()[1:],
n_pages=doc.page_count,
)

View File

@ -1,12 +1,13 @@
import logging
import pathlib
from datetime import datetime
from typing import Iterable, cast
import memory.common.settings as settings
from memory.parsers.ebook import Ebook, parse_ebook, Section
from memory.common.db.models import Book, BookSection
from memory.common.celery_app import SYNC_BOOK, app
from memory.common.db.connection import make_session
from memory.common.celery_app import app, SYNC_BOOK
from memory.common.db.models import Book, BookSection
from memory.parsers.ebook import Ebook, Section, parse_ebook
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
@ -143,7 +144,18 @@ def embed_sections(all_sections: list[BookSection]) -> int:
@app.task(name=SYNC_BOOK)
@safe_task_execution
def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
def sync_book(
file_path: str,
tags: Iterable[str] = [],
title: str = "",
author: str = "",
publisher: str = "",
published: str = "",
language: str = "",
edition: str = "",
series: str = "",
series_number: int | None = None,
) -> dict:
"""
Synchronize a book from a file path.
@ -154,12 +166,13 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
dict: Summary of what was processed
"""
ebook = validate_and_parse_book(file_path)
logger.info(f"Ebook parsed: {ebook.title}")
logger.info(f"Ebook parsed: {ebook.title}, {ebook.file_path.as_posix()}")
with make_session() as session:
# Check for existing book
logger.info(f"Checking for existing book: {ebook.relative_path.as_posix()}")
existing_book = check_content_exists(
session, Book, file_path=ebook.file_path.as_posix()
session, Book, file_path=ebook.relative_path.as_posix()
)
if existing_book:
logger.info(f"Book already exists: {existing_book.title}")
@ -175,6 +188,24 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
# Create book and sections with relationships
book, all_sections = create_book_and_sections(ebook, session, tags)
if title:
book.title = title # type: ignore
if author:
book.author = author # type: ignore
if publisher:
book.publisher = publisher # type: ignore
if published:
book.published = datetime.fromisoformat(published) # type: ignore
if language:
book.language = language # type: ignore
if edition:
book.edition = edition # type: ignore
if series:
book.series = series # type: ignore
if series_number:
book.series_number = series_number # type: ignore
session.add(book)
# Embed sections
logger.info("Embedding sections")
embedded_count = sum(embed_source_item(section) for section in all_sections)

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
from typing import cast
import pytest
from PIL import Image
from memory.common import settings, chunker, extract
from memory.common import settings, chunker, extract, tokens
from memory.common.db.models.source_item import (
Chunk,
)
@ -610,7 +610,7 @@ def test_chunk_mixed_long_content(tmp_path):
with (
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10),
patch.object(chunker, "approx_token_count", return_value=100),
patch.object(tokens, "approx_token_count", return_value=100),
): # Force it to be > 2 * 10
result = chunk_mixed(long_content, [])