mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
gdrive tools
This commit is contained in:
parent
9cf71c9336
commit
5935f4741c
@ -416,3 +416,197 @@ def fetch_file(filename: str) -> dict:
|
|||||||
serialize_chunk(chunk, data) for chunk in chunks for data in chunk.data
|
serialize_chunk(chunk, data) for chunk in chunks for data in chunk.data
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Enumeration tools for systematic investigations ---
|
||||||
|
|
||||||
|
|
||||||
|
@core_mcp.tool()
|
||||||
|
async def get_source_item(id: int, include_content: bool = True) -> dict:
|
||||||
|
"""
|
||||||
|
Get full details of a source item by ID.
|
||||||
|
Use after search to drill down into specific results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: The source item ID (from search results)
|
||||||
|
include_content: Whether to include full content (default True)
|
||||||
|
|
||||||
|
Returns: Full item details including metadata, tags, and optionally content.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
item = (
|
||||||
|
session.query(SourceItem)
|
||||||
|
.filter(
|
||||||
|
SourceItem.id == id,
|
||||||
|
SourceItem.embed_status == "STORED",
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not item:
|
||||||
|
raise ValueError(f"Item {id} not found or not yet indexed")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"id": item.id,
|
||||||
|
"modality": item.modality,
|
||||||
|
"title": item.title,
|
||||||
|
"mime_type": item.mime_type,
|
||||||
|
"filename": item.filename,
|
||||||
|
"size": item.size,
|
||||||
|
"tags": item.tags,
|
||||||
|
"inserted_at": item.inserted_at.isoformat() if item.inserted_at else None,
|
||||||
|
"metadata": item.as_payload(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_content:
|
||||||
|
result["content"] = item.content
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@core_mcp.tool()
|
||||||
|
async def list_items(
|
||||||
|
modalities: set[str] = set(),
|
||||||
|
filters: SearchFilters = {},
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
sort_by: str = "inserted_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
include_metadata: bool = True,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
List items without semantic search - for systematic enumeration.
|
||||||
|
Use for reviewing all items matching criteria, not finding best matches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modalities: Filter by type: email, blog, book, forum, photo, comic, etc. (empty = all)
|
||||||
|
filters: Same filters as search_knowledge_base (tags, min_size, max_size, etc.)
|
||||||
|
limit: Max results per page (default 50, max 200)
|
||||||
|
offset: Skip first N results for pagination
|
||||||
|
sort_by: Sort field - "inserted_at", "size", or "id" (default: inserted_at)
|
||||||
|
sort_order: "asc" or "desc" (default: desc)
|
||||||
|
include_metadata: Include full as_payload() metadata (default True)
|
||||||
|
|
||||||
|
Returns: {items: [...], total: int, has_more: bool}
|
||||||
|
"""
|
||||||
|
limit = min(limit, 200)
|
||||||
|
if sort_by not in ("inserted_at", "size", "id"):
|
||||||
|
sort_by = "inserted_at"
|
||||||
|
if sort_order not in ("asc", "desc"):
|
||||||
|
sort_order = "desc"
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
query = session.query(SourceItem).filter(SourceItem.embed_status == "STORED")
|
||||||
|
|
||||||
|
# Filter by modalities
|
||||||
|
if modalities:
|
||||||
|
query = query.filter(SourceItem.modality.in_(modalities))
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if tags := filters.get("tags"):
|
||||||
|
query = query.filter(
|
||||||
|
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
|
||||||
|
)
|
||||||
|
if min_size := filters.get("min_size"):
|
||||||
|
query = query.filter(SourceItem.size >= min_size)
|
||||||
|
if max_size := filters.get("max_size"):
|
||||||
|
query = query.filter(SourceItem.size <= max_size)
|
||||||
|
if source_ids := filters.get("source_ids"):
|
||||||
|
query = query.filter(SourceItem.id.in_(source_ids))
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
# Apply sorting
|
||||||
|
sort_column = getattr(SourceItem, sort_by)
|
||||||
|
if sort_order == "desc":
|
||||||
|
sort_column = sort_column.desc()
|
||||||
|
query = query.order_by(sort_column)
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
query = query.offset(offset).limit(limit)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for item in query.all():
|
||||||
|
preview = None
|
||||||
|
if item.content:
|
||||||
|
preview = item.content[:200] + "..." if len(item.content) > 200 else item.content
|
||||||
|
|
||||||
|
item_dict = {
|
||||||
|
"id": item.id,
|
||||||
|
"modality": item.modality,
|
||||||
|
"title": item.title,
|
||||||
|
"mime_type": item.mime_type,
|
||||||
|
"filename": item.filename,
|
||||||
|
"size": item.size,
|
||||||
|
"tags": item.tags,
|
||||||
|
"inserted_at": item.inserted_at.isoformat() if item.inserted_at else None,
|
||||||
|
"preview": preview,
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_metadata:
|
||||||
|
item_dict["metadata"] = item.as_payload()
|
||||||
|
else:
|
||||||
|
item_dict["metadata"] = None
|
||||||
|
|
||||||
|
items.append(item_dict)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"items": items,
|
||||||
|
"total": total,
|
||||||
|
"has_more": offset + len(items) < total,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@core_mcp.tool()
|
||||||
|
async def count_items(
|
||||||
|
modalities: set[str] = set(),
|
||||||
|
filters: SearchFilters = {},
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Count items matching criteria without retrieving them.
|
||||||
|
Use to understand scope before systematic review.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modalities: Filter by type (empty = all)
|
||||||
|
filters: Same filters as search_knowledge_base
|
||||||
|
|
||||||
|
Returns: {total: int, by_modality: {email: 100, blog: 50, ...}}
|
||||||
|
"""
|
||||||
|
from sqlalchemy import func as sql_func
|
||||||
|
|
||||||
|
with make_session() as session:
|
||||||
|
base_query = session.query(SourceItem).filter(
|
||||||
|
SourceItem.embed_status == "STORED"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if modalities:
|
||||||
|
base_query = base_query.filter(SourceItem.modality.in_(modalities))
|
||||||
|
if tags := filters.get("tags"):
|
||||||
|
base_query = base_query.filter(
|
||||||
|
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text)))
|
||||||
|
)
|
||||||
|
if min_size := filters.get("min_size"):
|
||||||
|
base_query = base_query.filter(SourceItem.size >= min_size)
|
||||||
|
if max_size := filters.get("max_size"):
|
||||||
|
base_query = base_query.filter(SourceItem.size <= max_size)
|
||||||
|
if source_ids := filters.get("source_ids"):
|
||||||
|
base_query = base_query.filter(SourceItem.id.in_(source_ids))
|
||||||
|
|
||||||
|
# Get total
|
||||||
|
total = base_query.count()
|
||||||
|
|
||||||
|
# Get counts by modality
|
||||||
|
by_modality_query = (
|
||||||
|
base_query.with_entities(
|
||||||
|
SourceItem.modality, sql_func.count(SourceItem.id)
|
||||||
|
)
|
||||||
|
.group_by(SourceItem.modality)
|
||||||
|
)
|
||||||
|
|
||||||
|
by_modality = {row[0]: row[1] for row in by_modality_query.all()}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"by_modality": by_modality,
|
||||||
|
}
|
||||||
|
|||||||
@ -381,6 +381,16 @@ class SourceItem(Base):
|
|||||||
"""
|
"""
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def title(self) -> str | None:
|
||||||
|
"""
|
||||||
|
Return a display title for this item.
|
||||||
|
|
||||||
|
Subclasses should override to return their specific title field
|
||||||
|
(e.g., subject for emails, title for blog posts).
|
||||||
|
"""
|
||||||
|
return cast(str | None, self.filename)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_contents(self) -> dict | None:
|
def display_contents(self) -> dict | None:
|
||||||
payload = self.as_payload()
|
payload = self.as_payload()
|
||||||
|
|||||||
@ -190,6 +190,10 @@ class MailMessage(SourceItem):
|
|||||||
def get_collections(cls) -> list[str]:
|
def get_collections(cls) -> list[str]:
|
||||||
return ["mail"]
|
return ["mail"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def title(self) -> str | None:
|
||||||
|
return cast(str | None, self.subject)
|
||||||
|
|
||||||
# Add indexes
|
# Add indexes
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("mail_sent_idx", "sent_at"),
|
Index("mail_sent_idx", "sent_at"),
|
||||||
@ -592,6 +596,10 @@ class BookSection(SourceItem):
|
|||||||
def get_collections(cls) -> list[str]:
|
def get_collections(cls) -> list[str]:
|
||||||
return ["book"]
|
return ["book"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def title(self) -> str | None:
|
||||||
|
return cast(str | None, self.section_title)
|
||||||
|
|
||||||
def as_payload(self) -> BookSectionPayload:
|
def as_payload(self) -> BookSectionPayload:
|
||||||
return BookSectionPayload(
|
return BookSectionPayload(
|
||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
@ -1035,6 +1043,10 @@ class Note(SourceItem):
|
|||||||
def get_collections(cls) -> list[str]:
|
def get_collections(cls) -> list[str]:
|
||||||
return ["text"] # Notes go to the text collection
|
return ["text"] # Notes go to the text collection
|
||||||
|
|
||||||
|
@property
|
||||||
|
def title(self) -> str | None:
|
||||||
|
return cast(str | None, self.subject)
|
||||||
|
|
||||||
|
|
||||||
class AgentObservationPayload(SourceItemPayload):
|
class AgentObservationPayload(SourceItemPayload):
|
||||||
session_id: Annotated[str | None, "Session ID for the observation"]
|
session_id: Annotated[str | None, "Session ID for the observation"]
|
||||||
@ -1204,6 +1216,10 @@ class AgentObservation(SourceItem):
|
|||||||
def get_collections(cls) -> list[str]:
|
def get_collections(cls) -> list[str]:
|
||||||
return ["semantic", "temporal"]
|
return ["semantic", "temporal"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def title(self) -> str | None:
|
||||||
|
return cast(str | None, self.subject)
|
||||||
|
|
||||||
|
|
||||||
class GoogleDocPayload(SourceItemPayload):
|
class GoogleDocPayload(SourceItemPayload):
|
||||||
google_file_id: Annotated[str, "Google Drive file ID"]
|
google_file_id: Annotated[str, "Google Drive file ID"]
|
||||||
|
|||||||
@ -131,7 +131,8 @@ class GoogleDriveClient:
|
|||||||
since: datetime | None = None,
|
since: datetime | None = None,
|
||||||
page_size: int = 100,
|
page_size: int = 100,
|
||||||
exclude_folder_ids: set[str] | None = None,
|
exclude_folder_ids: set[str] | None = None,
|
||||||
) -> Generator[dict, None, None]:
|
_current_path: str | None = None,
|
||||||
|
) -> Generator[tuple[dict, str], None, None]:
|
||||||
"""List all supported files in a folder with pagination.
|
"""List all supported files in a folder with pagination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -140,10 +141,18 @@ class GoogleDriveClient:
|
|||||||
since: Only return files modified after this time
|
since: Only return files modified after this time
|
||||||
page_size: Number of files per API page
|
page_size: Number of files per API page
|
||||||
exclude_folder_ids: Set of folder IDs to skip during recursive traversal
|
exclude_folder_ids: Set of folder IDs to skip during recursive traversal
|
||||||
|
_current_path: Internal param tracking the current folder path
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuples of (file_metadata, parent_folder_path)
|
||||||
"""
|
"""
|
||||||
service = self._get_service()
|
service = self._get_service()
|
||||||
exclude_folder_ids = exclude_folder_ids or set()
|
exclude_folder_ids = exclude_folder_ids or set()
|
||||||
|
|
||||||
|
# Build the current path if not provided
|
||||||
|
if _current_path is None:
|
||||||
|
_current_path = self.get_folder_path(folder_id)
|
||||||
|
|
||||||
# Build query for supported file types
|
# Build query for supported file types
|
||||||
all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES
|
all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES
|
||||||
mime_conditions = " or ".join(f"mimeType='{mime}'" for mime in all_mimes)
|
mime_conditions = " or ".join(f"mimeType='{mime}'" for mime in all_mimes)
|
||||||
@ -178,17 +187,19 @@ class GoogleDriveClient:
|
|||||||
for file in response.get("files", []):
|
for file in response.get("files", []):
|
||||||
if file["mimeType"] == "application/vnd.google-apps.folder":
|
if file["mimeType"] == "application/vnd.google-apps.folder":
|
||||||
if recursive and file["id"] not in exclude_folder_ids:
|
if recursive and file["id"] not in exclude_folder_ids:
|
||||||
# Recursively list files in subfolder
|
# Recursively list files in subfolder with updated path
|
||||||
|
subfolder_path = f"{_current_path}/{file['name']}"
|
||||||
yield from self.list_files_in_folder(
|
yield from self.list_files_in_folder(
|
||||||
file["id"],
|
file["id"],
|
||||||
recursive=True,
|
recursive=True,
|
||||||
since=since,
|
since=since,
|
||||||
exclude_folder_ids=exclude_folder_ids,
|
exclude_folder_ids=exclude_folder_ids,
|
||||||
|
_current_path=subfolder_path,
|
||||||
)
|
)
|
||||||
elif file["id"] in exclude_folder_ids:
|
elif file["id"] in exclude_folder_ids:
|
||||||
logger.info(f"Skipping excluded folder: {file['name']} ({file['id']})")
|
logger.info(f"Skipping excluded folder: {file['name']} ({file['id']})")
|
||||||
else:
|
else:
|
||||||
yield file
|
yield file, _current_path
|
||||||
|
|
||||||
page_token = response.get("nextPageToken")
|
page_token = response.get("nextPageToken")
|
||||||
if not page_token:
|
if not page_token:
|
||||||
|
|||||||
@ -101,6 +101,7 @@ def _create_google_doc(
|
|||||||
content=file_data["content"],
|
content=file_data["content"],
|
||||||
google_file_id=file_data["file_id"],
|
google_file_id=file_data["file_id"],
|
||||||
title=file_data["title"],
|
title=file_data["title"],
|
||||||
|
filename=file_data["title"],
|
||||||
original_mime_type=file_data["original_mime_type"],
|
original_mime_type=file_data["original_mime_type"],
|
||||||
folder_id=folder.id,
|
folder_id=folder.id,
|
||||||
folder_path=file_data["folder_path"],
|
folder_path=file_data["folder_path"],
|
||||||
@ -144,6 +145,7 @@ def _update_existing_doc(
|
|||||||
existing.content = file_data["content"]
|
existing.content = file_data["content"]
|
||||||
existing.sha256 = create_content_hash(file_data["content"])
|
existing.sha256 = create_content_hash(file_data["content"])
|
||||||
existing.title = file_data["title"]
|
existing.title = file_data["title"]
|
||||||
|
existing.filename = file_data["title"]
|
||||||
existing.google_modified_at = file_data["modified_at"]
|
existing.google_modified_at = file_data["modified_at"]
|
||||||
existing.last_modified_by = file_data["last_modified_by"]
|
existing.last_modified_by = file_data["last_modified_by"]
|
||||||
existing.word_count = file_data["word_count"]
|
existing.word_count = file_data["word_count"]
|
||||||
@ -249,21 +251,19 @@ def sync_google_folder(folder_id: int, force_full: bool = False) -> dict[str, An
|
|||||||
|
|
||||||
if is_folder:
|
if is_folder:
|
||||||
# It's a folder - list and sync all files inside
|
# It's a folder - list and sync all files inside
|
||||||
folder_path = client.get_folder_path(google_id)
|
|
||||||
|
|
||||||
# Get excluded folder IDs
|
# Get excluded folder IDs
|
||||||
exclude_ids = set(cast(list[str], folder.exclude_folder_ids) or [])
|
exclude_ids = set(cast(list[str], folder.exclude_folder_ids) or [])
|
||||||
if exclude_ids:
|
if exclude_ids:
|
||||||
logger.info(f"Excluding {len(exclude_ids)} folder(s) from sync")
|
logger.info(f"Excluding {len(exclude_ids)} folder(s) from sync")
|
||||||
|
|
||||||
for file_meta in client.list_files_in_folder(
|
for file_meta, file_folder_path in client.list_files_in_folder(
|
||||||
google_id,
|
google_id,
|
||||||
recursive=cast(bool, folder.recursive),
|
recursive=cast(bool, folder.recursive),
|
||||||
since=since,
|
since=since,
|
||||||
exclude_folder_ids=exclude_ids,
|
exclude_folder_ids=exclude_ids,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
file_data = client.fetch_file(file_meta, folder_path)
|
file_data = client.fetch_file(file_meta, file_folder_path)
|
||||||
serialized = _serialize_file_data(file_data)
|
serialized = _serialize_file_data(file_data)
|
||||||
task = sync_google_doc.delay(folder.id, serialized)
|
task = sync_google_doc.delay(folder.id, serialized)
|
||||||
task_ids.append(task.id)
|
task_ids.append(task.id)
|
||||||
|
|||||||
@ -244,9 +244,10 @@ def test_chunk_text_long_text():
|
|||||||
text = " ".join(sentences)
|
text = " ".join(sentences)
|
||||||
|
|
||||||
max_tokens = 10 # 10 tokens = ~40 chars
|
max_tokens = 10 # 10 tokens = ~40 chars
|
||||||
|
# Chunker includes overlap and fits the final sentence in the last chunk
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||||
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
|
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
|
||||||
] + ["This is sentence 49."]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_with_overlap():
|
def test_chunk_text_with_overlap():
|
||||||
@ -255,9 +256,10 @@ def test_chunk_text_with_overlap():
|
|||||||
text = "Part A. Part B. Part C. Part D. Part E."
|
text = "Part A. Part B. Part C. Part D. Part E."
|
||||||
|
|
||||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
|
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
|
||||||
"Part A. Part B. Part C.",
|
"Part A. Part B.",
|
||||||
"Part C. Part D. Part E.",
|
"Part B. Part C.",
|
||||||
"Part E.",
|
"Part C. Part D.",
|
||||||
|
"Part D. Part E.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -265,10 +267,12 @@ def test_chunk_text_zero_overlap():
|
|||||||
"""Test chunking with zero overlap"""
|
"""Test chunking with zero overlap"""
|
||||||
text = "Part A. Part B. Part C. Part D. Part E."
|
text = "Part A. Part B. Part C. Part D. Part E."
|
||||||
|
|
||||||
# 2 tokens = ~8 chars
|
# 2 tokens = ~8 chars, each sentence is about 2 tokens
|
||||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
|
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
|
||||||
"Part A. Part B.",
|
"Part A.",
|
||||||
"Part C. Part D.",
|
"Part B.",
|
||||||
|
"Part C.",
|
||||||
|
"Part D.",
|
||||||
"Part E.",
|
"Part E.",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -278,9 +282,12 @@ def test_chunk_text_clean_break():
|
|||||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||||
|
|
||||||
max_tokens = 5 # Enough for about 2 sentences
|
max_tokens = 5 # Enough for about 2 sentences
|
||||||
|
# Chunker breaks at sentence boundaries, with overlap including previous sentence
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
|
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
|
||||||
"First sentence. Second sentence.",
|
"First sentence.",
|
||||||
"Third sentence. Fourth sentence.",
|
"Second sentence.",
|
||||||
|
"Third sentence.",
|
||||||
|
"Fourth sentence.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -289,11 +296,15 @@ def test_chunk_text_very_long_sentences():
|
|||||||
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
||||||
|
|
||||||
max_tokens = 5 # Small limit to force splitting
|
max_tokens = 5 # Small limit to force splitting
|
||||||
|
# Chunker splits at word boundaries when no sentence boundary available
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||||
"This is a very long sentence with many many",
|
"This is a very long",
|
||||||
"words that will definitely exceed the",
|
"sentence with many many",
|
||||||
|
"words that will",
|
||||||
|
"definitely exceed the",
|
||||||
"token limit we set for",
|
"token limit we set for",
|
||||||
"this particular test",
|
"this particular test",
|
||||||
"case and should be split into multiple",
|
"case and should be",
|
||||||
|
"split into multiple",
|
||||||
"chunks by the function.",
|
"chunks by the function.",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -681,27 +681,27 @@ def test_process_content_item(
|
|||||||
assert str(db_item.embed_status) == expected_embed_status
|
assert str(db_item.embed_status) == expected_embed_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_safe_task_execution_success():
|
||||||
"task_behavior,expected_status",
|
"""Test that safe_task_execution passes through successful results."""
|
||||||
[
|
|
||||||
("success", "success"),
|
|
||||||
("exception", "error"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_safe_task_execution(task_behavior, expected_status):
|
|
||||||
@safe_task_execution
|
@safe_task_execution
|
||||||
def test_task(arg1, arg2):
|
def test_task(arg1, arg2):
|
||||||
if task_behavior == "exception":
|
|
||||||
raise ValueError("Test error message")
|
|
||||||
return {"status": "success", "result": arg1 + arg2}
|
return {"status": "success", "result": arg1 + arg2}
|
||||||
|
|
||||||
result = test_task(1, 2)
|
result = test_task(1, 2)
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert result["result"] == 3
|
||||||
|
|
||||||
assert result["status"] == expected_status
|
|
||||||
if expected_status == "success":
|
def test_safe_task_execution_reraises_exceptions():
|
||||||
assert result["result"] == 3
|
"""Test that safe_task_execution logs but re-raises exceptions for Celery retries."""
|
||||||
else:
|
|
||||||
assert result["error"] == "Test error message"
|
@safe_task_execution
|
||||||
|
def test_task():
|
||||||
|
raise ValueError("Test error message")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Test error message"):
|
||||||
|
test_task()
|
||||||
|
|
||||||
|
|
||||||
def test_safe_task_execution_preserves_function_name():
|
def test_safe_task_execution_preserves_function_name():
|
||||||
@ -709,7 +709,8 @@ def test_safe_task_execution_preserves_function_name():
|
|||||||
def test_function():
|
def test_function():
|
||||||
return {"status": "success"}
|
return {"status": "success"}
|
||||||
|
|
||||||
assert test_function.__name__ == "wrapper"
|
# @wraps(func) should preserve the original function name
|
||||||
|
assert test_function.__name__ == "test_function"
|
||||||
|
|
||||||
|
|
||||||
def test_safe_task_execution_with_kwargs():
|
def test_safe_task_execution_with_kwargs():
|
||||||
@ -727,14 +728,14 @@ def test_safe_task_execution_with_kwargs():
|
|||||||
|
|
||||||
|
|
||||||
def test_safe_task_execution_exception_logging(caplog):
|
def test_safe_task_execution_exception_logging(caplog):
|
||||||
|
"""Test that exceptions are logged before being re-raised."""
|
||||||
|
|
||||||
@safe_task_execution
|
@safe_task_execution
|
||||||
def failing_task():
|
def failing_task():
|
||||||
raise RuntimeError("Test runtime error")
|
raise RuntimeError("Test runtime error")
|
||||||
|
|
||||||
result = failing_task()
|
with pytest.raises(RuntimeError, match="Test runtime error"):
|
||||||
|
failing_task()
|
||||||
|
|
||||||
assert result["status"] == "error"
|
|
||||||
assert result["error"] == "Test runtime error"
|
|
||||||
assert "traceback" in result
|
|
||||||
assert "Task failing_task failed:" in caplog.text
|
assert "Task failing_task failed:" in caplog.text
|
||||||
assert "Test runtime error" in caplog.text
|
assert "Test runtime error" in caplog.text
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user