mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
gdrive tools
This commit is contained in:
parent
9cf71c9336
commit
5935f4741c
@ -416,3 +416,197 @@ def fetch_file(filename: str) -> dict:
|
||||
serialize_chunk(chunk, data) for chunk in chunks for data in chunk.data
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# --- Enumeration tools for systematic investigations ---
|
||||
|
||||
|
||||
@core_mcp.tool()
|
||||
async def get_source_item(id: int, include_content: bool = True) -> dict:
|
||||
"""
|
||||
Get full details of a source item by ID.
|
||||
Use after search to drill down into specific results.
|
||||
|
||||
Args:
|
||||
id: The source item ID (from search results)
|
||||
include_content: Whether to include full content (default True)
|
||||
|
||||
Returns: Full item details including metadata, tags, and optionally content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
item = (
|
||||
session.query(SourceItem)
|
||||
.filter(
|
||||
SourceItem.id == id,
|
||||
SourceItem.embed_status == "STORED",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
raise ValueError(f"Item {id} not found or not yet indexed")
|
||||
|
||||
result = {
|
||||
"id": item.id,
|
||||
"modality": item.modality,
|
||||
"title": item.title,
|
||||
"mime_type": item.mime_type,
|
||||
"filename": item.filename,
|
||||
"size": item.size,
|
||||
"tags": item.tags,
|
||||
"inserted_at": item.inserted_at.isoformat() if item.inserted_at else None,
|
||||
"metadata": item.as_payload(),
|
||||
}
|
||||
|
||||
if include_content:
|
||||
result["content"] = item.content
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@core_mcp.tool()
|
||||
async def list_items(
|
||||
modalities: set[str] = set(),
|
||||
filters: SearchFilters = {},
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
sort_by: str = "inserted_at",
|
||||
sort_order: str = "desc",
|
||||
include_metadata: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
List items without semantic search - for systematic enumeration.
|
||||
Use for reviewing all items matching criteria, not finding best matches.
|
||||
|
||||
Args:
|
||||
modalities: Filter by type: email, blog, book, forum, photo, comic, etc. (empty = all)
|
||||
filters: Same filters as search_knowledge_base (tags, min_size, max_size, etc.)
|
||||
limit: Max results per page (default 50, max 200)
|
||||
offset: Skip first N results for pagination
|
||||
sort_by: Sort field - "inserted_at", "size", or "id" (default: inserted_at)
|
||||
sort_order: "asc" or "desc" (default: desc)
|
||||
include_metadata: Include full as_payload() metadata (default True)
|
||||
|
||||
Returns: {items: [...], total: int, has_more: bool}
|
||||
"""
|
||||
limit = min(limit, 200)
|
||||
if sort_by not in ("inserted_at", "size", "id"):
|
||||
sort_by = "inserted_at"
|
||||
if sort_order not in ("asc", "desc"):
|
||||
sort_order = "desc"
|
||||
|
||||
with make_session() as session:
|
||||
query = session.query(SourceItem).filter(SourceItem.embed_status == "STORED")
|
||||
|
||||
# Filter by modalities
|
||||
if modalities:
|
||||
query = query.filter(SourceItem.modality.in_(modalities))
|
||||
|
||||
# Apply filters
|
||||
if tags := filters.get("tags"):
|
||||
query = query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
|
||||
)
|
||||
if min_size := filters.get("min_size"):
|
||||
query = query.filter(SourceItem.size >= min_size)
|
||||
if max_size := filters.get("max_size"):
|
||||
query = query.filter(SourceItem.size <= max_size)
|
||||
if source_ids := filters.get("source_ids"):
|
||||
query = query.filter(SourceItem.id.in_(source_ids))
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(SourceItem, sort_by)
|
||||
if sort_order == "desc":
|
||||
sort_column = sort_column.desc()
|
||||
query = query.order_by(sort_column)
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(offset).limit(limit)
|
||||
|
||||
items = []
|
||||
for item in query.all():
|
||||
preview = None
|
||||
if item.content:
|
||||
preview = item.content[:200] + "..." if len(item.content) > 200 else item.content
|
||||
|
||||
item_dict = {
|
||||
"id": item.id,
|
||||
"modality": item.modality,
|
||||
"title": item.title,
|
||||
"mime_type": item.mime_type,
|
||||
"filename": item.filename,
|
||||
"size": item.size,
|
||||
"tags": item.tags,
|
||||
"inserted_at": item.inserted_at.isoformat() if item.inserted_at else None,
|
||||
"preview": preview,
|
||||
}
|
||||
|
||||
if include_metadata:
|
||||
item_dict["metadata"] = item.as_payload()
|
||||
else:
|
||||
item_dict["metadata"] = None
|
||||
|
||||
items.append(item_dict)
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"has_more": offset + len(items) < total,
|
||||
}
|
||||
|
||||
|
||||
@core_mcp.tool()
|
||||
async def count_items(
|
||||
modalities: set[str] = set(),
|
||||
filters: SearchFilters = {},
|
||||
) -> dict:
|
||||
"""
|
||||
Count items matching criteria without retrieving them.
|
||||
Use to understand scope before systematic review.
|
||||
|
||||
Args:
|
||||
modalities: Filter by type (empty = all)
|
||||
filters: Same filters as search_knowledge_base
|
||||
|
||||
Returns: {total: int, by_modality: {email: 100, blog: 50, ...}}
|
||||
"""
|
||||
from sqlalchemy import func as sql_func
|
||||
|
||||
with make_session() as session:
|
||||
base_query = session.query(SourceItem).filter(
|
||||
SourceItem.embed_status == "STORED"
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if modalities:
|
||||
base_query = base_query.filter(SourceItem.modality.in_(modalities))
|
||||
if tags := filters.get("tags"):
|
||||
base_query = base_query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
|
||||
)
|
||||
if min_size := filters.get("min_size"):
|
||||
base_query = base_query.filter(SourceItem.size >= min_size)
|
||||
if max_size := filters.get("max_size"):
|
||||
base_query = base_query.filter(SourceItem.size <= max_size)
|
||||
if source_ids := filters.get("source_ids"):
|
||||
base_query = base_query.filter(SourceItem.id.in_(source_ids))
|
||||
|
||||
# Get total
|
||||
total = base_query.count()
|
||||
|
||||
# Get counts by modality
|
||||
by_modality_query = (
|
||||
base_query.with_entities(
|
||||
SourceItem.modality, sql_func.count(SourceItem.id)
|
||||
)
|
||||
.group_by(SourceItem.modality)
|
||||
)
|
||||
|
||||
by_modality = {row[0]: row[1] for row in by_modality_query.all()}
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"by_modality": by_modality,
|
||||
}
|
||||
|
||||
@ -381,6 +381,16 @@ class SourceItem(Base):
|
||||
"""
|
||||
return 1.0
|
||||
|
||||
@property
|
||||
def title(self) -> str | None:
|
||||
"""
|
||||
Return a display title for this item.
|
||||
|
||||
Subclasses should override to return their specific title field
|
||||
(e.g., subject for emails, title for blog posts).
|
||||
"""
|
||||
return cast(str | None, self.filename)
|
||||
|
||||
@property
|
||||
def display_contents(self) -> dict | None:
|
||||
payload = self.as_payload()
|
||||
|
||||
@ -190,6 +190,10 @@ class MailMessage(SourceItem):
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["mail"]
|
||||
|
||||
@property
|
||||
def title(self) -> str | None:
|
||||
return cast(str | None, self.subject)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("mail_sent_idx", "sent_at"),
|
||||
@ -592,6 +596,10 @@ class BookSection(SourceItem):
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["book"]
|
||||
|
||||
@property
|
||||
def title(self) -> str | None:
|
||||
return cast(str | None, self.section_title)
|
||||
|
||||
def as_payload(self) -> BookSectionPayload:
|
||||
return BookSectionPayload(
|
||||
**super().as_payload(),
|
||||
@ -1035,6 +1043,10 @@ class Note(SourceItem):
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["text"] # Notes go to the text collection
|
||||
|
||||
@property
|
||||
def title(self) -> str | None:
|
||||
return cast(str | None, self.subject)
|
||||
|
||||
|
||||
class AgentObservationPayload(SourceItemPayload):
|
||||
session_id: Annotated[str | None, "Session ID for the observation"]
|
||||
@ -1204,6 +1216,10 @@ class AgentObservation(SourceItem):
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["semantic", "temporal"]
|
||||
|
||||
@property
|
||||
def title(self) -> str | None:
|
||||
return cast(str | None, self.subject)
|
||||
|
||||
|
||||
class GoogleDocPayload(SourceItemPayload):
|
||||
google_file_id: Annotated[str, "Google Drive file ID"]
|
||||
|
||||
@ -131,7 +131,8 @@ class GoogleDriveClient:
|
||||
since: datetime | None = None,
|
||||
page_size: int = 100,
|
||||
exclude_folder_ids: set[str] | None = None,
|
||||
) -> Generator[dict, None, None]:
|
||||
_current_path: str | None = None,
|
||||
) -> Generator[tuple[dict, str], None, None]:
|
||||
"""List all supported files in a folder with pagination.
|
||||
|
||||
Args:
|
||||
@ -140,10 +141,18 @@ class GoogleDriveClient:
|
||||
since: Only return files modified after this time
|
||||
page_size: Number of files per API page
|
||||
exclude_folder_ids: Set of folder IDs to skip during recursive traversal
|
||||
_current_path: Internal param tracking the current folder path
|
||||
|
||||
Yields:
|
||||
Tuples of (file_metadata, parent_folder_path)
|
||||
"""
|
||||
service = self._get_service()
|
||||
exclude_folder_ids = exclude_folder_ids or set()
|
||||
|
||||
# Build the current path if not provided
|
||||
if _current_path is None:
|
||||
_current_path = self.get_folder_path(folder_id)
|
||||
|
||||
# Build query for supported file types
|
||||
all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES
|
||||
mime_conditions = " or ".join(f"mimeType='{mime}'" for mime in all_mimes)
|
||||
@ -178,17 +187,19 @@ class GoogleDriveClient:
|
||||
for file in response.get("files", []):
|
||||
if file["mimeType"] == "application/vnd.google-apps.folder":
|
||||
if recursive and file["id"] not in exclude_folder_ids:
|
||||
# Recursively list files in subfolder
|
||||
# Recursively list files in subfolder with updated path
|
||||
subfolder_path = f"{_current_path}/{file['name']}"
|
||||
yield from self.list_files_in_folder(
|
||||
file["id"],
|
||||
recursive=True,
|
||||
since=since,
|
||||
exclude_folder_ids=exclude_folder_ids,
|
||||
_current_path=subfolder_path,
|
||||
)
|
||||
elif file["id"] in exclude_folder_ids:
|
||||
logger.info(f"Skipping excluded folder: {file['name']} ({file['id']})")
|
||||
else:
|
||||
yield file
|
||||
yield file, _current_path
|
||||
|
||||
page_token = response.get("nextPageToken")
|
||||
if not page_token:
|
||||
|
||||
@ -101,6 +101,7 @@ def _create_google_doc(
|
||||
content=file_data["content"],
|
||||
google_file_id=file_data["file_id"],
|
||||
title=file_data["title"],
|
||||
filename=file_data["title"],
|
||||
original_mime_type=file_data["original_mime_type"],
|
||||
folder_id=folder.id,
|
||||
folder_path=file_data["folder_path"],
|
||||
@ -144,6 +145,7 @@ def _update_existing_doc(
|
||||
existing.content = file_data["content"]
|
||||
existing.sha256 = create_content_hash(file_data["content"])
|
||||
existing.title = file_data["title"]
|
||||
existing.filename = file_data["title"]
|
||||
existing.google_modified_at = file_data["modified_at"]
|
||||
existing.last_modified_by = file_data["last_modified_by"]
|
||||
existing.word_count = file_data["word_count"]
|
||||
@ -249,21 +251,19 @@ def sync_google_folder(folder_id: int, force_full: bool = False) -> dict[str, An
|
||||
|
||||
if is_folder:
|
||||
# It's a folder - list and sync all files inside
|
||||
folder_path = client.get_folder_path(google_id)
|
||||
|
||||
# Get excluded folder IDs
|
||||
exclude_ids = set(cast(list[str], folder.exclude_folder_ids) or [])
|
||||
if exclude_ids:
|
||||
logger.info(f"Excluding {len(exclude_ids)} folder(s) from sync")
|
||||
|
||||
for file_meta in client.list_files_in_folder(
|
||||
for file_meta, file_folder_path in client.list_files_in_folder(
|
||||
google_id,
|
||||
recursive=cast(bool, folder.recursive),
|
||||
since=since,
|
||||
exclude_folder_ids=exclude_ids,
|
||||
):
|
||||
try:
|
||||
file_data = client.fetch_file(file_meta, folder_path)
|
||||
file_data = client.fetch_file(file_meta, file_folder_path)
|
||||
serialized = _serialize_file_data(file_data)
|
||||
task = sync_google_doc.delay(folder.id, serialized)
|
||||
task_ids.append(task.id)
|
||||
|
||||
@ -244,9 +244,10 @@ def test_chunk_text_long_text():
|
||||
text = " ".join(sentences)
|
||||
|
||||
max_tokens = 10 # 10 tokens = ~40 chars
|
||||
# Chunker includes overlap and fits the final sentence in the last chunk
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
|
||||
] + ["This is sentence 49."]
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_with_overlap():
|
||||
@ -255,9 +256,10 @@ def test_chunk_text_with_overlap():
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
|
||||
"Part A. Part B. Part C.",
|
||||
"Part C. Part D. Part E.",
|
||||
"Part E.",
|
||||
"Part A. Part B.",
|
||||
"Part B. Part C.",
|
||||
"Part C. Part D.",
|
||||
"Part D. Part E.",
|
||||
]
|
||||
|
||||
|
||||
@ -265,10 +267,12 @@ def test_chunk_text_zero_overlap():
|
||||
"""Test chunking with zero overlap"""
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
# 2 tokens = ~8 chars
|
||||
# 2 tokens = ~8 chars, each sentence is about 2 tokens
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
|
||||
"Part A. Part B.",
|
||||
"Part C. Part D.",
|
||||
"Part A.",
|
||||
"Part B.",
|
||||
"Part C.",
|
||||
"Part D.",
|
||||
"Part E.",
|
||||
]
|
||||
|
||||
@ -278,9 +282,12 @@ def test_chunk_text_clean_break():
|
||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||
|
||||
max_tokens = 5 # Enough for about 2 sentences
|
||||
# Chunker breaks at sentence boundaries, with overlap including previous sentence
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
|
||||
"First sentence. Second sentence.",
|
||||
"Third sentence. Fourth sentence.",
|
||||
"First sentence.",
|
||||
"Second sentence.",
|
||||
"Third sentence.",
|
||||
"Fourth sentence.",
|
||||
]
|
||||
|
||||
|
||||
@ -289,11 +296,15 @@ def test_chunk_text_very_long_sentences():
|
||||
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
||||
|
||||
max_tokens = 5 # Small limit to force splitting
|
||||
# Chunker splits at word boundaries when no sentence boundary available
|
||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||
"This is a very long sentence with many many",
|
||||
"words that will definitely exceed the",
|
||||
"This is a very long",
|
||||
"sentence with many many",
|
||||
"words that will",
|
||||
"definitely exceed the",
|
||||
"token limit we set for",
|
||||
"this particular test",
|
||||
"case and should be split into multiple",
|
||||
"case and should be",
|
||||
"split into multiple",
|
||||
"chunks by the function.",
|
||||
]
|
||||
|
||||
@ -681,27 +681,27 @@ def test_process_content_item(
|
||||
assert str(db_item.embed_status) == expected_embed_status
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_behavior,expected_status",
|
||||
[
|
||||
("success", "success"),
|
||||
("exception", "error"),
|
||||
],
|
||||
)
|
||||
def test_safe_task_execution(task_behavior, expected_status):
|
||||
def test_safe_task_execution_success():
|
||||
"""Test that safe_task_execution passes through successful results."""
|
||||
|
||||
@safe_task_execution
|
||||
def test_task(arg1, arg2):
|
||||
if task_behavior == "exception":
|
||||
raise ValueError("Test error message")
|
||||
return {"status": "success", "result": arg1 + arg2}
|
||||
|
||||
result = test_task(1, 2)
|
||||
|
||||
assert result["status"] == expected_status
|
||||
if expected_status == "success":
|
||||
assert result["status"] == "success"
|
||||
assert result["result"] == 3
|
||||
else:
|
||||
assert result["error"] == "Test error message"
|
||||
|
||||
|
||||
def test_safe_task_execution_reraises_exceptions():
|
||||
"""Test that safe_task_execution logs but re-raises exceptions for Celery retries."""
|
||||
|
||||
@safe_task_execution
|
||||
def test_task():
|
||||
raise ValueError("Test error message")
|
||||
|
||||
with pytest.raises(ValueError, match="Test error message"):
|
||||
test_task()
|
||||
|
||||
|
||||
def test_safe_task_execution_preserves_function_name():
|
||||
@ -709,7 +709,8 @@ def test_safe_task_execution_preserves_function_name():
|
||||
def test_function():
|
||||
return {"status": "success"}
|
||||
|
||||
assert test_function.__name__ == "wrapper"
|
||||
# @wraps(func) should preserve the original function name
|
||||
assert test_function.__name__ == "test_function"
|
||||
|
||||
|
||||
def test_safe_task_execution_with_kwargs():
|
||||
@ -727,14 +728,14 @@ def test_safe_task_execution_with_kwargs():
|
||||
|
||||
|
||||
def test_safe_task_execution_exception_logging(caplog):
|
||||
"""Test that exceptions are logged before being re-raised."""
|
||||
|
||||
@safe_task_execution
|
||||
def failing_task():
|
||||
raise RuntimeError("Test runtime error")
|
||||
|
||||
result = failing_task()
|
||||
with pytest.raises(RuntimeError, match="Test runtime error"):
|
||||
failing_task()
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "Test runtime error"
|
||||
assert "traceback" in result
|
||||
assert "Task failing_task failed:" in caplog.text
|
||||
assert "Test runtime error" in caplog.text
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user