From 4f1ca777e9eb6d04fe34f051a35db596256ef4d9 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Tue, 20 May 2025 22:54:45 +0200 Subject: [PATCH] fix linting --- setup.py | 16 +-- src/memory/common/db/models.py | 21 ++-- src/memory/common/embedding.py | 32 +++-- src/memory/common/extract.py | 14 +-- src/memory/common/parsers/comics.py | 4 +- src/memory/common/parsers/email.py | 10 +- src/memory/common/qdrant.py | 19 +-- src/memory/mcp/server.py | 6 +- src/memory/workers/celery_app.py | 2 +- src/memory/workers/email.py | 46 ++++--- src/memory/workers/tasks/comic.py | 22 ++-- src/memory/workers/tasks/email.py | 20 +-- src/memory/workers/tasks/maintenance.py | 8 +- .../common/parsers/test_email_parsers.py | 2 +- tests/memory/common/test_embedding.py | 50 ++++---- tests/memory/common/test_extract.py | 4 +- tests/memory/common/test_qdrant.py | 5 +- tests/memory/workers/test_email.py | 66 +++++----- tests/providers/email_provider.py | 117 ++++++++++-------- 19 files changed, 255 insertions(+), 209 deletions(-) diff --git a/setup.py b/setup.py index 1c6ccba..7a527e3 100644 --- a/setup.py +++ b/setup.py @@ -4,19 +4,19 @@ from setuptools import setup, find_namespace_packages def read_requirements(filename: str) -> list[str]: """Read requirements from file, ignoring comments and -r directives.""" - filename = pathlib.Path(filename) + path = pathlib.Path(filename) return [ line.strip() - for line in filename.read_text().splitlines() - if line.strip() and not line.strip().startswith(('#', '-r')) + for line in path.read_text().splitlines() + if line.strip() and not line.strip().startswith(("#", "-r")) ] # Read requirements files -common_requires = read_requirements('requirements-common.txt') -api_requires = read_requirements('requirements-api.txt') -workers_requires = read_requirements('requirements-workers.txt') -dev_requires = read_requirements('requirements-dev.txt') +common_requires = read_requirements("requirements-common.txt") +api_requires = read_requirements("requirements-api.txt") +workers_requires = read_requirements("requirements-workers.txt") +dev_requires = read_requirements("requirements-dev.txt") setup( name="memory", @@ -31,4 +31,4 @@ setup( "dev": dev_requires, "all": api_requires + workers_requires + common_requires + dev_requires, }, -) \ No newline at end of file +) diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index 34c3168..c934238 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -4,10 +4,9 @@ Database models for the knowledge base system. import pathlib import re -from email.message import EmailMessage from pathlib import Path import textwrap -from typing import Any, ClassVar +from typing import Any, ClassVar, cast from PIL import Image from sqlalchemy import ( ARRAY, @@ -31,7 +30,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, Session from memory.common import settings -from memory.common.parsers.email import parse_email_message +from memory.common.parsers.email import parse_email_message, EmailMessage Base = declarative_base() @@ -113,10 +112,10 @@ class Chunk(Base): @property def data(self) -> list[bytes | str | Image.Image]: if self.file_path is None: - return [self.content] + return [cast(str, self.content)] path = pathlib.Path(self.file_path.replace("/app/", "")) - if self.file_path.endswith("*"): + if cast(str, self.file_path).endswith("*"): files = list(path.parent.glob(path.name)) else: files = [path] @@ -182,7 +181,7 @@ class SourceItem(Base): @property def display_contents(self) -> str | None: - return self.content or self.filename + return cast(str | None, self.content) or cast(str | None, self.filename) class MailMessage(SourceItem): @@ -217,8 +216,8 @@ class MailMessage(SourceItem): @property def attachments_path(self) -> Path: - clean_sender = clean_filename(self.sender) - clean_folder = clean_filename(self.folder or "INBOX") + clean_sender = clean_filename(cast(str, self.sender)) + clean_folder = clean_filename(cast(str | None, self.folder) or "INBOX") return Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder def safe_filename(self, filename: str) -> Path: @@ -237,12 +236,12 @@ class MailMessage(SourceItem): "recipients": self.recipients, "folder": self.folder, "tags": self.tags + [self.sender] + self.recipients, - "date": self.sent_at and self.sent_at.isoformat() or None, + "date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore } @property def parsed_content(self) -> EmailMessage: - return parse_email_message(self.content, self.message_id) + return parse_email_message(cast(str, self.content), cast(str, self.message_id)) @property def body(self) -> str: @@ -300,7 +299,7 @@ class EmailAttachment(SourceItem): "filename": self.filename, "content_type": self.mime_type, "size": self.size, - "created_at": self.created_at and self.created_at.isoformat() or None, + "created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore "mail_message_id": self.mail_message_id, "source_id": self.id, "tags": self.tags, diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index 2b68dff..1332355 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -1,7 +1,7 @@ import logging import pathlib import uuid -from typing import Any, Iterable, Literal, NotRequired, TypedDict +from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast import voyageai from PIL import Image @@ -21,7 +21,6 @@ CHARS_PER_TOKEN = 4 DistanceType = Literal["Cosine", "Dot", "Euclidean"] Vector = list[float] -Embedding = tuple[str, Vector, dict[str, Any]] class Collection(TypedDict): @@ -133,16 +132,18 @@ def get_modality(mime_type: str) -> str: def embed_chunks( - chunks: list[extract.MulitmodalChunk], + chunks: list[str] | list[list[extract.MulitmodalChunk]], model: str = settings.TEXT_EMBEDDING_MODEL, input_type: Literal["document", "query"] = "document", ) -> list[Vector]: - vo = voyageai.Client() + vo = voyageai.Client() # type: ignore if model == settings.MIXED_EMBEDDING_MODEL: return vo.multimodal_embed( - chunks, model=model, input_type=input_type + chunks, # type: ignore + model=model, + input_type=input_type, ).embeddings - return vo.embed(chunks, model=model, input_type=input_type).embeddings + return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore def embed_text( @@ -162,7 +163,7 @@ def embed_text( try: return embed_chunks(chunks, model, input_type) - except voyageai.error.InvalidRequestError as e: + except voyageai.error.InvalidRequestError as e: # type: ignore logger.error(f"Error embedding text: {e}") logger.debug(f"Text: {texts}") raise @@ -179,7 +180,7 @@ def embed_mixed( model: str = settings.MIXED_EMBEDDING_MODEL, input_type: Literal["document", "query"] = "document", ) -> list[Vector]: - def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]: + def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]: if isinstance(item, str): return [ c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip() @@ -190,11 +191,16 @@ def embed_mixed( return embed_chunks([chunks], model, input_type) -def embed_page(page: dict[str, Any]) -> list[Vector]: +def embed_page(page: extract.Page) -> list[Vector]: contents = page["contents"] if all(isinstance(c, str) for c in contents): - return embed_text(contents, model=settings.TEXT_EMBEDDING_MODEL) - return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL) + return embed_text( + cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL + ) + return embed_mixed( + cast(list[extract.MulitmodalChunk], contents), + model=settings.MIXED_EMBEDDING_MODEL, + ) def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path: @@ -224,7 +230,7 @@ def make_chunk( contents = page["contents"] content, filename = None, None if all(isinstance(c, str) for c in contents): - content = "\n\n".join(contents) + content = "\n\n".join(cast(list[str], contents)) model = settings.TEXT_EMBEDDING_MODEL elif len(contents) == 1: filename = write_to_file(chunk_id, contents[0]).absolute().as_posix() @@ -249,7 +255,7 @@ def embed( mime_type: str, content: bytes | str | pathlib.Path, metadata: dict[str, Any] = {}, -) -> tuple[str, list[Embedding]]: +) -> tuple[str, list[Chunk]]: modality = get_modality(mime_type) pages = extract.extract_content(mime_type, content) chunks = [ diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index 5e4cf1c..dc21892 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -1,13 +1,13 @@ -from contextlib import contextmanager import io +import logging import pathlib import tempfile -import pypandoc -import pymupdf # PyMuPDF -from PIL import Image -from typing import Any, TypedDict, Generator, Sequence +from contextlib import contextmanager +from typing import Any, Generator, Sequence, TypedDict, cast -import logging +import pymupdf # PyMuPDF +import pypandoc +from PIL import Image logger = logging.getLogger(__name__) @@ -105,7 +105,7 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]: if isinstance(content, bytes): content = content.decode("utf-8") - return [{"contents": [content], "metadata": {}}] + return [{"contents": [cast(str, content)], "metadata": {}}] def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]: diff --git a/src/memory/common/parsers/comics.py b/src/memory/common/parsers/comics.py index 12f0231..33798f4 100644 --- a/src/memory/common/parsers/comics.py +++ b/src/memory/common/parsers/comics.py @@ -1,5 +1,5 @@ import logging -from typing import TypedDict, NotRequired +from typing import TypedDict, NotRequired, cast from bs4 import BeautifulSoup, Tag import requests @@ -59,7 +59,7 @@ def extract_smbc(url: str) -> ComicInfo: "title": title, "image_url": image_url, "published_date": published_date, - "url": comic_url or url, + "url": cast(str, comic_url or url), } diff --git a/src/memory/common/parsers/email.py b/src/memory/common/parsers/email.py index 99313f6..1680f2b 100644 --- a/src/memory/common/parsers/email.py +++ b/src/memory/common/parsers/email.py @@ -28,10 +28,10 @@ class EmailMessage(TypedDict): hash: bytes -RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes] +RawEmailResponse = tuple[str | None, bytes] -def extract_recipients(msg: email.message.Message) -> list[str]: +def extract_recipients(msg: email.message.Message) -> list[str]: # type: ignore """ Extract email recipients from message headers. @@ -50,7 +50,7 @@ def extract_recipients(msg: email.message.Message) -> list[str]: ] -def extract_date(msg: email.message.Message) -> datetime | None: +def extract_date(msg: email.message.Message) -> datetime | None: # type: ignore """ Parse date from email header. @@ -68,7 +68,7 @@ def extract_date(msg: email.message.Message) -> datetime | None: return None -def extract_body(msg: email.message.Message) -> str: +def extract_body(msg: email.message.Message) -> str: # type: ignore """ Extract plain text body from email message. @@ -99,7 +99,7 @@ def extract_body(msg: email.message.Message) -> str: return body -def extract_attachments(msg: email.message.Message) -> list[Attachment]: +def extract_attachments(msg: email.message.Message) -> list[Attachment]: # type: ignore """ Extract attachment metadata and content from email. diff --git a/src/memory/common/qdrant.py b/src/memory/common/qdrant.py index 2bf7637..9bfde5e 100644 --- a/src/memory/common/qdrant.py +++ b/src/memory/common/qdrant.py @@ -1,5 +1,5 @@ import logging -from typing import Any, cast, Iterator, Sequence +from typing import Any, cast, Generator, Sequence import qdrant_client from qdrant_client.http import models as qdrant_models @@ -24,7 +24,7 @@ def get_qdrant_client() -> qdrant_client.QdrantClient: return qdrant_client.QdrantClient( host=settings.QDRANT_HOST, port=settings.QDRANT_PORT, - grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None, + grpc_port=settings.QDRANT_GRPC_PORT or 6334, prefer_grpc=settings.QDRANT_PREFER_GRPC, api_key=settings.QDRANT_API_KEY, timeout=settings.QDRANT_TIMEOUT, @@ -80,7 +80,8 @@ def ensure_collection_exists( def initialize_collections( - client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None + client: qdrant_client.QdrantClient, + collections: dict[str, Collection] | None = None, ) -> None: """ Initialize all required collections in Qdrant. @@ -122,7 +123,7 @@ def upsert_vectors( collection_name: str, ids: list[str], vectors: list[Vector], - payloads: list[dict[str, Any]] = None, + payloads: list[dict[str, Any]] | None = None, ) -> None: """Upsert vectors into a collection. @@ -147,7 +148,7 @@ def upsert_vectors( client.upsert( collection_name=collection_name, - points=points, + points=points, # type: ignore ) logger.debug(f"Upserted {len(ids)} vectors into {collection_name}") @@ -157,7 +158,7 @@ def search_vectors( client: qdrant_client.QdrantClient, collection_name: str, query_vector: Vector, - filter_params: dict = None, + filter_params: dict | None = None, limit: int = 10, ) -> list[qdrant_models.ScoredPoint]: """Search for similar vectors in a collection. @@ -200,7 +201,7 @@ def delete_points( client.delete( collection_name=collection_name, points_selector=qdrant_models.PointIdsList( - points=ids, + points=ids, # type: ignore ), ) @@ -226,7 +227,7 @@ def get_collection_info( def batch_ids( client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000 -) -> Iterator[list[str]]: +) -> Generator[list[str], None, None]: """Iterate over all IDs in a collection.""" offset = None while resp := client.scroll( @@ -236,7 +237,7 @@ def batch_ids( limit=batch_size, ): points, offset = resp - yield [point.id for point in points] + yield [cast(str, point.id) for point in points] if not offset: return diff --git a/src/memory/mcp/server.py b/src/memory/mcp/server.py index 001dc9d..96ea9c2 100644 --- a/src/memory/mcp/server.py +++ b/src/memory/mcp/server.py @@ -21,7 +21,11 @@ async def make_request( ) -> httpx.Response: async with httpx.AsyncClient() as client: return await client.request( - method, f"{SERVER}/{path}", data=data, json=json, files=files + method, + f"{SERVER}/{path}", + data=data, + json=json, + files=files, # type: ignore ) diff --git a/src/memory/workers/celery_app.py b/src/memory/workers/celery_app.py index f730bc8..bde18a0 100644 --- a/src/memory/workers/celery_app.py +++ b/src/memory/workers/celery_app.py @@ -31,7 +31,7 @@ app.conf.update( ) -@app.on_after_configure.connect +@app.on_after_configure.connect # type: ignore def ensure_qdrant_initialised(sender, **_): from memory.common import qdrant diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index adcc133..c9f3fbb 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -1,24 +1,26 @@ import hashlib import imaplib import logging +import pathlib import re +from collections import defaultdict from contextlib import contextmanager from datetime import datetime -from typing import Generator, Callable -import pathlib -from sqlalchemy.orm import Session -from collections import defaultdict -from memory.common import settings, embedding, qdrant +from typing import Callable, Generator, Sequence, cast + +from sqlalchemy.orm import Session, scoped_session + +from memory.common import embedding, qdrant, settings from memory.common.db.models import ( EmailAccount, + EmailAttachment, MailMessage, SourceItem, - EmailAttachment, ) from memory.common.parsers.email import ( Attachment, - parse_email_message, RawEmailResponse, + parse_email_message, ) logger = logging.getLogger(__name__) @@ -89,7 +91,7 @@ def process_attachments( def create_mail_message( - db_session: Session, + db_session: Session | scoped_session, tags: list[str], folder: str, raw_email: str, @@ -136,7 +138,7 @@ def create_mail_message( def does_message_exist( - db_session: Session, message_id: str, message_hash: bytes + db_session: Session | scoped_session, message_id: str, message_hash: bytes ) -> bool: """ Check if a message already exists in the database. @@ -167,7 +169,7 @@ def does_message_exist( def check_message_exists( - db: Session, account_id: int, message_id: str, raw_email: str + db: Session | scoped_session, account_id: int, message_id: str, raw_email: str ) -> bool: account = db.query(EmailAccount).get(account_id) if not account: @@ -181,7 +183,9 @@ def check_message_exists( return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"]) -def extract_email_uid(msg_data: bytes) -> tuple[str, str]: +def extract_email_uid( + msg_data: Sequence[tuple[bytes, bytes]], +) -> tuple[str | None, bytes]: """ Extract the UID and raw email data from the message data. """ @@ -199,7 +203,7 @@ def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None: logger.error(f"Error fetching message {uid}") return None - return extract_email_uid(msg_data) + return extract_email_uid(msg_data) # type: ignore except Exception as e: logger.error(f"Error processing message {uid}: {str(e)}") return None @@ -248,7 +252,7 @@ def process_folder( folder: str, account: EmailAccount, since_date: datetime, - processor: Callable[[int, str, str, bytes], int | None], + processor: Callable[[int, str, str, str], int | None], ) -> dict: """ Process a single folder from an email account. @@ -272,7 +276,7 @@ def process_folder( for uid, raw_email in emails: try: task = processor( - account_id=account.id, + account_id=account.id, # type: ignore message_id=uid, folder=folder, raw_email=raw_email.decode("utf-8", errors="replace"), @@ -296,9 +300,11 @@ def process_folder( @contextmanager def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]: - conn = imaplib.IMAP4_SSL(host=account.imap_server, port=account.imap_port) + conn = imaplib.IMAP4_SSL( + host=cast(str, account.imap_server), port=cast(int, account.imap_port) + ) try: - conn.login(account.username, account.password) + conn.login(cast(str, account.username), cast(str, account.password)) yield conn finally: # Always try to logout and close the connection @@ -318,15 +324,15 @@ def vectorize_email(email: MailMessage): ) email.chunks = chunks if chunks: - vector_ids = [c.id for c in chunks] + vector_ids = [cast(str, c.id) for c in chunks] vectors = [c.vector for c in chunks] metadata = [c.item_metadata for c in chunks] qdrant.upsert_vectors( client=qdrant_client, collection_name="mail", ids=vector_ids, - vectors=vectors, - payloads=metadata, + vectors=vectors, # type: ignore + payloads=metadata, # type: ignore ) embeds = defaultdict(list) @@ -356,7 +362,7 @@ def vectorize_email(email: MailMessage): payloads=metadata, ) - email.embed_status = "STORED" + email.embed_status = "STORED" # type: ignore for attachment in email.attachments: attachment.embed_status = "STORED" diff --git a/src/memory/workers/tasks/comic.py b/src/memory/workers/tasks/comic.py index 549bd99..0016403 100644 --- a/src/memory/workers/tasks/comic.py +++ b/src/memory/workers/tasks/comic.py @@ -1,7 +1,7 @@ import hashlib import logging from datetime import datetime -from typing import Callable +from typing import Callable, cast import feedparser import requests @@ -35,7 +35,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]: logger.error(f"Failed to fetch or parse {rss_url}: {e}") return set() - urls = {item.get("link") or item.get("id") for item in feed.entries} + urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries} with make_session() as session: known = { @@ -46,7 +46,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]: ) } - return urls - known + return cast(set[str], urls - known) def fetch_new_comics( @@ -56,7 +56,7 @@ def fetch_new_comics( for url in new_urls: data = parser(url) | {"author": base_url, "url": url} - sync_comic.delay(**data) + sync_comic.delay(**data) # type: ignore return new_urls @@ -108,7 +108,7 @@ def sync_comic( client=qdrant.get_qdrant_client(), collection_name="comic", ids=[str(chunk.id)], - vectors=[chunk.vector], + vectors=[chunk.vector], # type: ignore payloads=[comic.as_payload()], ) @@ -130,8 +130,8 @@ def sync_xkcd() -> set[str]: @app.task(name=SYNC_ALL_COMICS) def sync_all_comics(): """Synchronize all active comics.""" - sync_smbc.delay() - sync_xkcd.delay() + sync_smbc.delay() # type: ignore + sync_xkcd.delay() # type: ignore @app.task(name="memory.workers.tasks.comic.full_sync_comic") @@ -141,8 +141,8 @@ def trigger_comic_sync(): response = requests.get(url) soup = BeautifulSoup(response.text, "html.parser") - if link := soup.find("a", attrs={"class", "cc-prev"}): - return link.attrs["href"] + if link := soup.find("a", attrs={"class": "cc-prev"}): + return link.attrs["href"] # type: ignore return None next_url = "https://www.smbc-comics.com" @@ -155,7 +155,7 @@ def trigger_comic_sync(): data = comics.extract_smbc(next_url) | { "author": "https://www.smbc-comics.com/" } - sync_comic.delay(**data) + sync_comic.delay(**data) # type: ignore except Exception as e: logger.error(f"failed to sync {next_url}: {e}") urls.append(next_url) @@ -167,6 +167,6 @@ def trigger_comic_sync(): url = f"{BASE_XKCD_URL}/{i}" try: data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"} - sync_comic.delay(**data) + sync_comic.delay(**data) # type: ignore except Exception as e: logger.error(f"failed to sync {url}: {e}") diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py index 37a84b4..eb5e74a 100644 --- a/src/memory/workers/tasks/email.py +++ b/src/memory/workers/tasks/email.py @@ -1,6 +1,6 @@ import logging from datetime import datetime - +from typing import cast from memory.common.db.connection import make_session from memory.common.db.models import EmailAccount from memory.workers.celery_app import app @@ -71,7 +71,7 @@ def process_message( for chunk in attachment.chunks: logger.info(f" - {chunk.id}") - return mail_message.id + return cast(int, mail_message.id) @app.task(name=SYNC_ACCOUNT) @@ -89,15 +89,17 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict: with make_session() as db: account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first() - if not account or not account.active: + if not account or not cast(bool, account.active): logger.warning(f"Account {account_id} not found or inactive") return {"error": "Account not found or inactive"} - folders_to_process: list[str] = account.folders or ["INBOX"] + folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"] if since_date: cutoff_date = datetime.fromisoformat(since_date) else: - cutoff_date: datetime = account.last_sync_at or datetime(1970, 1, 1) + cutoff_date: datetime = cast(datetime, account.last_sync_at) or datetime( + 1970, 1, 1 + ) messages_found = 0 new_messages = 0 @@ -106,9 +108,9 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict: def process_message_wrapper( account_id: int, message_id: str, folder: str, raw_email: str ) -> int | None: - if check_message_exists(db, account_id, message_id, raw_email): + if check_message_exists(db, account_id, message_id, raw_email): # type: ignore return None - return process_message.delay(account_id, message_id, folder, raw_email) + return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore try: with imap_connection(account) as conn: @@ -121,7 +123,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict: new_messages += folder_stats["new_messages"] errors += folder_stats["errors"] - account.last_sync_at = datetime.now() + account.last_sync_at = datetime.now() # type: ignore db.commit() except Exception as e: logger.error(f"Error connecting to server {account.imap_server}: {str(e)}") @@ -152,7 +154,7 @@ def sync_all_accounts() -> list[dict]: { "account_id": account.id, "email": account.email_address, - "task_id": sync_account.delay(account.id).id, + "task_id": sync_account.delay(account.id).id, # type: ignore } for account in active_accounts ] diff --git a/src/memory/workers/tasks/maintenance.py b/src/memory/workers/tasks/maintenance.py index 243c13b..cdd8085 100644 --- a/src/memory/workers/tasks/maintenance.py +++ b/src/memory/workers/tasks/maintenance.py @@ -48,7 +48,7 @@ def clean_collection(collection: str) -> dict[str, int]: def clean_all_collections(): logger.info("Cleaning all collections") for collection in embedding.ALL_COLLECTIONS: - clean_collection.delay(collection) + clean_collection.delay(collection) # type: ignore @app.task(name=REINGEST_CHUNK) @@ -97,7 +97,7 @@ def check_batch(batch: Sequence[Chunk]) -> dict: for chunk in chunks: if str(chunk.id) in missing: - reingest_chunk.delay(str(chunk.id), collection) + reingest_chunk.delay(str(chunk.id), collection) # type: ignore else: chunk.checked_at = datetime.now() @@ -132,7 +132,9 @@ def reingest_missing_chunks(batch_size: int = 1000): .filter(Chunk.checked_at < since) .options( contains_eager(Chunk.source).load_only( - SourceItem.id, SourceItem.modality, SourceItem.tags + SourceItem.id, # type: ignore + SourceItem.modality, # type: ignore + SourceItem.tags, # type: ignore ) ) .order_by(Chunk.id) diff --git a/tests/memory/common/parsers/test_email_parsers.py b/tests/memory/common/parsers/test_email_parsers.py index d4936e6..449ea20 100644 --- a/tests/memory/common/parsers/test_email_parsers.py +++ b/tests/memory/common/parsers/test_email_parsers.py @@ -252,7 +252,7 @@ def test_parse_simple_email(): "hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87" + b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1", } - assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400 + assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400 # type: ignore def test_parse_email_with_attachments(): diff --git a/tests/memory/common/test_embedding.py b/tests/memory/common/test_embedding.py index f177ad9..adf3602 100644 --- a/tests/memory/common/test_embedding.py +++ b/tests/memory/common/test_embedding.py @@ -1,7 +1,7 @@ import pathlib import uuid from unittest.mock import Mock, patch - +from typing import cast import pytest from PIL import Image @@ -72,12 +72,12 @@ def test_embed_mixed(mock_embed): def test_embed_page_text_only(mock_embed): page = {"contents": ["text1", "text2"]} - assert embed_page(page) == [[0], [1]] + assert embed_page(page) == [[0], [1]] # type: ignore def test_embed_page_mixed_content(mock_embed): page = {"contents": ["text", {"type": "image", "data": "base64"}]} - assert embed_page(page) == [[0]] + assert embed_page(page) == [[0]] # type: ignore def test_embed(mock_embed): @@ -91,12 +91,12 @@ def test_embed(mock_embed): assert modality == "text" assert [ { - "id": c.id, - "file_path": c.file_path, - "content": c.content, - "embedding_model": c.embedding_model, - "vector": c.vector, - "item_metadata": c.item_metadata, + "id": c.id, # type: ignore + "file_path": c.file_path, # type: ignore + "content": c.content, # type: ignore + "embedding_model": c.embedding_model, # type: ignore + "vector": c.vector, # type: ignore + "item_metadata": c.item_metadata, # type: ignore } for c in chunks ] == [ @@ -128,7 +128,7 @@ def test_write_to_file_bytes(mock_file_storage): chunk_id = "test-chunk-id" content = b"These are test bytes" - file_path = write_to_file(chunk_id, content) + file_path = write_to_file(chunk_id, content) # type: ignore assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin" assert file_path.exists() @@ -140,7 +140,7 @@ def test_write_to_file_image(mock_file_storage): img = Image.new("RGB", (100, 100), color=(73, 109, 137)) chunk_id = "test-chunk-id" - file_path = write_to_file(chunk_id, img) + file_path = write_to_file(chunk_id, img) # type: ignore assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png" assert file_path.exists() @@ -155,7 +155,7 @@ def test_write_to_file_unsupported_type(mock_file_storage): content = 123 # Integer is not a supported type with pytest.raises(ValueError, match="Unsupported content type"): - write_to_file(chunk_id, content) + write_to_file(chunk_id, content) # type: ignore def test_make_chunk_text_only(mock_file_storage, db_session): @@ -170,12 +170,12 @@ def test_make_chunk_text_only(mock_file_storage, db_session): with patch.object( uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001") ): - chunk = make_chunk(page, vector, metadata) + chunk = make_chunk(page, vector, metadata) # type: ignore - assert chunk.id == "00000000-0000-0000-0000-000000000001" - assert chunk.content == "text content 1\n\ntext content 2" + assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001" + assert cast(str, chunk.content) == "text content 1\n\ntext content 2" assert chunk.file_path is None - assert chunk.embedding_model == settings.TEXT_EMBEDDING_MODEL + assert cast(str, chunk.embedding_model) == settings.TEXT_EMBEDDING_MODEL assert chunk.vector == vector assert chunk.item_metadata == metadata @@ -190,19 +190,19 @@ def test_make_chunk_single_image(mock_file_storage, db_session): with patch.object( uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002") ): - chunk = make_chunk(page, vector, metadata) + chunk = make_chunk(page, vector, metadata) # type: ignore - assert chunk.id == "00000000-0000-0000-0000-000000000002" + assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002" assert chunk.content is None - assert chunk.file_path == str( + assert cast(str, chunk.file_path) == str( settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png", ) - assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL + assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL assert chunk.vector == vector assert chunk.item_metadata == metadata # Verify the file exists - assert pathlib.Path(chunk.file_path[0]).exists() + assert pathlib.Path(cast(str, chunk.file_path)).exists() def test_make_chunk_mixed_content(mock_file_storage, db_session): @@ -215,14 +215,14 @@ def test_make_chunk_mixed_content(mock_file_storage, db_session): with patch.object( uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003") ): - chunk = make_chunk(page, vector, metadata) + chunk = make_chunk(page, vector, metadata) # type: ignore - assert chunk.id == "00000000-0000-0000-0000-000000000003" + assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003" assert chunk.content is None - assert chunk.file_path == str( + assert cast(str, chunk.file_path) == str( settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*", ) - assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL + assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL assert chunk.vector == vector assert chunk.item_metadata == metadata diff --git a/tests/memory/common/test_extract.py b/tests/memory/common/test_extract.py index 8e07c6c..8be3aef 100644 --- a/tests/memory/common/test_extract.py +++ b/tests/memory/common/test_extract.py @@ -87,7 +87,7 @@ def test_extract_image_with_path(tmp_path): img.save(img_path) (page,) = extract_image(img_path) - assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() + assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore assert page["metadata"] == {} @@ -98,7 +98,7 @@ def test_extract_image_with_bytes(): img_bytes = buffer.getvalue() (page,) = extract_image(img_bytes) - assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() + assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore assert page["metadata"] == {} diff --git a/tests/memory/common/test_qdrant.py b/tests/memory/common/test_qdrant.py index 0cd037a..491a478 100644 --- a/tests/memory/common/test_qdrant.py +++ b/tests/memory/common/test_qdrant.py @@ -33,7 +33,10 @@ def test_ensure_collection_exists_existing(mock_qdrant_client): def test_ensure_collection_exists_new(mock_qdrant_client): mock_qdrant_client.get_collection.side_effect = UnexpectedResponse( - status_code=404, reason_phrase="asd", content=b"asd", headers=None + status_code=404, + reason_phrase="asd", + content=b"asd", + headers=None, # type: ignore ) assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128) diff --git a/tests/memory/workers/test_email.py b/tests/memory/workers/test_email.py index bef72f5..154aca0 100644 --- a/tests/memory/workers/test_email.py +++ b/tests/memory/workers/test_email.py @@ -1,24 +1,26 @@ import base64 import pathlib - from datetime import datetime +from typing import cast from unittest.mock import MagicMock, patch + import pytest + +from memory.common import embedding, settings from memory.common.db.models import ( - MailMessage, - EmailAttachment, EmailAccount, + EmailAttachment, + MailMessage, ) -from memory.common import settings -from memory.common import embedding +from memory.common.parsers.email import Attachment from memory.workers.email import ( - extract_email_uid, create_mail_message, + extract_email_uid, fetch_email, fetch_email_since, - process_folder, process_attachment, process_attachments, + process_folder, vectorize_email, ) @@ -45,7 +47,9 @@ def mock_uuid4(): (100, 100, ""), ], ) -def test_process_attachment_inline(attachment_size, max_inline_size, message_id): +def test_process_attachment_inline( + attachment_size: int, max_inline_size: int, message_id: str +): attachment = { "filename": "test.txt", "content_type": "text/plain", @@ -60,10 +64,12 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id) ) with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size): - result = process_attachment(attachment, message) + result = process_attachment(cast(Attachment, attachment), message) assert result is not None - assert result.content == attachment["content"].decode("utf-8", errors="replace") + assert cast(str, result.content) == attachment["content"].decode( + "utf-8", errors="replace" + ) assert result.filename is None @@ -90,11 +96,11 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id): folder="INBOX", ) with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size): - result = process_attachment(attachment, message) + result = process_attachment(cast(Attachment, attachment), message) assert result is not None - assert not result.content - assert result.filename == str( + assert not cast(str, result.content) + assert cast(str, result.filename) == str( settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" @@ -125,11 +131,11 @@ def test_process_attachment_write_error(): patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 10), patch.object(pathlib.Path, "write_bytes", mock_write_bytes), ): - assert process_attachment(attachment, message) is None + assert process_attachment(cast(Attachment, attachment), message) is None def test_process_attachments_empty(): - assert process_attachments([], "") == [] + assert process_attachments([], MagicMock()) == [] def test_process_attachments_mixed(): @@ -167,16 +173,16 @@ def test_process_attachments_mixed(): with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 50): # Process attachments - results = process_attachments(attachments, message) + results = process_attachments(cast(list[Attachment], attachments), message) # Verify we have all attachments processed assert len(results) == 3 - assert results[0].content == "a" * 20 - assert results[2].content == "c" * 30 + assert cast(str, results[0].content) == "a" * 20 + assert cast(str, results[2].content) == "c" * 30 # Verify large attachment has a path - assert results[1].filename == str( + assert cast(str, results[1].filename) == str( settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" / "large.txt" ) @@ -239,12 +245,12 @@ def test_create_mail_message(db_session): # Verify the mail message was created correctly assert isinstance(mail_message, MailMessage) - assert mail_message.message_id == "321" - assert mail_message.subject == "Test Subject" - assert mail_message.sender == "sender@example.com" - assert mail_message.recipients == ["recipient@example.com"] + assert cast(str, mail_message.message_id) == "321" + assert cast(str, mail_message.subject) == "Test Subject" + assert cast(str, mail_message.sender) == "sender@example.com" + assert cast(list[str], mail_message.recipients) == ["recipient@example.com"] assert mail_message.sent_at.isoformat()[:-6] == "2023-01-01T12:00:00" - assert mail_message.content == raw_email + assert cast(str, mail_message.content) == raw_email assert mail_message.body == "Test body content\n" assert mail_message.attachments == attachments @@ -275,7 +281,7 @@ def test_fetch_email_since(email_provider): assert len(result) == 2 # Verify content of fetched emails - uids = sorted([uid for uid, _ in result]) + uids = sorted([uid or "" for uid, _ in result]) assert uids == ["101", "102"] # Test with a folder that doesn't exist @@ -342,7 +348,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4): db_session.add(mail_message) db_session.flush() - assert mail_message.embed_status == "RAW" + assert cast(str, mail_message.embed_status) == "RAW" with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]): vectorize_email(mail_message) @@ -351,7 +357,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4): ] db_session.commit() - assert mail_message.embed_status == "STORED" + assert cast(str, mail_message.embed_status) == "STORED" def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4): @@ -418,6 +424,6 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4): ] db_session.commit() - assert mail_message.embed_status == "STORED" - assert attachment1.embed_status == "STORED" - assert attachment2.embed_status == "STORED" + assert cast(str, mail_message.embed_status) == "STORED" + assert cast(str, attachment1.embed_status) == "STORED" + assert cast(str, attachment2.embed_status) == "STORED" diff --git a/tests/providers/email_provider.py b/tests/providers/email_provider.py index e2c3bcc..a5e2054 100644 --- a/tests/providers/email_provider.py +++ b/tests/providers/email_provider.py @@ -8,115 +8,130 @@ class MockEmailProvider: Mock IMAP email provider for integration testing. Can be initialized with predefined emails to return. """ - - def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] = None): + + def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] | None = None): """ Initialize with a dictionary of emails organized by folder. - + Args: emails_by_folder: A dictionary mapping folder names to lists of email dictionaries. - Each email dict should have: 'uid', 'flags', 'date', 'from', 'to', 'subject', + Each email dict should have: 'uid', 'flags', 'date', 'from', 'to', 'subject', 'message_id', 'body', and optionally 'attachments'. """ self.emails_by_folder = emails_by_folder or { "INBOX": [], "Sent": [], - "Archive": [] + "Archive": [], } self.current_folder = None self.is_connected = False - + def _generate_email_string(self, email_data: dict[str, Any]) -> str: """Generate a raw email string from the provided email data.""" - msg = email.message.EmailMessage() + msg = email.message.EmailMessage() # type: ignore msg["From"] = email_data.get("from", "sender@example.com") msg["To"] = email_data.get("to", "recipient@example.com") msg["Subject"] = email_data.get("subject", "Test Subject") - msg["Message-ID"] = email_data.get("message_id", f"") - msg["Date"] = email_data.get("date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000")) - + msg["Message-ID"] = email_data.get( + "message_id", f"" + ) + msg["Date"] = email_data.get( + "date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000") + ) + # Set the body content - msg.set_content(email_data.get("body", f"This is test email body {email_data['uid']}")) - + msg.set_content( + email_data.get("body", f"This is test email body {email_data['uid']}") + ) + # Add attachments if present for attachment in email_data.get("attachments", []): - if isinstance(attachment, dict) and "filename" in attachment and "content" in attachment: + if ( + isinstance(attachment, dict) + and "filename" in attachment + and "content" in attachment + ): msg.add_attachment( attachment["content"], maintype=attachment.get("maintype", "application"), subtype=attachment.get("subtype", "octet-stream"), - filename=attachment["filename"] + filename=attachment["filename"], ) - + return msg.as_string() - + def login(self, username: str, password: str) -> tuple[str, list[bytes]]: """Mock login method.""" self.is_connected = True - return ('OK', [b'Login successful']) - + return ("OK", [b"Login successful"]) + def logout(self) -> tuple[str, list[bytes]]: """Mock logout method.""" self.is_connected = False - return ('OK', [b'Logout successful']) - + return ("OK", [b"Logout successful"]) + def select(self, folder: str, readonly: bool = False) -> tuple[str, list[bytes]]: """ Select a folder and make it the current active folder. - + Args: folder: Folder name to select readonly: Whether to open in readonly mode - + Returns: IMAP-style response with message count """ folder_name = folder.decode() if isinstance(folder, bytes) else folder self.current_folder = folder_name message_count = len(self.emails_by_folder.get(folder_name, [])) - return ('OK', [str(message_count).encode()]) - - def list(self, directory: str = '', pattern: str = '*') -> tuple[str, list[bytes]]: + return ("OK", [str(message_count).encode()]) + + def list(self, directory: str = "", pattern: str = "*") -> tuple[str, list[bytes]]: """List available folders.""" folders = [] for folder in self.emails_by_folder.keys(): folders.append(f'(\\HasNoChildren) "/" "{folder}"'.encode()) - return ('OK', folders) - + return ("OK", folders) + def search(self, charset, criteria): """ Handle SEARCH command to find email UIDs. - + Args: charset: Character set (ignored in mock) criteria: Search criteria (ignored in mock, we return all emails) - + Returns: All email UIDs in the current folder """ if not self.current_folder or self.current_folder not in self.emails_by_folder: - return ('OK', [b'']) - - uids = [str(email["uid"]).encode() for email in self.emails_by_folder[self.current_folder]] - return ('OK', [b' '.join(uids) if uids else b'']) - - def fetch(self, message_set, message_parts) -> tuple[str, list]: + return ("OK", [b""]) + + uids = [ + str(email["uid"]).encode() + for email in self.emails_by_folder[self.current_folder] + ] + return ("OK", [b" ".join(uids) if uids else b""]) + + def fetch(self, message_set: bytes | str, message_parts: bytes | str): """ Handle FETCH command to retrieve email data. - + Args: message_set: Message numbers/UIDs to fetch message_parts: Parts of the message to fetch - + Returns: Email data in IMAP format """ if not self.current_folder or self.current_folder not in self.emails_by_folder: - return ('OK', [None]) - + return ("OK", [None]) + # For simplicity, we'll just match the UID with the ID provided - uid = int(message_set.decode() if isinstance(message_set, bytes) else message_set) - + uid = int( + message_set.decode() if isinstance(message_set, bytes) else message_set + ) + # Find the email with the matching UID for email_data in self.emails_by_folder[self.current_folder]: if email_data["uid"] == uid: @@ -124,14 +139,16 @@ class MockEmailProvider: email_string = self._generate_email_string(email_data) flags = email_data.get("flags", "\\Seen") date = email_data.get("date_internal", "01-Jan-2023 00:00:00 +0000") - + # Format the response as expected by the IMAP client - response = [( - f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 ' - f'{{{len(email_string)}}}'.encode(), - email_string.encode() - )] - return ('OK', response) - + response = [ + ( + f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 ' + f"{{{len(email_string)}}}".encode(), + email_string.encode(), + ) + ] + return ("OK", response) + # No matching email found - return ('NO', [b'Email not found']) \ No newline at end of file + return ("NO", [b"Email not found"])