mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
fix linting
This commit is contained in:
parent
29a2a7ba3a
commit
4f1ca777e9
16
setup.py
16
setup.py
@ -4,19 +4,19 @@ from setuptools import setup, find_namespace_packages
|
||||
|
||||
def read_requirements(filename: str) -> list[str]:
|
||||
"""Read requirements from file, ignoring comments and -r directives."""
|
||||
filename = pathlib.Path(filename)
|
||||
path = pathlib.Path(filename)
|
||||
return [
|
||||
line.strip()
|
||||
for line in filename.read_text().splitlines()
|
||||
if line.strip() and not line.strip().startswith(('#', '-r'))
|
||||
for line in path.read_text().splitlines()
|
||||
if line.strip() and not line.strip().startswith(("#", "-r"))
|
||||
]
|
||||
|
||||
|
||||
# Read requirements files
|
||||
common_requires = read_requirements('requirements-common.txt')
|
||||
api_requires = read_requirements('requirements-api.txt')
|
||||
workers_requires = read_requirements('requirements-workers.txt')
|
||||
dev_requires = read_requirements('requirements-dev.txt')
|
||||
common_requires = read_requirements("requirements-common.txt")
|
||||
api_requires = read_requirements("requirements-api.txt")
|
||||
workers_requires = read_requirements("requirements-workers.txt")
|
||||
dev_requires = read_requirements("requirements-dev.txt")
|
||||
|
||||
setup(
|
||||
name="memory",
|
||||
@ -31,4 +31,4 @@ setup(
|
||||
"dev": dev_requires,
|
||||
"all": api_requires + workers_requires + common_requires + dev_requires,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
@ -4,10 +4,9 @@ Database models for the knowledge base system.
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
from email.message import EmailMessage
|
||||
from pathlib import Path
|
||||
import textwrap
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any, ClassVar, cast
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
@ -31,7 +30,7 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship, Session
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.parsers.email import parse_email_message
|
||||
from memory.common.parsers.email import parse_email_message, EmailMessage
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@ -113,10 +112,10 @@ class Chunk(Base):
|
||||
@property
|
||||
def data(self) -> list[bytes | str | Image.Image]:
|
||||
if self.file_path is None:
|
||||
return [self.content]
|
||||
return [cast(str, self.content)]
|
||||
|
||||
path = pathlib.Path(self.file_path.replace("/app/", ""))
|
||||
if self.file_path.endswith("*"):
|
||||
if cast(str, self.file_path).endswith("*"):
|
||||
files = list(path.parent.glob(path.name))
|
||||
else:
|
||||
files = [path]
|
||||
@ -182,7 +181,7 @@ class SourceItem(Base):
|
||||
|
||||
@property
|
||||
def display_contents(self) -> str | None:
|
||||
return self.content or self.filename
|
||||
return cast(str | None, self.content) or cast(str | None, self.filename)
|
||||
|
||||
|
||||
class MailMessage(SourceItem):
|
||||
@ -217,8 +216,8 @@ class MailMessage(SourceItem):
|
||||
|
||||
@property
|
||||
def attachments_path(self) -> Path:
|
||||
clean_sender = clean_filename(self.sender)
|
||||
clean_folder = clean_filename(self.folder or "INBOX")
|
||||
clean_sender = clean_filename(cast(str, self.sender))
|
||||
clean_folder = clean_filename(cast(str | None, self.folder) or "INBOX")
|
||||
return Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder
|
||||
|
||||
def safe_filename(self, filename: str) -> Path:
|
||||
@ -237,12 +236,12 @@ class MailMessage(SourceItem):
|
||||
"recipients": self.recipients,
|
||||
"folder": self.folder,
|
||||
"tags": self.tags + [self.sender] + self.recipients,
|
||||
"date": self.sent_at and self.sent_at.isoformat() or None,
|
||||
"date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore
|
||||
}
|
||||
|
||||
@property
|
||||
def parsed_content(self) -> EmailMessage:
|
||||
return parse_email_message(self.content, self.message_id)
|
||||
return parse_email_message(cast(str, self.content), cast(str, self.message_id))
|
||||
|
||||
@property
|
||||
def body(self) -> str:
|
||||
@ -300,7 +299,7 @@ class EmailAttachment(SourceItem):
|
||||
"filename": self.filename,
|
||||
"content_type": self.mime_type,
|
||||
"size": self.size,
|
||||
"created_at": self.created_at and self.created_at.isoformat() or None,
|
||||
"created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
|
||||
"mail_message_id": self.mail_message_id,
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import pathlib
|
||||
import uuid
|
||||
from typing import Any, Iterable, Literal, NotRequired, TypedDict
|
||||
from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast
|
||||
|
||||
import voyageai
|
||||
from PIL import Image
|
||||
@ -21,7 +21,6 @@ CHARS_PER_TOKEN = 4
|
||||
|
||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||
Vector = list[float]
|
||||
Embedding = tuple[str, Vector, dict[str, Any]]
|
||||
|
||||
|
||||
class Collection(TypedDict):
|
||||
@ -133,16 +132,18 @@ def get_modality(mime_type: str) -> str:
|
||||
|
||||
|
||||
def embed_chunks(
|
||||
chunks: list[extract.MulitmodalChunk],
|
||||
chunks: list[str] | list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
vo = voyageai.Client()
|
||||
vo = voyageai.Client() # type: ignore
|
||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||
return vo.multimodal_embed(
|
||||
chunks, model=model, input_type=input_type
|
||||
chunks, # type: ignore
|
||||
model=model,
|
||||
input_type=input_type,
|
||||
).embeddings
|
||||
return vo.embed(chunks, model=model, input_type=input_type).embeddings
|
||||
return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
|
||||
|
||||
|
||||
def embed_text(
|
||||
@ -162,7 +163,7 @@ def embed_text(
|
||||
|
||||
try:
|
||||
return embed_chunks(chunks, model, input_type)
|
||||
except voyageai.error.InvalidRequestError as e:
|
||||
except voyageai.error.InvalidRequestError as e: # type: ignore
|
||||
logger.error(f"Error embedding text: {e}")
|
||||
logger.debug(f"Text: {texts}")
|
||||
raise
|
||||
@ -179,7 +180,7 @@ def embed_mixed(
|
||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
|
||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
|
||||
if isinstance(item, str):
|
||||
return [
|
||||
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
|
||||
@ -190,11 +191,16 @@ def embed_mixed(
|
||||
return embed_chunks([chunks], model, input_type)
|
||||
|
||||
|
||||
def embed_page(page: dict[str, Any]) -> list[Vector]:
|
||||
def embed_page(page: extract.Page) -> list[Vector]:
|
||||
contents = page["contents"]
|
||||
if all(isinstance(c, str) for c in contents):
|
||||
return embed_text(contents, model=settings.TEXT_EMBEDDING_MODEL)
|
||||
return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL)
|
||||
return embed_text(
|
||||
cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL
|
||||
)
|
||||
return embed_mixed(
|
||||
cast(list[extract.MulitmodalChunk], contents),
|
||||
model=settings.MIXED_EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
|
||||
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
||||
@ -224,7 +230,7 @@ def make_chunk(
|
||||
contents = page["contents"]
|
||||
content, filename = None, None
|
||||
if all(isinstance(c, str) for c in contents):
|
||||
content = "\n\n".join(contents)
|
||||
content = "\n\n".join(cast(list[str], contents))
|
||||
model = settings.TEXT_EMBEDDING_MODEL
|
||||
elif len(contents) == 1:
|
||||
filename = write_to_file(chunk_id, contents[0]).absolute().as_posix()
|
||||
@ -249,7 +255,7 @@ def embed(
|
||||
mime_type: str,
|
||||
content: bytes | str | pathlib.Path,
|
||||
metadata: dict[str, Any] = {},
|
||||
) -> tuple[str, list[Embedding]]:
|
||||
) -> tuple[str, list[Chunk]]:
|
||||
modality = get_modality(mime_type)
|
||||
pages = extract.extract_content(mime_type, content)
|
||||
chunks = [
|
||||
|
@ -1,13 +1,13 @@
|
||||
from contextlib import contextmanager
|
||||
import io
|
||||
import logging
|
||||
import pathlib
|
||||
import tempfile
|
||||
import pypandoc
|
||||
import pymupdf # PyMuPDF
|
||||
from PIL import Image
|
||||
from typing import Any, TypedDict, Generator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Sequence, TypedDict, cast
|
||||
|
||||
import logging
|
||||
import pymupdf # PyMuPDF
|
||||
import pypandoc
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -105,7 +105,7 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return [{"contents": [content], "metadata": {}}]
|
||||
return [{"contents": [cast(str, content)], "metadata": {}}]
|
||||
|
||||
|
||||
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import TypedDict, NotRequired
|
||||
from typing import TypedDict, NotRequired, cast
|
||||
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
import requests
|
||||
@ -59,7 +59,7 @@ def extract_smbc(url: str) -> ComicInfo:
|
||||
"title": title,
|
||||
"image_url": image_url,
|
||||
"published_date": published_date,
|
||||
"url": comic_url or url,
|
||||
"url": cast(str, comic_url or url),
|
||||
}
|
||||
|
||||
|
||||
|
@ -28,10 +28,10 @@ class EmailMessage(TypedDict):
|
||||
hash: bytes
|
||||
|
||||
|
||||
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
|
||||
RawEmailResponse = tuple[str | None, bytes]
|
||||
|
||||
|
||||
def extract_recipients(msg: email.message.Message) -> list[str]:
|
||||
def extract_recipients(msg: email.message.Message) -> list[str]: # type: ignore
|
||||
"""
|
||||
Extract email recipients from message headers.
|
||||
|
||||
@ -50,7 +50,7 @@ def extract_recipients(msg: email.message.Message) -> list[str]:
|
||||
]
|
||||
|
||||
|
||||
def extract_date(msg: email.message.Message) -> datetime | None:
|
||||
def extract_date(msg: email.message.Message) -> datetime | None: # type: ignore
|
||||
"""
|
||||
Parse date from email header.
|
||||
|
||||
@ -68,7 +68,7 @@ def extract_date(msg: email.message.Message) -> datetime | None:
|
||||
return None
|
||||
|
||||
|
||||
def extract_body(msg: email.message.Message) -> str:
|
||||
def extract_body(msg: email.message.Message) -> str: # type: ignore
|
||||
"""
|
||||
Extract plain text body from email message.
|
||||
|
||||
@ -99,7 +99,7 @@ def extract_body(msg: email.message.Message) -> str:
|
||||
return body
|
||||
|
||||
|
||||
def extract_attachments(msg: email.message.Message) -> list[Attachment]:
|
||||
def extract_attachments(msg: email.message.Message) -> list[Attachment]: # type: ignore
|
||||
"""
|
||||
Extract attachment metadata and content from email.
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, cast, Iterator, Sequence
|
||||
from typing import Any, cast, Generator, Sequence
|
||||
|
||||
import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
@ -24,7 +24,7 @@ def get_qdrant_client() -> qdrant_client.QdrantClient:
|
||||
return qdrant_client.QdrantClient(
|
||||
host=settings.QDRANT_HOST,
|
||||
port=settings.QDRANT_PORT,
|
||||
grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None,
|
||||
grpc_port=settings.QDRANT_GRPC_PORT or 6334,
|
||||
prefer_grpc=settings.QDRANT_PREFER_GRPC,
|
||||
api_key=settings.QDRANT_API_KEY,
|
||||
timeout=settings.QDRANT_TIMEOUT,
|
||||
@ -80,7 +80,8 @@ def ensure_collection_exists(
|
||||
|
||||
|
||||
def initialize_collections(
|
||||
client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None
|
||||
client: qdrant_client.QdrantClient,
|
||||
collections: dict[str, Collection] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize all required collections in Qdrant.
|
||||
@ -122,7 +123,7 @@ def upsert_vectors(
|
||||
collection_name: str,
|
||||
ids: list[str],
|
||||
vectors: list[Vector],
|
||||
payloads: list[dict[str, Any]] = None,
|
||||
payloads: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Upsert vectors into a collection.
|
||||
|
||||
@ -147,7 +148,7 @@ def upsert_vectors(
|
||||
|
||||
client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=points,
|
||||
points=points, # type: ignore
|
||||
)
|
||||
|
||||
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
|
||||
@ -157,7 +158,7 @@ def search_vectors(
|
||||
client: qdrant_client.QdrantClient,
|
||||
collection_name: str,
|
||||
query_vector: Vector,
|
||||
filter_params: dict = None,
|
||||
filter_params: dict | None = None,
|
||||
limit: int = 10,
|
||||
) -> list[qdrant_models.ScoredPoint]:
|
||||
"""Search for similar vectors in a collection.
|
||||
@ -200,7 +201,7 @@ def delete_points(
|
||||
client.delete(
|
||||
collection_name=collection_name,
|
||||
points_selector=qdrant_models.PointIdsList(
|
||||
points=ids,
|
||||
points=ids, # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
@ -226,7 +227,7 @@ def get_collection_info(
|
||||
|
||||
def batch_ids(
|
||||
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
|
||||
) -> Iterator[list[str]]:
|
||||
) -> Generator[list[str], None, None]:
|
||||
"""Iterate over all IDs in a collection."""
|
||||
offset = None
|
||||
while resp := client.scroll(
|
||||
@ -236,7 +237,7 @@ def batch_ids(
|
||||
limit=batch_size,
|
||||
):
|
||||
points, offset = resp
|
||||
yield [point.id for point in points]
|
||||
yield [cast(str, point.id) for point in points]
|
||||
|
||||
if not offset:
|
||||
return
|
||||
|
@ -21,7 +21,11 @@ async def make_request(
|
||||
) -> httpx.Response:
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.request(
|
||||
method, f"{SERVER}/{path}", data=data, json=json, files=files
|
||||
method,
|
||||
f"{SERVER}/{path}",
|
||||
data=data,
|
||||
json=json,
|
||||
files=files, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ app.conf.update(
|
||||
)
|
||||
|
||||
|
||||
@app.on_after_configure.connect
|
||||
@app.on_after_configure.connect # type: ignore
|
||||
def ensure_qdrant_initialised(sender, **_):
|
||||
from memory.common import qdrant
|
||||
|
||||
|
@ -1,24 +1,26 @@
|
||||
import hashlib
|
||||
import imaplib
|
||||
import logging
|
||||
import pathlib
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Generator, Callable
|
||||
import pathlib
|
||||
from sqlalchemy.orm import Session
|
||||
from collections import defaultdict
|
||||
from memory.common import settings, embedding, qdrant
|
||||
from typing import Callable, Generator, Sequence, cast
|
||||
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from memory.common import embedding, qdrant, settings
|
||||
from memory.common.db.models import (
|
||||
EmailAccount,
|
||||
EmailAttachment,
|
||||
MailMessage,
|
||||
SourceItem,
|
||||
EmailAttachment,
|
||||
)
|
||||
from memory.common.parsers.email import (
|
||||
Attachment,
|
||||
parse_email_message,
|
||||
RawEmailResponse,
|
||||
parse_email_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -89,7 +91,7 @@ def process_attachments(
|
||||
|
||||
|
||||
def create_mail_message(
|
||||
db_session: Session,
|
||||
db_session: Session | scoped_session,
|
||||
tags: list[str],
|
||||
folder: str,
|
||||
raw_email: str,
|
||||
@ -136,7 +138,7 @@ def create_mail_message(
|
||||
|
||||
|
||||
def does_message_exist(
|
||||
db_session: Session, message_id: str, message_hash: bytes
|
||||
db_session: Session | scoped_session, message_id: str, message_hash: bytes
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a message already exists in the database.
|
||||
@ -167,7 +169,7 @@ def does_message_exist(
|
||||
|
||||
|
||||
def check_message_exists(
|
||||
db: Session, account_id: int, message_id: str, raw_email: str
|
||||
db: Session | scoped_session, account_id: int, message_id: str, raw_email: str
|
||||
) -> bool:
|
||||
account = db.query(EmailAccount).get(account_id)
|
||||
if not account:
|
||||
@ -181,7 +183,9 @@ def check_message_exists(
|
||||
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
|
||||
|
||||
|
||||
def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
|
||||
def extract_email_uid(
|
||||
msg_data: Sequence[tuple[bytes, bytes]],
|
||||
) -> tuple[str | None, bytes]:
|
||||
"""
|
||||
Extract the UID and raw email data from the message data.
|
||||
"""
|
||||
@ -199,7 +203,7 @@ def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None:
|
||||
logger.error(f"Error fetching message {uid}")
|
||||
return None
|
||||
|
||||
return extract_email_uid(msg_data)
|
||||
return extract_email_uid(msg_data) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message {uid}: {str(e)}")
|
||||
return None
|
||||
@ -248,7 +252,7 @@ def process_folder(
|
||||
folder: str,
|
||||
account: EmailAccount,
|
||||
since_date: datetime,
|
||||
processor: Callable[[int, str, str, bytes], int | None],
|
||||
processor: Callable[[int, str, str, str], int | None],
|
||||
) -> dict:
|
||||
"""
|
||||
Process a single folder from an email account.
|
||||
@ -272,7 +276,7 @@ def process_folder(
|
||||
for uid, raw_email in emails:
|
||||
try:
|
||||
task = processor(
|
||||
account_id=account.id,
|
||||
account_id=account.id, # type: ignore
|
||||
message_id=uid,
|
||||
folder=folder,
|
||||
raw_email=raw_email.decode("utf-8", errors="replace"),
|
||||
@ -296,9 +300,11 @@ def process_folder(
|
||||
|
||||
@contextmanager
|
||||
def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]:
|
||||
conn = imaplib.IMAP4_SSL(host=account.imap_server, port=account.imap_port)
|
||||
conn = imaplib.IMAP4_SSL(
|
||||
host=cast(str, account.imap_server), port=cast(int, account.imap_port)
|
||||
)
|
||||
try:
|
||||
conn.login(account.username, account.password)
|
||||
conn.login(cast(str, account.username), cast(str, account.password))
|
||||
yield conn
|
||||
finally:
|
||||
# Always try to logout and close the connection
|
||||
@ -318,15 +324,15 @@ def vectorize_email(email: MailMessage):
|
||||
)
|
||||
email.chunks = chunks
|
||||
if chunks:
|
||||
vector_ids = [c.id for c in chunks]
|
||||
vector_ids = [cast(str, c.id) for c in chunks]
|
||||
vectors = [c.vector for c in chunks]
|
||||
metadata = [c.item_metadata for c in chunks]
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant_client,
|
||||
collection_name="mail",
|
||||
ids=vector_ids,
|
||||
vectors=vectors,
|
||||
payloads=metadata,
|
||||
vectors=vectors, # type: ignore
|
||||
payloads=metadata, # type: ignore
|
||||
)
|
||||
|
||||
embeds = defaultdict(list)
|
||||
@ -356,7 +362,7 @@ def vectorize_email(email: MailMessage):
|
||||
payloads=metadata,
|
||||
)
|
||||
|
||||
email.embed_status = "STORED"
|
||||
email.embed_status = "STORED" # type: ignore
|
||||
for attachment in email.attachments:
|
||||
attachment.embed_status = "STORED"
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Callable
|
||||
from typing import Callable, cast
|
||||
|
||||
import feedparser
|
||||
import requests
|
||||
@ -35,7 +35,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
|
||||
logger.error(f"Failed to fetch or parse {rss_url}: {e}")
|
||||
return set()
|
||||
|
||||
urls = {item.get("link") or item.get("id") for item in feed.entries}
|
||||
urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries}
|
||||
|
||||
with make_session() as session:
|
||||
known = {
|
||||
@ -46,7 +46,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
|
||||
)
|
||||
}
|
||||
|
||||
return urls - known
|
||||
return cast(set[str], urls - known)
|
||||
|
||||
|
||||
def fetch_new_comics(
|
||||
@ -56,7 +56,7 @@ def fetch_new_comics(
|
||||
|
||||
for url in new_urls:
|
||||
data = parser(url) | {"author": base_url, "url": url}
|
||||
sync_comic.delay(**data)
|
||||
sync_comic.delay(**data) # type: ignore
|
||||
return new_urls
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ def sync_comic(
|
||||
client=qdrant.get_qdrant_client(),
|
||||
collection_name="comic",
|
||||
ids=[str(chunk.id)],
|
||||
vectors=[chunk.vector],
|
||||
vectors=[chunk.vector], # type: ignore
|
||||
payloads=[comic.as_payload()],
|
||||
)
|
||||
|
||||
@ -130,8 +130,8 @@ def sync_xkcd() -> set[str]:
|
||||
@app.task(name=SYNC_ALL_COMICS)
|
||||
def sync_all_comics():
|
||||
"""Synchronize all active comics."""
|
||||
sync_smbc.delay()
|
||||
sync_xkcd.delay()
|
||||
sync_smbc.delay() # type: ignore
|
||||
sync_xkcd.delay() # type: ignore
|
||||
|
||||
|
||||
@app.task(name="memory.workers.tasks.comic.full_sync_comic")
|
||||
@ -141,8 +141,8 @@ def trigger_comic_sync():
|
||||
|
||||
response = requests.get(url)
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
if link := soup.find("a", attrs={"class", "cc-prev"}):
|
||||
return link.attrs["href"]
|
||||
if link := soup.find("a", attrs={"class": "cc-prev"}):
|
||||
return link.attrs["href"] # type: ignore
|
||||
return None
|
||||
|
||||
next_url = "https://www.smbc-comics.com"
|
||||
@ -155,7 +155,7 @@ def trigger_comic_sync():
|
||||
data = comics.extract_smbc(next_url) | {
|
||||
"author": "https://www.smbc-comics.com/"
|
||||
}
|
||||
sync_comic.delay(**data)
|
||||
sync_comic.delay(**data) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"failed to sync {next_url}: {e}")
|
||||
urls.append(next_url)
|
||||
@ -167,6 +167,6 @@ def trigger_comic_sync():
|
||||
url = f"{BASE_XKCD_URL}/{i}"
|
||||
try:
|
||||
data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"}
|
||||
sync_comic.delay(**data)
|
||||
sync_comic.delay(**data) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"failed to sync {url}: {e}")
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from typing import cast
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import EmailAccount
|
||||
from memory.workers.celery_app import app
|
||||
@ -71,7 +71,7 @@ def process_message(
|
||||
for chunk in attachment.chunks:
|
||||
logger.info(f" - {chunk.id}")
|
||||
|
||||
return mail_message.id
|
||||
return cast(int, mail_message.id)
|
||||
|
||||
|
||||
@app.task(name=SYNC_ACCOUNT)
|
||||
@ -89,15 +89,17 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||
|
||||
with make_session() as db:
|
||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||
if not account or not account.active:
|
||||
if not account or not cast(bool, account.active):
|
||||
logger.warning(f"Account {account_id} not found or inactive")
|
||||
return {"error": "Account not found or inactive"}
|
||||
|
||||
folders_to_process: list[str] = account.folders or ["INBOX"]
|
||||
folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"]
|
||||
if since_date:
|
||||
cutoff_date = datetime.fromisoformat(since_date)
|
||||
else:
|
||||
cutoff_date: datetime = account.last_sync_at or datetime(1970, 1, 1)
|
||||
cutoff_date: datetime = cast(datetime, account.last_sync_at) or datetime(
|
||||
1970, 1, 1
|
||||
)
|
||||
|
||||
messages_found = 0
|
||||
new_messages = 0
|
||||
@ -106,9 +108,9 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||
def process_message_wrapper(
|
||||
account_id: int, message_id: str, folder: str, raw_email: str
|
||||
) -> int | None:
|
||||
if check_message_exists(db, account_id, message_id, raw_email):
|
||||
if check_message_exists(db, account_id, message_id, raw_email): # type: ignore
|
||||
return None
|
||||
return process_message.delay(account_id, message_id, folder, raw_email)
|
||||
return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore
|
||||
|
||||
try:
|
||||
with imap_connection(account) as conn:
|
||||
@ -121,7 +123,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
||||
new_messages += folder_stats["new_messages"]
|
||||
errors += folder_stats["errors"]
|
||||
|
||||
account.last_sync_at = datetime.now()
|
||||
account.last_sync_at = datetime.now() # type: ignore
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
|
||||
@ -152,7 +154,7 @@ def sync_all_accounts() -> list[dict]:
|
||||
{
|
||||
"account_id": account.id,
|
||||
"email": account.email_address,
|
||||
"task_id": sync_account.delay(account.id).id,
|
||||
"task_id": sync_account.delay(account.id).id, # type: ignore
|
||||
}
|
||||
for account in active_accounts
|
||||
]
|
||||
|
@ -48,7 +48,7 @@ def clean_collection(collection: str) -> dict[str, int]:
|
||||
def clean_all_collections():
|
||||
logger.info("Cleaning all collections")
|
||||
for collection in embedding.ALL_COLLECTIONS:
|
||||
clean_collection.delay(collection)
|
||||
clean_collection.delay(collection) # type: ignore
|
||||
|
||||
|
||||
@app.task(name=REINGEST_CHUNK)
|
||||
@ -97,7 +97,7 @@ def check_batch(batch: Sequence[Chunk]) -> dict:
|
||||
|
||||
for chunk in chunks:
|
||||
if str(chunk.id) in missing:
|
||||
reingest_chunk.delay(str(chunk.id), collection)
|
||||
reingest_chunk.delay(str(chunk.id), collection) # type: ignore
|
||||
else:
|
||||
chunk.checked_at = datetime.now()
|
||||
|
||||
@ -132,7 +132,9 @@ def reingest_missing_chunks(batch_size: int = 1000):
|
||||
.filter(Chunk.checked_at < since)
|
||||
.options(
|
||||
contains_eager(Chunk.source).load_only(
|
||||
SourceItem.id, SourceItem.modality, SourceItem.tags
|
||||
SourceItem.id, # type: ignore
|
||||
SourceItem.modality, # type: ignore
|
||||
SourceItem.tags, # type: ignore
|
||||
)
|
||||
)
|
||||
.order_by(Chunk.id)
|
||||
|
@ -252,7 +252,7 @@ def test_parse_simple_email():
|
||||
"hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87"
|
||||
+ b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1",
|
||||
}
|
||||
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400
|
||||
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400 # type: ignore
|
||||
|
||||
|
||||
def test_parse_email_with_attachments():
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pathlib
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from typing import cast
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
@ -72,12 +72,12 @@ def test_embed_mixed(mock_embed):
|
||||
|
||||
def test_embed_page_text_only(mock_embed):
|
||||
page = {"contents": ["text1", "text2"]}
|
||||
assert embed_page(page) == [[0], [1]]
|
||||
assert embed_page(page) == [[0], [1]] # type: ignore
|
||||
|
||||
|
||||
def test_embed_page_mixed_content(mock_embed):
|
||||
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
|
||||
assert embed_page(page) == [[0]]
|
||||
assert embed_page(page) == [[0]] # type: ignore
|
||||
|
||||
|
||||
def test_embed(mock_embed):
|
||||
@ -91,12 +91,12 @@ def test_embed(mock_embed):
|
||||
assert modality == "text"
|
||||
assert [
|
||||
{
|
||||
"id": c.id,
|
||||
"file_path": c.file_path,
|
||||
"content": c.content,
|
||||
"embedding_model": c.embedding_model,
|
||||
"vector": c.vector,
|
||||
"item_metadata": c.item_metadata,
|
||||
"id": c.id, # type: ignore
|
||||
"file_path": c.file_path, # type: ignore
|
||||
"content": c.content, # type: ignore
|
||||
"embedding_model": c.embedding_model, # type: ignore
|
||||
"vector": c.vector, # type: ignore
|
||||
"item_metadata": c.item_metadata, # type: ignore
|
||||
}
|
||||
for c in chunks
|
||||
] == [
|
||||
@ -128,7 +128,7 @@ def test_write_to_file_bytes(mock_file_storage):
|
||||
chunk_id = "test-chunk-id"
|
||||
content = b"These are test bytes"
|
||||
|
||||
file_path = write_to_file(chunk_id, content)
|
||||
file_path = write_to_file(chunk_id, content) # type: ignore
|
||||
|
||||
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin"
|
||||
assert file_path.exists()
|
||||
@ -140,7 +140,7 @@ def test_write_to_file_image(mock_file_storage):
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
chunk_id = "test-chunk-id"
|
||||
|
||||
file_path = write_to_file(chunk_id, img)
|
||||
file_path = write_to_file(chunk_id, img) # type: ignore
|
||||
|
||||
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png"
|
||||
assert file_path.exists()
|
||||
@ -155,7 +155,7 @@ def test_write_to_file_unsupported_type(mock_file_storage):
|
||||
content = 123 # Integer is not a supported type
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported content type"):
|
||||
write_to_file(chunk_id, content)
|
||||
write_to_file(chunk_id, content) # type: ignore
|
||||
|
||||
|
||||
def test_make_chunk_text_only(mock_file_storage, db_session):
|
||||
@ -170,12 +170,12 @@ def test_make_chunk_text_only(mock_file_storage, db_session):
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
):
|
||||
chunk = make_chunk(page, vector, metadata)
|
||||
chunk = make_chunk(page, vector, metadata) # type: ignore
|
||||
|
||||
assert chunk.id == "00000000-0000-0000-0000-000000000001"
|
||||
assert chunk.content == "text content 1\n\ntext content 2"
|
||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
|
||||
assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
|
||||
assert chunk.file_path is None
|
||||
assert chunk.embedding_model == settings.TEXT_EMBEDDING_MODEL
|
||||
assert cast(str, chunk.embedding_model) == settings.TEXT_EMBEDDING_MODEL
|
||||
assert chunk.vector == vector
|
||||
assert chunk.item_metadata == metadata
|
||||
|
||||
@ -190,19 +190,19 @@ def test_make_chunk_single_image(mock_file_storage, db_session):
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
|
||||
):
|
||||
chunk = make_chunk(page, vector, metadata)
|
||||
chunk = make_chunk(page, vector, metadata) # type: ignore
|
||||
|
||||
assert chunk.id == "00000000-0000-0000-0000-000000000002"
|
||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
|
||||
assert chunk.content is None
|
||||
assert chunk.file_path == str(
|
||||
assert cast(str, chunk.file_path) == str(
|
||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png",
|
||||
)
|
||||
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
|
||||
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
|
||||
assert chunk.vector == vector
|
||||
assert chunk.item_metadata == metadata
|
||||
|
||||
# Verify the file exists
|
||||
assert pathlib.Path(chunk.file_path[0]).exists()
|
||||
assert pathlib.Path(cast(str, chunk.file_path)).exists()
|
||||
|
||||
|
||||
def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
||||
@ -215,14 +215,14 @@ def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
|
||||
):
|
||||
chunk = make_chunk(page, vector, metadata)
|
||||
chunk = make_chunk(page, vector, metadata) # type: ignore
|
||||
|
||||
assert chunk.id == "00000000-0000-0000-0000-000000000003"
|
||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
|
||||
assert chunk.content is None
|
||||
assert chunk.file_path == str(
|
||||
assert cast(str, chunk.file_path) == str(
|
||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*",
|
||||
)
|
||||
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
|
||||
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
|
||||
assert chunk.vector == vector
|
||||
assert chunk.item_metadata == metadata
|
||||
|
||||
|
@ -87,7 +87,7 @@ def test_extract_image_with_path(tmp_path):
|
||||
img.save(img_path)
|
||||
|
||||
(page,) = extract_image(img_path)
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
|
||||
assert page["metadata"] == {}
|
||||
|
||||
|
||||
@ -98,7 +98,7 @@ def test_extract_image_with_bytes():
|
||||
img_bytes = buffer.getvalue()
|
||||
|
||||
(page,) = extract_image(img_bytes)
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
|
||||
assert page["metadata"] == {}
|
||||
|
||||
|
||||
|
@ -33,7 +33,10 @@ def test_ensure_collection_exists_existing(mock_qdrant_client):
|
||||
|
||||
def test_ensure_collection_exists_new(mock_qdrant_client):
|
||||
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
|
||||
status_code=404, reason_phrase="asd", content=b"asd", headers=None
|
||||
status_code=404,
|
||||
reason_phrase="asd",
|
||||
content=b"asd",
|
||||
headers=None, # type: ignore
|
||||
)
|
||||
|
||||
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
|
||||
|
@ -1,24 +1,26 @@
|
||||
import base64
|
||||
import pathlib
|
||||
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from memory.common import embedding, settings
|
||||
from memory.common.db.models import (
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
EmailAccount,
|
||||
EmailAttachment,
|
||||
MailMessage,
|
||||
)
|
||||
from memory.common import settings
|
||||
from memory.common import embedding
|
||||
from memory.common.parsers.email import Attachment
|
||||
from memory.workers.email import (
|
||||
extract_email_uid,
|
||||
create_mail_message,
|
||||
extract_email_uid,
|
||||
fetch_email,
|
||||
fetch_email_since,
|
||||
process_folder,
|
||||
process_attachment,
|
||||
process_attachments,
|
||||
process_folder,
|
||||
vectorize_email,
|
||||
)
|
||||
|
||||
@ -45,7 +47,9 @@ def mock_uuid4():
|
||||
(100, 100, "<test@example.com>"),
|
||||
],
|
||||
)
|
||||
def test_process_attachment_inline(attachment_size, max_inline_size, message_id):
|
||||
def test_process_attachment_inline(
|
||||
attachment_size: int, max_inline_size: int, message_id: str
|
||||
):
|
||||
attachment = {
|
||||
"filename": "test.txt",
|
||||
"content_type": "text/plain",
|
||||
@ -60,10 +64,12 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
|
||||
)
|
||||
|
||||
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
|
||||
result = process_attachment(attachment, message)
|
||||
result = process_attachment(cast(Attachment, attachment), message)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == attachment["content"].decode("utf-8", errors="replace")
|
||||
assert cast(str, result.content) == attachment["content"].decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
assert result.filename is None
|
||||
|
||||
|
||||
@ -90,11 +96,11 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
|
||||
folder="INBOX",
|
||||
)
|
||||
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
|
||||
result = process_attachment(attachment, message)
|
||||
result = process_attachment(cast(Attachment, attachment), message)
|
||||
|
||||
assert result is not None
|
||||
assert not result.content
|
||||
assert result.filename == str(
|
||||
assert not cast(str, result.content)
|
||||
assert cast(str, result.filename) == str(
|
||||
settings.FILE_STORAGE_DIR
|
||||
/ "sender_example_com"
|
||||
/ "INBOX"
|
||||
@ -125,11 +131,11 @@ def test_process_attachment_write_error():
|
||||
patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 10),
|
||||
patch.object(pathlib.Path, "write_bytes", mock_write_bytes),
|
||||
):
|
||||
assert process_attachment(attachment, message) is None
|
||||
assert process_attachment(cast(Attachment, attachment), message) is None
|
||||
|
||||
|
||||
def test_process_attachments_empty():
|
||||
assert process_attachments([], "<test@example.com>") == []
|
||||
assert process_attachments([], MagicMock()) == []
|
||||
|
||||
|
||||
def test_process_attachments_mixed():
|
||||
@ -167,16 +173,16 @@ def test_process_attachments_mixed():
|
||||
|
||||
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 50):
|
||||
# Process attachments
|
||||
results = process_attachments(attachments, message)
|
||||
results = process_attachments(cast(list[Attachment], attachments), message)
|
||||
|
||||
# Verify we have all attachments processed
|
||||
assert len(results) == 3
|
||||
|
||||
assert results[0].content == "a" * 20
|
||||
assert results[2].content == "c" * 30
|
||||
assert cast(str, results[0].content) == "a" * 20
|
||||
assert cast(str, results[2].content) == "c" * 30
|
||||
|
||||
# Verify large attachment has a path
|
||||
assert results[1].filename == str(
|
||||
assert cast(str, results[1].filename) == str(
|
||||
settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" / "large.txt"
|
||||
)
|
||||
|
||||
@ -239,12 +245,12 @@ def test_create_mail_message(db_session):
|
||||
|
||||
# Verify the mail message was created correctly
|
||||
assert isinstance(mail_message, MailMessage)
|
||||
assert mail_message.message_id == "321"
|
||||
assert mail_message.subject == "Test Subject"
|
||||
assert mail_message.sender == "sender@example.com"
|
||||
assert mail_message.recipients == ["recipient@example.com"]
|
||||
assert cast(str, mail_message.message_id) == "321"
|
||||
assert cast(str, mail_message.subject) == "Test Subject"
|
||||
assert cast(str, mail_message.sender) == "sender@example.com"
|
||||
assert cast(list[str], mail_message.recipients) == ["recipient@example.com"]
|
||||
assert mail_message.sent_at.isoformat()[:-6] == "2023-01-01T12:00:00"
|
||||
assert mail_message.content == raw_email
|
||||
assert cast(str, mail_message.content) == raw_email
|
||||
assert mail_message.body == "Test body content\n"
|
||||
assert mail_message.attachments == attachments
|
||||
|
||||
@ -275,7 +281,7 @@ def test_fetch_email_since(email_provider):
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify content of fetched emails
|
||||
uids = sorted([uid for uid, _ in result])
|
||||
uids = sorted([uid or "" for uid, _ in result])
|
||||
assert uids == ["101", "102"]
|
||||
|
||||
# Test with a folder that doesn't exist
|
||||
@ -342,7 +348,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
||||
db_session.add(mail_message)
|
||||
db_session.flush()
|
||||
|
||||
assert mail_message.embed_status == "RAW"
|
||||
assert cast(str, mail_message.embed_status) == "RAW"
|
||||
|
||||
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
|
||||
vectorize_email(mail_message)
|
||||
@ -351,7 +357,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
||||
]
|
||||
|
||||
db_session.commit()
|
||||
assert mail_message.embed_status == "STORED"
|
||||
assert cast(str, mail_message.embed_status) == "STORED"
|
||||
|
||||
|
||||
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||
@ -418,6 +424,6 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||
]
|
||||
|
||||
db_session.commit()
|
||||
assert mail_message.embed_status == "STORED"
|
||||
assert attachment1.embed_status == "STORED"
|
||||
assert attachment2.embed_status == "STORED"
|
||||
assert cast(str, mail_message.embed_status) == "STORED"
|
||||
assert cast(str, attachment1.embed_status) == "STORED"
|
||||
assert cast(str, attachment2.embed_status) == "STORED"
|
||||
|
@ -8,115 +8,130 @@ class MockEmailProvider:
|
||||
Mock IMAP email provider for integration testing.
|
||||
Can be initialized with predefined emails to return.
|
||||
"""
|
||||
|
||||
def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] = None):
|
||||
|
||||
def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] | None = None):
|
||||
"""
|
||||
Initialize with a dictionary of emails organized by folder.
|
||||
|
||||
|
||||
Args:
|
||||
emails_by_folder: A dictionary mapping folder names to lists of email dictionaries.
|
||||
Each email dict should have: 'uid', 'flags', 'date', 'from', 'to', 'subject',
|
||||
Each email dict should have: 'uid', 'flags', 'date', 'from', 'to', 'subject',
|
||||
'message_id', 'body', and optionally 'attachments'.
|
||||
"""
|
||||
self.emails_by_folder = emails_by_folder or {
|
||||
"INBOX": [],
|
||||
"Sent": [],
|
||||
"Archive": []
|
||||
"Archive": [],
|
||||
}
|
||||
self.current_folder = None
|
||||
self.is_connected = False
|
||||
|
||||
|
||||
def _generate_email_string(self, email_data: dict[str, Any]) -> str:
|
||||
"""Generate a raw email string from the provided email data."""
|
||||
msg = email.message.EmailMessage()
|
||||
msg = email.message.EmailMessage() # type: ignore
|
||||
msg["From"] = email_data.get("from", "sender@example.com")
|
||||
msg["To"] = email_data.get("to", "recipient@example.com")
|
||||
msg["Subject"] = email_data.get("subject", "Test Subject")
|
||||
msg["Message-ID"] = email_data.get("message_id", f"<test-{email_data['uid']}@example.com>")
|
||||
msg["Date"] = email_data.get("date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000"))
|
||||
|
||||
msg["Message-ID"] = email_data.get(
|
||||
"message_id", f"<test-{email_data['uid']}@example.com>"
|
||||
)
|
||||
msg["Date"] = email_data.get(
|
||||
"date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000")
|
||||
)
|
||||
|
||||
# Set the body content
|
||||
msg.set_content(email_data.get("body", f"This is test email body {email_data['uid']}"))
|
||||
|
||||
msg.set_content(
|
||||
email_data.get("body", f"This is test email body {email_data['uid']}")
|
||||
)
|
||||
|
||||
# Add attachments if present
|
||||
for attachment in email_data.get("attachments", []):
|
||||
if isinstance(attachment, dict) and "filename" in attachment and "content" in attachment:
|
||||
if (
|
||||
isinstance(attachment, dict)
|
||||
and "filename" in attachment
|
||||
and "content" in attachment
|
||||
):
|
||||
msg.add_attachment(
|
||||
attachment["content"],
|
||||
maintype=attachment.get("maintype", "application"),
|
||||
subtype=attachment.get("subtype", "octet-stream"),
|
||||
filename=attachment["filename"]
|
||||
filename=attachment["filename"],
|
||||
)
|
||||
|
||||
|
||||
return msg.as_string()
|
||||
|
||||
|
||||
def login(self, username: str, password: str) -> tuple[str, list[bytes]]:
|
||||
"""Mock login method."""
|
||||
self.is_connected = True
|
||||
return ('OK', [b'Login successful'])
|
||||
|
||||
return ("OK", [b"Login successful"])
|
||||
|
||||
def logout(self) -> tuple[str, list[bytes]]:
|
||||
"""Mock logout method."""
|
||||
self.is_connected = False
|
||||
return ('OK', [b'Logout successful'])
|
||||
|
||||
return ("OK", [b"Logout successful"])
|
||||
|
||||
def select(self, folder: str, readonly: bool = False) -> tuple[str, list[bytes]]:
|
||||
"""
|
||||
Select a folder and make it the current active folder.
|
||||
|
||||
|
||||
Args:
|
||||
folder: Folder name to select
|
||||
readonly: Whether to open in readonly mode
|
||||
|
||||
|
||||
Returns:
|
||||
IMAP-style response with message count
|
||||
"""
|
||||
folder_name = folder.decode() if isinstance(folder, bytes) else folder
|
||||
self.current_folder = folder_name
|
||||
message_count = len(self.emails_by_folder.get(folder_name, []))
|
||||
return ('OK', [str(message_count).encode()])
|
||||
|
||||
def list(self, directory: str = '', pattern: str = '*') -> tuple[str, list[bytes]]:
|
||||
return ("OK", [str(message_count).encode()])
|
||||
|
||||
def list(self, directory: str = "", pattern: str = "*") -> tuple[str, list[bytes]]:
|
||||
"""List available folders."""
|
||||
folders = []
|
||||
for folder in self.emails_by_folder.keys():
|
||||
folders.append(f'(\\HasNoChildren) "/" "{folder}"'.encode())
|
||||
return ('OK', folders)
|
||||
|
||||
return ("OK", folders)
|
||||
|
||||
def search(self, charset, criteria):
|
||||
"""
|
||||
Handle SEARCH command to find email UIDs.
|
||||
|
||||
|
||||
Args:
|
||||
charset: Character set (ignored in mock)
|
||||
criteria: Search criteria (ignored in mock, we return all emails)
|
||||
|
||||
|
||||
Returns:
|
||||
All email UIDs in the current folder
|
||||
"""
|
||||
if not self.current_folder or self.current_folder not in self.emails_by_folder:
|
||||
return ('OK', [b''])
|
||||
|
||||
uids = [str(email["uid"]).encode() for email in self.emails_by_folder[self.current_folder]]
|
||||
return ('OK', [b' '.join(uids) if uids else b''])
|
||||
|
||||
def fetch(self, message_set, message_parts) -> tuple[str, list]:
|
||||
return ("OK", [b""])
|
||||
|
||||
uids = [
|
||||
str(email["uid"]).encode()
|
||||
for email in self.emails_by_folder[self.current_folder]
|
||||
]
|
||||
return ("OK", [b" ".join(uids) if uids else b""])
|
||||
|
||||
def fetch(self, message_set: bytes | str, message_parts: bytes | str):
|
||||
"""
|
||||
Handle FETCH command to retrieve email data.
|
||||
|
||||
|
||||
Args:
|
||||
message_set: Message numbers/UIDs to fetch
|
||||
message_parts: Parts of the message to fetch
|
||||
|
||||
|
||||
Returns:
|
||||
Email data in IMAP format
|
||||
"""
|
||||
if not self.current_folder or self.current_folder not in self.emails_by_folder:
|
||||
return ('OK', [None])
|
||||
|
||||
return ("OK", [None])
|
||||
|
||||
# For simplicity, we'll just match the UID with the ID provided
|
||||
uid = int(message_set.decode() if isinstance(message_set, bytes) else message_set)
|
||||
|
||||
uid = int(
|
||||
message_set.decode() if isinstance(message_set, bytes) else message_set
|
||||
)
|
||||
|
||||
# Find the email with the matching UID
|
||||
for email_data in self.emails_by_folder[self.current_folder]:
|
||||
if email_data["uid"] == uid:
|
||||
@ -124,14 +139,16 @@ class MockEmailProvider:
|
||||
email_string = self._generate_email_string(email_data)
|
||||
flags = email_data.get("flags", "\\Seen")
|
||||
date = email_data.get("date_internal", "01-Jan-2023 00:00:00 +0000")
|
||||
|
||||
|
||||
# Format the response as expected by the IMAP client
|
||||
response = [(
|
||||
f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 '
|
||||
f'{{{len(email_string)}}}'.encode(),
|
||||
email_string.encode()
|
||||
)]
|
||||
return ('OK', response)
|
||||
|
||||
response = [
|
||||
(
|
||||
f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 '
|
||||
f"{{{len(email_string)}}}".encode(),
|
||||
email_string.encode(),
|
||||
)
|
||||
]
|
||||
return ("OK", response)
|
||||
|
||||
# No matching email found
|
||||
return ('NO', [b'Email not found'])
|
||||
return ("NO", [b"Email not found"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user