notes and observations triggered as jobs

This commit is contained in:
Daniel O'Connell 2025-06-02 14:22:21 +02:00
parent 29b8ce6860
commit ac3b48a04c
35 changed files with 1271 additions and 296 deletions

View File

@ -0,0 +1,41 @@
"""Add note
Revision ID: ba301527a2eb
Revises: 6554eb260176
Create Date: 2025-06-02 10:38:20.112303
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "ba301527a2eb"
down_revision: Union[str, None] = "6554eb260176"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"notes",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("note_type", sa.Text(), nullable=True),
sa.Column("subject", sa.Text(), nullable=True),
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("note_confidence_idx", "notes", ["confidence"], unique=False)
op.create_index("note_subject_idx", "notes", ["subject"], unique=False)
op.create_index("note_type_idx", "notes", ["note_type"], unique=False)
def downgrade() -> None:
op.drop_index("note_type_idx", table_name="notes")
op.drop_index("note_subject_idx", table_name="notes")
op.drop_index("note_confidence_idx", table_name="notes")
op.drop_table("notes")

View File

@ -25,6 +25,8 @@ x-common-env: &env
RABBITMQ_HOST: rabbitmq RABBITMQ_HOST: rabbitmq
QDRANT_HOST: qdrant QDRANT_HOST: qdrant
DB_HOST: postgres DB_HOST: postgres
DB_PORT: 5432
RABBITMQ_PORT: 5672
FILE_STORAGE_DIR: /app/memory_files FILE_STORAGE_DIR: /app/memory_files
TZ: "Etc/UTC" TZ: "Etc/UTC"
@ -212,6 +214,13 @@ services:
QUEUES: "photo_embed,comic" QUEUES: "photo_embed,comic"
# deploy: { resources: { limits: { cpus: "4", memory: 4g } } } # deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
worker-notes:
<<: *worker-base
environment:
<<: *worker-env
QUEUES: "notes"
# deploy: { resources: { limits: { cpus: "4", memory: 4g } } }
worker-maintenance: worker-maintenance:
<<: *worker-base <<: *worker-base
environment: environment:

View File

@ -7,7 +7,7 @@ LOGLEVEL=${LOGLEVEL:-INFO}
HOSTNAME="${QUEUES%@*}@$(hostname)" HOSTNAME="${QUEUES%@*}@$(hostname)"
exec celery -A memory.workers.celery_app worker \ exec celery -A memory.common.celery_app worker \
-Q "${QUEUES}" \ -Q "${QUEUES}" \
--concurrency="${CONCURRENCY}" \ --concurrency="${CONCURRENCY}" \
--hostname="${HOSTNAME}" \ --hostname="${HOSTNAME}" \

View File

@ -8,3 +8,4 @@ qdrant-client==1.9.0
anthropic==0.18.1 anthropic==0.18.1
# Pin the httpx version, as newer versions break the anthropic client # Pin the httpx version, as newer versions break the anthropic client
httpx==0.27.0 httpx==0.27.0
celery==5.3.6

View File

@ -1,2 +0,0 @@
mcp==1.7.1
httpx==0.25.1

View File

@ -1,4 +1,3 @@
celery==5.3.6
openai==1.25.0 openai==1.25.0
pillow==10.4.0 pillow==10.4.0
pypandoc==1.15.0 pypandoc==1.15.0

View File

@ -19,21 +19,25 @@ Usage:
import json import json
import sys import sys
from pathlib import Path
from typing import Any from typing import Any
import click import click
from memory.workers.tasks.blogs import ( from celery import Celery
from memory.common.celery_app import (
SYNC_ALL_ARTICLE_FEEDS, SYNC_ALL_ARTICLE_FEEDS,
SYNC_ARTICLE_FEED, SYNC_ARTICLE_FEED,
SYNC_WEBPAGE, SYNC_WEBPAGE,
SYNC_WEBSITE_ARCHIVE, SYNC_WEBSITE_ARCHIVE,
) SYNC_ALL_COMICS,
from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_COMIC, SYNC_SMBC, SYNC_XKCD SYNC_COMIC,
from memory.workers.tasks.ebook import SYNC_BOOK SYNC_SMBC,
from memory.workers.tasks.email import PROCESS_EMAIL, SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS SYNC_XKCD,
from memory.workers.tasks.forums import SYNC_LESSWRONG, SYNC_LESSWRONG_POST SYNC_BOOK,
from memory.workers.tasks.maintenance import ( PROCESS_EMAIL,
SYNC_ACCOUNT,
SYNC_ALL_ACCOUNTS,
SYNC_LESSWRONG,
SYNC_LESSWRONG_POST,
CLEAN_ALL_COLLECTIONS, CLEAN_ALL_COLLECTIONS,
CLEAN_COLLECTION, CLEAN_COLLECTION,
REINGEST_CHUNK, REINGEST_CHUNK,
@ -43,14 +47,9 @@ from memory.workers.tasks.maintenance import (
REINGEST_MISSING_CHUNKS, REINGEST_MISSING_CHUNKS,
UPDATE_METADATA_FOR_ITEM, UPDATE_METADATA_FOR_ITEM,
UPDATE_METADATA_FOR_SOURCE_ITEMS, UPDATE_METADATA_FOR_SOURCE_ITEMS,
app,
) )
# Add the src directory to Python path so we can import memory modules
sys.path.insert(0, str(Path(__file__).parent / "src"))
from celery import Celery
from memory.common import settings
TASK_MAPPINGS = { TASK_MAPPINGS = {
"email": { "email": {
@ -96,25 +95,6 @@ QUEUE_MAPPINGS = {
} }
def create_local_celery_app() -> Celery:
"""Create a Celery app configured to connect to the Docker RabbitMQ."""
# Override settings for local connection to Docker services
rabbitmq_url = f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@localhost:15673//"
app = Celery(
"memory-local",
broker=rabbitmq_url,
backend=settings.CELERY_RESULT_BACKEND.replace(
"postgres:5432", "localhost:15432"
),
)
# Import task modules so they're registered
app.autodiscover_tasks(["memory.workers.tasks"])
return app
def run_task(app: Celery, category: str, task_name: str, **kwargs) -> str: def run_task(app: Celery, category: str, task_name: str, **kwargs) -> str:
"""Run a task using the task mappings.""" """Run a task using the task mappings."""
if category not in TASK_MAPPINGS: if category not in TASK_MAPPINGS:
@ -152,7 +132,7 @@ def cli(ctx, wait, timeout):
ctx.obj["timeout"] = timeout ctx.obj["timeout"] = timeout
try: try:
ctx.obj["app"] = create_local_celery_app() ctx.obj["app"] = app
except Exception as e: except Exception as e:
click.echo(f"Error connecting to Celery broker: {e}") click.echo(f"Error connecting to Celery broker: {e}")
click.echo( click.echo(

View File

@ -2,23 +2,22 @@
MCP tools for the epistemic sparring partner system. MCP tools for the epistemic sparring partner system.
""" """
from datetime import datetime, timezone
from hashlib import sha256
import logging import logging
import uuid import pathlib
from typing import cast from datetime import datetime, timezone
from mcp.server.fastmcp import FastMCP
from memory.common.db.models.source_item import SourceItem
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy import func, cast as sql_cast, Text
from memory.common.db.connection import make_session
from memory.common import extract from mcp.server.fastmcp import FastMCP
from memory.common.db.models import AgentObservation from sqlalchemy import Text, func
from memory.api.search.search import search, SearchFilters from sqlalchemy import cast as sql_cast
from memory.common.formatters import observation from sqlalchemy.dialects.postgresql import ARRAY
from memory.workers.tasks.content_processing import process_content_item
from memory.api.search.search import SearchFilters, search
from memory.common import extract, settings
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
from memory.common.db.connection import make_session
from memory.common.db.models import AgentObservation, SourceItem
from memory.common.formatters import observation
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -452,44 +451,24 @@ async def observe(
agent_model="gpt-4" agent_model="gpt-4"
) )
""" """
# Create the observation task = celery_app.send_task(
observation = AgentObservation( SYNC_OBSERVATION,
content=content, queue="notes",
subject=subject, kwargs={
observation_type=observation_type, "subject": subject,
confidence=confidence, "content": content,
evidence=evidence, "observation_type": observation_type,
tags=tags or [], "confidence": confidence,
session_id=uuid.UUID(session_id) if session_id else None, "evidence": evidence,
agent_model=agent_model, "tags": tags,
size=len(content), "session_id": session_id,
mime_type="text/plain", "agent_model": agent_model,
sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(), },
inserted_at=datetime.now(timezone.utc),
modality="observation",
) )
try: return {
with make_session() as session: "task_id": task.id,
process_content_item(observation, session) "status": "queued",
}
if not cast(int | None, observation.id):
raise ValueError("Observation not created")
logger.info(
f"Observation created: {observation.id}, {observation.inserted_at}"
)
return {
"id": observation.id,
"created_at": observation.inserted_at.isoformat(),
"subject": observation.subject,
"observation_type": observation.observation_type,
"confidence": cast(float, observation.confidence),
"tags": observation.tags,
}
except Exception as e:
logger.error(f"Error creating observation: {e}")
raise
@mcp.tool() @mcp.tool()
@ -651,3 +630,75 @@ async def search_observations(
} }
for r in results for r in results
] ]
@mcp.tool()
async def create_note(
subject: str,
content: str,
filename: str | None = None,
note_type: str | None = None,
confidence: float = 0.5,
tags: list[str] = [],
) -> dict:
"""
Create a note when the user asks for something to be noted down.
Purpose:
Use this tool when the user explicitly asks to note, save, or record
something for later reference. Notes don't have to be really short - long
markdown docs are fine, as long as that was what was asked for.
When to use:
- User says "note down that..." or "please save this"
- User asks to record information for future reference
- User wants to remember something specific
Args:
subject: What the note is about (e.g., "meeting_notes", "idea")
content: The actual content to note down, as markdown
filename: Optional path relative to notes folder (e.g., "project/ideas.md")
note_type: Optional categorization of the note
confidence: How confident you are in the note accuracy (0.0-1.0)
tags: Optional tags for organization
Example:
# User: "Please note down that we decided to use React for the frontend"
await create_note(
subject="project_decisions",
content="Decided to use React for the frontend",
tags=["project", "frontend"]
)
"""
if filename:
path = pathlib.Path(filename)
if path.is_absolute():
path = path.relative_to(settings.NOTES_STORAGE_DIR)
else:
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
filename = path.as_posix()
try:
task = celery_app.send_task(
SYNC_NOTE,
queue="notes",
kwargs={
"subject": subject,
"content": content,
"filename": filename,
"note_type": note_type,
"confidence": confidence,
"tags": tags,
},
)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Error creating note: {e}")
raise
return {
"task_id": task.id,
"status": "queued",
}

View File

@ -19,6 +19,7 @@ from memory.common.db.models import (
EmailAccount, EmailAccount,
ForumPost, ForumPost,
AgentObservation, AgentObservation,
Note,
) )
@ -189,10 +190,24 @@ class AgentObservationAdmin(ModelView, model=AgentObservation):
column_searchable_list = ["subject", "observation_type"] column_searchable_list = ["subject", "observation_type"]
class NoteAdmin(ModelView, model=Note):
column_list = [
"id",
"subject",
"content",
"note_type",
"confidence",
"tags",
"inserted_at",
]
column_searchable_list = ["subject", "content"]
def setup_admin(admin: Admin): def setup_admin(admin: Admin):
"""Add all admin views to the admin instance.""" """Add all admin views to the admin instance."""
admin.add_view(SourceItemAdmin) admin.add_view(SourceItemAdmin)
admin.add_view(AgentObservationAdmin) admin.add_view(AgentObservationAdmin)
admin.add_view(NoteAdmin)
admin.add_view(ChunkAdmin) admin.add_view(ChunkAdmin)
admin.add_view(EmailAccountAdmin) admin.add_view(EmailAccountAdmin)
admin.add_view(MailMessageAdmin) admin.add_view(MailMessageAdmin)

View File

@ -0,0 +1,79 @@
from celery import Celery
from memory.common import settings
EMAIL_ROOT = "memory.workers.tasks.email"
FORUMS_ROOT = "memory.workers.tasks.forums"
BLOGS_ROOT = "memory.workers.tasks.blogs"
PHOTO_ROOT = "memory.workers.tasks.photo"
COMIC_ROOT = "memory.workers.tasks.comic"
EBOOK_ROOT = "memory.workers.tasks.ebook"
MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
NOTES_ROOT = "memory.workers.tasks.notes"
OBSERVATIONS_ROOT = "memory.workers.tasks.observations"
SYNC_NOTES = f"{NOTES_ROOT}.sync_notes"
SYNC_NOTE = f"{NOTES_ROOT}.sync_note"
SYNC_OBSERVATION = f"{OBSERVATIONS_ROOT}.sync_observation"
SYNC_ALL_COMICS = f"{COMIC_ROOT}.sync_all_comics"
SYNC_SMBC = f"{COMIC_ROOT}.sync_smbc"
SYNC_XKCD = f"{COMIC_ROOT}.sync_xkcd"
SYNC_COMIC = f"{COMIC_ROOT}.sync_comic"
SYNC_BOOK = f"{EBOOK_ROOT}.sync_book"
PROCESS_EMAIL = f"{EMAIL_ROOT}.process_message"
SYNC_ACCOUNT = f"{EMAIL_ROOT}.sync_account"
SYNC_ALL_ACCOUNTS = f"{EMAIL_ROOT}.sync_all_accounts"
SYNC_LESSWRONG = f"{FORUMS_ROOT}.sync_lesswrong"
SYNC_LESSWRONG_POST = f"{FORUMS_ROOT}.sync_lesswrong_post"
CLEAN_ALL_COLLECTIONS = f"{MAINTENANCE_ROOT}.clean_all_collections"
CLEAN_COLLECTION = f"{MAINTENANCE_ROOT}.clean_collection"
REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks"
REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk"
REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item"
REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items"
REINGEST_ALL_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_all_empty_source_items"
UPDATE_METADATA_FOR_SOURCE_ITEMS = (
f"{MAINTENANCE_ROOT}.update_metadata_for_source_items"
)
UPDATE_METADATA_FOR_ITEM = f"{MAINTENANCE_ROOT}.update_metadata_for_item"
SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage"
SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed"
SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
def rabbit_url() -> str:
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:{settings.RABBITMQ_PORT}//"
app = Celery(
"memory",
broker=rabbit_url(),
backend=settings.CELERY_RESULT_BACKEND,
)
app.autodiscover_tasks(["memory.workers.tasks"])
app.conf.update(
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1,
task_routes={
f"{EMAIL_ROOT}.*": {"queue": "email"},
f"{PHOTO_ROOT}.*": {"queue": "photo_embed"},
f"{COMIC_ROOT}.*": {"queue": "comic"},
f"{EBOOK_ROOT}.*": {"queue": "ebooks"},
f"{BLOGS_ROOT}.*": {"queue": "blogs"},
f"{FORUMS_ROOT}.*": {"queue": "forums"},
f"{MAINTENANCE_ROOT}.*": {"queue": "maintenance"},
f"{NOTES_ROOT}.*": {"queue": "notes"},
f"{OBSERVATIONS_ROOT}.*": {"queue": "notes"},
},
)
@app.on_after_configure.connect # type: ignore
def ensure_qdrant_initialised(sender, **_):
from memory.common import qdrant
qdrant.setup_qdrant()

View File

@ -17,6 +17,7 @@ from memory.common.db.models.source_items import (
GitCommit, GitCommit,
Photo, Photo,
MiscDoc, MiscDoc,
Note,
) )
from memory.common.db.models.observations import ( from memory.common.db.models.observations import (
ObservationContradiction, ObservationContradiction,
@ -48,6 +49,7 @@ __all__ = [
"GitCommit", "GitCommit",
"Photo", "Photo",
"MiscDoc", "MiscDoc",
"Note",
# Observations # Observations
"ObservationContradiction", "ObservationContradiction",
"ReactionPattern", "ReactionPattern",

View File

@ -168,13 +168,18 @@ class Chunk(Base):
@property @property
def data(self) -> list[bytes | str | Image.Image]: def data(self) -> list[bytes | str | Image.Image]:
if self.file_paths is None: content = cast(str | None, self.content)
return [cast(str, self.content)] file_paths = cast(Sequence[str] | None, self.file_paths)
items: list[bytes | str | Image.Image] = []
if content:
items = [content]
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths] if not file_paths:
return items
paths = [pathlib.Path(cast(str, p)) for p in file_paths]
files = [path for path in paths if path.exists()] files = [path for path in paths if path.exists()]
items = []
for file_path in files: for file_path in files:
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}: if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
if file_path.exists(): if file_path.exists():

View File

@ -495,6 +495,68 @@ class GithubItem(SourceItem):
) )
class Note(SourceItem):
"""A quick note of something of interest."""
__tablename__ = "notes"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
note_type = Column(Text, nullable=True)
subject = Column(Text, nullable=True)
confidence = Column(Numeric(3, 2), nullable=False, default=0.5) # 0.0-1.0
__mapper_args__ = {
"polymorphic_identity": "note",
}
__table_args__ = (
Index("note_type_idx", "note_type"),
Index("note_subject_idx", "subject"),
Index("note_confidence_idx", "confidence"),
)
def as_payload(self) -> dict:
return {
**super().as_payload(),
"note_type": self.note_type,
"subject": self.subject,
"confidence": float(cast(Any, self.confidence)),
}
@property
def display_contents(self) -> dict:
return {
"subject": self.subject,
"content": self.content,
"note_type": self.note_type,
"confidence": self.confidence,
"tags": self.tags,
}
def save_to_file(self):
if not self.filename:
path = settings.NOTES_STORAGE_DIR / f"{self.subject}.md"
else:
path = pathlib.Path(self.filename)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(cast(str, self.content))
self.filename = path.as_posix()
@staticmethod
def as_text(content: str, subject: str | None = None) -> str:
text = content
if subject:
text = f"# {subject}\n\n{text}"
return text
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
return extract.extract_text(
self.as_text(cast(str, self.content), cast(str | None, self.subject))
)
class AgentObservation(SourceItem): class AgentObservation(SourceItem):
""" """
Records observations made by AI agents about the user. Records observations made by AI agents about the user.
@ -578,23 +640,27 @@ class AgentObservation(SourceItem):
"tags": self.tags, "tags": self.tags,
} }
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: def _chunk_contents(self) -> Sequence[extract.DataChunk]:
""" """
Generate multiple chunks for different embedding dimensions. Generate multiple chunks for different embedding dimensions.
Each chunk goes to a different Qdrant collection for specialized search. Each chunk goes to a different Qdrant collection for specialized search.
""" """
# 1. Semantic chunk - standard content representation # 1. Semantic chunk - standard content representation
chunks: list[extract.DataChunk] = []
semantic_text = observation.generate_semantic_text( semantic_text = observation.generate_semantic_text(
cast(str, self.subject), cast(str, self.subject),
cast(str, self.observation_type), cast(str, self.observation_type),
cast(str, self.content), cast(str, self.content),
cast(observation.Evidence, self.evidence), cast(observation.Evidence | None, self.evidence),
)
semantic_chunk = extract.DataChunk(
data=[semantic_text],
metadata=extract.merge_metadata(metadata, {"embedding_type": "semantic"}),
modality="semantic",
) )
if semantic_text:
chunks += [
extract.DataChunk(
data=[semantic_text],
metadata={"embedding_type": "semantic"},
modality="semantic",
)
]
# 2. Temporal chunk - time-aware representation # 2. Temporal chunk - time-aware representation
temporal_text = observation.generate_temporal_text( temporal_text = observation.generate_temporal_text(
@ -603,27 +669,27 @@ class AgentObservation(SourceItem):
cast(float, self.confidence), cast(float, self.confidence),
cast(datetime, self.inserted_at), cast(datetime, self.inserted_at),
) )
temporal_chunk = extract.DataChunk( if temporal_text:
data=[temporal_text], chunks += [
metadata=extract.merge_metadata(metadata, {"embedding_type": "temporal"}),
modality="temporal",
)
others = [
self._make_chunk(
extract.DataChunk( extract.DataChunk(
data=[i], data=[temporal_text],
metadata=extract.merge_metadata( metadata={"embedding_type": "temporal"},
metadata, {"embedding_type": "semantic"} modality="temporal",
),
modality="semantic",
) )
)
for i in [
self.content,
self.evidence.get("quote", ""),
] ]
if i
raw_data = [
self.content,
cast(dict | None, self.evidence) and self.evidence.get("quote"),
]
chunks += [
extract.DataChunk(
data=[datum],
metadata={"embedding_type": "semantic"},
modality="semantic",
)
for datum in raw_data
if datum and all(datum)
] ]
# TODO: Add more embedding dimensions here: # TODO: Add more embedding dimensions here:
@ -651,7 +717,4 @@ class AgentObservation(SourceItem):
# collection_name="observations_relational" # collection_name="observations_relational"
# )) # ))
return [ return chunks
self._make_chunk(semantic_chunk),
self._make_chunk(temporal_chunk),
] + others

View File

@ -55,7 +55,7 @@ def generate_temporal_text(
f"Subject: {subject}", f"Subject: {subject}",
f"Observation: {content}", f"Observation: {content}",
] ]
if confidence: if confidence is not None:
parts.append(f"Confidence: {confidence}") parts.append(f"Confidence: {confidence}")
return " | ".join(parts) return " | ".join(parts)

View File

@ -33,6 +33,7 @@ DB_URL = os.getenv("DATABASE_URL", make_db_url())
RABBITMQ_USER = os.getenv("RABBITMQ_USER", "kb") RABBITMQ_USER = os.getenv("RABBITMQ_USER", "kb")
RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", "kb") RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", "kb")
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq") RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq")
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", "5672")
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}") CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", f"db+{DB_URL}")
@ -57,6 +58,9 @@ PHOTO_STORAGE_DIR = pathlib.Path(
WEBPAGE_STORAGE_DIR = pathlib.Path( WEBPAGE_STORAGE_DIR = pathlib.Path(
os.getenv("WEBPAGE_STORAGE_DIR", FILE_STORAGE_DIR / "webpages") os.getenv("WEBPAGE_STORAGE_DIR", FILE_STORAGE_DIR / "webpages")
) )
NOTES_STORAGE_DIR = pathlib.Path(
os.getenv("NOTES_STORAGE_DIR", FILE_STORAGE_DIR / "notes")
)
storage_dirs = [ storage_dirs = [
FILE_STORAGE_DIR, FILE_STORAGE_DIR,
@ -66,6 +70,7 @@ storage_dirs = [
COMIC_STORAGE_DIR, COMIC_STORAGE_DIR,
PHOTO_STORAGE_DIR, PHOTO_STORAGE_DIR,
WEBPAGE_STORAGE_DIR, WEBPAGE_STORAGE_DIR,
NOTES_STORAGE_DIR,
] ]
for dir in storage_dirs: for dir in storage_dirs:
dir.mkdir(parents=True, exist_ok=True) dir.mkdir(parents=True, exist_ok=True)

View File

@ -1,47 +0,0 @@
from celery import Celery
from memory.common import settings
EMAIL_ROOT = "memory.workers.tasks.email"
FORUMS_ROOT = "memory.workers.tasks.forums"
BLOGS_ROOT = "memory.workers.tasks.blogs"
PHOTO_ROOT = "memory.workers.tasks.photo"
COMIC_ROOT = "memory.workers.tasks.comic"
EBOOK_ROOT = "memory.workers.tasks.ebook"
MAINTENANCE_ROOT = "memory.workers.tasks.maintenance"
def rabbit_url() -> str:
return f"amqp://{settings.RABBITMQ_USER}:{settings.RABBITMQ_PASSWORD}@{settings.RABBITMQ_HOST}:5672//"
app = Celery(
"memory",
broker=rabbit_url(),
backend=settings.CELERY_RESULT_BACKEND,
)
app.autodiscover_tasks(["memory.workers.tasks"])
app.conf.update(
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1,
task_routes={
f"{EMAIL_ROOT}.*": {"queue": "email"},
f"{PHOTO_ROOT}.*": {"queue": "photo_embed"},
f"{COMIC_ROOT}.*": {"queue": "comic"},
f"{EBOOK_ROOT}.*": {"queue": "ebooks"},
f"{BLOGS_ROOT}.*": {"queue": "blogs"},
f"{FORUMS_ROOT}.*": {"queue": "forums"},
f"{MAINTENANCE_ROOT}.*": {"queue": "maintenance"},
},
)
@app.on_after_configure.connect # type: ignore
def ensure_qdrant_initialised(sender, **_):
from memory.common import qdrant
qdrant.setup_qdrant()

View File

@ -1,8 +1,11 @@
import logging import logging
from memory.common import settings from memory.common import settings
from memory.workers.celery_app import app from memory.common.celery_app import (
from memory.workers.tasks import CLEAN_ALL_COLLECTIONS, REINGEST_MISSING_CHUNKS app,
CLEAN_ALL_COLLECTIONS,
REINGEST_MISSING_CHUNKS,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,24 +2,16 @@
Import sub-modules so Celery can register their @app.task decorators. Import sub-modules so Celery can register their @app.task decorators.
""" """
from memory.workers.tasks import email, comic, blogs, ebook, forums # noqa from memory.workers.tasks import (
from memory.workers.tasks.blogs import ( email,
SYNC_WEBPAGE, comic,
SYNC_ARTICLE_FEED, blogs,
SYNC_ALL_ARTICLE_FEEDS, ebook,
SYNC_WEBSITE_ARCHIVE, forums,
) maintenance,
from memory.workers.tasks.comic import SYNC_ALL_COMICS, SYNC_SMBC, SYNC_XKCD notes,
from memory.workers.tasks.ebook import SYNC_BOOK observations,
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL ) # noqa
from memory.workers.tasks.forums import SYNC_LESSWRONG, SYNC_LESSWRONG_POST
from memory.workers.tasks.maintenance import (
CLEAN_ALL_COLLECTIONS,
CLEAN_COLLECTION,
REINGEST_MISSING_CHUNKS,
REINGEST_CHUNK,
REINGEST_ITEM,
)
__all__ = [ __all__ = [
@ -28,22 +20,7 @@ __all__ = [
"blogs", "blogs",
"ebook", "ebook",
"forums", "forums",
"SYNC_WEBPAGE", "maintenance",
"SYNC_ARTICLE_FEED", "notes",
"SYNC_ALL_ARTICLE_FEEDS", "observations",
"SYNC_WEBSITE_ARCHIVE",
"SYNC_ALL_COMICS",
"SYNC_SMBC",
"SYNC_XKCD",
"SYNC_BOOK",
"SYNC_ACCOUNT",
"SYNC_LESSWRONG",
"SYNC_LESSWRONG_POST",
"SYNC_ALL_ACCOUNTS",
"PROCESS_EMAIL",
"CLEAN_ALL_COLLECTIONS",
"CLEAN_COLLECTION",
"REINGEST_MISSING_CHUNKS",
"REINGEST_CHUNK",
"REINGEST_ITEM",
] ]

View File

@ -7,7 +7,13 @@ from memory.common.db.models import ArticleFeed, BlogPost
from memory.parsers.blogs import parse_webpage from memory.parsers.blogs import parse_webpage
from memory.parsers.feeds import get_feed_parser from memory.parsers.feeds import get_feed_parser
from memory.parsers.archives import get_archive_fetcher from memory.parsers.archives import get_archive_fetcher
from memory.workers.celery_app import app, BLOGS_ROOT from memory.common.celery_app import (
app,
SYNC_WEBPAGE,
SYNC_ARTICLE_FEED,
SYNC_ALL_ARTICLE_FEEDS,
SYNC_WEBSITE_ARCHIVE,
)
from memory.workers.tasks.content_processing import ( from memory.workers.tasks.content_processing import (
check_content_exists, check_content_exists,
create_content_hash, create_content_hash,
@ -18,11 +24,6 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SYNC_WEBPAGE = f"{BLOGS_ROOT}.sync_webpage"
SYNC_ARTICLE_FEED = f"{BLOGS_ROOT}.sync_article_feed"
SYNC_ALL_ARTICLE_FEEDS = f"{BLOGS_ROOT}.sync_all_article_feeds"
SYNC_WEBSITE_ARCHIVE = f"{BLOGS_ROOT}.sync_website_archive"
@app.task(name=SYNC_WEBPAGE) @app.task(name=SYNC_WEBPAGE)
@safe_task_execution @safe_task_execution

View File

@ -9,7 +9,13 @@ from memory.common import settings
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import Comic, clean_filename from memory.common.db.models import Comic, clean_filename
from memory.parsers import comics from memory.parsers import comics
from memory.workers.celery_app import app, COMIC_ROOT from memory.common.celery_app import (
app,
SYNC_ALL_COMICS,
SYNC_COMIC,
SYNC_SMBC,
SYNC_XKCD,
)
from memory.workers.tasks.content_processing import ( from memory.workers.tasks.content_processing import (
check_content_exists, check_content_exists,
create_content_hash, create_content_hash,
@ -19,11 +25,6 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SYNC_ALL_COMICS = f"{COMIC_ROOT}.sync_all_comics"
SYNC_SMBC = f"{COMIC_ROOT}.sync_smbc"
SYNC_XKCD = f"{COMIC_ROOT}.sync_xkcd"
SYNC_COMIC = f"{COMIC_ROOT}.sync_comic"
BASE_SMBC_URL = "https://www.smbc-comics.com/" BASE_SMBC_URL = "https://www.smbc-comics.com/"
SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss" SMBC_RSS_URL = "https://www.smbc-comics.com/comic/rss"

View File

@ -6,7 +6,7 @@ import memory.common.settings as settings
from memory.parsers.ebook import Ebook, parse_ebook, Section from memory.parsers.ebook import Ebook, parse_ebook, Section
from memory.common.db.models import Book, BookSection from memory.common.db.models import Book, BookSection
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.workers.celery_app import app, EBOOK_ROOT from memory.common.celery_app import app, SYNC_BOOK
from memory.workers.tasks.content_processing import ( from memory.workers.tasks.content_processing import (
check_content_exists, check_content_exists,
create_content_hash, create_content_hash,
@ -17,7 +17,6 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SYNC_BOOK = f"{EBOOK_ROOT}.sync_book"
# Minimum section length to embed (avoid noise from very short sections) # Minimum section length to embed (avoid noise from very short sections)
MIN_SECTION_LENGTH = 100 MIN_SECTION_LENGTH = 100

View File

@ -3,7 +3,7 @@ from datetime import datetime
from typing import cast from typing import cast
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import EmailAccount, MailMessage from memory.common.db.models import EmailAccount, MailMessage
from memory.workers.celery_app import app, EMAIL_ROOT from memory.common.celery_app import app, PROCESS_EMAIL, SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS
from memory.workers.email import ( from memory.workers.email import (
create_mail_message, create_mail_message,
imap_connection, imap_connection,
@ -18,10 +18,6 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PROCESS_EMAIL = f"{EMAIL_ROOT}.process_message"
SYNC_ACCOUNT = f"{EMAIL_ROOT}.sync_account"
SYNC_ALL_ACCOUNTS = f"{EMAIL_ROOT}.sync_all_accounts"
@app.task(name=PROCESS_EMAIL) @app.task(name=PROCESS_EMAIL)
@safe_task_execution @safe_task_execution

View File

@ -4,7 +4,7 @@ import logging
from memory.parsers.lesswrong import fetch_lesswrong_posts, LessWrongPost from memory.parsers.lesswrong import fetch_lesswrong_posts, LessWrongPost
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import ForumPost from memory.common.db.models import ForumPost
from memory.workers.celery_app import app, FORUMS_ROOT from memory.common.celery_app import app, SYNC_LESSWRONG, SYNC_LESSWRONG_POST
from memory.workers.tasks.content_processing import ( from memory.workers.tasks.content_processing import (
check_content_exists, check_content_exists,
create_content_hash, create_content_hash,
@ -15,9 +15,6 @@ from memory.workers.tasks.content_processing import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SYNC_LESSWRONG = f"{FORUMS_ROOT}.sync_lesswrong"
SYNC_LESSWRONG_POST = f"{FORUMS_ROOT}.sync_lesswrong_post"
@app.task(name=SYNC_LESSWRONG_POST) @app.task(name=SYNC_LESSWRONG_POST)
@safe_task_execution @safe_task_execution

View File

@ -3,6 +3,7 @@ from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Sequence, Any from typing import Sequence, Any
from memory.common import extract
from memory.workers.tasks.content_processing import process_content_item from memory.workers.tasks.content_processing import process_content_item
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import contains_eager from sqlalchemy.orm import contains_eager
@ -10,24 +11,22 @@ from sqlalchemy.orm import contains_eager
from memory.common import collections, embedding, qdrant, settings from memory.common import collections, embedding, qdrant, settings
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem from memory.common.db.models import Chunk, SourceItem
from memory.workers.celery_app import app, MAINTENANCE_ROOT from memory.common.celery_app import (
app,
CLEAN_ALL_COLLECTIONS,
CLEAN_COLLECTION,
REINGEST_MISSING_CHUNKS,
REINGEST_CHUNK,
REINGEST_ITEM,
REINGEST_EMPTY_SOURCE_ITEMS,
REINGEST_ALL_EMPTY_SOURCE_ITEMS,
UPDATE_METADATA_FOR_SOURCE_ITEMS,
UPDATE_METADATA_FOR_ITEM,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CLEAN_ALL_COLLECTIONS = f"{MAINTENANCE_ROOT}.clean_all_collections"
CLEAN_COLLECTION = f"{MAINTENANCE_ROOT}.clean_collection"
REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks"
REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk"
REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item"
REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items"
REINGEST_ALL_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_all_empty_source_items"
UPDATE_METADATA_FOR_SOURCE_ITEMS = (
f"{MAINTENANCE_ROOT}.update_metadata_for_source_items"
)
UPDATE_METADATA_FOR_ITEM = f"{MAINTENANCE_ROOT}.update_metadata_for_item"
@app.task(name=CLEAN_COLLECTION) @app.task(name=CLEAN_COLLECTION)
def clean_collection(collection: str) -> dict[str, int]: def clean_collection(collection: str) -> dict[str, int]:
logger.info(f"Cleaning collection {collection}") logger.info(f"Cleaning collection {collection}")
@ -75,11 +74,11 @@ def reingest_chunk(chunk_id: str, collection: str):
if collection not in collections.ALL_COLLECTIONS: if collection not in collections.ALL_COLLECTIONS:
raise ValueError(f"Unsupported collection {collection}") raise ValueError(f"Unsupported collection {collection}")
data = chunk.data data = [extract.DataChunk(data=chunk.data)]
if collection in collections.MULTIMODAL_COLLECTIONS: if collection in collections.MULTIMODAL_COLLECTIONS:
vector = embedding.embed_mixed([data])[0] vector = embedding.embed_mixed(data)[0]
elif collection in collections.TEXT_COLLECTIONS: elif collection in collections.TEXT_COLLECTIONS:
vector = embedding.embed_text([data])[0] vector = embedding.embed_text(data)[0]
else: else:
raise ValueError(f"Unsupported data type for collection {collection}") raise ValueError(f"Unsupported data type for collection {collection}")

View File

@ -0,0 +1,77 @@
import logging
import pathlib
from memory.common.db.connection import make_session
from memory.common.db.models import Note
from memory.common.celery_app import app, SYNC_NOTE, SYNC_NOTES
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
create_task_result,
process_content_item,
safe_task_execution,
)
logger = logging.getLogger(__name__)
@app.task(name=SYNC_NOTE)
@safe_task_execution
def sync_note(
subject: str,
content: str,
filename: str | None = None,
note_type: str | None = None,
confidence: float = 0.5,
tags: list[str] = [],
):
logger.info(f"Syncing note {subject}")
text = Note.as_text(content, subject)
sha256 = create_content_hash(text)
note = Note(
subject=subject,
content=content,
embed_status="RAW",
size=len(text.encode("utf-8")),
modality="note",
mime_type="text/markdown",
sha256=sha256,
note_type=note_type,
confidence=confidence,
tags=tags,
filename=filename,
)
note.save_to_file()
with make_session() as session:
existing_note = check_content_exists(session, Note, sha256=sha256)
if existing_note:
logger.info(f"Note already exists: {existing_note.subject}")
return create_task_result(existing_note, "already_exists")
return process_content_item(note, session)
@app.task(name=SYNC_NOTES)
@safe_task_execution
def sync_notes(folder: str):
path = pathlib.Path(folder)
logger.info(f"Syncing notes from {folder}")
new_notes = 0
all_files = list(path.rglob("*.md"))
with make_session() as session:
for filename in all_files:
if not check_content_exists(session, Note, filename=filename.as_posix()):
new_notes += 1
sync_note.delay(
subject=filename.stem,
content=filename.read_text(),
filename=filename.as_posix(),
)
return {
"notes_num": len(all_files),
"new_notes": new_notes,
}

View File

@ -0,0 +1,55 @@
import logging
from memory.common.db.connection import make_session
from memory.common.db.models import AgentObservation
from memory.common.celery_app import app, SYNC_OBSERVATION
from memory.workers.tasks.content_processing import (
check_content_exists,
create_content_hash,
create_task_result,
process_content_item,
safe_task_execution,
)
logger = logging.getLogger(__name__)
@app.task(name=SYNC_OBSERVATION)
@safe_task_execution
def sync_observation(
subject: str,
content: str,
observation_type: str,
evidence: dict | None = None,
confidence: float = 0.5,
session_id: str | None = None,
agent_model: str = "unknown",
tags: list[str] = [],
):
logger.info(f"Syncing observation {subject}")
sha256 = create_content_hash(f"{content}{subject}{observation_type}")
observation = AgentObservation(
content=content,
subject=subject,
observation_type=observation_type,
confidence=confidence,
evidence=evidence,
tags=tags or [],
session_id=session_id,
agent_model=agent_model,
size=len(content),
mime_type="text/plain",
sha256=sha256,
modality="observation",
)
with make_session() as session:
existing_observation = check_content_exists(
session, AgentObservation, sha256=sha256
)
if existing_observation:
logger.info(f"Observation already exists: {existing_observation.subject}")
return create_task_result(existing_observation, "already_exists")
return process_content_item(observation, session)

View File

@ -211,11 +211,14 @@ def mock_file_storage(tmp_path: Path):
image_storage_dir.mkdir(parents=True, exist_ok=True) image_storage_dir.mkdir(parents=True, exist_ok=True)
email_storage_dir = tmp_path / "emails" email_storage_dir = tmp_path / "emails"
email_storage_dir.mkdir(parents=True, exist_ok=True) email_storage_dir.mkdir(parents=True, exist_ok=True)
notes_storage_dir = tmp_path / "notes"
notes_storage_dir.mkdir(parents=True, exist_ok=True)
with ( with (
patch.object(settings, "FILE_STORAGE_DIR", tmp_path), patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir), patch.object(settings, "CHUNK_STORAGE_DIR", chunk_storage_dir),
patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir), patch.object(settings, "WEBPAGE_STORAGE_DIR", image_storage_dir),
patch.object(settings, "EMAIL_STORAGE_DIR", email_storage_dir), patch.object(settings, "EMAIL_STORAGE_DIR", email_storage_dir),
patch.object(settings, "NOTES_STORAGE_DIR", notes_storage_dir),
): ):
yield yield

View File

@ -25,7 +25,6 @@ SAMPLE_HTML = f"""
and Elm have demonstrated the power of pure functions and type systems in creating more reliable and Elm have demonstrated the power of pure functions and type systems in creating more reliable
and maintainable code. These paradigms emphasize the elimination of side effects and the treatment and maintainable code. These paradigms emphasize the elimination of side effects and the treatment
of computation as the evaluation of mathematical functions.</p> of computation as the evaluation of mathematical functions.</p>
patch
<p>Modern development has also seen the emergence of domain-specific languages and the resurgence <p>Modern development has also seen the emergence of domain-specific languages and the resurgence
of interest in memory safety. The advent of languages like Python and JavaScript has democratized of interest in memory safety. The advent of languages like Python and JavaScript has democratized
programming by lowering the barrier to entry, while systems languages like Rust have proven that programming by lowering the barrier to entry, while systems languages like Rust have proven that

View File

@ -22,8 +22,7 @@ def default_chunk_size():
real_chunker = chunker.chunk_text real_chunker = chunker.chunk_text
def chunk_text(text: str, max_tokens: int = 0): def chunk_text(text: str, max_tokens: int = 0):
max_tokens = max_tokens or chunk_length return real_chunker(text, max_tokens=chunk_length)
return real_chunker(text, max_tokens=max_tokens)
def set_size(new_size: int): def set_size(new_size: int):
nonlocal chunk_length nonlocal chunk_length
@ -258,7 +257,9 @@ def test_source_item_chunk_contents_text(chunk_length, expected, default_chunk_s
) )
default_chunk_size(chunk_length) default_chunk_size(chunk_length)
assert source._chunk_contents() == [extract.DataChunk(data=e) for e in expected] assert source._chunk_contents() == [
extract.DataChunk(data=e, modality="text") for e in expected
]
def test_source_item_chunk_contents_image(tmp_path): def test_source_item_chunk_contents_image(tmp_path):
@ -368,20 +369,32 @@ def test_source_item_as_payload():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"content,filename,expected", "content,filename",
[ [
("Test content", None, "Test content"), ("Test content", None),
(None, "test.txt", "test.txt"), (None, "test.txt"),
("Test content", "test.txt", "Test content"), # content takes precedence ("Test content", "test.txt"),
(None, None, None), (None, None),
], ],
) )
def test_source_item_display_contents(content, filename, expected): def test_source_item_display_contents(content, filename):
"""Test SourceItem.display_contents property""" """Test SourceItem.display_contents property"""
source = SourceItem( source = SourceItem(
sha256=b"test123", content=content, filename=filename, modality="text" sha256=b"test123",
content=content,
filename=filename,
modality="text",
mime_type="text/plain",
size=123,
tags=["bla", "ble"],
) )
assert source.display_contents == expected assert source.display_contents == {
"content": content,
"filename": filename,
"mime_type": "text/plain",
"size": 123,
"tags": ["bla", "ble"],
}
def test_unique_source_items_same_commit(db_session: Session): def test_unique_source_items_same_commit(db_session: Session):

View File

@ -22,9 +22,17 @@ from memory.common.embedding import embed_source_item
from memory.common.extract import page_to_image from memory.common.extract import page_to_image
from tests.data.contents import ( from tests.data.contents import (
CHUNKS, CHUNKS,
DATA_DIR,
LANG_TIMELINE,
LANG_TIMELINE_HASH, LANG_TIMELINE_HASH,
CODE_COMPLEXITY,
CODE_COMPLEXITY_HASH,
SAMPLE_MARKDOWN, SAMPLE_MARKDOWN,
SAMPLE_TEXT, SAMPLE_TEXT,
SECOND_PAGE,
SECOND_PAGE_MARKDOWN,
SECOND_PAGE_TEXT,
TWO_PAGE_CHUNKS,
image_hash, image_hash,
) )

View File

@ -13,7 +13,9 @@ from memory.common.db.models.source_items import (
BookSection, BookSection,
BlogPost, BlogPost,
AgentObservation, AgentObservation,
Note,
) )
from tests.data.contents import SAMPLE_MARKDOWN, CHUNKS
@pytest.fixture @pytest.fixture
@ -330,7 +332,7 @@ def test_email_attachment_cascade_delete(db_session: Session):
"Page 1\n\nPage 2\n\nPage 3", "Page 1\n\nPage 2\n\nPage 3",
{"type": "section", "tags": {"tag1", "tag2"}}, {"type": "section", "tags": {"tag1", "tag2"}},
), ),
("test", {"type": "summary", "tags": {"tag1", "tag2"}}), ("test summary", {"type": "summary", "tags": {"tag1", "tag2"}}),
], ],
), ),
# Empty/whitespace pages filtered out # Empty/whitespace pages filtered out
@ -397,7 +399,7 @@ def test_book_section_data_chunks(pages, expected_chunks):
metadata={"tags": ["tag1", "tag2"]}, metadata={"tags": ["tag1", "tag2"]},
), ),
extract.DataChunk( extract.DataChunk(
data=["test"], data=["test summary"],
metadata={"tags": ["tag1", "tag2"]}, metadata={"tags": ["tag1", "tag2"]},
), ),
], ],
@ -482,7 +484,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
f"Second picture is here: {img2_path.as_posix()}", f"Second picture is here: {img2_path.as_posix()}",
img2_path.as_posix(), img2_path.as_posix(),
], ],
["test"], ["test summary"],
] ]
@ -491,20 +493,102 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
[ [
( (
{}, {},
{"embedding_type": "semantic"}, {
{"embedding_type": "temporal"}, "source_id": None,
"tags": [],
"size": None,
"observation_type": "preference",
"subject": "programming preferences",
"confidence": 0.9,
"evidence": {
"quote": "I really like Python",
"context": "discussion about languages",
},
"agent_model": "claude-3.5-sonnet",
"embedding_type": "semantic",
},
{
"source_id": None,
"tags": [],
"size": None,
"observation_type": "preference",
"subject": "programming preferences",
"confidence": 0.9,
"evidence": {
"quote": "I really like Python",
"context": "discussion about languages",
},
"agent_model": "claude-3.5-sonnet",
"embedding_type": "temporal",
},
[], [],
), ),
( (
{"extra_key": "extra_value"}, {"extra_key": "extra_value"},
{"extra_key": "extra_value", "embedding_type": "semantic"}, {
{"extra_key": "extra_value", "embedding_type": "temporal"}, "source_id": None,
"tags": [],
"size": None,
"observation_type": "preference",
"subject": "programming preferences",
"confidence": 0.9,
"evidence": {
"quote": "I really like Python",
"context": "discussion about languages",
},
"agent_model": "claude-3.5-sonnet",
"extra_key": "extra_value",
"embedding_type": "semantic",
},
{
"source_id": None,
"tags": [],
"size": None,
"observation_type": "preference",
"subject": "programming preferences",
"confidence": 0.9,
"evidence": {
"quote": "I really like Python",
"context": "discussion about languages",
},
"agent_model": "claude-3.5-sonnet",
"extra_key": "extra_value",
"embedding_type": "temporal",
},
[], [],
), ),
( (
{"tags": ["existing_tag"], "source": "test"}, {"tags": ["existing_tag"], "source": "test"},
{"tags": {"existing_tag"}, "source": "test", "embedding_type": "semantic"}, {
{"tags": {"existing_tag"}, "source": "test", "embedding_type": "temporal"}, "source_id": None,
"tags": {"existing_tag"},
"size": None,
"observation_type": "preference",
"subject": "programming preferences",
"confidence": 0.9,
"evidence": {
"quote": "I really like Python",
"context": "discussion about languages",
},
"agent_model": "claude-3.5-sonnet",
"source": "test",
"embedding_type": "semantic",
},
{
"source_id": None,
"tags": {"existing_tag"},
"size": None,
"observation_type": "preference",
"subject": "programming preferences",
"confidence": 0.9,
"evidence": {
"quote": "I really like Python",
"context": "discussion about languages",
},
"agent_model": "claude-3.5-sonnet",
"source": "test",
"embedding_type": "temporal",
},
[], [],
), ),
], ],
@ -513,6 +597,7 @@ def test_agent_observation_data_chunks(
metadata, expected_semantic_metadata, expected_temporal_metadata, observation_tags metadata, expected_semantic_metadata, expected_temporal_metadata, observation_tags
): ):
"""Test AgentObservation.data_chunks generates correct chunks with proper metadata""" """Test AgentObservation.data_chunks generates correct chunks with proper metadata"""
session_id = uuid.uuid4()
observation = AgentObservation( observation = AgentObservation(
sha256=b"test_obs", sha256=b"test_obs",
content="User prefers Python over JavaScript", content="User prefers Python over JavaScript",
@ -524,7 +609,7 @@ def test_agent_observation_data_chunks(
"context": "discussion about languages", "context": "discussion about languages",
}, },
agent_model="claude-3.5-sonnet", agent_model="claude-3.5-sonnet",
session_id=uuid.uuid4(), session_id=session_id,
tags=observation_tags, tags=observation_tags,
) )
# Set inserted_at using object.__setattr__ to bypass SQLAlchemy restrictions # Set inserted_at using object.__setattr__ to bypass SQLAlchemy restrictions
@ -533,19 +618,33 @@ def test_agent_observation_data_chunks(
result = observation.data_chunks(metadata) result = observation.data_chunks(metadata)
# Verify chunks # Verify chunks
assert len(result) == 2 assert len(result) == 4
semantic_chunk = result[0] semantic_chunk = result[0]
expected_semantic_text = "Subject: programming preferences | Type: preference | Observation: User prefers Python over JavaScript | Quote: I really like Python | Context: discussion about languages" expected_semantic_text = "Subject: programming preferences | Type: preference | Observation: User prefers Python over JavaScript | Quote: I really like Python | Context: discussion about languages"
assert semantic_chunk.data == [expected_semantic_text] assert semantic_chunk.data == [expected_semantic_text]
assert semantic_chunk.metadata == expected_semantic_metadata
assert semantic_chunk.collection_name == "semantic" # Add session_id to expected metadata and remove tags if empty
expected_semantic_with_session = expected_semantic_metadata.copy()
expected_semantic_with_session["session_id"] = str(session_id)
if not expected_semantic_with_session.get("tags"):
del expected_semantic_with_session["tags"]
assert semantic_chunk.item_metadata == expected_semantic_with_session
assert cast(str, semantic_chunk.collection_name) == "semantic"
temporal_chunk = result[1] temporal_chunk = result[1]
expected_temporal_text = "Time: 12:00 on Sunday (afternoon) | Subject: programming preferences | Observation: User prefers Python over JavaScript | Confidence: 0.9" expected_temporal_text = "Time: 12:00 on Sunday (afternoon) | Subject: programming preferences | Observation: User prefers Python over JavaScript | Confidence: 0.9"
assert temporal_chunk.data == [expected_temporal_text] assert temporal_chunk.data == [expected_temporal_text]
assert temporal_chunk.metadata == expected_temporal_metadata
assert temporal_chunk.collection_name == "temporal" # Add session_id to expected metadata and remove tags if empty
expected_temporal_with_session = expected_temporal_metadata.copy()
expected_temporal_with_session["session_id"] = str(session_id)
if not expected_temporal_with_session.get("tags"):
del expected_temporal_with_session["tags"]
assert temporal_chunk.item_metadata == expected_temporal_with_session
assert cast(str, temporal_chunk.collection_name) == "temporal"
def test_agent_observation_data_chunks_with_none_values(): def test_agent_observation_data_chunks_with_none_values():
@ -564,16 +663,18 @@ def test_agent_observation_data_chunks_with_none_values():
result = observation.data_chunks() result = observation.data_chunks()
assert len(result) == 2 assert len(result) == 3
assert result[0].collection_name == "semantic" assert cast(str, result[0].collection_name) == "semantic"
assert result[1].collection_name == "temporal" assert cast(str, result[1].collection_name) == "temporal"
# Verify content with None evidence # Verify content with None evidence
semantic_text = "Subject: subject | Type: belief | Observation: Content" assert [i.data for i in result] == [
assert result[0].data == [semantic_text] ["Subject: subject | Type: belief | Observation: Content"],
[
temporal_text = "Time: 09:30 on Wednesday (morning) | Subject: subject | Observation: Content | Confidence: 0.7" "Time: 09:30 on Wednesday (morning) | Subject: subject | Observation: Content | Confidence: 0.7"
assert result[1].data == [temporal_text] ],
["Content"],
]
def test_agent_observation_data_chunks_merge_metadata_behavior(): def test_agent_observation_data_chunks_merge_metadata_behavior():
@ -594,14 +695,51 @@ def test_agent_observation_data_chunks_merge_metadata_behavior():
input_metadata = {"existing": "value", "tags": ["tag1"]} input_metadata = {"existing": "value", "tags": ["tag1"]}
result = observation.data_chunks(input_metadata) result = observation.data_chunks(input_metadata)
semantic_metadata = result[0].metadata semantic_metadata = result[0].item_metadata
temporal_metadata = result[1].metadata temporal_metadata = result[1].item_metadata
# Both should have the existing metadata plus embedding_type # Both should have the existing metadata plus embedding_type
assert semantic_metadata["existing"] == "value" assert semantic_metadata["existing"] == "value"
assert semantic_metadata["tags"] == {"tag1"} # Merged tags assert semantic_metadata["tags"] == {"tag1", "base_tag"} # Merged tags
assert semantic_metadata["embedding_type"] == "semantic" assert semantic_metadata["embedding_type"] == "semantic"
assert temporal_metadata["existing"] == "value" assert temporal_metadata["existing"] == "value"
assert temporal_metadata["tags"] == {"tag1"} # Merged tags assert temporal_metadata["tags"] == {"tag1", "base_tag"} # Merged tags
assert temporal_metadata["embedding_type"] == "temporal" assert temporal_metadata["embedding_type"] == "temporal"
@pytest.mark.parametrize(
"subject, content, expected",
(
(None, "bla bla bla", ["bla bla bla"]),
(None, " \n\n bla bla bla \t\t \n ", ["bla bla bla"]),
("my gosh, a subject!", "blee bleee", ["# my gosh, a subject!\n\nblee bleee"]),
(None, SAMPLE_MARKDOWN, [i.strip() for i in CHUNKS] + ["test summary"]),
),
)
def test_note_data_chunks(subject, content, expected):
note = Note(
sha256=b"test_obs",
content=content,
subject=subject,
note_type="quicky",
confidence=0.9,
size=123,
tags=["bla"],
)
chunks = note.data_chunks()
assert [chunk.content for chunk in chunks] == expected
for chunk in chunks:
assert cast(list, chunk.file_paths) == []
tags = {"bla"}
if cast(str, chunk.content) == "test summary":
tags |= {"tag1", "tag2"}
assert chunk.item_metadata == {
"confidence": 0.9,
"note_type": "quicky",
"size": 123,
"source_id": None,
"subject": subject,
"tags": tags,
}

View File

@ -6,6 +6,7 @@ from memory.common.embedding import (
embed_mixed, embed_mixed,
embed_text, embed_text,
) )
from memory.common.extract import DataChunk
@pytest.fixture @pytest.fixture
@ -44,10 +45,10 @@ def test_get_modality(mime_type, expected_modality):
def test_embed_text(mock_embed): def test_embed_text(mock_embed):
texts = ["text1 with words", "text2"] chunks = [DataChunk(data=["text1 with words"]), DataChunk(data=["text2"])]
assert embed_text(texts) == [[0], [1]] assert embed_text(chunks) == [[0], [1]]
def test_embed_mixed(mock_embed): def test_embed_mixed(mock_embed):
items = ["text", {"type": "image", "data": "base64"}] items = [DataChunk(data=["text"])]
assert embed_mixed(items) == [[0]] assert embed_mixed(items) == [[0]]

View File

@ -10,7 +10,7 @@ from memory.common.extract import (
doc_to_images, doc_to_images,
extract_image, extract_image,
docx_to_pdf, docx_to_pdf,
extract_docx, merge_metadata,
DataChunk, DataChunk,
) )
@ -48,8 +48,11 @@ def test_as_file_with_str():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_content,expected", "input_content,expected",
[ [
("simple text", [DataChunk(data=["simple text"], metadata={})]), (
(b"bytes text", [DataChunk(data=["bytes text"], metadata={})]), "simple text",
[DataChunk(data=["simple text"], metadata={}, modality="text")],
),
(b"bytes text", [DataChunk(data=["bytes text"], metadata={}, modality="text")]),
], ],
) )
def test_extract_text(input_content, expected): def test_extract_text(input_content, expected):
@ -61,7 +64,7 @@ def test_extract_text_with_path(tmp_path):
test_file.write_text("file text content") test_file.write_text("file text content")
assert extract_text(test_file) == [ assert extract_text(test_file) == [
DataChunk(data=["file text content"], metadata={}) DataChunk(data=["file text content"], metadata={}, modality="text")
] ]
@ -128,3 +131,111 @@ def test_docx_to_pdf_default_output():
assert result_path == SAMPLE_DOCX.with_suffix(".pdf") assert result_path == SAMPLE_DOCX.with_suffix(".pdf")
assert result_path.exists() assert result_path.exists()
@pytest.mark.parametrize(
"dicts,expected",
[
# Empty input
([], {}),
# Single dict without tags
([{"key": "value"}], {"key": "value"}),
# Single dict with tags as list
(
[{"key": "value", "tags": ["tag1", "tag2"]}],
{"key": "value", "tags": {"tag1", "tag2"}},
),
# Single dict with tags as set
(
[{"key": "value", "tags": {"tag1", "tag2"}}],
{"key": "value", "tags": {"tag1", "tag2"}},
),
# Multiple dicts without tags
(
[{"key1": "value1"}, {"key2": "value2"}],
{"key1": "value1", "key2": "value2"},
),
# Multiple dicts with non-overlapping tags
(
[
{"key1": "value1", "tags": ["tag1"]},
{"key2": "value2", "tags": ["tag2"]},
],
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2"}},
),
# Multiple dicts with overlapping tags
(
[
{"key1": "value1", "tags": ["tag1", "tag2"]},
{"key2": "value2", "tags": ["tag2", "tag3"]},
],
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2", "tag3"}},
),
# Overlapping keys - later dict wins
(
[
{"key": "value1", "other": "data1"},
{"key": "value2", "another": "data2"},
],
{"key": "value2", "other": "data1", "another": "data2"},
),
# Mixed tags types (list and set)
(
[
{"key1": "value1", "tags": ["tag1", "tag2"]},
{"key2": "value2", "tags": {"tag3", "tag4"}},
],
{
"key1": "value1",
"key2": "value2",
"tags": {"tag1", "tag2", "tag3", "tag4"},
},
),
# Empty tags
(
[{"key": "value", "tags": []}, {"key2": "value2", "tags": []}],
{"key": "value", "key2": "value2"},
),
# None values
(
[{"key1": None, "key2": "value"}, {"key3": None}],
{"key1": None, "key2": "value", "key3": None},
),
# Complex nested structures
(
[
{"nested": {"inner": "value1"}, "list": [1, 2, 3], "tags": ["tag1"]},
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": ["tag2"]},
],
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": {"tag1", "tag2"}},
),
# Boolean and numeric values
(
[
{"bool": True, "int": 42, "float": 3.14, "tags": ["numeric"]},
{"bool": False, "int": 100},
],
{"bool": False, "int": 100, "float": 3.14, "tags": {"numeric"}},
),
# Three or more dicts
(
[
{"a": 1, "tags": ["t1"]},
{"b": 2, "tags": ["t2", "t3"]},
{"c": 3, "a": 10, "tags": ["t3", "t4"]},
],
{"a": 10, "b": 2, "c": 3, "tags": {"t1", "t2", "t3", "t4"}},
),
# Dict with only tags
([{"tags": ["tag1", "tag2"]}], {"tags": {"tag1", "tag2"}}),
# Empty dicts
([{}, {}], {}),
# Mix of empty and non-empty dicts
(
[{}, {"key": "value", "tags": ["tag"]}, {}],
{"key": "value", "tags": {"tag"}},
),
],
)
def test_merge_metadata(dicts, expected):
assert merge_metadata(*dicts) == expected

View File

@ -325,4 +325,4 @@ def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
# Check that the full content was passed to the embedding function # Check that the full content was passed to the embedding function
texts = mock_voyage_client.embed.call_args[0][0] texts = mock_voyage_client.embed.call_args[0][0]
assert texts == [[large_section_content.strip()], ["test"]] assert texts == [large_section_content.strip(), "test summary"]

View File

@ -0,0 +1,396 @@
import pytest
import pathlib
from decimal import Decimal
from unittest.mock import Mock, patch
from memory.common.db.models import Note
from memory.workers.tasks import notes
from memory.workers.tasks.content_processing import create_content_hash
from memory.common import settings
@pytest.fixture
def mock_note_data():
"""Mock note data for testing."""
test_filename = pathlib.Path(settings.NOTES_STORAGE_DIR) / "test_note.md"
return {
"subject": "Test Note Subject",
"content": "This is test note content with enough text to be processed and embedded.",
"filename": str(test_filename),
"note_type": "observation",
"confidence": 0.8,
"tags": ["test", "note"],
}
@pytest.fixture
def mock_minimal_note():
"""Mock note with minimal required data."""
return {
"subject": "Minimal Note",
"content": "Minimal content",
}
@pytest.fixture
def mock_empty_note():
"""Mock note with empty content."""
return {
"subject": "Empty Note",
"content": "",
}
@pytest.fixture
def markdown_files_in_storage():
"""Create real markdown files in the notes storage directory."""
notes_dir = pathlib.Path(settings.NOTES_STORAGE_DIR)
notes_dir.mkdir(parents=True, exist_ok=True)
# Create test markdown files
files = []
file1 = notes_dir / "note1.md"
file1.write_text("Content of note 1")
files.append(file1)
file2 = notes_dir / "note2.md"
file2.write_text("Content of note 2")
files.append(file2)
file3 = notes_dir / "note3.md"
file3.write_text("Content of note 3")
files.append(file3)
# Create a subdirectory with a file
subdir = notes_dir / "subdir"
subdir.mkdir(exist_ok=True)
file4 = subdir / "note4.md"
file4.write_text("Content of note 4 in subdirectory")
files.append(file4)
# Create a non-markdown file that should be ignored
txt_file = notes_dir / "not_markdown.txt"
txt_file.write_text("This should be ignored")
return files
def test_sync_note_success(mock_note_data, db_session, qdrant):
"""Test successful note synchronization."""
result = notes.sync_note(**mock_note_data)
# Verify the Note was created in the database
note = db_session.query(Note).filter_by(subject="Test Note Subject").first()
assert note is not None
assert note.subject == "Test Note Subject"
assert (
note.content
== "This is test note content with enough text to be processed and embedded."
)
assert note.modality == "note"
assert note.mime_type == "text/markdown"
assert note.note_type == "observation"
assert float(note.confidence) == 0.8 # Convert Decimal to float for comparison
assert note.filename is not None
assert note.tags == ["test", "note"]
# Verify the result
assert result["status"] == "processed"
assert result["note_id"] == note.id
assert (
"subject" not in result
) # create_task_result doesn't include subject for Note
def test_sync_note_minimal_data(mock_minimal_note, db_session, qdrant):
"""Test note sync with minimal required data."""
result = notes.sync_note(**mock_minimal_note)
note = db_session.query(Note).filter_by(subject="Minimal Note").first()
assert note is not None
assert note.subject == "Minimal Note"
assert note.content == "Minimal content"
assert note.note_type is None
assert float(note.confidence) == 0.5 # Default value, convert Decimal to float
assert note.tags == [] # Default empty list
assert note.filename is not None and "Minimal Note.md" in note.filename
assert result["status"] == "processed"
def test_sync_note_empty_content(mock_empty_note, db_session, qdrant):
"""Test note sync with empty content."""
result = notes.sync_note(**mock_empty_note)
# Note is still created even with empty content
note = db_session.query(Note).filter_by(subject="Empty Note").first()
assert note is not None
assert note.subject == "Empty Note"
assert note.content == ""
# Empty content with subject header "# Empty Note" still generates chunks
assert result["status"] == "processed"
assert result["chunks_count"] > 0
def test_sync_note_already_exists(mock_note_data, db_session):
"""Test note sync when content already exists."""
# Create the content text the same way sync_note does
text = Note.as_text(mock_note_data["content"], mock_note_data["subject"])
sha256 = create_content_hash(text)
# Add existing note with same content hash but different filename to avoid file conflicts
existing_note = Note(
subject="Existing Note",
content=mock_note_data["content"],
sha256=sha256,
modality="note",
tags=["existing"],
mime_type="text/markdown",
size=len(text.encode("utf-8")),
embed_status="RAW",
filename=str(pathlib.Path(settings.NOTES_STORAGE_DIR) / "existing_note.md"),
)
db_session.add(existing_note)
db_session.commit()
result = notes.sync_note(**mock_note_data)
assert result["status"] == "already_exists"
assert result["note_id"] == existing_note.id
# Verify no duplicate was created
notes_with_hash = db_session.query(Note).filter_by(sha256=sha256).all()
assert len(notes_with_hash) == 1
@pytest.mark.parametrize(
"note_type,confidence,tags",
[
("observation", 0.9, ["high-confidence", "important"]),
("reflection", 0.6, ["personal", "thoughts"]),
(None, 0.5, []),
("meeting", 1.0, ["work", "notes", "2024"]),
],
)
def test_sync_note_parameters(note_type, confidence, tags, db_session, qdrant):
"""Test note sync with various parameter combinations."""
result = notes.sync_note(
subject=f"Test Note {note_type}",
content="Test content for parameter testing",
note_type=note_type,
confidence=confidence,
tags=tags,
)
note = db_session.query(Note).filter_by(subject=f"Test Note {note_type}").first()
assert note is not None
assert note.note_type == note_type
assert float(note.confidence) == confidence # Convert Decimal to float
assert note.tags == tags
assert result["status"] == "processed"
def test_sync_note_content_hash_consistency(db_session):
"""Test that content hash is calculated consistently."""
note_data = {
"subject": "Hash Test",
"content": "Consistent content for hashing",
"tags": ["hash-test"],
}
# Sync the same note twice
result1 = notes.sync_note(**note_data)
result2 = notes.sync_note(**note_data)
# First should succeed, second should detect existing
assert result1["status"] == "processed"
assert result2["status"] == "already_exists"
assert result1["note_id"] == result2["note_id"]
# Verify only one note exists in database
notes_in_db = db_session.query(Note).filter_by(subject="Hash Test").all()
assert len(notes_in_db) == 1
@patch("memory.workers.tasks.notes.sync_note")
def test_sync_notes_success(mock_sync_note, markdown_files_in_storage, db_session):
"""Test successful notes folder synchronization."""
mock_sync_note.delay.return_value = Mock(id="task-123")
result = notes.sync_notes(settings.NOTES_STORAGE_DIR)
assert result["notes_num"] == 4 # 4 markdown files created by fixture
assert result["new_notes"] == 4 # All are new
# Verify sync_note.delay was called for each file
assert mock_sync_note.delay.call_count == 4
# Check some of the calls were made with correct parameters
call_args_list = mock_sync_note.delay.call_args_list
subjects = [call[1]["subject"] for call in call_args_list]
contents = [call[1]["content"] for call in call_args_list]
assert subjects == ["note1", "note2", "note3", "note4"]
assert contents == [
"Content of note 1",
"Content of note 2",
"Content of note 3",
"Content of note 4 in subdirectory",
]
def test_sync_notes_empty_folder(db_session):
"""Test sync when folder contains no markdown files."""
# Create an empty directory
empty_dir = pathlib.Path(settings.NOTES_STORAGE_DIR) / "empty"
empty_dir.mkdir(parents=True, exist_ok=True)
result = notes.sync_notes(str(empty_dir))
assert result["notes_num"] == 0
assert result["new_notes"] == 0
@patch("memory.workers.tasks.notes.sync_note")
def test_sync_notes_with_existing_notes(
mock_sync_note, markdown_files_in_storage, db_session
):
"""Test sync when some notes already exist."""
# Create one existing note in the database
existing_file = markdown_files_in_storage[0] # note1.md
existing_note = Note(
subject="note1",
content="Content of note 1",
sha256=b"existing_hash" + bytes(24),
modality="note",
tags=["existing"],
mime_type="text/markdown",
size=100,
filename=str(existing_file),
embed_status="RAW",
)
db_session.add(existing_note)
db_session.commit()
mock_sync_note.delay.return_value = Mock(id="task-456")
result = notes.sync_notes(settings.NOTES_STORAGE_DIR)
assert result["notes_num"] == 4
assert result["new_notes"] == 3 # Only 3 new notes (one already exists)
# Verify sync_note.delay was called only for new notes
assert mock_sync_note.delay.call_count == 3
def test_sync_notes_nonexistent_folder(db_session):
"""Test sync_notes with a folder that doesn't exist."""
nonexistent_path = "/nonexistent/folder/path"
result = notes.sync_notes(nonexistent_path)
# sync_notes should return successfully with 0 notes when folder doesn't exist
# This is the actual behavior - it gracefully handles the case
assert result["notes_num"] == 0
assert result["new_notes"] == 0
@patch("memory.workers.tasks.notes.sync_note")
def test_sync_notes_only_processes_md_files(
mock_sync_note, markdown_files_in_storage, db_session
):
"""Test that sync_notes only processes markdown files."""
mock_sync_note.delay.return_value = Mock(id="task-123")
# The fixture creates a .txt file that should be ignored
result = notes.sync_notes(settings.NOTES_STORAGE_DIR)
# Should only process the 4 .md files, not the .txt file
assert result["notes_num"] == 4
assert result["new_notes"] == 4
def test_note_as_text_method():
"""Test the Note.as_text static method used in sync_note."""
content = "This is the note content"
subject = "Note Subject"
text = Note.as_text(content, subject)
# The method should combine subject and content appropriately
assert subject in text
assert content in text
def test_sync_note_with_long_content(db_session, qdrant):
"""Test sync_note with longer content to ensure proper chunking."""
long_content = "This is a longer note content. " * 100 # Make it substantial
result = notes.sync_note(
subject="Long Note",
content=long_content,
tags=["long", "test"],
)
note = db_session.query(Note).filter_by(subject="Long Note").first()
assert note is not None
assert note.content == long_content
assert result["status"] == "processed"
assert result["chunks_count"] > 0
def test_sync_note_unicode_content(db_session, qdrant):
"""Test sync_note with unicode content."""
unicode_content = "This note contains unicode: 你好世界 🌍 математика"
result = notes.sync_note(
subject="Unicode Note",
content=unicode_content,
)
note = db_session.query(Note).filter_by(subject="Unicode Note").first()
assert note is not None
assert note.content == unicode_content
assert result["status"] == "processed"
@patch("memory.workers.tasks.notes.sync_note")
def test_sync_notes_recursive_discovery(mock_sync_note, db_session):
"""Test that sync_notes discovers files recursively in subdirectories."""
mock_sync_note.delay.return_value = Mock(id="task-123")
# Create nested directory structure
notes_dir = pathlib.Path(settings.NOTES_STORAGE_DIR)
deep_dir = notes_dir / "level1" / "level2" / "level3"
deep_dir.mkdir(parents=True, exist_ok=True)
deep_file = deep_dir / "deep_note.md"
deep_file.write_text("This is a note in a deep subdirectory")
result = notes.sync_notes(settings.NOTES_STORAGE_DIR)
# Should find the deep file
assert result["new_notes"] >= 1
# Verify the file was processed
processed_files = list(pathlib.Path(settings.NOTES_STORAGE_DIR).rglob("*.md"))
assert any("deep_note.md" in str(f) for f in processed_files)
@patch("memory.workers.tasks.notes.sync_note")
def test_sync_notes_handles_file_read_errors(mock_sync_note, db_session):
"""Test sync_notes handles file read errors gracefully."""
# Create a markdown file
notes_dir = pathlib.Path(settings.NOTES_STORAGE_DIR)
notes_dir.mkdir(parents=True, exist_ok=True)
test_file = notes_dir / "test.md"
test_file.write_text("Test content")
# Mock sync_note to raise an exception
mock_sync_note.delay.side_effect = Exception("File read error")
# This should not crash the whole operation
result = notes.sync_notes(settings.NOTES_STORAGE_DIR)
# Should catch the error and return error status
assert result["status"] == "error"
assert "File read error" in result["error"]