gdrive tools

This commit is contained in:
Daniel O'Connell 2026-01-01 18:09:54 +01:00
parent 9cf71c9336
commit 5935f4741c
7 changed files with 282 additions and 39 deletions

View File

@ -416,3 +416,197 @@ def fetch_file(filename: str) -> dict:
serialize_chunk(chunk, data) for chunk in chunks for data in chunk.data
]
}
# --- Enumeration tools for systematic investigations ---
@core_mcp.tool()
async def get_source_item(id: int, include_content: bool = True) -> dict:
"""
Get full details of a source item by ID.
Use after search to drill down into specific results.
Args:
id: The source item ID (from search results)
include_content: Whether to include full content (default True)
Returns: Full item details including metadata, tags, and optionally content.
"""
with make_session() as session:
item = (
session.query(SourceItem)
.filter(
SourceItem.id == id,
SourceItem.embed_status == "STORED",
)
.first()
)
if not item:
raise ValueError(f"Item {id} not found or not yet indexed")
result = {
"id": item.id,
"modality": item.modality,
"title": item.title,
"mime_type": item.mime_type,
"filename": item.filename,
"size": item.size,
"tags": item.tags,
"inserted_at": item.inserted_at.isoformat() if item.inserted_at else None,
"metadata": item.as_payload(),
}
if include_content:
result["content"] = item.content
return result
@core_mcp.tool()
async def list_items(
modalities: set[str] = set(),
filters: SearchFilters = {},
limit: int = 50,
offset: int = 0,
sort_by: str = "inserted_at",
sort_order: str = "desc",
include_metadata: bool = True,
) -> dict:
"""
List items without semantic search - for systematic enumeration.
Use for reviewing all items matching criteria, not finding best matches.
Args:
modalities: Filter by type: email, blog, book, forum, photo, comic, etc. (empty = all)
filters: Same filters as search_knowledge_base (tags, min_size, max_size, etc.)
limit: Max results per page (default 50, max 200)
offset: Skip first N results for pagination
sort_by: Sort field - "inserted_at", "size", or "id" (default: inserted_at)
sort_order: "asc" or "desc" (default: desc)
include_metadata: Include full as_payload() metadata (default True)
Returns: {items: [...], total: int, has_more: bool}
"""
limit = min(limit, 200)
if sort_by not in ("inserted_at", "size", "id"):
sort_by = "inserted_at"
if sort_order not in ("asc", "desc"):
sort_order = "desc"
with make_session() as session:
query = session.query(SourceItem).filter(SourceItem.embed_status == "STORED")
# Filter by modalities
if modalities:
query = query.filter(SourceItem.modality.in_(modalities))
# Apply filters
if tags := filters.get("tags"):
query = query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
)
if min_size := filters.get("min_size"):
query = query.filter(SourceItem.size >= min_size)
if max_size := filters.get("max_size"):
query = query.filter(SourceItem.size <= max_size)
if source_ids := filters.get("source_ids"):
query = query.filter(SourceItem.id.in_(source_ids))
# Get total count
total = query.count()
# Apply sorting
sort_column = getattr(SourceItem, sort_by)
if sort_order == "desc":
sort_column = sort_column.desc()
query = query.order_by(sort_column)
# Apply pagination
query = query.offset(offset).limit(limit)
items = []
for item in query.all():
preview = None
if item.content:
preview = item.content[:200] + "..." if len(item.content) > 200 else item.content
item_dict = {
"id": item.id,
"modality": item.modality,
"title": item.title,
"mime_type": item.mime_type,
"filename": item.filename,
"size": item.size,
"tags": item.tags,
"inserted_at": item.inserted_at.isoformat() if item.inserted_at else None,
"preview": preview,
}
if include_metadata:
item_dict["metadata"] = item.as_payload()
else:
item_dict["metadata"] = None
items.append(item_dict)
return {
"items": items,
"total": total,
"has_more": offset + len(items) < total,
}
@core_mcp.tool()
async def count_items(
modalities: set[str] = set(),
filters: SearchFilters = {},
) -> dict:
"""
Count items matching criteria without retrieving them.
Use to understand scope before systematic review.
Args:
modalities: Filter by type (empty = all)
filters: Same filters as search_knowledge_base
Returns: {total: int, by_modality: {email: 100, blog: 50, ...}}
"""
from sqlalchemy import func as sql_func
with make_session() as session:
base_query = session.query(SourceItem).filter(
SourceItem.embed_status == "STORED"
)
# Apply filters
if modalities:
base_query = base_query.filter(SourceItem.modality.in_(modalities))
if tags := filters.get("tags"):
base_query = base_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
)
if min_size := filters.get("min_size"):
base_query = base_query.filter(SourceItem.size >= min_size)
if max_size := filters.get("max_size"):
base_query = base_query.filter(SourceItem.size <= max_size)
if source_ids := filters.get("source_ids"):
base_query = base_query.filter(SourceItem.id.in_(source_ids))
# Get total
total = base_query.count()
# Get counts by modality
by_modality_query = (
base_query.with_entities(
SourceItem.modality, sql_func.count(SourceItem.id)
)
.group_by(SourceItem.modality)
)
by_modality = {row[0]: row[1] for row in by_modality_query.all()}
return {
"total": total,
"by_modality": by_modality,
}

View File

@ -381,6 +381,16 @@ class SourceItem(Base):
"""
return 1.0
@property
def title(self) -> str | None:
"""
Return a display title for this item.
Subclasses should override to return their specific title field
(e.g., subject for emails, title for blog posts).
"""
return cast(str | None, self.filename)
@property
def display_contents(self) -> dict | None:
payload = self.as_payload()

View File

@ -190,6 +190,10 @@ class MailMessage(SourceItem):
def get_collections(cls) -> list[str]:
return ["mail"]
@property
def title(self) -> str | None:
return cast(str | None, self.subject)
# Add indexes
__table_args__ = (
Index("mail_sent_idx", "sent_at"),
@ -592,6 +596,10 @@ class BookSection(SourceItem):
def get_collections(cls) -> list[str]:
return ["book"]
@property
def title(self) -> str | None:
return cast(str | None, self.section_title)
def as_payload(self) -> BookSectionPayload:
return BookSectionPayload(
**super().as_payload(),
@ -1035,6 +1043,10 @@ class Note(SourceItem):
def get_collections(cls) -> list[str]:
return ["text"] # Notes go to the text collection
@property
def title(self) -> str | None:
return cast(str | None, self.subject)
class AgentObservationPayload(SourceItemPayload):
session_id: Annotated[str | None, "Session ID for the observation"]
@ -1204,6 +1216,10 @@ class AgentObservation(SourceItem):
def get_collections(cls) -> list[str]:
return ["semantic", "temporal"]
@property
def title(self) -> str | None:
return cast(str | None, self.subject)
class GoogleDocPayload(SourceItemPayload):
google_file_id: Annotated[str, "Google Drive file ID"]

View File

@ -131,7 +131,8 @@ class GoogleDriveClient:
since: datetime | None = None,
page_size: int = 100,
exclude_folder_ids: set[str] | None = None,
) -> Generator[dict, None, None]:
_current_path: str | None = None,
) -> Generator[tuple[dict, str], None, None]:
"""List all supported files in a folder with pagination.
Args:
@ -140,10 +141,18 @@ class GoogleDriveClient:
since: Only return files modified after this time
page_size: Number of files per API page
exclude_folder_ids: Set of folder IDs to skip during recursive traversal
_current_path: Internal param tracking the current folder path
Yields:
Tuples of (file_metadata, parent_folder_path)
"""
service = self._get_service()
exclude_folder_ids = exclude_folder_ids or set()
# Build the current path if not provided
if _current_path is None:
_current_path = self.get_folder_path(folder_id)
# Build query for supported file types
all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES
mime_conditions = " or ".join(f"mimeType='{mime}'" for mime in all_mimes)
@ -178,17 +187,19 @@ class GoogleDriveClient:
for file in response.get("files", []):
if file["mimeType"] == "application/vnd.google-apps.folder":
if recursive and file["id"] not in exclude_folder_ids:
# Recursively list files in subfolder
# Recursively list files in subfolder with updated path
subfolder_path = f"{_current_path}/{file['name']}"
yield from self.list_files_in_folder(
file["id"],
recursive=True,
since=since,
exclude_folder_ids=exclude_folder_ids,
_current_path=subfolder_path,
)
elif file["id"] in exclude_folder_ids:
logger.info(f"Skipping excluded folder: {file['name']} ({file['id']})")
else:
yield file
yield file, _current_path
page_token = response.get("nextPageToken")
if not page_token:

View File

@ -101,6 +101,7 @@ def _create_google_doc(
content=file_data["content"],
google_file_id=file_data["file_id"],
title=file_data["title"],
filename=file_data["title"],
original_mime_type=file_data["original_mime_type"],
folder_id=folder.id,
folder_path=file_data["folder_path"],
@ -144,6 +145,7 @@ def _update_existing_doc(
existing.content = file_data["content"]
existing.sha256 = create_content_hash(file_data["content"])
existing.title = file_data["title"]
existing.filename = file_data["title"]
existing.google_modified_at = file_data["modified_at"]
existing.last_modified_by = file_data["last_modified_by"]
existing.word_count = file_data["word_count"]
@ -249,21 +251,19 @@ def sync_google_folder(folder_id: int, force_full: bool = False) -> dict[str, An
if is_folder:
# It's a folder - list and sync all files inside
folder_path = client.get_folder_path(google_id)
# Get excluded folder IDs
exclude_ids = set(cast(list[str], folder.exclude_folder_ids) or [])
if exclude_ids:
logger.info(f"Excluding {len(exclude_ids)} folder(s) from sync")
for file_meta in client.list_files_in_folder(
for file_meta, file_folder_path in client.list_files_in_folder(
google_id,
recursive=cast(bool, folder.recursive),
since=since,
exclude_folder_ids=exclude_ids,
):
try:
file_data = client.fetch_file(file_meta, folder_path)
file_data = client.fetch_file(file_meta, file_folder_path)
serialized = _serialize_file_data(file_data)
task = sync_google_doc.delay(folder.id, serialized)
task_ids.append(task.id)

View File

@ -244,9 +244,10 @@ def test_chunk_text_long_text():
text = " ".join(sentences)
max_tokens = 10 # 10 tokens = ~40 chars
# Chunker includes overlap and fits the final sentence in the last chunk
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
] + ["This is sentence 49."]
]
def test_chunk_text_with_overlap():
@ -255,9 +256,10 @@ def test_chunk_text_with_overlap():
text = "Part A. Part B. Part C. Part D. Part E."
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
"Part A. Part B. Part C.",
"Part C. Part D. Part E.",
"Part E.",
"Part A. Part B.",
"Part B. Part C.",
"Part C. Part D.",
"Part D. Part E.",
]
@ -265,10 +267,12 @@ def test_chunk_text_zero_overlap():
"""Test chunking with zero overlap"""
text = "Part A. Part B. Part C. Part D. Part E."
# 2 tokens = ~8 chars
# 2 tokens = ~8 chars, each sentence is about 2 tokens
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
"Part A. Part B.",
"Part C. Part D.",
"Part A.",
"Part B.",
"Part C.",
"Part D.",
"Part E.",
]
@ -278,9 +282,12 @@ def test_chunk_text_clean_break():
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
max_tokens = 5 # Enough for about 2 sentences
# Chunker breaks at sentence boundaries, with overlap including previous sentence
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
"First sentence. Second sentence.",
"Third sentence. Fourth sentence.",
"First sentence.",
"Second sentence.",
"Third sentence.",
"Fourth sentence.",
]
@ -289,11 +296,15 @@ def test_chunk_text_very_long_sentences():
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
max_tokens = 5 # Small limit to force splitting
# Chunker splits at word boundaries when no sentence boundary available
assert list(chunk_text(text, max_tokens=max_tokens)) == [
"This is a very long sentence with many many",
"words that will definitely exceed the",
"This is a very long",
"sentence with many many",
"words that will",
"definitely exceed the",
"token limit we set for",
"this particular test",
"case and should be split into multiple",
"case and should be",
"split into multiple",
"chunks by the function.",
]

View File

@ -681,27 +681,27 @@ def test_process_content_item(
assert str(db_item.embed_status) == expected_embed_status
@pytest.mark.parametrize(
"task_behavior,expected_status",
[
("success", "success"),
("exception", "error"),
],
)
def test_safe_task_execution(task_behavior, expected_status):
def test_safe_task_execution_success():
"""Test that safe_task_execution passes through successful results."""
@safe_task_execution
def test_task(arg1, arg2):
if task_behavior == "exception":
raise ValueError("Test error message")
return {"status": "success", "result": arg1 + arg2}
result = test_task(1, 2)
assert result["status"] == expected_status
if expected_status == "success":
assert result["status"] == "success"
assert result["result"] == 3
else:
assert result["error"] == "Test error message"
def test_safe_task_execution_reraises_exceptions():
"""Test that safe_task_execution logs but re-raises exceptions for Celery retries."""
@safe_task_execution
def test_task():
raise ValueError("Test error message")
with pytest.raises(ValueError, match="Test error message"):
test_task()
def test_safe_task_execution_preserves_function_name():
@ -709,7 +709,8 @@ def test_safe_task_execution_preserves_function_name():
def test_function():
return {"status": "success"}
assert test_function.__name__ == "wrapper"
# @wraps(func) should preserve the original function name
assert test_function.__name__ == "test_function"
def test_safe_task_execution_with_kwargs():
@ -727,14 +728,14 @@ def test_safe_task_execution_with_kwargs():
def test_safe_task_execution_exception_logging(caplog):
"""Test that exceptions are logged before being re-raised."""
@safe_task_execution
def failing_task():
raise RuntimeError("Test runtime error")
result = failing_task()
with pytest.raises(RuntimeError, match="Test runtime error"):
failing_task()
assert result["status"] == "error"
assert result["error"] == "Test runtime error"
assert "traceback" in result
assert "Task failing_task failed:" in caplog.text
assert "Test runtime error" in caplog.text