diff --git a/db/migrations/env.py b/db/migrations/env.py index c1ab961..3dc310c 100644 --- a/db/migrations/env.py +++ b/db/migrations/env.py @@ -12,6 +12,32 @@ from alembic import context from memory.common import settings from memory.common.db.models import Base +# Import all models to ensure they're registered with Base.metadata +from memory.common.db.models import ( + SourceItem, + Chunk, + MailMessage, + EmailAttachment, + ChatMessage, + BlogPost, + Comic, + BookSection, + ForumPost, + GithubItem, + GitCommit, + Photo, + MiscDoc, + AgentObservation, + ObservationContradiction, + ReactionPattern, + ObservationPattern, + BeliefCluster, + ConversationMetrics, + Book, + ArticleFeed, + EmailAccount, +) + # this is the Alembic Config object config = context.config diff --git a/db/migrations/versions/20250531_154947_add_observation_models.py b/db/migrations/versions/20250531_154947_add_observation_models.py new file mode 100644 index 0000000..2b7d88d --- /dev/null +++ b/db/migrations/versions/20250531_154947_add_observation_models.py @@ -0,0 +1,279 @@ +"""Add observation models + +Revision ID: 6554eb260176 +Revises: 2524646f56f6 +Create Date: 2025-05-31 15:49:47.579256 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "6554eb260176" +down_revision: Union[str, None] = "2524646f56f6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "belief_cluster", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("cluster_name", sa.Text(), nullable=False), + sa.Column("core_beliefs", sa.ARRAY(sa.Text()), nullable=False), + sa.Column("peripheral_beliefs", sa.ARRAY(sa.Text()), nullable=True), + sa.Column( + "internal_consistency", sa.Numeric(precision=3, scale=2), nullable=True + ), + sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "last_updated", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "cluster_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "belief_cluster_consistency_idx", + "belief_cluster", + ["internal_consistency"], + unique=False, + ) + op.create_index( + "belief_cluster_name_idx", "belief_cluster", ["cluster_name"], unique=False + ) + op.create_table( + "conversation_metrics", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("session_id", sa.UUID(), nullable=False), + sa.Column("depth_score", sa.Numeric(precision=3, scale=2), nullable=True), + sa.Column("breakthrough_count", sa.Integer(), nullable=True), + sa.Column( + "challenge_acceptance", sa.Numeric(precision=3, scale=2), nullable=True + ), + sa.Column("new_insights", sa.Integer(), nullable=True), + sa.Column("user_engagement", sa.Numeric(precision=3, scale=2), nullable=True), + sa.Column("duration_minutes", sa.Integer(), nullable=True), + sa.Column("observation_count", sa.Integer(), nullable=True), + sa.Column("contradiction_count", sa.Integer(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "metrics_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "conv_metrics_breakthrough_idx", + "conversation_metrics", + ["breakthrough_count"], + unique=False, + ) + op.create_index( + "conv_metrics_depth_idx", "conversation_metrics", ["depth_score"], unique=False + ) + op.create_index( + "conv_metrics_session_idx", "conversation_metrics", ["session_id"], unique=True + ) + op.create_table( + "observation_pattern", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("pattern_type", sa.Text(), nullable=False), + sa.Column("description", sa.Text(), nullable=False), + sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=False), + sa.Column("exceptions", sa.ARRAY(sa.BigInteger()), nullable=True), + sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False), + sa.Column("validity_start", sa.DateTime(timezone=True), nullable=True), + sa.Column("validity_end", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "pattern_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "obs_pattern_confidence_idx", + "observation_pattern", + ["confidence"], + unique=False, + ) + op.create_index( + "obs_pattern_type_idx", "observation_pattern", ["pattern_type"], unique=False + ) + op.create_index( + "obs_pattern_validity_idx", + "observation_pattern", + ["validity_start", "validity_end"], + unique=False, + ) + op.create_table( + "reaction_pattern", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("trigger_type", sa.Text(), nullable=False), + sa.Column("reaction_type", sa.Text(), nullable=False), + sa.Column("frequency", sa.Numeric(precision=3, scale=2), nullable=False), + sa.Column( + "first_observed", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "last_observed", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column("example_observations", sa.ARRAY(sa.BigInteger()), nullable=True), + sa.Column( + "reaction_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "reaction_frequency_idx", "reaction_pattern", ["frequency"], unique=False + ) + op.create_index( + "reaction_trigger_idx", "reaction_pattern", ["trigger_type"], unique=False + ) + op.create_index( + "reaction_type_idx", "reaction_pattern", ["reaction_type"], unique=False + ) + op.create_table( + "agent_observation", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("session_id", sa.UUID(), nullable=True), + sa.Column("observation_type", sa.Text(), nullable=False), + sa.Column("subject", sa.Text(), nullable=False), + sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False), + sa.Column("evidence", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("agent_model", sa.Text(), nullable=False), + sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "agent_obs_confidence_idx", "agent_observation", ["confidence"], unique=False + ) + op.create_index( + "agent_obs_model_idx", "agent_observation", ["agent_model"], unique=False + ) + op.create_index( + "agent_obs_session_idx", "agent_observation", ["session_id"], unique=False + ) + op.create_index( + "agent_obs_subject_idx", "agent_observation", ["subject"], unique=False + ) + op.create_index( + "agent_obs_type_idx", "agent_observation", ["observation_type"], unique=False + ) + op.create_table( + "observation_contradiction", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("observation_1_id", sa.BigInteger(), nullable=False), + sa.Column("observation_2_id", sa.BigInteger(), nullable=False), + sa.Column("contradiction_type", sa.Text(), nullable=False), + sa.Column( + "detected_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column("detection_method", sa.Text(), nullable=False), + sa.Column("resolution", sa.Text(), nullable=True), + sa.Column( + "observation_metadata", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["observation_1_id"], ["agent_observation.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["observation_2_id"], ["agent_observation.id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "obs_contra_method_idx", + "observation_contradiction", + ["detection_method"], + unique=False, + ) + op.create_index( + "obs_contra_obs1_idx", + "observation_contradiction", + ["observation_1_id"], + unique=False, + ) + op.create_index( + "obs_contra_obs2_idx", + "observation_contradiction", + ["observation_2_id"], + unique=False, + ) + op.create_index( + "obs_contra_type_idx", + "observation_contradiction", + ["contradiction_type"], + unique=False, + ) + op.add_column("chunk", sa.Column("collection_name", sa.Text(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("chunk", "collection_name") + op.drop_index("obs_contra_type_idx", table_name="observation_contradiction") + op.drop_index("obs_contra_obs2_idx", table_name="observation_contradiction") + op.drop_index("obs_contra_obs1_idx", table_name="observation_contradiction") + op.drop_index("obs_contra_method_idx", table_name="observation_contradiction") + op.drop_table("observation_contradiction") + op.drop_index("agent_obs_type_idx", table_name="agent_observation") + op.drop_index("agent_obs_subject_idx", table_name="agent_observation") + op.drop_index("agent_obs_session_idx", table_name="agent_observation") + op.drop_index("agent_obs_model_idx", table_name="agent_observation") + op.drop_index("agent_obs_confidence_idx", table_name="agent_observation") + op.drop_table("agent_observation") + op.drop_index("reaction_type_idx", table_name="reaction_pattern") + op.drop_index("reaction_trigger_idx", table_name="reaction_pattern") + op.drop_index("reaction_frequency_idx", table_name="reaction_pattern") + op.drop_table("reaction_pattern") + op.drop_index("obs_pattern_validity_idx", table_name="observation_pattern") + op.drop_index("obs_pattern_type_idx", table_name="observation_pattern") + op.drop_index("obs_pattern_confidence_idx", table_name="observation_pattern") + op.drop_table("observation_pattern") + op.drop_index("conv_metrics_session_idx", table_name="conversation_metrics") + op.drop_index("conv_metrics_depth_idx", table_name="conversation_metrics") + op.drop_index("conv_metrics_breakthrough_idx", table_name="conversation_metrics") + op.drop_table("conversation_metrics") + op.drop_index("belief_cluster_name_idx", table_name="belief_cluster") + op.drop_index("belief_cluster_consistency_idx", table_name="belief_cluster") + op.drop_table("belief_cluster") diff --git a/src/memory/common/db/__init__.py b/src/memory/common/db/__init__.py index 91a2a75..2c908a5 100644 --- a/src/memory/common/db/__init__.py +++ b/src/memory/common/db/__init__.py @@ -2,7 +2,7 @@ Database utilities package. """ -from memory.common.db.models import Base +from memory.common.db.models.base import Base from memory.common.db.connection import ( get_engine, get_session_factory, diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py new file mode 100644 index 0000000..2f1d997 --- /dev/null +++ b/src/memory/common/db/models/__init__.py @@ -0,0 +1,61 @@ +from memory.common.db.models.base import Base +from memory.common.db.models.source_item import ( + Chunk, + SourceItem, + clean_filename, +) +from memory.common.db.models.source_items import ( + MailMessage, + EmailAttachment, + AgentObservation, + ChatMessage, + BlogPost, + Comic, + BookSection, + ForumPost, + GithubItem, + GitCommit, + Photo, + MiscDoc, +) +from memory.common.db.models.observations import ( + ObservationContradiction, + ReactionPattern, + ObservationPattern, + BeliefCluster, + ConversationMetrics, +) +from memory.common.db.models.sources import ( + Book, + ArticleFeed, + EmailAccount, +) + +__all__ = [ + "Base", + "Chunk", + "clean_filename", + "SourceItem", + "MailMessage", + "EmailAttachment", + "AgentObservation", + "ChatMessage", + "BlogPost", + "Comic", + "BookSection", + "ForumPost", + "GithubItem", + "GitCommit", + "Photo", + "MiscDoc", + # Observations + "ObservationContradiction", + "ReactionPattern", + "ObservationPattern", + "BeliefCluster", + "ConversationMetrics", + # Sources + "Book", + "ArticleFeed", + "EmailAccount", +] diff --git a/src/memory/common/db/models/base.py b/src/memory/common/db/models/base.py new file mode 100644 index 0000000..00ea8e1 --- /dev/null +++ b/src/memory/common/db/models/base.py @@ -0,0 +1,4 @@ +from sqlalchemy.ext.declarative import declarative_base + + +Base = declarative_base() diff --git a/src/memory/common/db/models/observations.py b/src/memory/common/db/models/observations.py new file mode 100644 index 0000000..a106ccd --- /dev/null +++ b/src/memory/common/db/models/observations.py @@ -0,0 +1,171 @@ +""" +Agent observation models for the epistemic sparring partner system. +""" + +from sqlalchemy import ( + ARRAY, + BigInteger, + Column, + DateTime, + ForeignKey, + Index, + Integer, + Numeric, + Text, + UUID, + func, +) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import relationship + +from memory.common.db.models.base import Base + + +class ObservationContradiction(Base): + """ + Tracks contradictions between observations. + Can be detected automatically or reported by agents. + """ + + __tablename__ = "observation_contradiction" + + id = Column(BigInteger, primary_key=True) + observation_1_id = Column( + BigInteger, + ForeignKey("agent_observation.id", ondelete="CASCADE"), + nullable=False, + ) + observation_2_id = Column( + BigInteger, + ForeignKey("agent_observation.id", ondelete="CASCADE"), + nullable=False, + ) + contradiction_type = Column(Text, nullable=False) # direct, implied, temporal + detected_at = Column(DateTime(timezone=True), server_default=func.now()) + detection_method = Column(Text, nullable=False) # manual, automatic, agent-reported + resolution = Column(Text) # How it was resolved, if at all + observation_metadata = Column(JSONB) + + # Relationships - use string references to avoid circular imports + observation_1 = relationship( + "AgentObservation", + foreign_keys=[observation_1_id], + back_populates="contradictions_as_first", + ) + observation_2 = relationship( + "AgentObservation", + foreign_keys=[observation_2_id], + back_populates="contradictions_as_second", + ) + + __table_args__ = ( + Index("obs_contra_obs1_idx", "observation_1_id"), + Index("obs_contra_obs2_idx", "observation_2_id"), + Index("obs_contra_type_idx", "contradiction_type"), + Index("obs_contra_method_idx", "detection_method"), + ) + + +class ReactionPattern(Base): + """ + Tracks patterns in how the user reacts to certain types of observations or challenges. + """ + + __tablename__ = "reaction_pattern" + + id = Column(BigInteger, primary_key=True) + trigger_type = Column( + Text, nullable=False + ) # What kind of observation triggers this + reaction_type = Column(Text, nullable=False) # How user typically responds + frequency = Column(Numeric(3, 2), nullable=False) # How often this pattern appears + first_observed = Column(DateTime(timezone=True), server_default=func.now()) + last_observed = Column(DateTime(timezone=True), server_default=func.now()) + example_observations = Column( + ARRAY(BigInteger) + ) # IDs of observations showing this pattern + reaction_metadata = Column(JSONB) + + __table_args__ = ( + Index("reaction_trigger_idx", "trigger_type"), + Index("reaction_type_idx", "reaction_type"), + Index("reaction_frequency_idx", "frequency"), + ) + + +class ObservationPattern(Base): + """ + Higher-level patterns detected across multiple observations. + """ + + __tablename__ = "observation_pattern" + + id = Column(BigInteger, primary_key=True) + pattern_type = Column(Text, nullable=False) # behavioral, cognitive, emotional + description = Column(Text, nullable=False) + supporting_observations = Column( + ARRAY(BigInteger), nullable=False + ) # Observation IDs + exceptions = Column(ARRAY(BigInteger)) # Observations that don't fit + confidence = Column(Numeric(3, 2), nullable=False, default=0.7) + validity_start = Column(DateTime(timezone=True)) + validity_end = Column(DateTime(timezone=True)) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now()) + pattern_metadata = Column(JSONB) + + __table_args__ = ( + Index("obs_pattern_type_idx", "pattern_type"), + Index("obs_pattern_confidence_idx", "confidence"), + Index("obs_pattern_validity_idx", "validity_start", "validity_end"), + ) + + +class BeliefCluster(Base): + """ + Groups of related beliefs that support or depend on each other. + """ + + __tablename__ = "belief_cluster" + + id = Column(BigInteger, primary_key=True) + cluster_name = Column(Text, nullable=False) + core_beliefs = Column(ARRAY(Text), nullable=False) + peripheral_beliefs = Column(ARRAY(Text)) + internal_consistency = Column(Numeric(3, 2)) # How well beliefs align + supporting_observations = Column(ARRAY(BigInteger)) # Observation IDs + created_at = Column(DateTime(timezone=True), server_default=func.now()) + last_updated = Column(DateTime(timezone=True), server_default=func.now()) + cluster_metadata = Column(JSONB) + + __table_args__ = ( + Index("belief_cluster_name_idx", "cluster_name"), + Index("belief_cluster_consistency_idx", "internal_consistency"), + ) + + +class ConversationMetrics(Base): + """ + Tracks the effectiveness and depth of conversations. + """ + + __tablename__ = "conversation_metrics" + + id = Column(BigInteger, primary_key=True) + session_id = Column(UUID(as_uuid=True), nullable=False) + depth_score = Column(Numeric(3, 2)) # How deep the conversation went + breakthrough_count = Column(Integer, default=0) + challenge_acceptance = Column(Numeric(3, 2)) # How well challenges were received + new_insights = Column(Integer, default=0) + user_engagement = Column(Numeric(3, 2)) # Inferred engagement level + duration_minutes = Column(Integer) + observation_count = Column(Integer, default=0) + contradiction_count = Column(Integer, default=0) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + metrics_metadata = Column(JSONB) + + __table_args__ = ( + Index("conv_metrics_session_idx", "session_id", unique=True), + Index("conv_metrics_depth_idx", "depth_score"), + Index("conv_metrics_breakthrough_idx", "breakthrough_count"), + ) diff --git a/src/memory/common/db/models/source_item.py b/src/memory/common/db/models/source_item.py new file mode 100644 index 0000000..6d8cf1d --- /dev/null +++ b/src/memory/common/db/models/source_item.py @@ -0,0 +1,287 @@ +""" +Database models for the knowledge base system. +""" + +import pathlib +import re +from typing import Any, Sequence, cast +import uuid + +from PIL import Image +from sqlalchemy import ( + ARRAY, + UUID, + BigInteger, + CheckConstraint, + Column, + DateTime, + ForeignKey, + Index, + Integer, + String, + Text, + event, + func, +) +from sqlalchemy.dialects.postgresql import BYTEA +from sqlalchemy.orm import Session, relationship + +from memory.common import settings +import memory.common.extract as extract +import memory.common.collections as collections +import memory.common.chunker as chunker +import memory.common.summarizer as summarizer +from memory.common.db.models.base import Base + + +@event.listens_for(Session, "before_flush") +def handle_duplicate_sha256(session, flush_context, instances): + """ + Event listener that efficiently checks for duplicate sha256 values before flush + and removes items with duplicate sha256 from the session. + + Uses a single query to identify all duplicates rather than querying for each item. + """ + # Find all SourceItem objects being added + new_items = [obj for obj in session.new if isinstance(obj, SourceItem)] + if not new_items: + return + + items = {} + for item in new_items: + try: + if (sha256 := item.sha256) is None: + continue + + if sha256 in items: + session.expunge(item) + continue + + items[sha256] = item + except (AttributeError, TypeError): + continue + + if not new_items: + return + + # Query database for existing items with these sha256 values in a single query + existing_sha256s = set( + row[0] + for row in session.query(SourceItem.sha256).filter( + SourceItem.sha256.in_(items.keys()) + ) + ) + + # Remove objects with duplicate sha256 values from the session + for sha256 in existing_sha256s: + if sha256 in items: + session.expunge(items[sha256]) + + +def clean_filename(filename: str) -> str: + return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_") + + +def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]: + for i, image in enumerate(images): + if not image.filename: # type: ignore + filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore + image.save(filename) + image.filename = str(filename) # type: ignore + + return [image.filename for image in images] # type: ignore + + +def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]: + return [chunk] + [ + i + for i in images + if getattr(i, "filename", None) and i.filename in chunk # type: ignore + ] + + +def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]: + final = {} + for m in metadata: + if tags := set(m.pop("tags", [])): + final["tags"] = tags | final.get("tags", set()) + final |= m + return final + + +def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]: + if not content.strip(): + return [] + + images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths] + + summary, tags = summarizer.summarize(content) + full_text: extract.DataChunk = extract.DataChunk( + data=[content.strip(), *images], metadata={"tags": tags} + ) + + chunks: list[extract.DataChunk] = [full_text] + tokens = chunker.approx_token_count(content) + if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2: + chunks += [ + extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags}) + for c in chunker.chunk_text(content) + ] + chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags})) + + return [c for c in chunks if c.data] + + +class Chunk(Base): + """Stores content chunks with their vector embeddings.""" + + __tablename__ = "chunk" + + # The ID is also used as the vector ID in the vector database + id = Column( + UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() + ) + source_id = Column( + BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False + ) + file_paths = Column( + ARRAY(Text), nullable=True + ) # Path to content if stored as a file + content = Column(Text) # Direct content storage + embedding_model = Column(Text) + collection_name = Column(Text) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + checked_at = Column(DateTime(timezone=True), server_default=func.now()) + vector: list[float] = [] + item_metadata: dict[str, Any] = {} + images: list[Image.Image] = [] + + # One of file_path or content must be populated + __table_args__ = ( + CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"), + Index("chunk_source_idx", "source_id"), + ) + + @property + def chunks(self) -> list[extract.MulitmodalChunk]: + chunks: list[extract.MulitmodalChunk] = [] + if cast(str | None, self.content): + chunks = [cast(str, self.content)] + if self.images: + chunks += self.images + elif cast(Sequence[str] | None, self.file_paths): + chunks += [ + Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths + ] + return chunks + + @property + def data(self) -> list[bytes | str | Image.Image]: + if self.file_paths is None: + return [cast(str, self.content)] + + paths = [pathlib.Path(cast(str, p)) for p in self.file_paths] + files = [path for path in paths if path.exists()] + + items = [] + for file_path in files: + if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}: + if file_path.exists(): + items.append(Image.open(file_path)) + elif file_path.suffix == ".bin": + items.append(file_path.read_bytes()) + else: + items.append(file_path.read_text()) + return items + + +class SourceItem(Base): + """Base class for all content in the system using SQLAlchemy's joined table inheritance.""" + + __tablename__ = "source_item" + __allow_unmapped__ = True + + id = Column(BigInteger, primary_key=True) + modality = Column(Text, nullable=False) + sha256 = Column(BYTEA, nullable=False, unique=True) + inserted_at = Column(DateTime(timezone=True), server_default=func.now()) + tags = Column(ARRAY(Text), nullable=False, server_default="{}") + size = Column(Integer) + mime_type = Column(Text) + + # Content is stored in the database if it's small enough and text + content = Column(Text) + # Otherwise the content is stored on disk + filename = Column(Text, nullable=True) + + # Chunks relationship + embed_status = Column(Text, nullable=False, server_default="RAW") + chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan") + + # Discriminator column for SQLAlchemy inheritance + type = Column(String(50)) + + __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"} + + # Add table-level constraint and indexes + __table_args__ = ( + CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"), + Index("source_modality_idx", "modality"), + Index("source_status_idx", "embed_status"), + Index("source_tags_idx", "tags", postgresql_using="gin"), + Index("source_filename_idx", "filename"), + ) + + @property + def vector_ids(self): + """Get vector IDs from associated chunks.""" + return [chunk.id for chunk in self.chunks] + + def _chunk_contents(self) -> Sequence[extract.DataChunk]: + chunks: list[extract.DataChunk] = [] + content = cast(str | None, self.content) + if content: + chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)] + + if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2: + summary, tags = summarizer.summarize(content) + chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags})) + + mime_type = cast(str | None, self.mime_type) + if mime_type and mime_type.startswith("image/"): + chunks.append(extract.DataChunk(data=[Image.open(self.filename)])) + return chunks + + def _make_chunk( + self, data: extract.DataChunk, metadata: dict[str, Any] = {} + ) -> Chunk: + chunk_id = str(uuid.uuid4()) + text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip()) + images = [c for c in data.data if isinstance(c, Image.Image)] + image_names = image_filenames(chunk_id, images) + + chunk = Chunk( + id=chunk_id, + source=self, + content=text or None, + images=images, + file_paths=image_names, + collection_name=data.collection_name or cast(str, self.modality), + embedding_model=collections.collection_model(cast(str, self.modality)), + item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata), + ) + return chunk + + def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: + return [self._make_chunk(data) for data in self._chunk_contents()] + + def as_payload(self) -> dict: + return { + "source_id": self.id, + "tags": self.tags, + "size": self.size, + } + + @property + def display_contents(self) -> str | None: + return cast(str | None, self.content) or cast(str | None, self.filename) diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models/source_items.py similarity index 55% rename from src/memory/common/db/models.py rename to src/memory/common/db/models/source_items.py index 52e7735..2dcf47e 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models/source_items.py @@ -3,18 +3,14 @@ Database models for the knowledge base system. """ import pathlib -import re import textwrap from datetime import datetime from typing import Any, Sequence, cast -import uuid from PIL import Image from sqlalchemy import ( ARRAY, - UUID, BigInteger, - Boolean, CheckConstraint, Column, DateTime, @@ -22,273 +18,23 @@ from sqlalchemy import ( Index, Integer, Numeric, - String, Text, - event, func, ) -from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, relationship +from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR, UUID +from sqlalchemy.orm import relationship from memory.common import settings import memory.common.extract as extract -import memory.common.collections as collections -import memory.common.chunker as chunker import memory.common.summarizer as summarizer -Base = declarative_base() - - -@event.listens_for(Session, "before_flush") -def handle_duplicate_sha256(session, flush_context, instances): - """ - Event listener that efficiently checks for duplicate sha256 values before flush - and removes items with duplicate sha256 from the session. - - Uses a single query to identify all duplicates rather than querying for each item. - """ - # Find all SourceItem objects being added - new_items = [obj for obj in session.new if isinstance(obj, SourceItem)] - if not new_items: - return - - items = {} - for item in new_items: - try: - if (sha256 := item.sha256) is None: - continue - - if sha256 in items: - session.expunge(item) - continue - - items[sha256] = item - except (AttributeError, TypeError): - continue - - if not new_items: - return - - # Query database for existing items with these sha256 values in a single query - existing_sha256s = set( - row[0] - for row in session.query(SourceItem.sha256).filter( - SourceItem.sha256.in_(items.keys()) - ) - ) - - # Remove objects with duplicate sha256 values from the session - for sha256 in existing_sha256s: - if sha256 in items: - session.expunge(items[sha256]) - - -def clean_filename(filename: str) -> str: - return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_") - - -def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]: - for i, image in enumerate(images): - if not image.filename: # type: ignore - filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore - image.save(filename) - image.filename = str(filename) # type: ignore - - return [image.filename for image in images] # type: ignore - - -def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]: - return [chunk] + [ - i - for i in images - if getattr(i, "filename", None) and i.filename in chunk # type: ignore - ] - - -def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]: - final = {} - for m in metadata: - if tags := set(m.pop("tags", [])): - final["tags"] = tags | final.get("tags", set()) - final |= m - return final - - -def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]: - if not content.strip(): - return [] - - images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths] - - summary, tags = summarizer.summarize(content) - full_text: extract.DataChunk = extract.DataChunk( - data=[content.strip(), *images], metadata={"tags": tags} - ) - - chunks: list[extract.DataChunk] = [full_text] - tokens = chunker.approx_token_count(content) - if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2: - chunks += [ - extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags}) - for c in chunker.chunk_text(content) - ] - chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags})) - - return [c for c in chunks if c.data] - - -class Chunk(Base): - """Stores content chunks with their vector embeddings.""" - - __tablename__ = "chunk" - - # The ID is also used as the vector ID in the vector database - id = Column( - UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4() - ) - source_id = Column( - BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False - ) - file_paths = Column( - ARRAY(Text), nullable=True - ) # Path to content if stored as a file - content = Column(Text) # Direct content storage - embedding_model = Column(Text) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - checked_at = Column(DateTime(timezone=True), server_default=func.now()) - vector: list[float] = [] - item_metadata: dict[str, Any] = {} - images: list[Image.Image] = [] - - # One of file_path or content must be populated - __table_args__ = ( - CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"), - Index("chunk_source_idx", "source_id"), - ) - - @property - def chunks(self) -> list[extract.MulitmodalChunk]: - chunks: list[extract.MulitmodalChunk] = [] - if cast(str | None, self.content): - chunks = [cast(str, self.content)] - if self.images: - chunks += self.images - elif cast(Sequence[str] | None, self.file_paths): - chunks += [ - Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths - ] - return chunks - - @property - def data(self) -> list[bytes | str | Image.Image]: - if self.file_paths is None: - return [cast(str, self.content)] - - paths = [pathlib.Path(cast(str, p)) for p in self.file_paths] - files = [path for path in paths if path.exists()] - - items = [] - for file_path in files: - if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}: - if file_path.exists(): - items.append(Image.open(file_path)) - elif file_path.suffix == ".bin": - items.append(file_path.read_bytes()) - else: - items.append(file_path.read_text()) - return items - - -class SourceItem(Base): - """Base class for all content in the system using SQLAlchemy's joined table inheritance.""" - - __tablename__ = "source_item" - __allow_unmapped__ = True - - id = Column(BigInteger, primary_key=True) - modality = Column(Text, nullable=False) - sha256 = Column(BYTEA, nullable=False, unique=True) - inserted_at = Column(DateTime(timezone=True), server_default=func.now()) - tags = Column(ARRAY(Text), nullable=False, server_default="{}") - size = Column(Integer) - mime_type = Column(Text) - - # Content is stored in the database if it's small enough and text - content = Column(Text) - # Otherwise the content is stored on disk - filename = Column(Text, nullable=True) - - # Chunks relationship - embed_status = Column(Text, nullable=False, server_default="RAW") - chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan") - - # Discriminator column for SQLAlchemy inheritance - type = Column(String(50)) - - __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"} - - # Add table-level constraint and indexes - __table_args__ = ( - CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"), - Index("source_modality_idx", "modality"), - Index("source_status_idx", "embed_status"), - Index("source_tags_idx", "tags", postgresql_using="gin"), - Index("source_filename_idx", "filename"), - ) - - @property - def vector_ids(self): - """Get vector IDs from associated chunks.""" - return [chunk.id for chunk in self.chunks] - - def _chunk_contents(self) -> Sequence[extract.DataChunk]: - chunks: list[extract.DataChunk] = [] - content = cast(str | None, self.content) - if content: - chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)] - - if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2: - summary, tags = summarizer.summarize(content) - chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags})) - - mime_type = cast(str | None, self.mime_type) - if mime_type and mime_type.startswith("image/"): - chunks.append(extract.DataChunk(data=[Image.open(self.filename)])) - return chunks - - def _make_chunk( - self, data: extract.DataChunk, metadata: dict[str, Any] = {} - ) -> Chunk: - chunk_id = str(uuid.uuid4()) - text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip()) - images = [c for c in data.data if isinstance(c, Image.Image)] - image_names = image_filenames(chunk_id, images) - - chunk = Chunk( - id=chunk_id, - source=self, - content=text or None, - images=images, - file_paths=image_names, - embedding_model=collections.collection_model(cast(str, self.modality)), - item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata), - ) - return chunk - - def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: - return [self._make_chunk(data) for data in self._chunk_contents()] - - def as_payload(self) -> dict: - return { - "source_id": self.id, - "tags": self.tags, - "size": self.size, - } - - @property - def display_contents(self) -> str | None: - return cast(str | None, self.content) or cast(str | None, self.filename) +from memory.common.db.models.source_item import ( + SourceItem, + Chunk, + clean_filename, + merge_metadata, + chunk_mixed, +) class MailMessage(SourceItem): @@ -528,51 +274,6 @@ class Comic(SourceItem): return [extract.DataChunk(data=[image, description])] -class Book(Base): - """Book-level metadata table""" - - __tablename__ = "book" - - id = Column(BigInteger, primary_key=True) - isbn = Column(Text, unique=True) - title = Column(Text, nullable=False) - author = Column(Text) - publisher = Column(Text) - published = Column(DateTime(timezone=True)) - language = Column(Text) - edition = Column(Text) - series = Column(Text) - series_number = Column(Integer) - total_pages = Column(Integer) - file_path = Column(Text) - tags = Column(ARRAY(Text), nullable=False, server_default="{}") - - # Metadata from ebook parser - book_metadata = Column(JSONB, name="metadata") - - created_at = Column(DateTime(timezone=True), server_default=func.now()) - - __table_args__ = ( - Index("book_isbn_idx", "isbn"), - Index("book_author_idx", "author"), - Index("book_title_idx", "title"), - ) - - def as_payload(self) -> dict: - return { - **super().as_payload(), - "isbn": self.isbn, - "title": self.title, - "author": self.author, - "publisher": self.publisher, - "published": self.published, - "language": self.language, - "edition": self.edition, - "series": self.series, - "series_number": self.series_number, - } | (cast(dict, self.book_metadata) or {}) - - class BookSection(SourceItem): """Individual sections/chapters of books""" @@ -797,58 +498,73 @@ class GithubItem(SourceItem): ) -class ArticleFeed(Base): - __tablename__ = "article_feeds" +class AgentObservation(SourceItem): + """ + Records observations made by AI agents about the user. + This is the primary data model for the epistemic sparring partner. + """ - id = Column(BigInteger, primary_key=True) - url = Column(Text, nullable=False, unique=True) - title = Column(Text) - description = Column(Text) - tags = Column(ARRAY(Text), nullable=False, server_default="{}") - check_interval = Column( - Integer, nullable=False, server_default="60", doc="Minutes between checks" + __tablename__ = "agent_observation" + + id = Column( + BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True ) - last_checked_at = Column(DateTime(timezone=True)) - active = Column(Boolean, nullable=False, server_default="true") - created_at = Column( - DateTime(timezone=True), nullable=False, server_default=func.now() + session_id = Column( + UUID(as_uuid=True) + ) # Groups observations from same conversation + observation_type = Column( + Text, nullable=False + ) # belief, preference, pattern, contradiction, behavior + subject = Column(Text, nullable=False) # What/who the observation is about + confidence = Column(Numeric(3, 2), nullable=False, default=0.8) # 0.0-1.0 + evidence = Column(JSONB) # Supporting context, quotes, etc. + agent_model = Column(Text, nullable=False) # Which AI model made this observation + + # Relationships + contradictions_as_first = relationship( + "ObservationContradiction", + foreign_keys="ObservationContradiction.observation_1_id", + back_populates="observation_1", + cascade="all, delete-orphan", ) - updated_at = Column( - DateTime(timezone=True), nullable=False, server_default=func.now() + contradictions_as_second = relationship( + "ObservationContradiction", + foreign_keys="ObservationContradiction.observation_2_id", + back_populates="observation_2", + cascade="all, delete-orphan", ) - # Add indexes + __mapper_args__ = { + "polymorphic_identity": "agent_observation", + } + __table_args__ = ( - Index("article_feeds_active_idx", "active", "last_checked_at"), - Index("article_feeds_tags_idx", "tags", postgresql_using="gin"), + Index("agent_obs_session_idx", "session_id"), + Index("agent_obs_type_idx", "observation_type"), + Index("agent_obs_subject_idx", "subject"), + Index("agent_obs_confidence_idx", "confidence"), + Index("agent_obs_model_idx", "agent_model"), ) + def __init__(self, **kwargs): + if not kwargs.get("modality"): + kwargs["modality"] = "observation" + super().__init__(**kwargs) -class EmailAccount(Base): - __tablename__ = "email_accounts" + def as_payload(self) -> dict: + payload = { + **super().as_payload(), + "observation_type": self.observation_type, + "subject": self.subject, + "confidence": float(cast(Any, self.confidence)), + "evidence": self.evidence, + "agent_model": self.agent_model, + } + if self.session_id is not None: + payload["session_id"] = str(self.session_id) + return payload - id = Column(BigInteger, primary_key=True) - name = Column(Text, nullable=False) - email_address = Column(Text, nullable=False, unique=True) - imap_server = Column(Text, nullable=False) - imap_port = Column(Integer, nullable=False, server_default="993") - username = Column(Text, nullable=False) - password = Column(Text, nullable=False) - use_ssl = Column(Boolean, nullable=False, server_default="true") - folders = Column(ARRAY(Text), nullable=False, server_default="{}") - tags = Column(ARRAY(Text), nullable=False, server_default="{}") - last_sync_at = Column(DateTime(timezone=True)) - active = Column(Boolean, nullable=False, server_default="true") - created_at = Column( - DateTime(timezone=True), nullable=False, server_default=func.now() - ) - updated_at = Column( - DateTime(timezone=True), nullable=False, server_default=func.now() - ) - - # Add indexes - __table_args__ = ( - Index("email_accounts_address_idx", "email_address", unique=True), - Index("email_accounts_active_idx", "active", "last_sync_at"), - Index("email_accounts_tags_idx", "tags", postgresql_using="gin"), - ) + @property + def all_contradictions(self): + """Get all contradictions involving this observation.""" + return self.contradictions_as_first + self.contradictions_as_second diff --git a/src/memory/common/db/models/sources.py b/src/memory/common/db/models/sources.py new file mode 100644 index 0000000..78e2df9 --- /dev/null +++ b/src/memory/common/db/models/sources.py @@ -0,0 +1,122 @@ +""" +Database models for the knowledge base system. +""" + +from typing import cast + +from sqlalchemy import ( + ARRAY, + BigInteger, + Boolean, + Column, + DateTime, + Index, + Integer, + Text, + func, +) +from sqlalchemy.dialects.postgresql import JSONB + +from memory.common.db.models.base import Base + + +class Book(Base): + """Book-level metadata table""" + + __tablename__ = "book" + + id = Column(BigInteger, primary_key=True) + isbn = Column(Text, unique=True) + title = Column(Text, nullable=False) + author = Column(Text) + publisher = Column(Text) + published = Column(DateTime(timezone=True)) + language = Column(Text) + edition = Column(Text) + series = Column(Text) + series_number = Column(Integer) + total_pages = Column(Integer) + file_path = Column(Text) + tags = Column(ARRAY(Text), nullable=False, server_default="{}") + + # Metadata from ebook parser + book_metadata = Column(JSONB, name="metadata") + + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + __table_args__ = ( + Index("book_isbn_idx", "isbn"), + Index("book_author_idx", "author"), + Index("book_title_idx", "title"), + ) + + def as_payload(self) -> dict: + return { + **super().as_payload(), + "isbn": self.isbn, + "title": self.title, + "author": self.author, + "publisher": self.publisher, + "published": self.published, + "language": self.language, + "edition": self.edition, + "series": self.series, + "series_number": self.series_number, + } | (cast(dict, self.book_metadata) or {}) + + +class ArticleFeed(Base): + __tablename__ = "article_feeds" + + id = Column(BigInteger, primary_key=True) + url = Column(Text, nullable=False, unique=True) + title = Column(Text) + description = Column(Text) + tags = Column(ARRAY(Text), nullable=False, server_default="{}") + check_interval = Column( + Integer, nullable=False, server_default="60", doc="Minutes between checks" + ) + last_checked_at = Column(DateTime(timezone=True)) + active = Column(Boolean, nullable=False, server_default="true") + created_at = Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at = Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + # Add indexes + __table_args__ = ( + Index("article_feeds_active_idx", "active", "last_checked_at"), + Index("article_feeds_tags_idx", "tags", postgresql_using="gin"), + ) + + +class EmailAccount(Base): + __tablename__ = "email_accounts" + + id = Column(BigInteger, primary_key=True) + name = Column(Text, nullable=False) + email_address = Column(Text, nullable=False, unique=True) + imap_server = Column(Text, nullable=False) + imap_port = Column(Integer, nullable=False, server_default="993") + username = Column(Text, nullable=False) + password = Column(Text, nullable=False) + use_ssl = Column(Boolean, nullable=False, server_default="true") + folders = Column(ARRAY(Text), nullable=False, server_default="{}") + tags = Column(ARRAY(Text), nullable=False, server_default="{}") + last_sync_at = Column(DateTime(timezone=True)) + active = Column(Boolean, nullable=False, server_default="true") + created_at = Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at = Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + # Add indexes + __table_args__ = ( + Index("email_accounts_address_idx", "email_address", unique=True), + Index("email_accounts_active_idx", "active", "last_sync_at"), + Index("email_accounts_tags_idx", "tags", postgresql_using="gin"), + ) diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index 785ba7d..6d1ad0e 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -21,6 +21,7 @@ class DataChunk: data: Sequence[MulitmodalChunk] metadata: dict[str, Any] = field(default_factory=dict) mime_type: str = "text/plain" + collection_name: str | None = None @contextmanager diff --git a/src/memory/workers/tasks/content_processing.py b/src/memory/workers/tasks/content_processing.py index 51dd4cc..b026986 100644 --- a/src/memory/workers/tasks/content_processing.py +++ b/src/memory/workers/tasks/content_processing.py @@ -6,13 +6,14 @@ the complete workflow: existence checking, content hashing, embedding generation vector storage, and result tracking. """ +from collections import defaultdict import hashlib import traceback import logging from typing import Any, Callable, Iterable, Sequence, cast from memory.common import embedding, qdrant -from memory.common.db.models import SourceItem +from memory.common.db.models import SourceItem, Chunk logger = logging.getLogger(__name__) @@ -103,7 +104,19 @@ def embed_source_item(source_item: SourceItem) -> int: return 0 -def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str): +def by_collection(chunks: Sequence[Chunk]) -> dict[str, dict[str, Any]]: + collections: dict[str, dict[str, list[Any]]] = defaultdict( + lambda: defaultdict(list) + ) + for chunk in chunks: + collection = collections[cast(str, chunk.collection_name)] + collection["ids"].append(chunk.id) + collection["vectors"].append(chunk.vector) + collection["payloads"].append(chunk.item_metadata) + return collections + + +def push_to_qdrant(source_items: Sequence[SourceItem]): """ Push embeddings to Qdrant vector database. @@ -135,17 +148,16 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str): return try: - vector_ids = [str(chunk.id) for chunk in all_chunks] - vectors = [chunk.vector for chunk in all_chunks] - payloads = [chunk.item_metadata for chunk in all_chunks] - - qdrant.upsert_vectors( - client=qdrant.get_qdrant_client(), - collection_name=collection_name, - ids=vector_ids, - vectors=vectors, - payloads=payloads, - ) + client = qdrant.get_qdrant_client() + collections = by_collection(all_chunks) + for collection_name, collection in collections.items(): + qdrant.upsert_vectors( + client=client, + collection_name=collection_name, + ids=collection["ids"], + vectors=collection["vectors"], + payloads=collection["payloads"], + ) for item in items_to_process: item.embed_status = "STORED" # type: ignore @@ -222,7 +234,7 @@ def process_content_item(item: SourceItem, session) -> dict[str, Any]: return create_task_result(item, status, content_length=getattr(item, "size", 0)) try: - push_to_qdrant([item], cast(str, item.modality)) + push_to_qdrant([item]) status = "processed" item.embed_status = "STORED" # type: ignore logger.info( diff --git a/src/memory/workers/tasks/ebook.py b/src/memory/workers/tasks/ebook.py index 1eedc88..189f243 100644 --- a/src/memory/workers/tasks/ebook.py +++ b/src/memory/workers/tasks/ebook.py @@ -185,7 +185,7 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict: f"Embedded section: {section.section_title} - {section.content[:100]}" ) logger.info("Pushing to Qdrant") - push_to_qdrant(all_sections, "book") + push_to_qdrant(all_sections) logger.info("Committing session") session.commit() diff --git a/tests/memory/common/db/test_models.py b/tests/memory/common/db/test_models.py index 099d7a1..da90fa8 100644 --- a/tests/memory/common/db/test_models.py +++ b/tests/memory/common/db/test_models.py @@ -1,4 +1,3 @@ -from memory.common.db.models import SourceItem from sqlalchemy.orm import Session from unittest.mock import patch, Mock from typing import cast @@ -6,17 +5,20 @@ import pytest from PIL import Image from datetime import datetime from memory.common import settings, chunker, extract -from memory.common.db.models import ( +from memory.common.db.models.sources import Book +from memory.common.db.models.source_items import ( Chunk, - clean_filename, - image_filenames, - add_pics, MailMessage, EmailAttachment, BookSection, BlogPost, - Book, +) +from memory.common.db.models.source_item import ( + SourceItem, + image_filenames, + add_pics, merge_metadata, + clean_filename, ) diff --git a/tests/memory/workers/tasks/test_content_processing.py b/tests/memory/workers/tasks/test_content_processing.py index c3e8467..0918d9c 100644 --- a/tests/memory/workers/tasks/test_content_processing.py +++ b/tests/memory/workers/tasks/test_content_processing.py @@ -18,6 +18,7 @@ from memory.workers.tasks.content_processing import ( process_content_item, push_to_qdrant, safe_task_execution, + by_collection, ) @@ -71,6 +72,7 @@ def mock_chunk(): chunk.id = "00000000-0000-0000-0000-000000000001" chunk.vector = [0.1] * 1024 chunk.item_metadata = {"source_id": 1, "tags": ["test"]} + chunk.collection_name = "mail" return chunk @@ -242,17 +244,19 @@ def test_push_to_qdrant_success(qdrant): mock_chunk1.id = "00000000-0000-0000-0000-000000000001" mock_chunk1.vector = [0.1] * 1024 mock_chunk1.item_metadata = {"source_id": 1, "tags": ["test"]} + mock_chunk1.collection_name = "mail" mock_chunk2 = MagicMock() mock_chunk2.id = "00000000-0000-0000-0000-000000000002" mock_chunk2.vector = [0.2] * 1024 mock_chunk2.item_metadata = {"source_id": 2, "tags": ["test"]} + mock_chunk2.collection_name = "mail" # Assign chunks directly (bypassing SQLAlchemy relationship) item1.chunks = [mock_chunk1] item2.chunks = [mock_chunk2] - push_to_qdrant([item1, item2], "mail") + push_to_qdrant([item1, item2]) assert str(item1.embed_status) == "STORED" assert str(item2.embed_status) == "STORED" @@ -294,6 +298,7 @@ def test_push_to_qdrant_no_processing( mock_chunk.id = f"00000000-0000-0000-0000-00000000000{suffix}" mock_chunk.vector = [0.1] * 1024 mock_chunk.item_metadata = {"source_id": int(suffix), "tags": ["test"]} + mock_chunk.collection_name = "mail" item.chunks = [mock_chunk] else: item.chunks = [] @@ -302,7 +307,7 @@ def test_push_to_qdrant_no_processing( item1 = create_item("1", item1_status, item1_has_chunks) item2 = create_item("2", item2_status, item2_has_chunks) - push_to_qdrant([item1, item2], "mail") + push_to_qdrant([item1, item2]) assert str(item1.embed_status) == expected_item1_status assert str(item2.embed_status) == expected_item2_status @@ -317,7 +322,7 @@ def test_push_to_qdrant_exception(sample_mail_message, mock_chunk): side_effect=Exception("Qdrant error"), ): with pytest.raises(Exception, match="Qdrant error"): - push_to_qdrant([sample_mail_message], "mail") + push_to_qdrant([sample_mail_message]) assert str(sample_mail_message.embed_status) == "FAILED" @@ -418,6 +423,196 @@ def test_create_task_result_no_title(): assert result["chunks_count"] == 0 +def test_by_collection_empty_chunks(): + result = by_collection([]) + assert result == {} + + +def test_by_collection_single_chunk(): + chunk = Chunk( + id="00000000-0000-0000-0000-000000000001", + content="test content", + embedding_model="test-model", + vector=[0.1, 0.2, 0.3], + item_metadata={"source_id": 1, "tags": ["test"]}, + collection_name="test_collection", + ) + + result = by_collection([chunk]) + + assert len(result) == 1 + assert "test_collection" in result + assert result["test_collection"]["ids"] == ["00000000-0000-0000-0000-000000000001"] + assert result["test_collection"]["vectors"] == [[0.1, 0.2, 0.3]] + assert result["test_collection"]["payloads"] == [{"source_id": 1, "tags": ["test"]}] + + +def test_by_collection_multiple_chunks_same_collection(): + chunks = [ + Chunk( + id="00000000-0000-0000-0000-000000000001", + content="test content 1", + embedding_model="test-model", + vector=[0.1, 0.2], + item_metadata={"source_id": 1}, + collection_name="collection_a", + ), + Chunk( + id="00000000-0000-0000-0000-000000000002", + content="test content 2", + embedding_model="test-model", + vector=[0.3, 0.4], + item_metadata={"source_id": 2}, + collection_name="collection_a", + ), + ] + + result = by_collection(chunks) + + assert len(result) == 1 + assert "collection_a" in result + assert result["collection_a"]["ids"] == [ + "00000000-0000-0000-0000-000000000001", + "00000000-0000-0000-0000-000000000002", + ] + assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.3, 0.4]] + assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 2}] + + +def test_by_collection_multiple_chunks_different_collections(): + chunks = [ + Chunk( + id="00000000-0000-0000-0000-000000000001", + content="test content 1", + embedding_model="test-model", + vector=[0.1, 0.2], + item_metadata={"source_id": 1}, + collection_name="collection_a", + ), + Chunk( + id="00000000-0000-0000-0000-000000000002", + content="test content 2", + embedding_model="test-model", + vector=[0.3, 0.4], + item_metadata={"source_id": 2}, + collection_name="collection_b", + ), + Chunk( + id="00000000-0000-0000-0000-000000000003", + content="test content 3", + embedding_model="test-model", + vector=[0.5, 0.6], + item_metadata={"source_id": 3}, + collection_name="collection_a", + ), + ] + + result = by_collection(chunks) + + assert len(result) == 2 + assert "collection_a" in result + assert "collection_b" in result + + # Check collection_a + assert result["collection_a"]["ids"] == [ + "00000000-0000-0000-0000-000000000001", + "00000000-0000-0000-0000-000000000003", + ] + assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.5, 0.6]] + assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 3}] + + # Check collection_b + assert result["collection_b"]["ids"] == ["00000000-0000-0000-0000-000000000002"] + assert result["collection_b"]["vectors"] == [[0.3, 0.4]] + assert result["collection_b"]["payloads"] == [{"source_id": 2}] + + +@pytest.mark.parametrize( + "collection_names,expected_collections", + [ + (["col1", "col1", "col1"], 1), + (["col1", "col2", "col3"], 3), + (["col1", "col2", "col1", "col2"], 2), + (["single"], 1), + ], +) +def test_by_collection_various_groupings(collection_names, expected_collections): + chunks = [ + Chunk( + id=f"00000000-0000-0000-0000-00000000000{i}", + content=f"test content {i}", + embedding_model="test-model", + vector=[float(i)], + item_metadata={"index": i}, + collection_name=collection_name, + ) + for i, collection_name in enumerate(collection_names, 1) + ] + + result = by_collection(chunks) + + assert len(result) == expected_collections + # Verify all chunks are accounted for + total_chunks = sum(len(coll["ids"]) for coll in result.values()) + assert total_chunks == len(chunks) + + +def test_by_collection_with_none_values(): + chunks = [ + Chunk( + id="00000000-0000-0000-0000-000000000001", + content="test content", + embedding_model="test-model", + vector=None, # None vector + item_metadata=None, # None metadata + collection_name="test_collection", + ), + Chunk( + id="00000000-0000-0000-0000-000000000002", + content="test content 2", + embedding_model="test-model", + vector=[0.1, 0.2], + item_metadata={"key": "value"}, + collection_name="test_collection", + ), + ] + + result = by_collection(chunks) + + assert len(result) == 1 + assert "test_collection" in result + assert result["test_collection"]["ids"] == [ + "00000000-0000-0000-0000-000000000001", + "00000000-0000-0000-0000-000000000002", + ] + assert result["test_collection"]["vectors"] == [None, [0.1, 0.2]] + assert result["test_collection"]["payloads"] == [None, {"key": "value"}] + + +def test_by_collection_preserves_order(): + chunks = [] + for i in range(5): + chunks.append( + Chunk( + id=f"00000000-0000-0000-0000-00000000000{i}", + content=f"test content {i}", + embedding_model="test-model", + vector=[float(i)], + item_metadata={"order": i}, + collection_name="ordered_collection", + ) + ) + + result = by_collection(chunks) + + assert len(result) == 1 + assert result["ordered_collection"]["ids"] == [ + f"00000000-0000-0000-0000-00000000000{i}" for i in range(5) + ] + assert result["ordered_collection"]["vectors"] == [[float(i)] for i in range(5)] + assert result["ordered_collection"]["payloads"] == [{"order": i} for i in range(5)] + + @pytest.mark.parametrize( "embedding_return,qdrant_error,expected_status,expected_embed_status", [ @@ -459,6 +654,7 @@ def test_process_content_item( embedding_model="test-model", vector=[0.1] * 1024, item_metadata={"source_id": 1, "tags": ["test"]}, + collection_name="mail", ) mock_chunks = [real_chunk] else: # empty diff --git a/tests/memory/workers/tasks/test_maintenance.py b/tests/memory/workers/tasks/test_maintenance.py index 6a48c8b..38649fb 100644 --- a/tests/memory/workers/tasks/test_maintenance.py +++ b/tests/memory/workers/tasks/test_maintenance.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime, timedelta from unittest.mock import patch, call +from typing import cast import pytest from PIL import Image @@ -350,6 +351,7 @@ def test_reingest_item_success(db_session, qdrant, item_type): id=chunk_id, source=item, content=f"Test chunk content {i}", + collection_name=item.modality, embedding_model="test-model", ) for i, chunk_id in enumerate(chunk_ids) @@ -358,7 +360,7 @@ def test_reingest_item_success(db_session, qdrant, item_type): db_session.commit() # Add vectors to Qdrant - modality = "mail" if item_type == "MailMessage" else "blog" + modality = cast(str, item.modality) qd.ensure_collection_exists(qdrant, modality, 1024) qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids)) @@ -375,6 +377,7 @@ def test_reingest_item_success(db_session, qdrant, item_type): id=str(uuid.uuid4()), content="New chunk content 1", embedding_model="test-model", + collection_name=modality, vector=[0.1] * 1024, item_metadata={"source_id": item.id, "tags": ["test"]}, ), @@ -382,6 +385,7 @@ def test_reingest_item_success(db_session, qdrant, item_type): id=str(uuid.uuid4()), content="New chunk content 2", embedding_model="test-model", + collection_name=modality, vector=[0.2] * 1024, item_metadata={"source_id": item.id, "tags": ["test"]}, ), @@ -449,6 +453,7 @@ def test_reingest_item_no_chunks(db_session, qdrant): id=str(uuid.uuid4()), content="New chunk content", embedding_model="test-model", + collection_name=item.modality, vector=[0.1] * 1024, item_metadata={"source_id": item.id, "tags": ["test"]}, ), @@ -538,6 +543,7 @@ def test_reingest_empty_source_items_success(db_session, item_type): source=item_with_chunks, content="Test chunk content", embedding_model="test-model", + collection_name=item_with_chunks.modality, ) db_session.add(chunk) db_session.commit()