mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 21:34:42 +02:00
fix linting
This commit is contained in:
parent
29a2a7ba3a
commit
4f1ca777e9
14
setup.py
14
setup.py
@ -4,19 +4,19 @@ from setuptools import setup, find_namespace_packages
|
|||||||
|
|
||||||
def read_requirements(filename: str) -> list[str]:
|
def read_requirements(filename: str) -> list[str]:
|
||||||
"""Read requirements from file, ignoring comments and -r directives."""
|
"""Read requirements from file, ignoring comments and -r directives."""
|
||||||
filename = pathlib.Path(filename)
|
path = pathlib.Path(filename)
|
||||||
return [
|
return [
|
||||||
line.strip()
|
line.strip()
|
||||||
for line in filename.read_text().splitlines()
|
for line in path.read_text().splitlines()
|
||||||
if line.strip() and not line.strip().startswith(('#', '-r'))
|
if line.strip() and not line.strip().startswith(("#", "-r"))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Read requirements files
|
# Read requirements files
|
||||||
common_requires = read_requirements('requirements-common.txt')
|
common_requires = read_requirements("requirements-common.txt")
|
||||||
api_requires = read_requirements('requirements-api.txt')
|
api_requires = read_requirements("requirements-api.txt")
|
||||||
workers_requires = read_requirements('requirements-workers.txt')
|
workers_requires = read_requirements("requirements-workers.txt")
|
||||||
dev_requires = read_requirements('requirements-dev.txt')
|
dev_requires = read_requirements("requirements-dev.txt")
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="memory",
|
name="memory",
|
||||||
|
@ -4,10 +4,9 @@ Database models for the knowledge base system.
|
|||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from email.message import EmailMessage
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar, cast
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
ARRAY,
|
ARRAY,
|
||||||
@ -31,7 +30,7 @@ from sqlalchemy.ext.declarative import declarative_base
|
|||||||
from sqlalchemy.orm import relationship, Session
|
from sqlalchemy.orm import relationship, Session
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.parsers.email import parse_email_message
|
from memory.common.parsers.email import parse_email_message, EmailMessage
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
@ -113,10 +112,10 @@ class Chunk(Base):
|
|||||||
@property
|
@property
|
||||||
def data(self) -> list[bytes | str | Image.Image]:
|
def data(self) -> list[bytes | str | Image.Image]:
|
||||||
if self.file_path is None:
|
if self.file_path is None:
|
||||||
return [self.content]
|
return [cast(str, self.content)]
|
||||||
|
|
||||||
path = pathlib.Path(self.file_path.replace("/app/", ""))
|
path = pathlib.Path(self.file_path.replace("/app/", ""))
|
||||||
if self.file_path.endswith("*"):
|
if cast(str, self.file_path).endswith("*"):
|
||||||
files = list(path.parent.glob(path.name))
|
files = list(path.parent.glob(path.name))
|
||||||
else:
|
else:
|
||||||
files = [path]
|
files = [path]
|
||||||
@ -182,7 +181,7 @@ class SourceItem(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def display_contents(self) -> str | None:
|
def display_contents(self) -> str | None:
|
||||||
return self.content or self.filename
|
return cast(str | None, self.content) or cast(str | None, self.filename)
|
||||||
|
|
||||||
|
|
||||||
class MailMessage(SourceItem):
|
class MailMessage(SourceItem):
|
||||||
@ -217,8 +216,8 @@ class MailMessage(SourceItem):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def attachments_path(self) -> Path:
|
def attachments_path(self) -> Path:
|
||||||
clean_sender = clean_filename(self.sender)
|
clean_sender = clean_filename(cast(str, self.sender))
|
||||||
clean_folder = clean_filename(self.folder or "INBOX")
|
clean_folder = clean_filename(cast(str | None, self.folder) or "INBOX")
|
||||||
return Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder
|
return Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder
|
||||||
|
|
||||||
def safe_filename(self, filename: str) -> Path:
|
def safe_filename(self, filename: str) -> Path:
|
||||||
@ -237,12 +236,12 @@ class MailMessage(SourceItem):
|
|||||||
"recipients": self.recipients,
|
"recipients": self.recipients,
|
||||||
"folder": self.folder,
|
"folder": self.folder,
|
||||||
"tags": self.tags + [self.sender] + self.recipients,
|
"tags": self.tags + [self.sender] + self.recipients,
|
||||||
"date": self.sent_at and self.sent_at.isoformat() or None,
|
"date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parsed_content(self) -> EmailMessage:
|
def parsed_content(self) -> EmailMessage:
|
||||||
return parse_email_message(self.content, self.message_id)
|
return parse_email_message(cast(str, self.content), cast(str, self.message_id))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def body(self) -> str:
|
def body(self) -> str:
|
||||||
@ -300,7 +299,7 @@ class EmailAttachment(SourceItem):
|
|||||||
"filename": self.filename,
|
"filename": self.filename,
|
||||||
"content_type": self.mime_type,
|
"content_type": self.mime_type,
|
||||||
"size": self.size,
|
"size": self.size,
|
||||||
"created_at": self.created_at and self.created_at.isoformat() or None,
|
"created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
|
||||||
"mail_message_id": self.mail_message_id,
|
"mail_message_id": self.mail_message_id,
|
||||||
"source_id": self.id,
|
"source_id": self.id,
|
||||||
"tags": self.tags,
|
"tags": self.tags,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Iterable, Literal, NotRequired, TypedDict
|
from typing import Any, Iterable, Literal, NotRequired, TypedDict, cast
|
||||||
|
|
||||||
import voyageai
|
import voyageai
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -21,7 +21,6 @@ CHARS_PER_TOKEN = 4
|
|||||||
|
|
||||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||||
Vector = list[float]
|
Vector = list[float]
|
||||||
Embedding = tuple[str, Vector, dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class Collection(TypedDict):
|
class Collection(TypedDict):
|
||||||
@ -133,16 +132,18 @@ def get_modality(mime_type: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def embed_chunks(
|
def embed_chunks(
|
||||||
chunks: list[extract.MulitmodalChunk],
|
chunks: list[str] | list[list[extract.MulitmodalChunk]],
|
||||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
vo = voyageai.Client()
|
vo = voyageai.Client() # type: ignore
|
||||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||||
return vo.multimodal_embed(
|
return vo.multimodal_embed(
|
||||||
chunks, model=model, input_type=input_type
|
chunks, # type: ignore
|
||||||
|
model=model,
|
||||||
|
input_type=input_type,
|
||||||
).embeddings
|
).embeddings
|
||||||
return vo.embed(chunks, model=model, input_type=input_type).embeddings
|
return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def embed_text(
|
def embed_text(
|
||||||
@ -162,7 +163,7 @@ def embed_text(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return embed_chunks(chunks, model, input_type)
|
return embed_chunks(chunks, model, input_type)
|
||||||
except voyageai.error.InvalidRequestError as e:
|
except voyageai.error.InvalidRequestError as e: # type: ignore
|
||||||
logger.error(f"Error embedding text: {e}")
|
logger.error(f"Error embedding text: {e}")
|
||||||
logger.debug(f"Text: {texts}")
|
logger.debug(f"Text: {texts}")
|
||||||
raise
|
raise
|
||||||
@ -179,7 +180,7 @@ def embed_mixed(
|
|||||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
|
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
return [
|
return [
|
||||||
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
|
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
|
||||||
@ -190,11 +191,16 @@ def embed_mixed(
|
|||||||
return embed_chunks([chunks], model, input_type)
|
return embed_chunks([chunks], model, input_type)
|
||||||
|
|
||||||
|
|
||||||
def embed_page(page: dict[str, Any]) -> list[Vector]:
|
def embed_page(page: extract.Page) -> list[Vector]:
|
||||||
contents = page["contents"]
|
contents = page["contents"]
|
||||||
if all(isinstance(c, str) for c in contents):
|
if all(isinstance(c, str) for c in contents):
|
||||||
return embed_text(contents, model=settings.TEXT_EMBEDDING_MODEL)
|
return embed_text(
|
||||||
return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL)
|
cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL
|
||||||
|
)
|
||||||
|
return embed_mixed(
|
||||||
|
cast(list[extract.MulitmodalChunk], contents),
|
||||||
|
model=settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
||||||
@ -224,7 +230,7 @@ def make_chunk(
|
|||||||
contents = page["contents"]
|
contents = page["contents"]
|
||||||
content, filename = None, None
|
content, filename = None, None
|
||||||
if all(isinstance(c, str) for c in contents):
|
if all(isinstance(c, str) for c in contents):
|
||||||
content = "\n\n".join(contents)
|
content = "\n\n".join(cast(list[str], contents))
|
||||||
model = settings.TEXT_EMBEDDING_MODEL
|
model = settings.TEXT_EMBEDDING_MODEL
|
||||||
elif len(contents) == 1:
|
elif len(contents) == 1:
|
||||||
filename = write_to_file(chunk_id, contents[0]).absolute().as_posix()
|
filename = write_to_file(chunk_id, contents[0]).absolute().as_posix()
|
||||||
@ -249,7 +255,7 @@ def embed(
|
|||||||
mime_type: str,
|
mime_type: str,
|
||||||
content: bytes | str | pathlib.Path,
|
content: bytes | str | pathlib.Path,
|
||||||
metadata: dict[str, Any] = {},
|
metadata: dict[str, Any] = {},
|
||||||
) -> tuple[str, list[Embedding]]:
|
) -> tuple[str, list[Chunk]]:
|
||||||
modality = get_modality(mime_type)
|
modality = get_modality(mime_type)
|
||||||
pages = extract.extract_content(mime_type, content)
|
pages = extract.extract_content(mime_type, content)
|
||||||
chunks = [
|
chunks = [
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
import pypandoc
|
from contextlib import contextmanager
|
||||||
import pymupdf # PyMuPDF
|
from typing import Any, Generator, Sequence, TypedDict, cast
|
||||||
from PIL import Image
|
|
||||||
from typing import Any, TypedDict, Generator, Sequence
|
|
||||||
|
|
||||||
import logging
|
import pymupdf # PyMuPDF
|
||||||
|
import pypandoc
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
|||||||
if isinstance(content, bytes):
|
if isinstance(content, bytes):
|
||||||
content = content.decode("utf-8")
|
content = content.decode("utf-8")
|
||||||
|
|
||||||
return [{"contents": [content], "metadata": {}}]
|
return [{"contents": [cast(str, content)], "metadata": {}}]
|
||||||
|
|
||||||
|
|
||||||
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import TypedDict, NotRequired
|
from typing import TypedDict, NotRequired, cast
|
||||||
|
|
||||||
from bs4 import BeautifulSoup, Tag
|
from bs4 import BeautifulSoup, Tag
|
||||||
import requests
|
import requests
|
||||||
@ -59,7 +59,7 @@ def extract_smbc(url: str) -> ComicInfo:
|
|||||||
"title": title,
|
"title": title,
|
||||||
"image_url": image_url,
|
"image_url": image_url,
|
||||||
"published_date": published_date,
|
"published_date": published_date,
|
||||||
"url": comic_url or url,
|
"url": cast(str, comic_url or url),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,10 +28,10 @@ class EmailMessage(TypedDict):
|
|||||||
hash: bytes
|
hash: bytes
|
||||||
|
|
||||||
|
|
||||||
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
|
RawEmailResponse = tuple[str | None, bytes]
|
||||||
|
|
||||||
|
|
||||||
def extract_recipients(msg: email.message.Message) -> list[str]:
|
def extract_recipients(msg: email.message.Message) -> list[str]: # type: ignore
|
||||||
"""
|
"""
|
||||||
Extract email recipients from message headers.
|
Extract email recipients from message headers.
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ def extract_recipients(msg: email.message.Message) -> list[str]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def extract_date(msg: email.message.Message) -> datetime | None:
|
def extract_date(msg: email.message.Message) -> datetime | None: # type: ignore
|
||||||
"""
|
"""
|
||||||
Parse date from email header.
|
Parse date from email header.
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ def extract_date(msg: email.message.Message) -> datetime | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_body(msg: email.message.Message) -> str:
|
def extract_body(msg: email.message.Message) -> str: # type: ignore
|
||||||
"""
|
"""
|
||||||
Extract plain text body from email message.
|
Extract plain text body from email message.
|
||||||
|
|
||||||
@ -99,7 +99,7 @@ def extract_body(msg: email.message.Message) -> str:
|
|||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
def extract_attachments(msg: email.message.Message) -> list[Attachment]:
|
def extract_attachments(msg: email.message.Message) -> list[Attachment]: # type: ignore
|
||||||
"""
|
"""
|
||||||
Extract attachment metadata and content from email.
|
Extract attachment metadata and content from email.
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, cast, Iterator, Sequence
|
from typing import Any, cast, Generator, Sequence
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from qdrant_client.http import models as qdrant_models
|
from qdrant_client.http import models as qdrant_models
|
||||||
@ -24,7 +24,7 @@ def get_qdrant_client() -> qdrant_client.QdrantClient:
|
|||||||
return qdrant_client.QdrantClient(
|
return qdrant_client.QdrantClient(
|
||||||
host=settings.QDRANT_HOST,
|
host=settings.QDRANT_HOST,
|
||||||
port=settings.QDRANT_PORT,
|
port=settings.QDRANT_PORT,
|
||||||
grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None,
|
grpc_port=settings.QDRANT_GRPC_PORT or 6334,
|
||||||
prefer_grpc=settings.QDRANT_PREFER_GRPC,
|
prefer_grpc=settings.QDRANT_PREFER_GRPC,
|
||||||
api_key=settings.QDRANT_API_KEY,
|
api_key=settings.QDRANT_API_KEY,
|
||||||
timeout=settings.QDRANT_TIMEOUT,
|
timeout=settings.QDRANT_TIMEOUT,
|
||||||
@ -80,7 +80,8 @@ def ensure_collection_exists(
|
|||||||
|
|
||||||
|
|
||||||
def initialize_collections(
|
def initialize_collections(
|
||||||
client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None
|
client: qdrant_client.QdrantClient,
|
||||||
|
collections: dict[str, Collection] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize all required collections in Qdrant.
|
Initialize all required collections in Qdrant.
|
||||||
@ -122,7 +123,7 @@ def upsert_vectors(
|
|||||||
collection_name: str,
|
collection_name: str,
|
||||||
ids: list[str],
|
ids: list[str],
|
||||||
vectors: list[Vector],
|
vectors: list[Vector],
|
||||||
payloads: list[dict[str, Any]] = None,
|
payloads: list[dict[str, Any]] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Upsert vectors into a collection.
|
"""Upsert vectors into a collection.
|
||||||
|
|
||||||
@ -147,7 +148,7 @@ def upsert_vectors(
|
|||||||
|
|
||||||
client.upsert(
|
client.upsert(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
points=points,
|
points=points, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
|
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
|
||||||
@ -157,7 +158,7 @@ def search_vectors(
|
|||||||
client: qdrant_client.QdrantClient,
|
client: qdrant_client.QdrantClient,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_vector: Vector,
|
query_vector: Vector,
|
||||||
filter_params: dict = None,
|
filter_params: dict | None = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> list[qdrant_models.ScoredPoint]:
|
) -> list[qdrant_models.ScoredPoint]:
|
||||||
"""Search for similar vectors in a collection.
|
"""Search for similar vectors in a collection.
|
||||||
@ -200,7 +201,7 @@ def delete_points(
|
|||||||
client.delete(
|
client.delete(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
points_selector=qdrant_models.PointIdsList(
|
points_selector=qdrant_models.PointIdsList(
|
||||||
points=ids,
|
points=ids, # type: ignore
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -226,7 +227,7 @@ def get_collection_info(
|
|||||||
|
|
||||||
def batch_ids(
|
def batch_ids(
|
||||||
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
|
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
|
||||||
) -> Iterator[list[str]]:
|
) -> Generator[list[str], None, None]:
|
||||||
"""Iterate over all IDs in a collection."""
|
"""Iterate over all IDs in a collection."""
|
||||||
offset = None
|
offset = None
|
||||||
while resp := client.scroll(
|
while resp := client.scroll(
|
||||||
@ -236,7 +237,7 @@ def batch_ids(
|
|||||||
limit=batch_size,
|
limit=batch_size,
|
||||||
):
|
):
|
||||||
points, offset = resp
|
points, offset = resp
|
||||||
yield [point.id for point in points]
|
yield [cast(str, point.id) for point in points]
|
||||||
|
|
||||||
if not offset:
|
if not offset:
|
||||||
return
|
return
|
||||||
|
@ -21,7 +21,11 @@ async def make_request(
|
|||||||
) -> httpx.Response:
|
) -> httpx.Response:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
return await client.request(
|
return await client.request(
|
||||||
method, f"{SERVER}/{path}", data=data, json=json, files=files
|
method,
|
||||||
|
f"{SERVER}/{path}",
|
||||||
|
data=data,
|
||||||
|
json=json,
|
||||||
|
files=files, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ app.conf.update(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.on_after_configure.connect
|
@app.on_after_configure.connect # type: ignore
|
||||||
def ensure_qdrant_initialised(sender, **_):
|
def ensure_qdrant_initialised(sender, **_):
|
||||||
from memory.common import qdrant
|
from memory.common import qdrant
|
||||||
|
|
||||||
|
@ -1,24 +1,26 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import imaplib
|
import imaplib
|
||||||
import logging
|
import logging
|
||||||
|
import pathlib
|
||||||
import re
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Generator, Callable
|
from typing import Callable, Generator, Sequence, cast
|
||||||
import pathlib
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
from collections import defaultdict
|
|
||||||
from memory.common import settings, embedding, qdrant
|
from memory.common import embedding, qdrant, settings
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
|
EmailAttachment,
|
||||||
MailMessage,
|
MailMessage,
|
||||||
SourceItem,
|
SourceItem,
|
||||||
EmailAttachment,
|
|
||||||
)
|
)
|
||||||
from memory.common.parsers.email import (
|
from memory.common.parsers.email import (
|
||||||
Attachment,
|
Attachment,
|
||||||
parse_email_message,
|
|
||||||
RawEmailResponse,
|
RawEmailResponse,
|
||||||
|
parse_email_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -89,7 +91,7 @@ def process_attachments(
|
|||||||
|
|
||||||
|
|
||||||
def create_mail_message(
|
def create_mail_message(
|
||||||
db_session: Session,
|
db_session: Session | scoped_session,
|
||||||
tags: list[str],
|
tags: list[str],
|
||||||
folder: str,
|
folder: str,
|
||||||
raw_email: str,
|
raw_email: str,
|
||||||
@ -136,7 +138,7 @@ def create_mail_message(
|
|||||||
|
|
||||||
|
|
||||||
def does_message_exist(
|
def does_message_exist(
|
||||||
db_session: Session, message_id: str, message_hash: bytes
|
db_session: Session | scoped_session, message_id: str, message_hash: bytes
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a message already exists in the database.
|
Check if a message already exists in the database.
|
||||||
@ -167,7 +169,7 @@ def does_message_exist(
|
|||||||
|
|
||||||
|
|
||||||
def check_message_exists(
|
def check_message_exists(
|
||||||
db: Session, account_id: int, message_id: str, raw_email: str
|
db: Session | scoped_session, account_id: int, message_id: str, raw_email: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
account = db.query(EmailAccount).get(account_id)
|
account = db.query(EmailAccount).get(account_id)
|
||||||
if not account:
|
if not account:
|
||||||
@ -181,7 +183,9 @@ def check_message_exists(
|
|||||||
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
|
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
|
||||||
|
|
||||||
|
|
||||||
def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
|
def extract_email_uid(
|
||||||
|
msg_data: Sequence[tuple[bytes, bytes]],
|
||||||
|
) -> tuple[str | None, bytes]:
|
||||||
"""
|
"""
|
||||||
Extract the UID and raw email data from the message data.
|
Extract the UID and raw email data from the message data.
|
||||||
"""
|
"""
|
||||||
@ -199,7 +203,7 @@ def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None:
|
|||||||
logger.error(f"Error fetching message {uid}")
|
logger.error(f"Error fetching message {uid}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return extract_email_uid(msg_data)
|
return extract_email_uid(msg_data) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing message {uid}: {str(e)}")
|
logger.error(f"Error processing message {uid}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
@ -248,7 +252,7 @@ def process_folder(
|
|||||||
folder: str,
|
folder: str,
|
||||||
account: EmailAccount,
|
account: EmailAccount,
|
||||||
since_date: datetime,
|
since_date: datetime,
|
||||||
processor: Callable[[int, str, str, bytes], int | None],
|
processor: Callable[[int, str, str, str], int | None],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Process a single folder from an email account.
|
Process a single folder from an email account.
|
||||||
@ -272,7 +276,7 @@ def process_folder(
|
|||||||
for uid, raw_email in emails:
|
for uid, raw_email in emails:
|
||||||
try:
|
try:
|
||||||
task = processor(
|
task = processor(
|
||||||
account_id=account.id,
|
account_id=account.id, # type: ignore
|
||||||
message_id=uid,
|
message_id=uid,
|
||||||
folder=folder,
|
folder=folder,
|
||||||
raw_email=raw_email.decode("utf-8", errors="replace"),
|
raw_email=raw_email.decode("utf-8", errors="replace"),
|
||||||
@ -296,9 +300,11 @@ def process_folder(
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]:
|
def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]:
|
||||||
conn = imaplib.IMAP4_SSL(host=account.imap_server, port=account.imap_port)
|
conn = imaplib.IMAP4_SSL(
|
||||||
|
host=cast(str, account.imap_server), port=cast(int, account.imap_port)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
conn.login(account.username, account.password)
|
conn.login(cast(str, account.username), cast(str, account.password))
|
||||||
yield conn
|
yield conn
|
||||||
finally:
|
finally:
|
||||||
# Always try to logout and close the connection
|
# Always try to logout and close the connection
|
||||||
@ -318,15 +324,15 @@ def vectorize_email(email: MailMessage):
|
|||||||
)
|
)
|
||||||
email.chunks = chunks
|
email.chunks = chunks
|
||||||
if chunks:
|
if chunks:
|
||||||
vector_ids = [c.id for c in chunks]
|
vector_ids = [cast(str, c.id) for c in chunks]
|
||||||
vectors = [c.vector for c in chunks]
|
vectors = [c.vector for c in chunks]
|
||||||
metadata = [c.item_metadata for c in chunks]
|
metadata = [c.item_metadata for c in chunks]
|
||||||
qdrant.upsert_vectors(
|
qdrant.upsert_vectors(
|
||||||
client=qdrant_client,
|
client=qdrant_client,
|
||||||
collection_name="mail",
|
collection_name="mail",
|
||||||
ids=vector_ids,
|
ids=vector_ids,
|
||||||
vectors=vectors,
|
vectors=vectors, # type: ignore
|
||||||
payloads=metadata,
|
payloads=metadata, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
embeds = defaultdict(list)
|
embeds = defaultdict(list)
|
||||||
@ -356,7 +362,7 @@ def vectorize_email(email: MailMessage):
|
|||||||
payloads=metadata,
|
payloads=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
email.embed_status = "STORED"
|
email.embed_status = "STORED" # type: ignore
|
||||||
for attachment in email.attachments:
|
for attachment in email.attachments:
|
||||||
attachment.embed_status = "STORED"
|
attachment.embed_status = "STORED"
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Callable
|
from typing import Callable, cast
|
||||||
|
|
||||||
import feedparser
|
import feedparser
|
||||||
import requests
|
import requests
|
||||||
@ -35,7 +35,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
|
|||||||
logger.error(f"Failed to fetch or parse {rss_url}: {e}")
|
logger.error(f"Failed to fetch or parse {rss_url}: {e}")
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
urls = {item.get("link") or item.get("id") for item in feed.entries}
|
urls = {cast(str, item.get("link") or item.get("id")) for item in feed.entries}
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
known = {
|
known = {
|
||||||
@ -46,7 +46,7 @@ def find_new_urls(base_url: str, rss_url: str) -> set[str]:
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return urls - known
|
return cast(set[str], urls - known)
|
||||||
|
|
||||||
|
|
||||||
def fetch_new_comics(
|
def fetch_new_comics(
|
||||||
@ -56,7 +56,7 @@ def fetch_new_comics(
|
|||||||
|
|
||||||
for url in new_urls:
|
for url in new_urls:
|
||||||
data = parser(url) | {"author": base_url, "url": url}
|
data = parser(url) | {"author": base_url, "url": url}
|
||||||
sync_comic.delay(**data)
|
sync_comic.delay(**data) # type: ignore
|
||||||
return new_urls
|
return new_urls
|
||||||
|
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ def sync_comic(
|
|||||||
client=qdrant.get_qdrant_client(),
|
client=qdrant.get_qdrant_client(),
|
||||||
collection_name="comic",
|
collection_name="comic",
|
||||||
ids=[str(chunk.id)],
|
ids=[str(chunk.id)],
|
||||||
vectors=[chunk.vector],
|
vectors=[chunk.vector], # type: ignore
|
||||||
payloads=[comic.as_payload()],
|
payloads=[comic.as_payload()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -130,8 +130,8 @@ def sync_xkcd() -> set[str]:
|
|||||||
@app.task(name=SYNC_ALL_COMICS)
|
@app.task(name=SYNC_ALL_COMICS)
|
||||||
def sync_all_comics():
|
def sync_all_comics():
|
||||||
"""Synchronize all active comics."""
|
"""Synchronize all active comics."""
|
||||||
sync_smbc.delay()
|
sync_smbc.delay() # type: ignore
|
||||||
sync_xkcd.delay()
|
sync_xkcd.delay() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@app.task(name="memory.workers.tasks.comic.full_sync_comic")
|
@app.task(name="memory.workers.tasks.comic.full_sync_comic")
|
||||||
@ -141,8 +141,8 @@ def trigger_comic_sync():
|
|||||||
|
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
soup = BeautifulSoup(response.text, "html.parser")
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
if link := soup.find("a", attrs={"class", "cc-prev"}):
|
if link := soup.find("a", attrs={"class": "cc-prev"}):
|
||||||
return link.attrs["href"]
|
return link.attrs["href"] # type: ignore
|
||||||
return None
|
return None
|
||||||
|
|
||||||
next_url = "https://www.smbc-comics.com"
|
next_url = "https://www.smbc-comics.com"
|
||||||
@ -155,7 +155,7 @@ def trigger_comic_sync():
|
|||||||
data = comics.extract_smbc(next_url) | {
|
data = comics.extract_smbc(next_url) | {
|
||||||
"author": "https://www.smbc-comics.com/"
|
"author": "https://www.smbc-comics.com/"
|
||||||
}
|
}
|
||||||
sync_comic.delay(**data)
|
sync_comic.delay(**data) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"failed to sync {next_url}: {e}")
|
logger.error(f"failed to sync {next_url}: {e}")
|
||||||
urls.append(next_url)
|
urls.append(next_url)
|
||||||
@ -167,6 +167,6 @@ def trigger_comic_sync():
|
|||||||
url = f"{BASE_XKCD_URL}/{i}"
|
url = f"{BASE_XKCD_URL}/{i}"
|
||||||
try:
|
try:
|
||||||
data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"}
|
data = comics.extract_xkcd(url) | {"author": "https://xkcd.com/"}
|
||||||
sync_comic.delay(**data)
|
sync_comic.delay(**data) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"failed to sync {url}: {e}")
|
logger.error(f"failed to sync {url}: {e}")
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import cast
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import EmailAccount
|
from memory.common.db.models import EmailAccount
|
||||||
from memory.workers.celery_app import app
|
from memory.workers.celery_app import app
|
||||||
@ -71,7 +71,7 @@ def process_message(
|
|||||||
for chunk in attachment.chunks:
|
for chunk in attachment.chunks:
|
||||||
logger.info(f" - {chunk.id}")
|
logger.info(f" - {chunk.id}")
|
||||||
|
|
||||||
return mail_message.id
|
return cast(int, mail_message.id)
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=SYNC_ACCOUNT)
|
@app.task(name=SYNC_ACCOUNT)
|
||||||
@ -89,15 +89,17 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
|||||||
|
|
||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||||
if not account or not account.active:
|
if not account or not cast(bool, account.active):
|
||||||
logger.warning(f"Account {account_id} not found or inactive")
|
logger.warning(f"Account {account_id} not found or inactive")
|
||||||
return {"error": "Account not found or inactive"}
|
return {"error": "Account not found or inactive"}
|
||||||
|
|
||||||
folders_to_process: list[str] = account.folders or ["INBOX"]
|
folders_to_process: list[str] = cast(list[str], account.folders) or ["INBOX"]
|
||||||
if since_date:
|
if since_date:
|
||||||
cutoff_date = datetime.fromisoformat(since_date)
|
cutoff_date = datetime.fromisoformat(since_date)
|
||||||
else:
|
else:
|
||||||
cutoff_date: datetime = account.last_sync_at or datetime(1970, 1, 1)
|
cutoff_date: datetime = cast(datetime, account.last_sync_at) or datetime(
|
||||||
|
1970, 1, 1
|
||||||
|
)
|
||||||
|
|
||||||
messages_found = 0
|
messages_found = 0
|
||||||
new_messages = 0
|
new_messages = 0
|
||||||
@ -106,9 +108,9 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
|||||||
def process_message_wrapper(
|
def process_message_wrapper(
|
||||||
account_id: int, message_id: str, folder: str, raw_email: str
|
account_id: int, message_id: str, folder: str, raw_email: str
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
if check_message_exists(db, account_id, message_id, raw_email):
|
if check_message_exists(db, account_id, message_id, raw_email): # type: ignore
|
||||||
return None
|
return None
|
||||||
return process_message.delay(account_id, message_id, folder, raw_email)
|
return process_message.delay(account_id, message_id, folder, raw_email) # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with imap_connection(account) as conn:
|
with imap_connection(account) as conn:
|
||||||
@ -121,7 +123,7 @@ def sync_account(account_id: int, since_date: str | None = None) -> dict:
|
|||||||
new_messages += folder_stats["new_messages"]
|
new_messages += folder_stats["new_messages"]
|
||||||
errors += folder_stats["errors"]
|
errors += folder_stats["errors"]
|
||||||
|
|
||||||
account.last_sync_at = datetime.now()
|
account.last_sync_at = datetime.now() # type: ignore
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
|
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
|
||||||
@ -152,7 +154,7 @@ def sync_all_accounts() -> list[dict]:
|
|||||||
{
|
{
|
||||||
"account_id": account.id,
|
"account_id": account.id,
|
||||||
"email": account.email_address,
|
"email": account.email_address,
|
||||||
"task_id": sync_account.delay(account.id).id,
|
"task_id": sync_account.delay(account.id).id, # type: ignore
|
||||||
}
|
}
|
||||||
for account in active_accounts
|
for account in active_accounts
|
||||||
]
|
]
|
||||||
|
@ -48,7 +48,7 @@ def clean_collection(collection: str) -> dict[str, int]:
|
|||||||
def clean_all_collections():
|
def clean_all_collections():
|
||||||
logger.info("Cleaning all collections")
|
logger.info("Cleaning all collections")
|
||||||
for collection in embedding.ALL_COLLECTIONS:
|
for collection in embedding.ALL_COLLECTIONS:
|
||||||
clean_collection.delay(collection)
|
clean_collection.delay(collection) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@app.task(name=REINGEST_CHUNK)
|
@app.task(name=REINGEST_CHUNK)
|
||||||
@ -97,7 +97,7 @@ def check_batch(batch: Sequence[Chunk]) -> dict:
|
|||||||
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if str(chunk.id) in missing:
|
if str(chunk.id) in missing:
|
||||||
reingest_chunk.delay(str(chunk.id), collection)
|
reingest_chunk.delay(str(chunk.id), collection) # type: ignore
|
||||||
else:
|
else:
|
||||||
chunk.checked_at = datetime.now()
|
chunk.checked_at = datetime.now()
|
||||||
|
|
||||||
@ -132,7 +132,9 @@ def reingest_missing_chunks(batch_size: int = 1000):
|
|||||||
.filter(Chunk.checked_at < since)
|
.filter(Chunk.checked_at < since)
|
||||||
.options(
|
.options(
|
||||||
contains_eager(Chunk.source).load_only(
|
contains_eager(Chunk.source).load_only(
|
||||||
SourceItem.id, SourceItem.modality, SourceItem.tags
|
SourceItem.id, # type: ignore
|
||||||
|
SourceItem.modality, # type: ignore
|
||||||
|
SourceItem.tags, # type: ignore
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.order_by(Chunk.id)
|
.order_by(Chunk.id)
|
||||||
|
@ -252,7 +252,7 @@ def test_parse_simple_email():
|
|||||||
"hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87"
|
"hash": b"\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87"
|
||||||
+ b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1",
|
+ b"\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1",
|
||||||
}
|
}
|
||||||
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400
|
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400 # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_parse_email_with_attachments():
|
def test_parse_email_with_attachments():
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
from typing import cast
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -72,12 +72,12 @@ def test_embed_mixed(mock_embed):
|
|||||||
|
|
||||||
def test_embed_page_text_only(mock_embed):
|
def test_embed_page_text_only(mock_embed):
|
||||||
page = {"contents": ["text1", "text2"]}
|
page = {"contents": ["text1", "text2"]}
|
||||||
assert embed_page(page) == [[0], [1]]
|
assert embed_page(page) == [[0], [1]] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_embed_page_mixed_content(mock_embed):
|
def test_embed_page_mixed_content(mock_embed):
|
||||||
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
|
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
|
||||||
assert embed_page(page) == [[0]]
|
assert embed_page(page) == [[0]] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_embed(mock_embed):
|
def test_embed(mock_embed):
|
||||||
@ -91,12 +91,12 @@ def test_embed(mock_embed):
|
|||||||
assert modality == "text"
|
assert modality == "text"
|
||||||
assert [
|
assert [
|
||||||
{
|
{
|
||||||
"id": c.id,
|
"id": c.id, # type: ignore
|
||||||
"file_path": c.file_path,
|
"file_path": c.file_path, # type: ignore
|
||||||
"content": c.content,
|
"content": c.content, # type: ignore
|
||||||
"embedding_model": c.embedding_model,
|
"embedding_model": c.embedding_model, # type: ignore
|
||||||
"vector": c.vector,
|
"vector": c.vector, # type: ignore
|
||||||
"item_metadata": c.item_metadata,
|
"item_metadata": c.item_metadata, # type: ignore
|
||||||
}
|
}
|
||||||
for c in chunks
|
for c in chunks
|
||||||
] == [
|
] == [
|
||||||
@ -128,7 +128,7 @@ def test_write_to_file_bytes(mock_file_storage):
|
|||||||
chunk_id = "test-chunk-id"
|
chunk_id = "test-chunk-id"
|
||||||
content = b"These are test bytes"
|
content = b"These are test bytes"
|
||||||
|
|
||||||
file_path = write_to_file(chunk_id, content)
|
file_path = write_to_file(chunk_id, content) # type: ignore
|
||||||
|
|
||||||
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin"
|
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin"
|
||||||
assert file_path.exists()
|
assert file_path.exists()
|
||||||
@ -140,7 +140,7 @@ def test_write_to_file_image(mock_file_storage):
|
|||||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||||
chunk_id = "test-chunk-id"
|
chunk_id = "test-chunk-id"
|
||||||
|
|
||||||
file_path = write_to_file(chunk_id, img)
|
file_path = write_to_file(chunk_id, img) # type: ignore
|
||||||
|
|
||||||
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png"
|
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png"
|
||||||
assert file_path.exists()
|
assert file_path.exists()
|
||||||
@ -155,7 +155,7 @@ def test_write_to_file_unsupported_type(mock_file_storage):
|
|||||||
content = 123 # Integer is not a supported type
|
content = 123 # Integer is not a supported type
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Unsupported content type"):
|
with pytest.raises(ValueError, match="Unsupported content type"):
|
||||||
write_to_file(chunk_id, content)
|
write_to_file(chunk_id, content) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_make_chunk_text_only(mock_file_storage, db_session):
|
def test_make_chunk_text_only(mock_file_storage, db_session):
|
||||||
@ -170,12 +170,12 @@ def test_make_chunk_text_only(mock_file_storage, db_session):
|
|||||||
with patch.object(
|
with patch.object(
|
||||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
|
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||||
):
|
):
|
||||||
chunk = make_chunk(page, vector, metadata)
|
chunk = make_chunk(page, vector, metadata) # type: ignore
|
||||||
|
|
||||||
assert chunk.id == "00000000-0000-0000-0000-000000000001"
|
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
|
||||||
assert chunk.content == "text content 1\n\ntext content 2"
|
assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
|
||||||
assert chunk.file_path is None
|
assert chunk.file_path is None
|
||||||
assert chunk.embedding_model == settings.TEXT_EMBEDDING_MODEL
|
assert cast(str, chunk.embedding_model) == settings.TEXT_EMBEDDING_MODEL
|
||||||
assert chunk.vector == vector
|
assert chunk.vector == vector
|
||||||
assert chunk.item_metadata == metadata
|
assert chunk.item_metadata == metadata
|
||||||
|
|
||||||
@ -190,19 +190,19 @@ def test_make_chunk_single_image(mock_file_storage, db_session):
|
|||||||
with patch.object(
|
with patch.object(
|
||||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
|
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
|
||||||
):
|
):
|
||||||
chunk = make_chunk(page, vector, metadata)
|
chunk = make_chunk(page, vector, metadata) # type: ignore
|
||||||
|
|
||||||
assert chunk.id == "00000000-0000-0000-0000-000000000002"
|
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
|
||||||
assert chunk.content is None
|
assert chunk.content is None
|
||||||
assert chunk.file_path == str(
|
assert cast(str, chunk.file_path) == str(
|
||||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png",
|
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png",
|
||||||
)
|
)
|
||||||
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
|
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
|
||||||
assert chunk.vector == vector
|
assert chunk.vector == vector
|
||||||
assert chunk.item_metadata == metadata
|
assert chunk.item_metadata == metadata
|
||||||
|
|
||||||
# Verify the file exists
|
# Verify the file exists
|
||||||
assert pathlib.Path(chunk.file_path[0]).exists()
|
assert pathlib.Path(cast(str, chunk.file_path)).exists()
|
||||||
|
|
||||||
|
|
||||||
def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
||||||
@ -215,14 +215,14 @@ def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
|||||||
with patch.object(
|
with patch.object(
|
||||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
|
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
|
||||||
):
|
):
|
||||||
chunk = make_chunk(page, vector, metadata)
|
chunk = make_chunk(page, vector, metadata) # type: ignore
|
||||||
|
|
||||||
assert chunk.id == "00000000-0000-0000-0000-000000000003"
|
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
|
||||||
assert chunk.content is None
|
assert chunk.content is None
|
||||||
assert chunk.file_path == str(
|
assert cast(str, chunk.file_path) == str(
|
||||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*",
|
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*",
|
||||||
)
|
)
|
||||||
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
|
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
|
||||||
assert chunk.vector == vector
|
assert chunk.vector == vector
|
||||||
assert chunk.item_metadata == metadata
|
assert chunk.item_metadata == metadata
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ def test_extract_image_with_path(tmp_path):
|
|||||||
img.save(img_path)
|
img.save(img_path)
|
||||||
|
|
||||||
(page,) = extract_image(img_path)
|
(page,) = extract_image(img_path)
|
||||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
|
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
|
||||||
assert page["metadata"] == {}
|
assert page["metadata"] == {}
|
||||||
|
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ def test_extract_image_with_bytes():
|
|||||||
img_bytes = buffer.getvalue()
|
img_bytes = buffer.getvalue()
|
||||||
|
|
||||||
(page,) = extract_image(img_bytes)
|
(page,) = extract_image(img_bytes)
|
||||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
|
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
|
||||||
assert page["metadata"] == {}
|
assert page["metadata"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +33,10 @@ def test_ensure_collection_exists_existing(mock_qdrant_client):
|
|||||||
|
|
||||||
def test_ensure_collection_exists_new(mock_qdrant_client):
|
def test_ensure_collection_exists_new(mock_qdrant_client):
|
||||||
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
|
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
|
||||||
status_code=404, reason_phrase="asd", content=b"asd", headers=None
|
status_code=404,
|
||||||
|
reason_phrase="asd",
|
||||||
|
content=b"asd",
|
||||||
|
headers=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
|
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
|
||||||
|
@ -1,24 +1,26 @@
|
|||||||
import base64
|
import base64
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import cast
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from memory.common import embedding, settings
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models import (
|
||||||
MailMessage,
|
|
||||||
EmailAttachment,
|
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
|
EmailAttachment,
|
||||||
|
MailMessage,
|
||||||
)
|
)
|
||||||
from memory.common import settings
|
from memory.common.parsers.email import Attachment
|
||||||
from memory.common import embedding
|
|
||||||
from memory.workers.email import (
|
from memory.workers.email import (
|
||||||
extract_email_uid,
|
|
||||||
create_mail_message,
|
create_mail_message,
|
||||||
|
extract_email_uid,
|
||||||
fetch_email,
|
fetch_email,
|
||||||
fetch_email_since,
|
fetch_email_since,
|
||||||
process_folder,
|
|
||||||
process_attachment,
|
process_attachment,
|
||||||
process_attachments,
|
process_attachments,
|
||||||
|
process_folder,
|
||||||
vectorize_email,
|
vectorize_email,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,7 +47,9 @@ def mock_uuid4():
|
|||||||
(100, 100, "<test@example.com>"),
|
(100, 100, "<test@example.com>"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_process_attachment_inline(attachment_size, max_inline_size, message_id):
|
def test_process_attachment_inline(
|
||||||
|
attachment_size: int, max_inline_size: int, message_id: str
|
||||||
|
):
|
||||||
attachment = {
|
attachment = {
|
||||||
"filename": "test.txt",
|
"filename": "test.txt",
|
||||||
"content_type": "text/plain",
|
"content_type": "text/plain",
|
||||||
@ -60,10 +64,12 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
|
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
|
||||||
result = process_attachment(attachment, message)
|
result = process_attachment(cast(Attachment, attachment), message)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.content == attachment["content"].decode("utf-8", errors="replace")
|
assert cast(str, result.content) == attachment["content"].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
assert result.filename is None
|
assert result.filename is None
|
||||||
|
|
||||||
|
|
||||||
@ -90,11 +96,11 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
|
|||||||
folder="INBOX",
|
folder="INBOX",
|
||||||
)
|
)
|
||||||
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
|
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
|
||||||
result = process_attachment(attachment, message)
|
result = process_attachment(cast(Attachment, attachment), message)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert not result.content
|
assert not cast(str, result.content)
|
||||||
assert result.filename == str(
|
assert cast(str, result.filename) == str(
|
||||||
settings.FILE_STORAGE_DIR
|
settings.FILE_STORAGE_DIR
|
||||||
/ "sender_example_com"
|
/ "sender_example_com"
|
||||||
/ "INBOX"
|
/ "INBOX"
|
||||||
@ -125,11 +131,11 @@ def test_process_attachment_write_error():
|
|||||||
patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 10),
|
patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 10),
|
||||||
patch.object(pathlib.Path, "write_bytes", mock_write_bytes),
|
patch.object(pathlib.Path, "write_bytes", mock_write_bytes),
|
||||||
):
|
):
|
||||||
assert process_attachment(attachment, message) is None
|
assert process_attachment(cast(Attachment, attachment), message) is None
|
||||||
|
|
||||||
|
|
||||||
def test_process_attachments_empty():
|
def test_process_attachments_empty():
|
||||||
assert process_attachments([], "<test@example.com>") == []
|
assert process_attachments([], MagicMock()) == []
|
||||||
|
|
||||||
|
|
||||||
def test_process_attachments_mixed():
|
def test_process_attachments_mixed():
|
||||||
@ -167,16 +173,16 @@ def test_process_attachments_mixed():
|
|||||||
|
|
||||||
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 50):
|
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 50):
|
||||||
# Process attachments
|
# Process attachments
|
||||||
results = process_attachments(attachments, message)
|
results = process_attachments(cast(list[Attachment], attachments), message)
|
||||||
|
|
||||||
# Verify we have all attachments processed
|
# Verify we have all attachments processed
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
|
|
||||||
assert results[0].content == "a" * 20
|
assert cast(str, results[0].content) == "a" * 20
|
||||||
assert results[2].content == "c" * 30
|
assert cast(str, results[2].content) == "c" * 30
|
||||||
|
|
||||||
# Verify large attachment has a path
|
# Verify large attachment has a path
|
||||||
assert results[1].filename == str(
|
assert cast(str, results[1].filename) == str(
|
||||||
settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" / "large.txt"
|
settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" / "large.txt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -239,12 +245,12 @@ def test_create_mail_message(db_session):
|
|||||||
|
|
||||||
# Verify the mail message was created correctly
|
# Verify the mail message was created correctly
|
||||||
assert isinstance(mail_message, MailMessage)
|
assert isinstance(mail_message, MailMessage)
|
||||||
assert mail_message.message_id == "321"
|
assert cast(str, mail_message.message_id) == "321"
|
||||||
assert mail_message.subject == "Test Subject"
|
assert cast(str, mail_message.subject) == "Test Subject"
|
||||||
assert mail_message.sender == "sender@example.com"
|
assert cast(str, mail_message.sender) == "sender@example.com"
|
||||||
assert mail_message.recipients == ["recipient@example.com"]
|
assert cast(list[str], mail_message.recipients) == ["recipient@example.com"]
|
||||||
assert mail_message.sent_at.isoformat()[:-6] == "2023-01-01T12:00:00"
|
assert mail_message.sent_at.isoformat()[:-6] == "2023-01-01T12:00:00"
|
||||||
assert mail_message.content == raw_email
|
assert cast(str, mail_message.content) == raw_email
|
||||||
assert mail_message.body == "Test body content\n"
|
assert mail_message.body == "Test body content\n"
|
||||||
assert mail_message.attachments == attachments
|
assert mail_message.attachments == attachments
|
||||||
|
|
||||||
@ -275,7 +281,7 @@ def test_fetch_email_since(email_provider):
|
|||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
|
||||||
# Verify content of fetched emails
|
# Verify content of fetched emails
|
||||||
uids = sorted([uid for uid, _ in result])
|
uids = sorted([uid or "" for uid, _ in result])
|
||||||
assert uids == ["101", "102"]
|
assert uids == ["101", "102"]
|
||||||
|
|
||||||
# Test with a folder that doesn't exist
|
# Test with a folder that doesn't exist
|
||||||
@ -342,7 +348,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
|||||||
db_session.add(mail_message)
|
db_session.add(mail_message)
|
||||||
db_session.flush()
|
db_session.flush()
|
||||||
|
|
||||||
assert mail_message.embed_status == "RAW"
|
assert cast(str, mail_message.embed_status) == "RAW"
|
||||||
|
|
||||||
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
|
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
|
||||||
vectorize_email(mail_message)
|
vectorize_email(mail_message)
|
||||||
@ -351,7 +357,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
|||||||
]
|
]
|
||||||
|
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
assert mail_message.embed_status == "STORED"
|
assert cast(str, mail_message.embed_status) == "STORED"
|
||||||
|
|
||||||
|
|
||||||
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||||
@ -418,6 +424,6 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
|||||||
]
|
]
|
||||||
|
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
assert mail_message.embed_status == "STORED"
|
assert cast(str, mail_message.embed_status) == "STORED"
|
||||||
assert attachment1.embed_status == "STORED"
|
assert cast(str, attachment1.embed_status) == "STORED"
|
||||||
assert attachment2.embed_status == "STORED"
|
assert cast(str, attachment2.embed_status) == "STORED"
|
||||||
|
@ -9,7 +9,7 @@ class MockEmailProvider:
|
|||||||
Can be initialized with predefined emails to return.
|
Can be initialized with predefined emails to return.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] = None):
|
def __init__(self, emails_by_folder: dict[str, list[dict[str, Any]]] | None = None):
|
||||||
"""
|
"""
|
||||||
Initialize with a dictionary of emails organized by folder.
|
Initialize with a dictionary of emails organized by folder.
|
||||||
|
|
||||||
@ -21,31 +21,41 @@ class MockEmailProvider:
|
|||||||
self.emails_by_folder = emails_by_folder or {
|
self.emails_by_folder = emails_by_folder or {
|
||||||
"INBOX": [],
|
"INBOX": [],
|
||||||
"Sent": [],
|
"Sent": [],
|
||||||
"Archive": []
|
"Archive": [],
|
||||||
}
|
}
|
||||||
self.current_folder = None
|
self.current_folder = None
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
def _generate_email_string(self, email_data: dict[str, Any]) -> str:
|
def _generate_email_string(self, email_data: dict[str, Any]) -> str:
|
||||||
"""Generate a raw email string from the provided email data."""
|
"""Generate a raw email string from the provided email data."""
|
||||||
msg = email.message.EmailMessage()
|
msg = email.message.EmailMessage() # type: ignore
|
||||||
msg["From"] = email_data.get("from", "sender@example.com")
|
msg["From"] = email_data.get("from", "sender@example.com")
|
||||||
msg["To"] = email_data.get("to", "recipient@example.com")
|
msg["To"] = email_data.get("to", "recipient@example.com")
|
||||||
msg["Subject"] = email_data.get("subject", "Test Subject")
|
msg["Subject"] = email_data.get("subject", "Test Subject")
|
||||||
msg["Message-ID"] = email_data.get("message_id", f"<test-{email_data['uid']}@example.com>")
|
msg["Message-ID"] = email_data.get(
|
||||||
msg["Date"] = email_data.get("date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000"))
|
"message_id", f"<test-{email_data['uid']}@example.com>"
|
||||||
|
)
|
||||||
|
msg["Date"] = email_data.get(
|
||||||
|
"date", datetime.now().strftime("%a, %d %b %Y %H:%M:%S +0000")
|
||||||
|
)
|
||||||
|
|
||||||
# Set the body content
|
# Set the body content
|
||||||
msg.set_content(email_data.get("body", f"This is test email body {email_data['uid']}"))
|
msg.set_content(
|
||||||
|
email_data.get("body", f"This is test email body {email_data['uid']}")
|
||||||
|
)
|
||||||
|
|
||||||
# Add attachments if present
|
# Add attachments if present
|
||||||
for attachment in email_data.get("attachments", []):
|
for attachment in email_data.get("attachments", []):
|
||||||
if isinstance(attachment, dict) and "filename" in attachment and "content" in attachment:
|
if (
|
||||||
|
isinstance(attachment, dict)
|
||||||
|
and "filename" in attachment
|
||||||
|
and "content" in attachment
|
||||||
|
):
|
||||||
msg.add_attachment(
|
msg.add_attachment(
|
||||||
attachment["content"],
|
attachment["content"],
|
||||||
maintype=attachment.get("maintype", "application"),
|
maintype=attachment.get("maintype", "application"),
|
||||||
subtype=attachment.get("subtype", "octet-stream"),
|
subtype=attachment.get("subtype", "octet-stream"),
|
||||||
filename=attachment["filename"]
|
filename=attachment["filename"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return msg.as_string()
|
return msg.as_string()
|
||||||
@ -53,12 +63,12 @@ class MockEmailProvider:
|
|||||||
def login(self, username: str, password: str) -> tuple[str, list[bytes]]:
|
def login(self, username: str, password: str) -> tuple[str, list[bytes]]:
|
||||||
"""Mock login method."""
|
"""Mock login method."""
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
return ('OK', [b'Login successful'])
|
return ("OK", [b"Login successful"])
|
||||||
|
|
||||||
def logout(self) -> tuple[str, list[bytes]]:
|
def logout(self) -> tuple[str, list[bytes]]:
|
||||||
"""Mock logout method."""
|
"""Mock logout method."""
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
return ('OK', [b'Logout successful'])
|
return ("OK", [b"Logout successful"])
|
||||||
|
|
||||||
def select(self, folder: str, readonly: bool = False) -> tuple[str, list[bytes]]:
|
def select(self, folder: str, readonly: bool = False) -> tuple[str, list[bytes]]:
|
||||||
"""
|
"""
|
||||||
@ -74,14 +84,14 @@ class MockEmailProvider:
|
|||||||
folder_name = folder.decode() if isinstance(folder, bytes) else folder
|
folder_name = folder.decode() if isinstance(folder, bytes) else folder
|
||||||
self.current_folder = folder_name
|
self.current_folder = folder_name
|
||||||
message_count = len(self.emails_by_folder.get(folder_name, []))
|
message_count = len(self.emails_by_folder.get(folder_name, []))
|
||||||
return ('OK', [str(message_count).encode()])
|
return ("OK", [str(message_count).encode()])
|
||||||
|
|
||||||
def list(self, directory: str = '', pattern: str = '*') -> tuple[str, list[bytes]]:
|
def list(self, directory: str = "", pattern: str = "*") -> tuple[str, list[bytes]]:
|
||||||
"""List available folders."""
|
"""List available folders."""
|
||||||
folders = []
|
folders = []
|
||||||
for folder in self.emails_by_folder.keys():
|
for folder in self.emails_by_folder.keys():
|
||||||
folders.append(f'(\\HasNoChildren) "/" "{folder}"'.encode())
|
folders.append(f'(\\HasNoChildren) "/" "{folder}"'.encode())
|
||||||
return ('OK', folders)
|
return ("OK", folders)
|
||||||
|
|
||||||
def search(self, charset, criteria):
|
def search(self, charset, criteria):
|
||||||
"""
|
"""
|
||||||
@ -95,12 +105,15 @@ class MockEmailProvider:
|
|||||||
All email UIDs in the current folder
|
All email UIDs in the current folder
|
||||||
"""
|
"""
|
||||||
if not self.current_folder or self.current_folder not in self.emails_by_folder:
|
if not self.current_folder or self.current_folder not in self.emails_by_folder:
|
||||||
return ('OK', [b''])
|
return ("OK", [b""])
|
||||||
|
|
||||||
uids = [str(email["uid"]).encode() for email in self.emails_by_folder[self.current_folder]]
|
uids = [
|
||||||
return ('OK', [b' '.join(uids) if uids else b''])
|
str(email["uid"]).encode()
|
||||||
|
for email in self.emails_by_folder[self.current_folder]
|
||||||
|
]
|
||||||
|
return ("OK", [b" ".join(uids) if uids else b""])
|
||||||
|
|
||||||
def fetch(self, message_set, message_parts) -> tuple[str, list]:
|
def fetch(self, message_set: bytes | str, message_parts: bytes | str):
|
||||||
"""
|
"""
|
||||||
Handle FETCH command to retrieve email data.
|
Handle FETCH command to retrieve email data.
|
||||||
|
|
||||||
@ -112,10 +125,12 @@ class MockEmailProvider:
|
|||||||
Email data in IMAP format
|
Email data in IMAP format
|
||||||
"""
|
"""
|
||||||
if not self.current_folder or self.current_folder not in self.emails_by_folder:
|
if not self.current_folder or self.current_folder not in self.emails_by_folder:
|
||||||
return ('OK', [None])
|
return ("OK", [None])
|
||||||
|
|
||||||
# For simplicity, we'll just match the UID with the ID provided
|
# For simplicity, we'll just match the UID with the ID provided
|
||||||
uid = int(message_set.decode() if isinstance(message_set, bytes) else message_set)
|
uid = int(
|
||||||
|
message_set.decode() if isinstance(message_set, bytes) else message_set
|
||||||
|
)
|
||||||
|
|
||||||
# Find the email with the matching UID
|
# Find the email with the matching UID
|
||||||
for email_data in self.emails_by_folder[self.current_folder]:
|
for email_data in self.emails_by_folder[self.current_folder]:
|
||||||
@ -126,12 +141,14 @@ class MockEmailProvider:
|
|||||||
date = email_data.get("date_internal", "01-Jan-2023 00:00:00 +0000")
|
date = email_data.get("date_internal", "01-Jan-2023 00:00:00 +0000")
|
||||||
|
|
||||||
# Format the response as expected by the IMAP client
|
# Format the response as expected by the IMAP client
|
||||||
response = [(
|
response = [
|
||||||
|
(
|
||||||
f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 '
|
f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 '
|
||||||
f'{{{len(email_string)}}}'.encode(),
|
f"{{{len(email_string)}}}".encode(),
|
||||||
email_string.encode()
|
email_string.encode(),
|
||||||
)]
|
)
|
||||||
return ('OK', response)
|
]
|
||||||
|
return ("OK", response)
|
||||||
|
|
||||||
# No matching email found
|
# No matching email found
|
||||||
return ('NO', [b'Email not found'])
|
return ("NO", [b"Email not found"])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user