fix linting

This commit is contained in:
Daniel O'Connell 2025-05-20 22:54:45 +02:00
parent 29a2a7ba3a
commit 4f1ca777e9
19 changed files with 255 additions and 209 deletions

View File

@ -4,19 +4,19 @@ from setuptools import setup, find_namespace_packages
def read_requirements(filename: str) -> list[str]:
"""Read requirements from file, ignoring comments and -r directives."""
filename = pathlib.Path(filename)
path = pathlib.Path(filename)
return [
line.strip()
for line in filename.read_text().splitlines()
if line.strip() and not line.strip().startswith(('#', '-r'))
for line in path.read_text().splitlines()
if line.strip() and not line.strip().startswith(("#", "-r"))
]
# Read requirements files
common_requires = read_requirements('requirements-common.txt')
api_requires = read_requirements('requirements-api.txt')
workers_requires = read_requirements('requirements-workers.txt')
dev_requires = read_requirements('requirements-dev.txt')
common_requires = read_requirements("requirements-common.txt")
api_requires = read_requirements("requirements-api.txt")
workers_requires = read_requirements("requirements-workers.txt")
dev_requires = read_requirements("requirements-dev.txt")
setup(
name="memory",
@ -31,4 +31,4 @@ setup(
"dev": dev_requires,
"all": api_requires + workers_requires + common_requires + dev_requires,
},
)
)

View File

@ -4,10 +4,9 @@ Database models for the knowledge base system.
import pathlib
import re
from email.message import EmailMessage
from pathlib import Path
import textwrap
from typing import Any, ClassVar
from typing import Any, ClassVar, cast
from PIL import Image
from sqlalchemy import (
ARRAY,
@ -31,7 +30,7 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, Session
from memory.common import settings
from memory.common.parsers.email import parse_email_message
from memory.common.parsers.email import parse_email_message, EmailMessage
Base = declarative_base()
@ -113,10 +112,10 @@ class Chunk(Base):
@property
def data(self) -> list[bytes | str | Image.Image]:
if self.file_path is None:
return [self.content]
return [cast(str, self.content)]
path = pathlib.Path(self.file_path.replace("/app/", ""))
if self.file_path.endswith("*"):
if cast(str, self.file_path).endswith("*"):
files = list(path.parent.glob(path.name))
else:
files = [path]
@ -182,7 +181,7 @@ class SourceItem(Base):
@property
def display_contents(self) -> str | None:
return self.content or self.filename
return cast(str | None, self.content) or cast(str | None, self.filename)
class MailMessage(SourceItem):
@ -217,8 +216,8 @@ class MailMessage(SourceItem):
@property
def attachments_path(self) -> Path:
clean_sender = clean_filename(self.sender)
clean_folder = clean_filename(self.folder or "INBOX")
clean_sender = clean_filename(cast(str, self.sender))
clean_folder = clean_filename(cast(str | None, self.folder) or "INBOX")
return Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder
def safe_filename(self, filename: str) -> Path:
@ -237,12 +236,12 @@ class MailMessage(SourceItem):
"recipients": self.recipients,
"folder": self.folder,
"tags": self.tags + [self.sender] + self.recipients,
"date": self.sent_at and self.sent_at.isoformat() or None,
"date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore
}
@property
def parsed_content(self) -> EmailMessage:
return parse_email_message(self.content, self.message_id)
return parse_email_message(cast(str, self.content), cast(str, self.message_id))
@property
def body(self) -> str:
@ -300,7 +299,7 @@ class EmailAttachment(SourceItem):
"filename": self.filename,
"content_type": self.mime_type,
"size": self.size,
"created_at": self.created_at and self.created_at.isoformat() or None,
"created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
"mail_message_id": self.mail_message_id,
"source_id": self.id,
"tags": self.tags,

View File

@ -1,7 +1,7 @@
import logging
import pathlib
import uuid
from typing import Any, Iterable, Literal, NotRequired, TypedDict
from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast
import voyageai
from PIL import Image
@ -21,7 +21,6 @@ CHARS_PER_TOKEN = 4
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
Embedding = tuple[str, Vector, dict[str, Any]]
class Collection(TypedDict):
@ -133,16 +132,18 @@ def get_modality(mime_type: str) -> str:
def embed_chunks(
chunks: list[extract.MulitmodalChunk],
chunks: list[str] | list[list[extract.MulitmodalChunk]],
model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
) -> list[Vector]:
vo = voyageai.Client()
vo = voyageai.Client() # type: ignore
if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed(
chunks, model=model, input_type=input_type
chunks, # type: ignore
model=model,
input_type=input_type,
).embeddings
return vo.embed(chunks, model=model, input_type=input_type).embeddings
return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
def embed_text(
@ -162,7 +163,7 @@ def embed_text(
try:
return embed_chunks(chunks, model, input_type)
except voyageai.error.InvalidRequestError as e:
except voyageai.error.InvalidRequestError as e: # type: ignore
logger.error(f"Error embedding text: {e}")
logger.debug(f"Text: {texts}")
raise
@ -179,7 +180,7 @@ def embed_mixed(
model: str = settings.MIXED_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
) -> list[Vector]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
if isinstance(item, str):
return [
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
@ -190,11 +191,16 @@ def embed_mixed(
return embed_chunks([chunks], model, input_type)
def embed_page(page: dict[str, Any]) -> list[Vector]:
def embed_page(page: extract.Page) -> list[Vector]:
contents = page["contents"]
if all(isinstance(c, str) for c in contents):
return embed_text(contents, model=settings.TEXT_EMBEDDING_MODEL)
return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL)
return embed_text(
cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL
)
return embed_mixed(
cast(list[extract.MulitmodalChunk], contents),
model=settings.MIXED_EMBEDDING_MODEL,
)
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
@ -224,7 +230,7 @@ def make_chunk(
contents = page["contents"]
content, filename = None, None
if all(isinstance(c, str) for c in contents):
content = "\n\n".join(contents)
content = "\n\n".join(cast(list[str], contents))
model = settings.TEXT_EMBEDDING_MODEL
elif len(contents) == 1:
filename = write_to_file(chunk_id, contents[0]).absolute().as_posix()
@ -249,7 +255,7 @@ def embed(
mime_type: str,
content: bytes | str | pathlib.Path,
metadata: dict[str, Any] = {},
) -> tuple[str, list[Embedding]]:
) -> tuple[str, list[Chunk]]:
modality = get_modality(mime_type)
pages = extract.extract_content(mime_type, content)
chunks = [

View File

@ -1,13 +1,13 @@
from contextlib import contextmanager
import io
import logging
import pathlib
import tempfile
import pypandoc
import pymupdf # PyMuPDF
from PIL import Image
from typing import Any, TypedDict, Generator, Sequence
from contextlib import contextmanager
from typing import Any, Generator, Sequence, TypedDict, cast
import logging
import pymupdf # PyMuPDF
import pypandoc
from PIL import Image
logger = logging.getLogger(__name__)
@ -105,7 +105,7 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
if isinstance(content, bytes):
content = content.decode("utf-8")
return [{"contents": [content], "metadata": {}}]
return [{"contents": [cast(str, content)], "metadata": {}}]
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:

View File

@ -1,5 +1,5 @@
import logging
from typing import TypedDict, NotRequired
from typing import TypedDict, NotRequired, cast
from bs4 import BeautifulSoup, Tag
import requests
@ -59,7 +59,7 @@ def extract_smbc(url: str) -> ComicInfo:
"title": title,
"image_url": image_url,
"published_date": published_date,
"url": comic_url or url,
"url": cast(str, comic_url or url),
}

View File

@ -28,10 +28,10 @@ class EmailMessage(TypedDict):
hash: bytes
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
RawEmailResponse = tuple[str | None, bytes]
def extract_recipients(msg: email.message.Message) -> list[str]:
def extract_recipients(msg: email.message.Message) -> list[str]: # type: ignore
"""
Extract email recipients from message headers.
@ -50,7 +50,7 @@ def extract_recipients(msg: email.message.Message) -> list[str]:
]
def extract_date(msg: email.message.Message) -> datetime | None:
def extract_date(msg: email.message.Message) -> datetime | None: # type: ignore
"""
Parse date from email header.
@ -68,7 +68,7 @@ def extract_date(msg: email.message.Message) -> datetime | None:
return None
def extract_body(msg: email.message.Message) -> str:
def extract_body(msg: email.message.Message) -> str: # type: ignore
"""
Extract plain text body from email message.
@ -99,7 +99,7 @@ def extract_body(msg: email.message.Message) -> str:
return body
def extract_attachments(msg: email.message.Message) -> list[Attachment]:
def extract_attachments(msg: email.message.Message) -> list[Attachment]: # type: ignore
"""
Extract attachment metadata and content from email.

View File

@ -1,5 +1,5 @@
import logging
from typing import Any, cast, Iterator, Sequence
from typing import Any, cast, Generator, Sequence
import qdrant_client
from qdrant_client.http import models as qdrant_models
@ -24,7 +24,7 @@ def get_qdrant_client() -> qdrant_client.QdrantClient:
return qdrant_client.QdrantClient(
host=settings.QDRANT_HOST,
port=settings.QDRANT_PORT,
grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None,
grpc_port=settings.QDRANT_GRPC_PORT or 6334,
prefer_grpc=settings.QDRANT_PREFER_GRPC,
api_key=settings.QDRANT_API_KEY,
timeout=settings.QDRANT_TIMEOUT,
@ -80,7 +80,8 @@ def ensure_collection_exists(
def initialize_collections(
client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None
client: qdrant_client.QdrantClient,
collections: dict[str, Collection] | None = None,
) -> None:
"""
Initialize all required collections in Qdrant.
@ -122,7 +123,7 @@ def upsert_vectors(
collection_name: str,
ids: list[str],
vectors: list[Vector],
payloads: list[dict[str, Any]] = None,
payloads: list[dict[str, Any]] | None = None,
) -> None:
"""Upsert vectors into a collection.
@ -147,7 +148,7 @@ def upsert_vectors(
client.upsert(
collection_name=collection_name,
points=points,
points=points, # type: ignore
)
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
@ -157,7 +158,7 @@ def search_vectors(
client: qdrant_client.QdrantClient,
collection_name: str,
query_vector: Vector,
filter_params: dict = None,
filter_params: dict | None = None,
limit: int = 10,
) -> list[qdrant_models.ScoredPoint]:
"""Search for similar vectors in a collection.
@ -200,7 +201,7 @@ def delete_points(
client.delete(
collection_name=collection_name,
points_selector=qdrant_models.PointIdsList(
points=ids,
points=ids, # type: ignore
),
)
@ -226,7 +227,7 @@ def get_collection_info(
def batch_ids(
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
) -> Iterator[list[str]]:
) -> Generator[list[str], None, None]:
"""Iterate over all IDs in a collection."""
offset = None
while resp := client.scroll(
@ -236,7 +237,7 @@ def batch_ids(
limit=batch_size,
):
points, offset = resp
yield [point.id for point in points]
yield [cast(str, point.id) for point in points]
if not offset:
return

View File

@ -21,7 +21,11 @@ async def make_request(
) -> httpx.Response:
async with httpx.AsyncClient() as client:
return await client.request(
method, f"{SERVER}/{path}", data=data, json=json, files=files
method,
f"{SERVER}/{path}",
data=data,
json=json,
files=files, # type: ignore
)

View File

@ -31,7 +31,7 @@ app.conf.update(
)
@app.on_after_configure.connect
@app.on_after_configure.connect # type: ignore
def ensure_qdrant_initialised(sender, **_):
from memory.common import qdrant

View File

@ -1,24 +1,26 @@
import hashlib
import imaplib
import logging
import pathlib
import re
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from typing import Generator, Callable
import pathlib
from sqlalchemy.orm import Session
from collections import defaultdict
from memory.common import settings, embedding, qdrant
from typing import Callable, Generator, Sequence, cast
from sqlalchemy.orm import Session, scoped_session
from memory.common import embedding, qdrant, settings
from memory.common.db.models import (
EmailAccount,
EmailAttachment,
MailMessage,
SourceItem,
EmailAttachment,
)
from memory.common.parsers.email import (
Attachment,
parse_email_message,
RawEmailResponse,
parse_email_message,
)
logger = logging.getLogger(__name__)
@ -89,7 +91,7 @@ def process_attachments(
def create_mail_message(
db_session: Session,
db_session: Session | scoped_session,
tags: list[str],
folder: str,
raw_email: str,
@ -136,7 +138,7 @@ def create_mail_message(
def does_message_exist(
db_session: Session, message_id: str, message_hash: bytes
db_session: Session | scoped_session, message_id: str, message_hash: bytes
) -> bool:
"""
Check if a message already exists in the database.
@ -167,7 +169,7 @@ def does_message_exist(
def check_message_exists(
db: Session, account_id: int, message_id: str, raw_email: str
db: Session | scoped_session, account_id: int, message_id: str, raw_email: str
) -> bool:
account = db.query(EmailAccount).get(account_id)
if not account:
@ -181,7 +183,9 @@ def check_message_exists(
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
def extract_email_uid(
msg_data: Sequence[tuple[bytes, bytes]],
) -> tuple[str | None, bytes]:
"""
Extract the UID and raw email data from the message data.
"""
@ -199,7 +203,7 @@ def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None:
logger.error(f"Error fetching message {uid}")
return None
return extract_email_uid(msg_data)
return extract_email_uid(msg_data) # type: ignore
except Exception as e:
logger.error(f"Error processing message {uid}: {str(e)}")
return None
@ -248,7 +252,7 @@ def process_folder(
folder: str,
account: EmailAccount,
since_date: datetime,
processor: Callable[[int, str, str, bytes], int | None],
processor: Callable[[int, str, str, str], int | None],
) -> dict:
"""
Process a single folder from an email account.
@ -272,7 +276,7 @@ def process_folder(
for uid, raw_email in emails:
try:
task = processor(
account_id=account.id,
account_id=account.id, # type: ignore
message_id=uid,
folder=folder,
raw_email=raw_email.decode("utf-8", errors="replace"),
@ -296,9 +300,11 @@ def process_folder(
@contextmanager
def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]:
conn = imaplib.IMAP4_SSL(host=account.imap_server, port=account.imap_port)
conn = imaplib.IMAP4_SSL(
host=cast(str, account.imap_server), port=cast(int, account.imap_port)
)
try:
conn.login(account.username, account.password)
conn.login(cast(str, account.username), cast(str, account.password))
yield conn
finally:
# Always try to logout and close the connection
@ -318,15 +324,15 @@ def vectorize_email(email: MailMessage):
)
email.chunks = chunks
if chunks:
vector_ids = [c.id for c in chunks]
vector_ids = [cast(str, c.id) for c in chunks]
vectors = [c.vector for c in chunks]
metadata = [c.item_metadata for c in chunks]
qdrant.upsert_vectors(
client=qdrant_client,
collection_name="mail",
ids=vector_ids,
vectors=vectors,
payloads=metadata,
vectors=vectors, # type: ignore
payloads=metadata, # type: ignore
)
embeds = defaultdict(list)
@ -356,7 +362,7 @@ def vectorize_email(email: MailMessage):
payloads=metadata,
)
email.embed_status = "STORED"
email.embed_status = "STORED" # type: ignore
for attachment in email.attachments:
attachment.embed_status = "STORED"

View File

@ -1,7 +1,7 @@
import hashlib
import logging
from datetime import datetime
from typing import Callable
from typing import Callable, cast
import feedparser
import requests
@ -35,7 +35,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
logger.error(f"Failed to fetch or parse {rss_url}: {e}")
return set()
urls = {item.get("link") or item.get("id") for item in feed.entries}
urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries}
with make_session() as session:
known = {
@ -46,7 +46,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
)
}
return urls - known
return cast(set[str], urls - known)
def fetch_new_comics(
@ -56,7 +56,7 @@ def fetch_new_comics(
for url in new_urls:
data = parser(url) | {"author": base_url, "url": url}
sync_comic.delay(**data)
sync_comic.delay(**data) # type: ignore
return new_urls
@ -108,7 +108,7 @@ def sync_comic(
client=qdrant.get_qdrant_client(),
collection_name="comic",
ids=[str(chunk.id)],
vectors=[chunk.vector],
vectors=[chunk.vector], # type: ignore
payloads=[comic.as_payload()],
)
@ -130,8 +130,8 @@ def sync_xkcd() -> set[str]:
@app.task(name=SYNC_ALL_COMICS)
def sync_all_comics():
"""Synchronize all active comics."""
sync_smbc.delay()
sync_xkcd.delay()
sync_smbc.delay() # type: ignore
sync_xkcd.delay() # type: ignore
@app.task(name="memory.workers.tasks.comic.full_sync_comic")
@ -141,8 +141,8 @@ def trigger_comic_sync():
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")
if link := soup.find("a", attrs={"class", "cc-prev"}):
return link.attrs["href"]
if link := soup.find("a", attrs={"class": "cc-prev"}):
return link.attrs["href"] # type: ignore
return None
next_url = "https://www.smbc-comics.com"
@ -155,7 +155,7 @@ def trigger_comic_sync():
data = comics.extract_smbc(next_url) | {
"author": "https://www.smbc-comics.com/"
}
sync_comic.delay(**data)
sync_comic.delay(**data) # type: ignore
except Exception as e:
logger.error(f"failed to sync {next_url}: {e}")
urls.append(next_url)
@ -167,6 +167,6 @@ def trigger_comic_sync():
url = f"{BASE_XKCD_URL}/{i}"
try:
data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"}
sync_comic.delay(**data)
sync_comic.delay(**data) # type: ignore
except Exception as e:
logger.error(f"failed to sync {url}: {e}")

View File

@ -1,6 +1,6 @@
import logging
from datetime import datetime
from typing import cast
from memory.common.db.connection import make_session
from memory.common.db.models import EmailAccount
from memory.workers.celery_app import app
@ -71,7 +71,7 @@ def process_message(
for chunk in attachment.chunks:
logger.info(f" - {chunk.id}")
return mail_message.id
return cast(int, mail_message.id)
@app.task(name=SYNC_ACCOUNT)
@ -89,15 +89,17 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
with make_session() as db:
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
if not account or not account.active:
if not account or not cast(bool, account.active):
logger.warning(f"Account {account_id} not found or inactive")
return {"error": "Account not found or inactive"}
folders_to_process: list[str] = account.folders or ["INBOX"]
folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"]
if since_date:
cutoff_date = datetime.fromisoformat(since_date)
else:
cutoff_date: datetime = account.last_sync_at or datetime(1970, 1, 1)
cutoff_date: datetime = cast(datetime, account.last_sync_at) or datetime(
1970, 1, 1
)
messages_found = 0
new_messages = 0
@ -106,9 +108,9 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
def process_message_wrapper(
account_id: int, message_id: str, folder: str, raw_email: str
) -> int | None:
if check_message_exists(db, account_id, message_id, raw_email):
if check_message_exists(db, account_id, message_id, raw_email): # type: ignore
return None
return process_message.delay(account_id, message_id, folder, raw_email)
return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore
try:
with imap_connection(account) as conn:
@ -121,7 +123,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
new_messages += folder_stats["new_messages"]
errors += folder_stats["errors"]
account.last_sync_at = datetime.now()
account.last_sync_at = datetime.now() # type: ignore
db.commit()
except Exception as e:
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
@ -152,7 +154,7 @@ def sync_all_accounts() -> list[dict]:
{
"account_id": account.id,
"email": account.email_address,
"task_id": sync_account.delay(account.id).id,
"task_id": sync_account.delay(account.id).id, # type: ignore
}
for account in active_accounts
]

View File

@ -48,7 +48,7 @@ def clean_collection(collection: str) -> dict[str, int]:
def clean_all_collections():
logger.info("Cleaning all collections")
for collection in embedding.ALL_COLLECTIONS:
clean_collection.delay(collection)
clean_collection.delay(collection) # type: ignore
@app.task(name=REINGEST_CHUNK)
@ -97,7 +97,7 @@ def check_batch(batch: Sequence[Chunk]) -> dict:
for chunk in chunks:
if str(chunk.id) in missing:
reingest_chunk.delay(str(chunk.id), collection)
reingest_chunk.delay(str(chunk.id), collection) # type: ignore
else:
chunk.checked_at = datetime.now()
@ -132,7 +132,9 @@ def reingest_missing_chunks(batch_size: int = 1000):
.filter(Chunk.checked_at < since)
.options(
contains_eager(Chunk.source).load_only(
SourceItem.id, SourceItem.modality, SourceItem.tags
SourceItem.id, # type: ignore
SourceItem.modality, # type: ignore
SourceItem.tags, # type: ignore
)
)
.order_by(Chunk.id)

View File

@ -252,7 +252,7 @@ def test_parse_simple_email():
"hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87"
+ b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1",
}
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400 # type: ignore
def test_parse_email_with_attachments():

View File

@ -1,7 +1,7 @@
import pathlib
import uuid
from unittest.mock import Mock, patch
from typing import cast
import pytest
from PIL import Image
@ -72,12 +72,12 @@ def test_embed_mixed(mock_embed):
def test_embed_page_text_only(mock_embed):
page = {"contents": ["text1", "text2"]}
assert embed_page(page) == [[0], [1]]
assert embed_page(page) == [[0], [1]] # type: ignore
def test_embed_page_mixed_content(mock_embed):
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
assert embed_page(page) == [[0]]
assert embed_page(page) == [[0]] # type: ignore
def test_embed(mock_embed):
@ -91,12 +91,12 @@ def test_embed(mock_embed):
assert modality == "text"
assert [
{
"id": c.id,
"file_path": c.file_path,
"content": c.content,
"embedding_model": c.embedding_model,
"vector": c.vector,
"item_metadata": c.item_metadata,
"id": c.id, # type: ignore
"file_path": c.file_path, # type: ignore
"content": c.content, # type: ignore
"embedding_model": c.embedding_model, # type: ignore
"vector": c.vector, # type: ignore
"item_metadata": c.item_metadata, # type: ignore
}
for c in chunks
] == [
@ -128,7 +128,7 @@ def test_write_to_file_bytes(mock_file_storage):
chunk_id = "test-chunk-id"
content = b"These are test bytes"
file_path = write_to_file(chunk_id, content)
file_path = write_to_file(chunk_id, content) # type: ignore
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin"
assert file_path.exists()
@ -140,7 +140,7 @@ def test_write_to_file_image(mock_file_storage):
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
chunk_id = "test-chunk-id"
file_path = write_to_file(chunk_id, img)
file_path = write_to_file(chunk_id, img) # type: ignore
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png"
assert file_path.exists()
@ -155,7 +155,7 @@ def test_write_to_file_unsupported_type(mock_file_storage):
content = 123 # Integer is not a supported type
with pytest.raises(ValueError, match="Unsupported content type"):
write_to_file(chunk_id, content)
write_to_file(chunk_id, content) # type: ignore
def test_make_chunk_text_only(mock_file_storage, db_session):
@ -170,12 +170,12 @@ def test_make_chunk_text_only(mock_file_storage, db_session):
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
):
chunk = make_chunk(page, vector, metadata)
chunk = make_chunk(page, vector, metadata) # type: ignore
assert chunk.id == "00000000-0000-0000-0000-000000000001"
assert chunk.content == "text content 1\n\ntext content 2"
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
assert chunk.file_path is None
assert chunk.embedding_model == settings.TEXT_EMBEDDING_MODEL
assert cast(str, chunk.embedding_model) == settings.TEXT_EMBEDDING_MODEL
assert chunk.vector == vector
assert chunk.item_metadata == metadata
@ -190,19 +190,19 @@ def test_make_chunk_single_image(mock_file_storage, db_session):
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
):
chunk = make_chunk(page, vector, metadata)
chunk = make_chunk(page, vector, metadata) # type: ignore
assert chunk.id == "00000000-0000-0000-0000-000000000002"
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
assert chunk.content is None
assert chunk.file_path == str(
assert cast(str, chunk.file_path) == str(
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png",
)
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
assert chunk.vector == vector
assert chunk.item_metadata == metadata
# Verify the file exists
assert pathlib.Path(chunk.file_path[0]).exists()
assert pathlib.Path(cast(str, chunk.file_path)).exists()
def test_make_chunk_mixed_content(mock_file_storage, db_session):
@ -215,14 +215,14 @@ def test_make_chunk_mixed_content(mock_file_storage, db_session):
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
):
chunk = make_chunk(page, vector, metadata)
chunk = make_chunk(page, vector, metadata) # type: ignore
assert chunk.id == "00000000-0000-0000-0000-000000000003"
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
assert chunk.content is None
assert chunk.file_path == str(
assert cast(str, chunk.file_path) == str(
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*",
)
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
assert chunk.vector == vector
assert chunk.item_metadata == metadata

View File

@ -87,7 +87,7 @@ def test_extract_image_with_path(tmp_path):
img.save(img_path)
(page,) = extract_image(img_path)
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
assert page["metadata"] == {}
@ -98,7 +98,7 @@ def test_extract_image_with_bytes():
img_bytes = buffer.getvalue()
(page,) = extract_image(img_bytes)
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
assert page["metadata"] == {}

View File

@ -33,7 +33,10 @@ def test_ensure_collection_exists_existing(mock_qdrant_client):
def test_ensure_collection_exists_new(mock_qdrant_client):
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
status_code=404, reason_phrase="asd", content=b"asd", headers=None
status_code=404,
reason_phrase="asd",
content=b"asd",
headers=None, # type: ignore
)
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)

View File

@ -1,24 +1,26 @@
import base64
import pathlib
from datetime import datetime
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
from memory.common import embedding, settings
from memory.common.db.models import (
MailMessage,
EmailAttachment,
EmailAccount,
EmailAttachment,
MailMessage,
)
from memory.common import settings
from memory.common import embedding
from memory.common.parsers.email import Attachment
from memory.workers.email import (
extract_email_uid,
create_mail_message,
extract_email_uid,
fetch_email,
fetch_email_since,
process_folder,
process_attachment,
process_attachments,
process_folder,
vectorize_email,
)
@ -45,7 +47,9 @@ def mock_uuid4():
(100, 100, "<test@example.com>"),
],
)
def test_process_attachment_inline(attachment_size, max_inline_size, message_id):
def test_process_attachment_inline(
attachment_size: int, max_inline_size: int, message_id: str
):
attachment = {
"filename": "test.txt",
"content_type": "text/plain",
@ -60,10 +64,12 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
)
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
result = process_attachment(attachment, message)
result = process_attachment(cast(Attachment, attachment), message)
assert result is not None
assert result.content == attachment["content"].decode("utf-8", errors="replace")
assert cast(str, result.content) == attachment["content"].decode(
"utf-8", errors="replace"
)
assert result.filename is None
@ -90,11 +96,11 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
folder="INBOX",
)
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
result = process_attachment(attachment, message)
result = process_attachment(cast(Attachment, attachment), message)
assert result is not None
assert not result.content
assert result.filename == str(
assert not cast(str, result.content)
assert cast(str, result.filename) == str(
settings.FILE_STORAGE_DIR
/ "sender_example_com"
/ "INBOX"
@ -125,11 +131,11 @@ def test_process_attachment_write_error():
patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 10),
patch.object(pathlib.Path, "write_bytes", mock_write_bytes),
):
assert process_attachment(attachment, message) is None
assert process_attachment(cast(Attachment, attachment), message) is None
def test_process_attachments_empty():
assert process_attachments([], "<test@example.com>") == []
assert process_attachments([], MagicMock()) == []
def test_process_attachments_mixed():
@ -167,16 +173,16 @@ def test_process_attachments_mixed():
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 50):
# Process attachments
results = process_attachments(attachments, message)
results = process_attachments(cast(list[Attachment], attachments), message)
# Verify we have all attachments processed
assert len(results) == 3
assert results[0].content == "a" * 20
assert results[2].content == "c" * 30
assert cast(str, results[0].content) == "a" * 20
assert cast(str, results[2].content) == "c" * 30
# Verify large attachment has a path
assert results[1].filename == str(
assert cast(str, results[1].filename) == str(
settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" / "large.txt"
)
@ -239,12 +245,12 @@ def test_create_mail_message(db_session):
# Verify the mail message was created correctly
assert isinstance(mail_message, MailMessage)
assert mail_message.message_id == "321"
assert mail_message.subject == "Test Subject"
assert mail_message.sender == "sender@example.com"
assert mail_message.recipients == ["recipient@example.com"]
assert cast(str, mail_message.message_id) == "321"
assert cast(str, mail_message.subject) == "Test Subject"
assert cast(str, mail_message.sender) == "sender@example.com"
assert cast(list[str], mail_message.recipients) == ["recipient@example.com"]
assert mail_message.sent_at.isoformat()[:-6] == "2023-01-01T12:00:00"
assert mail_message.content == raw_email
assert cast(str, mail_message.content) == raw_email
assert mail_message.body == "Test body content\n"
assert mail_message.attachments == attachments
@ -275,7 +281,7 @@ def test_fetch_email_since(email_provider):
assert len(result) == 2
# Verify content of fetched emails
uids = sorted([uid for uid, _ in result])
uids = sorted([uid or "" for uid, _ in result])
assert uids == ["101", "102"]
# Test with a folder that doesn't exist
@ -342,7 +348,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
db_session.add(mail_message)
db_session.flush()
assert mail_message.embed_status == "RAW"
assert cast(str, mail_message.embed_status) == "RAW"
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
vectorize_email(mail_message)
@ -351,7 +357,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
]
db_session.commit()
assert mail_message.embed_status == "STORED"
assert cast(str, mail_message.embed_status) == "STORED"
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
@ -418,6 +424,6 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
]
db_session.commit()
assert mail_message.embed_status == "STORED"
assert attachment1.embed_status == "STORED"
assert attachment2.embed_status == "STORED"
assert cast(str, mail_message.embed_status) == "STORED"
assert cast(str, attachment1.embed_status) == "STORED"
assert cast(str, attachment2.embed_status) == "STORED"

View File

@ -8,115 +8,130 @@ class MockEmailProvider:
Mock IMAP email provider for integration testing.
Can be initialized with predefined emails to return.
"""
def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] = None):
def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] | None = None):
"""
Initialize with a dictionary of emails organized by folder.
Args:
emails_by_folder: A dictionary mapping folder names to lists of email dictionaries.
Each email dict should have: 'uid', 'flags', 'date', 'from', 'to', 'subject',
Each email dict should have: 'uid', 'flags', 'date', 'from', 'to', 'subject',
'message_id', 'body', and optionally 'attachments'.
"""
self.emails_by_folder = emails_by_folder or {
"INBOX": [],
"Sent": [],
"Archive": []
"Archive": [],
}
self.current_folder = None
self.is_connected = False
def _generate_email_string(self, email_data: dict[str, Any]) -> str:
"""Generate a raw email string from the provided email data."""
msg = email.message.EmailMessage()
msg = email.message.EmailMessage() # type: ignore
msg["From"] = email_data.get("from", "sender@example.com")
msg["To"] = email_data.get("to", "recipient@example.com")
msg["Subject"] = email_data.get("subject", "Test Subject")
msg["Message-ID"] = email_data.get("message_id", f"<test-{email_data['uid']}@example.com>")
msg["Date"] = email_data.get("date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000"))
msg["Message-ID"] = email_data.get(
"message_id", f"<test-{email_data['uid']}@example.com>"
)
msg["Date"] = email_data.get(
"date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000")
)
# Set the body content
msg.set_content(email_data.get("body", f"This is test email body {email_data['uid']}"))
msg.set_content(
email_data.get("body", f"This is test email body {email_data['uid']}")
)
# Add attachments if present
for attachment in email_data.get("attachments", []):
if isinstance(attachment, dict) and "filename" in attachment and "content" in attachment:
if (
isinstance(attachment, dict)
and "filename" in attachment
and "content" in attachment
):
msg.add_attachment(
attachment["content"],
maintype=attachment.get("maintype", "application"),
subtype=attachment.get("subtype", "octet-stream"),
filename=attachment["filename"]
filename=attachment["filename"],
)
return msg.as_string()
def login(self, username: str, password: str) -> tuple[str, list[bytes]]:
"""Mock login method."""
self.is_connected = True
return ('OK', [b'Login successful'])
return ("OK", [b"Login successful"])
def logout(self) -> tuple[str, list[bytes]]:
"""Mock logout method."""
self.is_connected = False
return ('OK', [b'Logout successful'])
return ("OK", [b"Logout successful"])
def select(self, folder: str, readonly: bool = False) -> tuple[str, list[bytes]]:
"""
Select a folder and make it the current active folder.
Args:
folder: Folder name to select
readonly: Whether to open in readonly mode
Returns:
IMAP-style response with message count
"""
folder_name = folder.decode() if isinstance(folder, bytes) else folder
self.current_folder = folder_name
message_count = len(self.emails_by_folder.get(folder_name, []))
return ('OK', [str(message_count).encode()])
def list(self, directory: str = '', pattern: str = '*') -> tuple[str, list[bytes]]:
return ("OK", [str(message_count).encode()])
def list(self, directory: str = "", pattern: str = "*") -> tuple[str, list[bytes]]:
"""List available folders."""
folders = []
for folder in self.emails_by_folder.keys():
folders.append(f'(\\HasNoChildren) "/" "{folder}"'.encode())
return ('OK', folders)
return ("OK", folders)
def search(self, charset, criteria):
"""
Handle SEARCH command to find email UIDs.
Args:
charset: Character set (ignored in mock)
criteria: Search criteria (ignored in mock, we return all emails)
Returns:
All email UIDs in the current folder
"""
if not self.current_folder or self.current_folder not in self.emails_by_folder:
return ('OK', [b''])
uids = [str(email["uid"]).encode() for email in self.emails_by_folder[self.current_folder]]
return ('OK', [b' '.join(uids) if uids else b''])
def fetch(self, message_set, message_parts) -> tuple[str, list]:
return ("OK", [b""])
uids = [
str(email["uid"]).encode()
for email in self.emails_by_folder[self.current_folder]
]
return ("OK", [b" ".join(uids) if uids else b""])
def fetch(self, message_set: bytes | str, message_parts: bytes | str):
"""
Handle FETCH command to retrieve email data.
Args:
message_set: Message numbers/UIDs to fetch
message_parts: Parts of the message to fetch
Returns:
Email data in IMAP format
"""
if not self.current_folder or self.current_folder not in self.emails_by_folder:
return ('OK', [None])
return ("OK", [None])
# For simplicity, we'll just match the UID with the ID provided
uid = int(message_set.decode() if isinstance(message_set, bytes) else message_set)
uid = int(
message_set.decode() if isinstance(message_set, bytes) else message_set
)
# Find the email with the matching UID
for email_data in self.emails_by_folder[self.current_folder]:
if email_data["uid"] == uid:
@ -124,14 +139,16 @@ class MockEmailProvider:
email_string = self._generate_email_string(email_data)
flags = email_data.get("flags", "\\Seen")
date = email_data.get("date_internal", "01-Jan-2023 00:00:00 +0000")
# Format the response as expected by the IMAP client
response = [(
f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 '
f'{{{len(email_string)}}}'.encode(),
email_string.encode()
)]
return ('OK', response)
response = [
(
f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 '
f"{{{len(email_string)}}}".encode(),
email_string.encode(),
)
]
return ("OK", response)
# No matching email found
return ('NO', [b'Email not found'])
return ("NO", [b"Email not found"])