fix linting

This commit is contained in:
Daniel O'Connell 2025-05-20 22:54:45 +02:00
parent 29a2a7ba3a
commit 4f1ca777e9
19 changed files with 255 additions and 209 deletions

View File

@ -4,19 +4,19 @@ from setuptools import setup, find_namespace_packages
def read_requirements(filename: str) -> list[str]: def read_requirements(filename: str) -> list[str]:
"""Read requirements from file, ignoring comments and -r directives.""" """Read requirements from file, ignoring comments and -r directives."""
filename = pathlib.Path(filename) path = pathlib.Path(filename)
return [ return [
line.strip() line.strip()
for line in filename.read_text().splitlines() for line in path.read_text().splitlines()
if line.strip() and not line.strip().startswith(('#', '-r')) if line.strip() and not line.strip().startswith(("#", "-r"))
] ]
# Read requirements files # Read requirements files
common_requires = read_requirements('requirements-common.txt') common_requires = read_requirements("requirements-common.txt")
api_requires = read_requirements('requirements-api.txt') api_requires = read_requirements("requirements-api.txt")
workers_requires = read_requirements('requirements-workers.txt') workers_requires = read_requirements("requirements-workers.txt")
dev_requires = read_requirements('requirements-dev.txt') dev_requires = read_requirements("requirements-dev.txt")
setup( setup(
name="memory", name="memory",

View File

@ -4,10 +4,9 @@ Database models for the knowledge base system.
import pathlib import pathlib
import re import re
from email.message import EmailMessage
from pathlib import Path from pathlib import Path
import textwrap import textwrap
from typing import Any, ClassVar from typing import Any, ClassVar, cast
from PIL import Image from PIL import Image
from sqlalchemy import ( from sqlalchemy import (
ARRAY, ARRAY,
@ -31,7 +30,7 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, Session from sqlalchemy.orm import relationship, Session
from memory.common import settings from memory.common import settings
from memory.common.parsers.email import parse_email_message from memory.common.parsers.email import parse_email_message, EmailMessage
Base = declarative_base() Base = declarative_base()
@ -113,10 +112,10 @@ class Chunk(Base):
@property @property
def data(self) -> list[bytes | str | Image.Image]: def data(self) -> list[bytes | str | Image.Image]:
if self.file_path is None: if self.file_path is None:
return [self.content] return [cast(str, self.content)]
path = pathlib.Path(self.file_path.replace("/app/", "")) path = pathlib.Path(self.file_path.replace("/app/", ""))
if self.file_path.endswith("*"): if cast(str, self.file_path).endswith("*"):
files = list(path.parent.glob(path.name)) files = list(path.parent.glob(path.name))
else: else:
files = [path] files = [path]
@ -182,7 +181,7 @@ class SourceItem(Base):
@property @property
def display_contents(self) -> str | None: def display_contents(self) -> str | None:
return self.content or self.filename return cast(str | None, self.content) or cast(str | None, self.filename)
class MailMessage(SourceItem): class MailMessage(SourceItem):
@ -217,8 +216,8 @@ class MailMessage(SourceItem):
@property @property
def attachments_path(self) -> Path: def attachments_path(self) -> Path:
clean_sender = clean_filename(self.sender) clean_sender = clean_filename(cast(str, self.sender))
clean_folder = clean_filename(self.folder or "INBOX") clean_folder = clean_filename(cast(str | None, self.folder) or "INBOX")
return Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder return Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder
def safe_filename(self, filename: str) -> Path: def safe_filename(self, filename: str) -> Path:
@ -237,12 +236,12 @@ class MailMessage(SourceItem):
"recipients": self.recipients, "recipients": self.recipients,
"folder": self.folder, "folder": self.folder,
"tags": self.tags + [self.sender] + self.recipients, "tags": self.tags + [self.sender] + self.recipients,
"date": self.sent_at and self.sent_at.isoformat() or None, "date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore
} }
@property @property
def parsed_content(self) -> EmailMessage: def parsed_content(self) -> EmailMessage:
return parse_email_message(self.content, self.message_id) return parse_email_message(cast(str, self.content), cast(str, self.message_id))
@property @property
def body(self) -> str: def body(self) -> str:
@ -300,7 +299,7 @@ class EmailAttachment(SourceItem):
"filename": self.filename, "filename": self.filename,
"content_type": self.mime_type, "content_type": self.mime_type,
"size": self.size, "size": self.size,
"created_at": self.created_at and self.created_at.isoformat() or None, "created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
"mail_message_id": self.mail_message_id, "mail_message_id": self.mail_message_id,
"source_id": self.id, "source_id": self.id,
"tags": self.tags, "tags": self.tags,

View File

@ -1,7 +1,7 @@
import logging import logging
import pathlib import pathlib
import uuid import uuid
from typing import Any, Iterable, Literal, NotRequired, TypedDict from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast
import voyageai import voyageai
from PIL import Image from PIL import Image
@ -21,7 +21,6 @@ CHARS_PER_TOKEN = 4
DistanceType = Literal["Cosine", "Dot", "Euclidean"] DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float] Vector = list[float]
Embedding = tuple[str, Vector, dict[str, Any]]
class Collection(TypedDict): class Collection(TypedDict):
@ -133,16 +132,18 @@ def get_modality(mime_type: str) -> str:
def embed_chunks( def embed_chunks(
chunks: list[extract.MulitmodalChunk], chunks: list[str] | list[list[extract.MulitmodalChunk]],
model: str = settings.TEXT_EMBEDDING_MODEL, model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document", input_type: Literal["document", "query"] = "document",
) -> list[Vector]: ) -> list[Vector]:
vo = voyageai.Client() vo = voyageai.Client() # type: ignore
if model == settings.MIXED_EMBEDDING_MODEL: if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed( return vo.multimodal_embed(
chunks, model=model, input_type=input_type chunks, # type: ignore
model=model,
input_type=input_type,
).embeddings ).embeddings
return vo.embed(chunks, model=model, input_type=input_type).embeddings return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
def embed_text( def embed_text(
@ -162,7 +163,7 @@ def embed_text(
try: try:
return embed_chunks(chunks, model, input_type) return embed_chunks(chunks, model, input_type)
except voyageai.error.InvalidRequestError as e: except voyageai.error.InvalidRequestError as e: # type: ignore
logger.error(f"Error embedding text: {e}") logger.error(f"Error embedding text: {e}")
logger.debug(f"Text: {texts}") logger.debug(f"Text: {texts}")
raise raise
@ -179,7 +180,7 @@ def embed_mixed(
model: str = settings.MIXED_EMBEDDING_MODEL, model: str = settings.MIXED_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document", input_type: Literal["document", "query"] = "document",
) -> list[Vector]: ) -> list[Vector]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]: def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
if isinstance(item, str): if isinstance(item, str):
return [ return [
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip() c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
@ -190,11 +191,16 @@ def embed_mixed(
return embed_chunks([chunks], model, input_type) return embed_chunks([chunks], model, input_type)
def embed_page(page: dict[str, Any]) -> list[Vector]: def embed_page(page: extract.Page) -> list[Vector]:
contents = page["contents"] contents = page["contents"]
if all(isinstance(c, str) for c in contents): if all(isinstance(c, str) for c in contents):
return embed_text(contents, model=settings.TEXT_EMBEDDING_MODEL) return embed_text(
return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL) cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL
)
return embed_mixed(
cast(list[extract.MulitmodalChunk], contents),
model=settings.MIXED_EMBEDDING_MODEL,
)
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path: def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
@ -224,7 +230,7 @@ def make_chunk(
contents = page["contents"] contents = page["contents"]
content, filename = None, None content, filename = None, None
if all(isinstance(c, str) for c in contents): if all(isinstance(c, str) for c in contents):
content = "\n\n".join(contents) content = "\n\n".join(cast(list[str], contents))
model = settings.TEXT_EMBEDDING_MODEL model = settings.TEXT_EMBEDDING_MODEL
elif len(contents) == 1: elif len(contents) == 1:
filename = write_to_file(chunk_id, contents[0]).absolute().as_posix() filename = write_to_file(chunk_id, contents[0]).absolute().as_posix()
@ -249,7 +255,7 @@ def embed(
mime_type: str, mime_type: str,
content: bytes | str | pathlib.Path, content: bytes | str | pathlib.Path,
metadata: dict[str, Any] = {}, metadata: dict[str, Any] = {},
) -> tuple[str, list[Embedding]]: ) -> tuple[str, list[Chunk]]:
modality = get_modality(mime_type) modality = get_modality(mime_type)
pages = extract.extract_content(mime_type, content) pages = extract.extract_content(mime_type, content)
chunks = [ chunks = [

View File

@ -1,13 +1,13 @@
from contextlib import contextmanager
import io import io
import logging
import pathlib import pathlib
import tempfile import tempfile
import pypandoc from contextlib import contextmanager
import pymupdf # PyMuPDF from typing import Any, Generator, Sequence, TypedDict, cast
from PIL import Image
from typing import Any, TypedDict, Generator, Sequence
import logging import pymupdf # PyMuPDF
import pypandoc
from PIL import Image
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,7 +105,7 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
if isinstance(content, bytes): if isinstance(content, bytes):
content = content.decode("utf-8") content = content.decode("utf-8")
return [{"contents": [content], "metadata": {}}] return [{"contents": [cast(str, content)], "metadata": {}}]
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]: def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import TypedDict, NotRequired from typing import TypedDict, NotRequired, cast
from bs4 import BeautifulSoup, Tag from bs4 import BeautifulSoup, Tag
import requests import requests
@ -59,7 +59,7 @@ def extract_smbc(url: str) -> ComicInfo:
"title": title, "title": title,
"image_url": image_url, "image_url": image_url,
"published_date": published_date, "published_date": published_date,
"url": comic_url or url, "url": cast(str, comic_url or url),
} }

View File

@ -28,10 +28,10 @@ class EmailMessage(TypedDict):
hash: bytes hash: bytes
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes] RawEmailResponse = tuple[str | None, bytes]
def extract_recipients(msg: email.message.Message) -> list[str]: def extract_recipients(msg: email.message.Message) -> list[str]: # type: ignore
""" """
Extract email recipients from message headers. Extract email recipients from message headers.
@ -50,7 +50,7 @@ def extract_recipients(msg: email.message.Message) -> list[str]:
] ]
def extract_date(msg: email.message.Message) -> datetime | None: def extract_date(msg: email.message.Message) -> datetime | None: # type: ignore
""" """
Parse date from email header. Parse date from email header.
@ -68,7 +68,7 @@ def extract_date(msg: email.message.Message) -> datetime | None:
return None return None
def extract_body(msg: email.message.Message) -> str: def extract_body(msg: email.message.Message) -> str: # type: ignore
""" """
Extract plain text body from email message. Extract plain text body from email message.
@ -99,7 +99,7 @@ def extract_body(msg: email.message.Message) -> str:
return body return body
def extract_attachments(msg: email.message.Message) -> list[Attachment]: def extract_attachments(msg: email.message.Message) -> list[Attachment]: # type: ignore
""" """
Extract attachment metadata and content from email. Extract attachment metadata and content from email.

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, cast, Iterator, Sequence from typing import Any, cast, Generator, Sequence
import qdrant_client import qdrant_client
from qdrant_client.http import models as qdrant_models from qdrant_client.http import models as qdrant_models
@ -24,7 +24,7 @@ def get_qdrant_client() -> qdrant_client.QdrantClient:
return qdrant_client.QdrantClient( return qdrant_client.QdrantClient(
host=settings.QDRANT_HOST, host=settings.QDRANT_HOST,
port=settings.QDRANT_PORT, port=settings.QDRANT_PORT,
grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None, grpc_port=settings.QDRANT_GRPC_PORT or 6334,
prefer_grpc=settings.QDRANT_PREFER_GRPC, prefer_grpc=settings.QDRANT_PREFER_GRPC,
api_key=settings.QDRANT_API_KEY, api_key=settings.QDRANT_API_KEY,
timeout=settings.QDRANT_TIMEOUT, timeout=settings.QDRANT_TIMEOUT,
@ -80,7 +80,8 @@ def ensure_collection_exists(
def initialize_collections( def initialize_collections(
client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None client: qdrant_client.QdrantClient,
collections: dict[str, Collection] | None = None,
) -> None: ) -> None:
""" """
Initialize all required collections in Qdrant. Initialize all required collections in Qdrant.
@ -122,7 +123,7 @@ def upsert_vectors(
collection_name: str, collection_name: str,
ids: list[str], ids: list[str],
vectors: list[Vector], vectors: list[Vector],
payloads: list[dict[str, Any]] = None, payloads: list[dict[str, Any]] | None = None,
) -> None: ) -> None:
"""Upsert vectors into a collection. """Upsert vectors into a collection.
@ -147,7 +148,7 @@ def upsert_vectors(
client.upsert( client.upsert(
collection_name=collection_name, collection_name=collection_name,
points=points, points=points, # type: ignore
) )
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}") logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
@ -157,7 +158,7 @@ def search_vectors(
client: qdrant_client.QdrantClient, client: qdrant_client.QdrantClient,
collection_name: str, collection_name: str,
query_vector: Vector, query_vector: Vector,
filter_params: dict = None, filter_params: dict | None = None,
limit: int = 10, limit: int = 10,
) -> list[qdrant_models.ScoredPoint]: ) -> list[qdrant_models.ScoredPoint]:
"""Search for similar vectors in a collection. """Search for similar vectors in a collection.
@ -200,7 +201,7 @@ def delete_points(
client.delete( client.delete(
collection_name=collection_name, collection_name=collection_name,
points_selector=qdrant_models.PointIdsList( points_selector=qdrant_models.PointIdsList(
points=ids, points=ids, # type: ignore
), ),
) )
@ -226,7 +227,7 @@ def get_collection_info(
def batch_ids( def batch_ids(
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000 client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
) -> Iterator[list[str]]: ) -> Generator[list[str], None, None]:
"""Iterate over all IDs in a collection.""" """Iterate over all IDs in a collection."""
offset = None offset = None
while resp := client.scroll( while resp := client.scroll(
@ -236,7 +237,7 @@ def batch_ids(
limit=batch_size, limit=batch_size,
): ):
points, offset = resp points, offset = resp
yield [point.id for point in points] yield [cast(str, point.id) for point in points]
if not offset: if not offset:
return return

View File

@ -21,7 +21,11 @@ async def make_request(
) -> httpx.Response: ) -> httpx.Response:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
return await client.request( return await client.request(
method, f"{SERVER}/{path}", data=data, json=json, files=files method,
f"{SERVER}/{path}",
data=data,
json=json,
files=files, # type: ignore
) )

View File

@ -31,7 +31,7 @@ app.conf.update(
) )
@app.on_after_configure.connect @app.on_after_configure.connect # type: ignore
def ensure_qdrant_initialised(sender, **_): def ensure_qdrant_initialised(sender, **_):
from memory.common import qdrant from memory.common import qdrant

View File

@ -1,24 +1,26 @@
import hashlib import hashlib
import imaplib import imaplib
import logging import logging
import pathlib
import re import re
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from typing import Generator, Callable from typing import Callable, Generator, Sequence, cast
import pathlib
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, scoped_session
from collections import defaultdict
from memory.common import settings, embedding, qdrant from memory.common import embedding, qdrant, settings
from memory.common.db.models import ( from memory.common.db.models import (
EmailAccount, EmailAccount,
EmailAttachment,
MailMessage, MailMessage,
SourceItem, SourceItem,
EmailAttachment,
) )
from memory.common.parsers.email import ( from memory.common.parsers.email import (
Attachment, Attachment,
parse_email_message,
RawEmailResponse, RawEmailResponse,
parse_email_message,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -89,7 +91,7 @@ def process_attachments(
def create_mail_message( def create_mail_message(
db_session: Session, db_session: Session | scoped_session,
tags: list[str], tags: list[str],
folder: str, folder: str,
raw_email: str, raw_email: str,
@ -136,7 +138,7 @@ def create_mail_message(
def does_message_exist( def does_message_exist(
db_session: Session, message_id: str, message_hash: bytes db_session: Session | scoped_session, message_id: str, message_hash: bytes
) -> bool: ) -> bool:
""" """
Check if a message already exists in the database. Check if a message already exists in the database.
@ -167,7 +169,7 @@ def does_message_exist(
def check_message_exists( def check_message_exists(
db: Session, account_id: int, message_id: str, raw_email: str db: Session | scoped_session, account_id: int, message_id: str, raw_email: str
) -> bool: ) -> bool:
account = db.query(EmailAccount).get(account_id) account = db.query(EmailAccount).get(account_id)
if not account: if not account:
@ -181,7 +183,9 @@ def check_message_exists(
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"]) return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
def extract_email_uid(msg_data: bytes) -> tuple[str, str]: def extract_email_uid(
msg_data: Sequence[tuple[bytes, bytes]],
) -> tuple[str | None, bytes]:
""" """
Extract the UID and raw email data from the message data. Extract the UID and raw email data from the message data.
""" """
@ -199,7 +203,7 @@ def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None:
logger.error(f"Error fetching message {uid}") logger.error(f"Error fetching message {uid}")
return None return None
return extract_email_uid(msg_data) return extract_email_uid(msg_data) # type: ignore
except Exception as e: except Exception as e:
logger.error(f"Error processing message {uid}: {str(e)}") logger.error(f"Error processing message {uid}: {str(e)}")
return None return None
@ -248,7 +252,7 @@ def process_folder(
folder: str, folder: str,
account: EmailAccount, account: EmailAccount,
since_date: datetime, since_date: datetime,
processor: Callable[[int, str, str, bytes], int | None], processor: Callable[[int, str, str, str], int | None],
) -> dict: ) -> dict:
""" """
Process a single folder from an email account. Process a single folder from an email account.
@ -272,7 +276,7 @@ def process_folder(
for uid, raw_email in emails: for uid, raw_email in emails:
try: try:
task = processor( task = processor(
account_id=account.id, account_id=account.id, # type: ignore
message_id=uid, message_id=uid,
folder=folder, folder=folder,
raw_email=raw_email.decode("utf-8", errors="replace"), raw_email=raw_email.decode("utf-8", errors="replace"),
@ -296,9 +300,11 @@ def process_folder(
@contextmanager @contextmanager
def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]: def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]:
conn = imaplib.IMAP4_SSL(host=account.imap_server, port=account.imap_port) conn = imaplib.IMAP4_SSL(
host=cast(str, account.imap_server), port=cast(int, account.imap_port)
)
try: try:
conn.login(account.username, account.password) conn.login(cast(str, account.username), cast(str, account.password))
yield conn yield conn
finally: finally:
# Always try to logout and close the connection # Always try to logout and close the connection
@ -318,15 +324,15 @@ def vectorize_email(email: MailMessage):
) )
email.chunks = chunks email.chunks = chunks
if chunks: if chunks:
vector_ids = [c.id for c in chunks] vector_ids = [cast(str, c.id) for c in chunks]
vectors = [c.vector for c in chunks] vectors = [c.vector for c in chunks]
metadata = [c.item_metadata for c in chunks] metadata = [c.item_metadata for c in chunks]
qdrant.upsert_vectors( qdrant.upsert_vectors(
client=qdrant_client, client=qdrant_client,
collection_name="mail", collection_name="mail",
ids=vector_ids, ids=vector_ids,
vectors=vectors, vectors=vectors, # type: ignore
payloads=metadata, payloads=metadata, # type: ignore
) )
embeds = defaultdict(list) embeds = defaultdict(list)
@ -356,7 +362,7 @@ def vectorize_email(email: MailMessage):
payloads=metadata, payloads=metadata,
) )
email.embed_status = "STORED" email.embed_status = "STORED" # type: ignore
for attachment in email.attachments: for attachment in email.attachments:
attachment.embed_status = "STORED" attachment.embed_status = "STORED"

View File

@ -1,7 +1,7 @@
import hashlib import hashlib
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Callable from typing import Callable, cast
import feedparser import feedparser
import requests import requests
@ -35,7 +35,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
logger.error(f"Failed to fetch or parse {rss_url}: {e}") logger.error(f"Failed to fetch or parse {rss_url}: {e}")
return set() return set()
urls = {item.get("link") or item.get("id") for item in feed.entries} urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries}
with make_session() as session: with make_session() as session:
known = { known = {
@ -46,7 +46,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
) )
} }
return urls - known return cast(set[str], urls - known)
def fetch_new_comics( def fetch_new_comics(
@ -56,7 +56,7 @@ def fetch_new_comics(
for url in new_urls: for url in new_urls:
data = parser(url) | {"author": base_url, "url": url} data = parser(url) | {"author": base_url, "url": url}
sync_comic.delay(**data) sync_comic.delay(**data) # type: ignore
return new_urls return new_urls
@ -108,7 +108,7 @@ def sync_comic(
client=qdrant.get_qdrant_client(), client=qdrant.get_qdrant_client(),
collection_name="comic", collection_name="comic",
ids=[str(chunk.id)], ids=[str(chunk.id)],
vectors=[chunk.vector], vectors=[chunk.vector], # type: ignore
payloads=[comic.as_payload()], payloads=[comic.as_payload()],
) )
@ -130,8 +130,8 @@ def sync_xkcd() -> set[str]:
@app.task(name=SYNC_ALL_COMICS) @app.task(name=SYNC_ALL_COMICS)
def sync_all_comics(): def sync_all_comics():
"""Synchronize all active comics.""" """Synchronize all active comics."""
sync_smbc.delay() sync_smbc.delay() # type: ignore
sync_xkcd.delay() sync_xkcd.delay() # type: ignore
@app.task(name="memory.workers.tasks.comic.full_sync_comic") @app.task(name="memory.workers.tasks.comic.full_sync_comic")
@ -141,8 +141,8 @@ def trigger_comic_sync():
response = requests.get(url) response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser") soup = BeautifulSoup(response.text, "html.parser")
if link := soup.find("a", attrs={"class", "cc-prev"}): if link := soup.find("a", attrs={"class": "cc-prev"}):
return link.attrs["href"] return link.attrs["href"] # type: ignore
return None return None
next_url = "https://www.smbc-comics.com" next_url = "https://www.smbc-comics.com"
@ -155,7 +155,7 @@ def trigger_comic_sync():
data = comics.extract_smbc(next_url) | { data = comics.extract_smbc(next_url) | {
"author": "https://www.smbc-comics.com/" "author": "https://www.smbc-comics.com/"
} }
sync_comic.delay(**data) sync_comic.delay(**data) # type: ignore
except Exception as e: except Exception as e:
logger.error(f"failed to sync {next_url}: {e}") logger.error(f"failed to sync {next_url}: {e}")
urls.append(next_url) urls.append(next_url)
@ -167,6 +167,6 @@ def trigger_comic_sync():
url = f"{BASE_XKCD_URL}/{i}" url = f"{BASE_XKCD_URL}/{i}"
try: try:
data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"} data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"}
sync_comic.delay(**data) sync_comic.delay(**data) # type: ignore
except Exception as e: except Exception as e:
logger.error(f"failed to sync {url}: {e}") logger.error(f"failed to sync {url}: {e}")

View File

@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime from datetime import datetime
from typing import cast
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import EmailAccount from memory.common.db.models import EmailAccount
from memory.workers.celery_app import app from memory.workers.celery_app import app
@ -71,7 +71,7 @@ def process_message(
for chunk in attachment.chunks: for chunk in attachment.chunks:
logger.info(f" - {chunk.id}") logger.info(f" - {chunk.id}")
return mail_message.id return cast(int, mail_message.id)
@app.task(name=SYNC_ACCOUNT) @app.task(name=SYNC_ACCOUNT)
@ -89,15 +89,17 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
with make_session() as db: with make_session() as db:
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first() account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
if not account or not account.active: if not account or not cast(bool, account.active):
logger.warning(f"Account {account_id} not found or inactive") logger.warning(f"Account {account_id} not found or inactive")
return {"error": "Account not found or inactive"} return {"error": "Account not found or inactive"}
folders_to_process: list[str] = account.folders or ["INBOX"] folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"]
if since_date: if since_date:
cutoff_date = datetime.fromisoformat(since_date) cutoff_date = datetime.fromisoformat(since_date)
else: else:
cutoff_date: datetime = account.last_sync_at or datetime(1970, 1, 1) cutoff_date: datetime = cast(datetime, account.last_sync_at) or datetime(
1970, 1, 1
)
messages_found = 0 messages_found = 0
new_messages = 0 new_messages = 0
@ -106,9 +108,9 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
def process_message_wrapper( def process_message_wrapper(
account_id: int, message_id: str, folder: str, raw_email: str account_id: int, message_id: str, folder: str, raw_email: str
) -> int | None: ) -> int | None:
if check_message_exists(db, account_id, message_id, raw_email): if check_message_exists(db, account_id, message_id, raw_email): # type: ignore
return None return None
return process_message.delay(account_id, message_id, folder, raw_email) return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore
try: try:
with imap_connection(account) as conn: with imap_connection(account) as conn:
@ -121,7 +123,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
new_messages += folder_stats["new_messages"] new_messages += folder_stats["new_messages"]
errors += folder_stats["errors"] errors += folder_stats["errors"]
account.last_sync_at = datetime.now() account.last_sync_at = datetime.now() # type: ignore
db.commit() db.commit()
except Exception as e: except Exception as e:
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}") logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
@ -152,7 +154,7 @@ def sync_all_accounts() -> list[dict]:
{ {
"account_id": account.id, "account_id": account.id,
"email": account.email_address, "email": account.email_address,
"task_id": sync_account.delay(account.id).id, "task_id": sync_account.delay(account.id).id, # type: ignore
} }
for account in active_accounts for account in active_accounts
] ]

View File

@ -48,7 +48,7 @@ def clean_collection(collection: str) -> dict[str, int]:
def clean_all_collections(): def clean_all_collections():
logger.info("Cleaning all collections") logger.info("Cleaning all collections")
for collection in embedding.ALL_COLLECTIONS: for collection in embedding.ALL_COLLECTIONS:
clean_collection.delay(collection) clean_collection.delay(collection) # type: ignore
@app.task(name=REINGEST_CHUNK) @app.task(name=REINGEST_CHUNK)
@ -97,7 +97,7 @@ def check_batch(batch: Sequence[Chunk]) -> dict:
for chunk in chunks: for chunk in chunks:
if str(chunk.id) in missing: if str(chunk.id) in missing:
reingest_chunk.delay(str(chunk.id), collection) reingest_chunk.delay(str(chunk.id), collection) # type: ignore
else: else:
chunk.checked_at = datetime.now() chunk.checked_at = datetime.now()
@ -132,7 +132,9 @@ def reingest_missing_chunks(batch_size: int = 1000):
.filter(Chunk.checked_at < since) .filter(Chunk.checked_at < since)
.options( .options(
contains_eager(Chunk.source).load_only( contains_eager(Chunk.source).load_only(
SourceItem.id, SourceItem.modality, SourceItem.tags SourceItem.id, # type: ignore
SourceItem.modality, # type: ignore
SourceItem.tags, # type: ignore
) )
) )
.order_by(Chunk.id) .order_by(Chunk.id)

View File

@ -252,7 +252,7 @@ def test_parse_simple_email():
"hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87" "hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87"
+ b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1", + b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1",
} }
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400 assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400 # type: ignore
def test_parse_email_with_attachments(): def test_parse_email_with_attachments():

View File

@ -1,7 +1,7 @@
import pathlib import pathlib
import uuid import uuid
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from typing import cast
import pytest import pytest
from PIL import Image from PIL import Image
@ -72,12 +72,12 @@ def test_embed_mixed(mock_embed):
def test_embed_page_text_only(mock_embed): def test_embed_page_text_only(mock_embed):
page = {"contents": ["text1", "text2"]} page = {"contents": ["text1", "text2"]}
assert embed_page(page) == [[0], [1]] assert embed_page(page) == [[0], [1]] # type: ignore
def test_embed_page_mixed_content(mock_embed): def test_embed_page_mixed_content(mock_embed):
page = {"contents": ["text", {"type": "image", "data": "base64"}]} page = {"contents": ["text", {"type": "image", "data": "base64"}]}
assert embed_page(page) == [[0]] assert embed_page(page) == [[0]] # type: ignore
def test_embed(mock_embed): def test_embed(mock_embed):
@ -91,12 +91,12 @@ def test_embed(mock_embed):
assert modality == "text" assert modality == "text"
assert [ assert [
{ {
"id": c.id, "id": c.id, # type: ignore
"file_path": c.file_path, "file_path": c.file_path, # type: ignore
"content": c.content, "content": c.content, # type: ignore
"embedding_model": c.embedding_model, "embedding_model": c.embedding_model, # type: ignore
"vector": c.vector, "vector": c.vector, # type: ignore
"item_metadata": c.item_metadata, "item_metadata": c.item_metadata, # type: ignore
} }
for c in chunks for c in chunks
] == [ ] == [
@ -128,7 +128,7 @@ def test_write_to_file_bytes(mock_file_storage):
chunk_id = "test-chunk-id" chunk_id = "test-chunk-id"
content = b"These are test bytes" content = b"These are test bytes"
file_path = write_to_file(chunk_id, content) file_path = write_to_file(chunk_id, content) # type: ignore
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin" assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin"
assert file_path.exists() assert file_path.exists()
@ -140,7 +140,7 @@ def test_write_to_file_image(mock_file_storage):
img = Image.new("RGB", (100, 100), color=(73, 109, 137)) img = Image.new("RGB", (100, 100), color=(73, 109, 137))
chunk_id = "test-chunk-id" chunk_id = "test-chunk-id"
file_path = write_to_file(chunk_id, img) file_path = write_to_file(chunk_id, img) # type: ignore
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png" assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png"
assert file_path.exists() assert file_path.exists()
@ -155,7 +155,7 @@ def test_write_to_file_unsupported_type(mock_file_storage):
content = 123 # Integer is not a supported type content = 123 # Integer is not a supported type
with pytest.raises(ValueError, match="Unsupported content type"): with pytest.raises(ValueError, match="Unsupported content type"):
write_to_file(chunk_id, content) write_to_file(chunk_id, content) # type: ignore
def test_make_chunk_text_only(mock_file_storage, db_session): def test_make_chunk_text_only(mock_file_storage, db_session):
@ -170,12 +170,12 @@ def test_make_chunk_text_only(mock_file_storage, db_session):
with patch.object( with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001") uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
): ):
chunk = make_chunk(page, vector, metadata) chunk = make_chunk(page, vector, metadata) # type: ignore
assert chunk.id == "00000000-0000-0000-0000-000000000001" assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
assert chunk.content == "text content 1\n\ntext content 2" assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
assert chunk.file_path is None assert chunk.file_path is None
assert chunk.embedding_model == settings.TEXT_EMBEDDING_MODEL assert cast(str, chunk.embedding_model) == settings.TEXT_EMBEDDING_MODEL
assert chunk.vector == vector assert chunk.vector == vector
assert chunk.item_metadata == metadata assert chunk.item_metadata == metadata
@ -190,19 +190,19 @@ def test_make_chunk_single_image(mock_file_storage, db_session):
with patch.object( with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002") uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
): ):
chunk = make_chunk(page, vector, metadata) chunk = make_chunk(page, vector, metadata) # type: ignore
assert chunk.id == "00000000-0000-0000-0000-000000000002" assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
assert chunk.content is None assert chunk.content is None
assert chunk.file_path == str( assert cast(str, chunk.file_path) == str(
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png", settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png",
) )
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
assert chunk.vector == vector assert chunk.vector == vector
assert chunk.item_metadata == metadata assert chunk.item_metadata == metadata
# Verify the file exists # Verify the file exists
assert pathlib.Path(chunk.file_path[0]).exists() assert pathlib.Path(cast(str, chunk.file_path)).exists()
def test_make_chunk_mixed_content(mock_file_storage, db_session): def test_make_chunk_mixed_content(mock_file_storage, db_session):
@ -215,14 +215,14 @@ def test_make_chunk_mixed_content(mock_file_storage, db_session):
with patch.object( with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003") uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
): ):
chunk = make_chunk(page, vector, metadata) chunk = make_chunk(page, vector, metadata) # type: ignore
assert chunk.id == "00000000-0000-0000-0000-000000000003" assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
assert chunk.content is None assert chunk.content is None
assert chunk.file_path == str( assert cast(str, chunk.file_path) == str(
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*", settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*",
) )
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
assert chunk.vector == vector assert chunk.vector == vector
assert chunk.item_metadata == metadata assert chunk.item_metadata == metadata

View File

@ -87,7 +87,7 @@ def test_extract_image_with_path(tmp_path):
img.save(img_path) img.save(img_path)
(page,) = extract_image(img_path) (page,) = extract_image(img_path)
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
assert page["metadata"] == {} assert page["metadata"] == {}
@ -98,7 +98,7 @@ def test_extract_image_with_bytes():
img_bytes = buffer.getvalue() img_bytes = buffer.getvalue()
(page,) = extract_image(img_bytes) (page,) = extract_image(img_bytes)
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
assert page["metadata"] == {} assert page["metadata"] == {}

View File

@ -33,7 +33,10 @@ def test_ensure_collection_exists_existing(mock_qdrant_client):
def test_ensure_collection_exists_new(mock_qdrant_client): def test_ensure_collection_exists_new(mock_qdrant_client):
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse( mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
status_code=404, reason_phrase="asd", content=b"asd", headers=None status_code=404,
reason_phrase="asd",
content=b"asd",
headers=None, # type: ignore
) )
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128) assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)

View File

@ -1,24 +1,26 @@
import base64 import base64
import pathlib import pathlib
from datetime import datetime from datetime import datetime
from typing import cast
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from memory.common import embedding, settings
from memory.common.db.models import ( from memory.common.db.models import (
MailMessage,
EmailAttachment,
EmailAccount, EmailAccount,
EmailAttachment,
MailMessage,
) )
from memory.common import settings from memory.common.parsers.email import Attachment
from memory.common import embedding
from memory.workers.email import ( from memory.workers.email import (
extract_email_uid,
create_mail_message, create_mail_message,
extract_email_uid,
fetch_email, fetch_email,
fetch_email_since, fetch_email_since,
process_folder,
process_attachment, process_attachment,
process_attachments, process_attachments,
process_folder,
vectorize_email, vectorize_email,
) )
@ -45,7 +47,9 @@ def mock_uuid4():
(100, 100, "<test@example.com>"), (100, 100, "<test@example.com>"),
], ],
) )
def test_process_attachment_inline(attachment_size, max_inline_size, message_id): def test_process_attachment_inline(
attachment_size: int, max_inline_size: int, message_id: str
):
attachment = { attachment = {
"filename": "test.txt", "filename": "test.txt",
"content_type": "text/plain", "content_type": "text/plain",
@ -60,10 +64,12 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
) )
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size): with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
result = process_attachment(attachment, message) result = process_attachment(cast(Attachment, attachment), message)
assert result is not None assert result is not None
assert result.content == attachment["content"].decode("utf-8", errors="replace") assert cast(str, result.content) == attachment["content"].decode(
"utf-8", errors="replace"
)
assert result.filename is None assert result.filename is None
@ -90,11 +96,11 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
folder="INBOX", folder="INBOX",
) )
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size): with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
result = process_attachment(attachment, message) result = process_attachment(cast(Attachment, attachment), message)
assert result is not None assert result is not None
assert not result.content assert not cast(str, result.content)
assert result.filename == str( assert cast(str, result.filename) == str(
settings.FILE_STORAGE_DIR settings.FILE_STORAGE_DIR
/ "sender_example_com" / "sender_example_com"
/ "INBOX" / "INBOX"
@ -125,11 +131,11 @@ def test_process_attachment_write_error():
patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 10), patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 10),
patch.object(pathlib.Path, "write_bytes", mock_write_bytes), patch.object(pathlib.Path, "write_bytes", mock_write_bytes),
): ):
assert process_attachment(attachment, message) is None assert process_attachment(cast(Attachment, attachment), message) is None
def test_process_attachments_empty(): def test_process_attachments_empty():
assert process_attachments([], "<test@example.com>") == [] assert process_attachments([], MagicMock()) == []
def test_process_attachments_mixed(): def test_process_attachments_mixed():
@ -167,16 +173,16 @@ def test_process_attachments_mixed():
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 50): with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 50):
# Process attachments # Process attachments
results = process_attachments(attachments, message) results = process_attachments(cast(list[Attachment], attachments), message)
# Verify we have all attachments processed # Verify we have all attachments processed
assert len(results) == 3 assert len(results) == 3
assert results[0].content == "a" * 20 assert cast(str, results[0].content) == "a" * 20
assert results[2].content == "c" * 30 assert cast(str, results[2].content) == "c" * 30
# Verify large attachment has a path # Verify large attachment has a path
assert results[1].filename == str( assert cast(str, results[1].filename) == str(
settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" / "large.txt" settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" / "large.txt"
) )
@ -239,12 +245,12 @@ def test_create_mail_message(db_session):
# Verify the mail message was created correctly # Verify the mail message was created correctly
assert isinstance(mail_message, MailMessage) assert isinstance(mail_message, MailMessage)
assert mail_message.message_id == "321" assert cast(str, mail_message.message_id) == "321"
assert mail_message.subject == "Test Subject" assert cast(str, mail_message.subject) == "Test Subject"
assert mail_message.sender == "sender@example.com" assert cast(str, mail_message.sender) == "sender@example.com"
assert mail_message.recipients == ["recipient@example.com"] assert cast(list[str], mail_message.recipients) == ["recipient@example.com"]
assert mail_message.sent_at.isoformat()[:-6] == "2023-01-01T12:00:00" assert mail_message.sent_at.isoformat()[:-6] == "2023-01-01T12:00:00"
assert mail_message.content == raw_email assert cast(str, mail_message.content) == raw_email
assert mail_message.body == "Test body content\n" assert mail_message.body == "Test body content\n"
assert mail_message.attachments == attachments assert mail_message.attachments == attachments
@ -275,7 +281,7 @@ def test_fetch_email_since(email_provider):
assert len(result) == 2 assert len(result) == 2
# Verify content of fetched emails # Verify content of fetched emails
uids = sorted([uid for uid, _ in result]) uids = sorted([uid or "" for uid, _ in result])
assert uids == ["101", "102"] assert uids == ["101", "102"]
# Test with a folder that doesn't exist # Test with a folder that doesn't exist
@ -342,7 +348,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
db_session.add(mail_message) db_session.add(mail_message)
db_session.flush() db_session.flush()
assert mail_message.embed_status == "RAW" assert cast(str, mail_message.embed_status) == "RAW"
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]): with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
vectorize_email(mail_message) vectorize_email(mail_message)
@ -351,7 +357,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
] ]
db_session.commit() db_session.commit()
assert mail_message.embed_status == "STORED" assert cast(str, mail_message.embed_status) == "STORED"
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4): def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
@ -418,6 +424,6 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
] ]
db_session.commit() db_session.commit()
assert mail_message.embed_status == "STORED" assert cast(str, mail_message.embed_status) == "STORED"
assert attachment1.embed_status == "STORED" assert cast(str, attachment1.embed_status) == "STORED"
assert attachment2.embed_status == "STORED" assert cast(str, attachment2.embed_status) == "STORED"

View File

@ -9,7 +9,7 @@ class MockEmailProvider:
Can be initialized with predefined emails to return. Can be initialized with predefined emails to return.
""" """
def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] = None): def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] | None = None):
""" """
Initialize with a dictionary of emails organized by folder. Initialize with a dictionary of emails organized by folder.
@ -21,31 +21,41 @@ class MockEmailProvider:
self.emails_by_folder = emails_by_folder or { self.emails_by_folder = emails_by_folder or {
"INBOX": [], "INBOX": [],
"Sent": [], "Sent": [],
"Archive": [] "Archive": [],
} }
self.current_folder = None self.current_folder = None
self.is_connected = False self.is_connected = False
def _generate_email_string(self, email_data: dict[str, Any]) -> str: def _generate_email_string(self, email_data: dict[str, Any]) -> str:
"""Generate a raw email string from the provided email data.""" """Generate a raw email string from the provided email data."""
msg = email.message.EmailMessage() msg = email.message.EmailMessage() # type: ignore
msg["From"] = email_data.get("from", "sender@example.com") msg["From"] = email_data.get("from", "sender@example.com")
msg["To"] = email_data.get("to", "recipient@example.com") msg["To"] = email_data.get("to", "recipient@example.com")
msg["Subject"] = email_data.get("subject", "Test Subject") msg["Subject"] = email_data.get("subject", "Test Subject")
msg["Message-ID"] = email_data.get("message_id", f"<test-{email_data['uid']}@example.com>") msg["Message-ID"] = email_data.get(
msg["Date"] = email_data.get("date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000")) "message_id", f"<test-{email_data['uid']}@example.com>"
)
msg["Date"] = email_data.get(
"date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000")
)
# Set the body content # Set the body content
msg.set_content(email_data.get("body", f"This is test email body {email_data['uid']}")) msg.set_content(
email_data.get("body", f"This is test email body {email_data['uid']}")
)
# Add attachments if present # Add attachments if present
for attachment in email_data.get("attachments", []): for attachment in email_data.get("attachments", []):
if isinstance(attachment, dict) and "filename" in attachment and "content" in attachment: if (
isinstance(attachment, dict)
and "filename" in attachment
and "content" in attachment
):
msg.add_attachment( msg.add_attachment(
attachment["content"], attachment["content"],
maintype=attachment.get("maintype", "application"), maintype=attachment.get("maintype", "application"),
subtype=attachment.get("subtype", "octet-stream"), subtype=attachment.get("subtype", "octet-stream"),
filename=attachment["filename"] filename=attachment["filename"],
) )
return msg.as_string() return msg.as_string()
@ -53,12 +63,12 @@ class MockEmailProvider:
def login(self, username: str, password: str) -> tuple[str, list[bytes]]: def login(self, username: str, password: str) -> tuple[str, list[bytes]]:
"""Mock login method.""" """Mock login method."""
self.is_connected = True self.is_connected = True
return ('OK', [b'Login successful']) return ("OK", [b"Login successful"])
def logout(self) -> tuple[str, list[bytes]]: def logout(self) -> tuple[str, list[bytes]]:
"""Mock logout method.""" """Mock logout method."""
self.is_connected = False self.is_connected = False
return ('OK', [b'Logout successful']) return ("OK", [b"Logout successful"])
def select(self, folder: str, readonly: bool = False) -> tuple[str, list[bytes]]: def select(self, folder: str, readonly: bool = False) -> tuple[str, list[bytes]]:
""" """
@ -74,14 +84,14 @@ class MockEmailProvider:
folder_name = folder.decode() if isinstance(folder, bytes) else folder folder_name = folder.decode() if isinstance(folder, bytes) else folder
self.current_folder = folder_name self.current_folder = folder_name
message_count = len(self.emails_by_folder.get(folder_name, [])) message_count = len(self.emails_by_folder.get(folder_name, []))
return ('OK', [str(message_count).encode()]) return ("OK", [str(message_count).encode()])
def list(self, directory: str = '', pattern: str = '*') -> tuple[str, list[bytes]]: def list(self, directory: str = "", pattern: str = "*") -> tuple[str, list[bytes]]:
"""List available folders.""" """List available folders."""
folders = [] folders = []
for folder in self.emails_by_folder.keys(): for folder in self.emails_by_folder.keys():
folders.append(f'(\\HasNoChildren) "/" "{folder}"'.encode()) folders.append(f'(\\HasNoChildren) "/" "{folder}"'.encode())
return ('OK', folders) return ("OK", folders)
def search(self, charset, criteria): def search(self, charset, criteria):
""" """
@ -95,12 +105,15 @@ class MockEmailProvider:
All email UIDs in the current folder All email UIDs in the current folder
""" """
if not self.current_folder or self.current_folder not in self.emails_by_folder: if not self.current_folder or self.current_folder not in self.emails_by_folder:
return ('OK', [b'']) return ("OK", [b""])
uids = [str(email["uid"]).encode() for email in self.emails_by_folder[self.current_folder]] uids = [
return ('OK', [b' '.join(uids) if uids else b'']) str(email["uid"]).encode()
for email in self.emails_by_folder[self.current_folder]
]
return ("OK", [b" ".join(uids) if uids else b""])
def fetch(self, message_set, message_parts) -> tuple[str, list]: def fetch(self, message_set: bytes | str, message_parts: bytes | str):
""" """
Handle FETCH command to retrieve email data. Handle FETCH command to retrieve email data.
@ -112,10 +125,12 @@ class MockEmailProvider:
Email data in IMAP format Email data in IMAP format
""" """
if not self.current_folder or self.current_folder not in self.emails_by_folder: if not self.current_folder or self.current_folder not in self.emails_by_folder:
return ('OK', [None]) return ("OK", [None])
# For simplicity, we'll just match the UID with the ID provided # For simplicity, we'll just match the UID with the ID provided
uid = int(message_set.decode() if isinstance(message_set, bytes) else message_set) uid = int(
message_set.decode() if isinstance(message_set, bytes) else message_set
)
# Find the email with the matching UID # Find the email with the matching UID
for email_data in self.emails_by_folder[self.current_folder]: for email_data in self.emails_by_folder[self.current_folder]:
@ -126,12 +141,14 @@ class MockEmailProvider:
date = email_data.get("date_internal", "01-Jan-2023 00:00:00 +0000") date = email_data.get("date_internal", "01-Jan-2023 00:00:00 +0000")
# Format the response as expected by the IMAP client # Format the response as expected by the IMAP client
response = [( response = [
(
f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 ' f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 '
f'{{{len(email_string)}}}'.encode(), f"{{{len(email_string)}}}".encode(),
email_string.encode() email_string.encode(),
)] )
return ('OK', response) ]
return ("OK", response)
# No matching email found # No matching email found
return ('NO', [b'Email not found']) return ("NO", [b"Email not found"])