Add observations model

This commit is contained in:
Daniel O'Connell 2025-05-31 16:15:30 +02:00
parent e505f9b53c
commit 004bd39987
15 changed files with 1262 additions and 379 deletions

View File

@ -12,6 +12,32 @@ from alembic import context
from memory.common import settings
from memory.common.db.models import Base
# Import all models to ensure they're registered with Base.metadata
from memory.common.db.models import (
SourceItem,
Chunk,
MailMessage,
EmailAttachment,
ChatMessage,
BlogPost,
Comic,
BookSection,
ForumPost,
GithubItem,
GitCommit,
Photo,
MiscDoc,
AgentObservation,
ObservationContradiction,
ReactionPattern,
ObservationPattern,
BeliefCluster,
ConversationMetrics,
Book,
ArticleFeed,
EmailAccount,
)
# this is the Alembic Config object
config = context.config

View File

@ -0,0 +1,279 @@
"""Add observation models
Revision ID: 6554eb260176
Revises: 2524646f56f6
Create Date: 2025-05-31 15:49:47.579256
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "6554eb260176"
down_revision: Union[str, None] = "2524646f56f6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"belief_cluster",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("cluster_name", sa.Text(), nullable=False),
sa.Column("core_beliefs", sa.ARRAY(sa.Text()), nullable=False),
sa.Column("peripheral_beliefs", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"internal_consistency", sa.Numeric(precision=3, scale=2), nullable=True
),
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"last_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"cluster_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"belief_cluster_consistency_idx",
"belief_cluster",
["internal_consistency"],
unique=False,
)
op.create_index(
"belief_cluster_name_idx", "belief_cluster", ["cluster_name"], unique=False
)
op.create_table(
"conversation_metrics",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("session_id", sa.UUID(), nullable=False),
sa.Column("depth_score", sa.Numeric(precision=3, scale=2), nullable=True),
sa.Column("breakthrough_count", sa.Integer(), nullable=True),
sa.Column(
"challenge_acceptance", sa.Numeric(precision=3, scale=2), nullable=True
),
sa.Column("new_insights", sa.Integer(), nullable=True),
sa.Column("user_engagement", sa.Numeric(precision=3, scale=2), nullable=True),
sa.Column("duration_minutes", sa.Integer(), nullable=True),
sa.Column("observation_count", sa.Integer(), nullable=True),
sa.Column("contradiction_count", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"metrics_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"conv_metrics_breakthrough_idx",
"conversation_metrics",
["breakthrough_count"],
unique=False,
)
op.create_index(
"conv_metrics_depth_idx", "conversation_metrics", ["depth_score"], unique=False
)
op.create_index(
"conv_metrics_session_idx", "conversation_metrics", ["session_id"], unique=True
)
op.create_table(
"observation_pattern",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("pattern_type", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=False),
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=False),
sa.Column("exceptions", sa.ARRAY(sa.BigInteger()), nullable=True),
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
sa.Column("validity_start", sa.DateTime(timezone=True), nullable=True),
sa.Column("validity_end", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"pattern_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"obs_pattern_confidence_idx",
"observation_pattern",
["confidence"],
unique=False,
)
op.create_index(
"obs_pattern_type_idx", "observation_pattern", ["pattern_type"], unique=False
)
op.create_index(
"obs_pattern_validity_idx",
"observation_pattern",
["validity_start", "validity_end"],
unique=False,
)
op.create_table(
"reaction_pattern",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("trigger_type", sa.Text(), nullable=False),
sa.Column("reaction_type", sa.Text(), nullable=False),
sa.Column("frequency", sa.Numeric(precision=3, scale=2), nullable=False),
sa.Column(
"first_observed",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"last_observed",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("example_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
sa.Column(
"reaction_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"reaction_frequency_idx", "reaction_pattern", ["frequency"], unique=False
)
op.create_index(
"reaction_trigger_idx", "reaction_pattern", ["trigger_type"], unique=False
)
op.create_index(
"reaction_type_idx", "reaction_pattern", ["reaction_type"], unique=False
)
op.create_table(
"agent_observation",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("session_id", sa.UUID(), nullable=True),
sa.Column("observation_type", sa.Text(), nullable=False),
sa.Column("subject", sa.Text(), nullable=False),
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
sa.Column("evidence", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("agent_model", sa.Text(), nullable=False),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"agent_obs_confidence_idx", "agent_observation", ["confidence"], unique=False
)
op.create_index(
"agent_obs_model_idx", "agent_observation", ["agent_model"], unique=False
)
op.create_index(
"agent_obs_session_idx", "agent_observation", ["session_id"], unique=False
)
op.create_index(
"agent_obs_subject_idx", "agent_observation", ["subject"], unique=False
)
op.create_index(
"agent_obs_type_idx", "agent_observation", ["observation_type"], unique=False
)
op.create_table(
"observation_contradiction",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("observation_1_id", sa.BigInteger(), nullable=False),
sa.Column("observation_2_id", sa.BigInteger(), nullable=False),
sa.Column("contradiction_type", sa.Text(), nullable=False),
sa.Column(
"detected_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("detection_method", sa.Text(), nullable=False),
sa.Column("resolution", sa.Text(), nullable=True),
sa.Column(
"observation_metadata",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
sa.ForeignKeyConstraint(
["observation_1_id"], ["agent_observation.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["observation_2_id"], ["agent_observation.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"obs_contra_method_idx",
"observation_contradiction",
["detection_method"],
unique=False,
)
op.create_index(
"obs_contra_obs1_idx",
"observation_contradiction",
["observation_1_id"],
unique=False,
)
op.create_index(
"obs_contra_obs2_idx",
"observation_contradiction",
["observation_2_id"],
unique=False,
)
op.create_index(
"obs_contra_type_idx",
"observation_contradiction",
["contradiction_type"],
unique=False,
)
op.add_column("chunk", sa.Column("collection_name", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("chunk", "collection_name")
op.drop_index("obs_contra_type_idx", table_name="observation_contradiction")
op.drop_index("obs_contra_obs2_idx", table_name="observation_contradiction")
op.drop_index("obs_contra_obs1_idx", table_name="observation_contradiction")
op.drop_index("obs_contra_method_idx", table_name="observation_contradiction")
op.drop_table("observation_contradiction")
op.drop_index("agent_obs_type_idx", table_name="agent_observation")
op.drop_index("agent_obs_subject_idx", table_name="agent_observation")
op.drop_index("agent_obs_session_idx", table_name="agent_observation")
op.drop_index("agent_obs_model_idx", table_name="agent_observation")
op.drop_index("agent_obs_confidence_idx", table_name="agent_observation")
op.drop_table("agent_observation")
op.drop_index("reaction_type_idx", table_name="reaction_pattern")
op.drop_index("reaction_trigger_idx", table_name="reaction_pattern")
op.drop_index("reaction_frequency_idx", table_name="reaction_pattern")
op.drop_table("reaction_pattern")
op.drop_index("obs_pattern_validity_idx", table_name="observation_pattern")
op.drop_index("obs_pattern_type_idx", table_name="observation_pattern")
op.drop_index("obs_pattern_confidence_idx", table_name="observation_pattern")
op.drop_table("observation_pattern")
op.drop_index("conv_metrics_session_idx", table_name="conversation_metrics")
op.drop_index("conv_metrics_depth_idx", table_name="conversation_metrics")
op.drop_index("conv_metrics_breakthrough_idx", table_name="conversation_metrics")
op.drop_table("conversation_metrics")
op.drop_index("belief_cluster_name_idx", table_name="belief_cluster")
op.drop_index("belief_cluster_consistency_idx", table_name="belief_cluster")
op.drop_table("belief_cluster")

View File

@ -2,7 +2,7 @@
Database utilities package.
"""
from memory.common.db.models import Base
from memory.common.db.models.base import Base
from memory.common.db.connection import (
get_engine,
get_session_factory,

View File

@ -0,0 +1,61 @@
from memory.common.db.models.base import Base
from memory.common.db.models.source_item import (
Chunk,
SourceItem,
clean_filename,
)
from memory.common.db.models.source_items import (
MailMessage,
EmailAttachment,
AgentObservation,
ChatMessage,
BlogPost,
Comic,
BookSection,
ForumPost,
GithubItem,
GitCommit,
Photo,
MiscDoc,
)
from memory.common.db.models.observations import (
ObservationContradiction,
ReactionPattern,
ObservationPattern,
BeliefCluster,
ConversationMetrics,
)
from memory.common.db.models.sources import (
Book,
ArticleFeed,
EmailAccount,
)
__all__ = [
"Base",
"Chunk",
"clean_filename",
"SourceItem",
"MailMessage",
"EmailAttachment",
"AgentObservation",
"ChatMessage",
"BlogPost",
"Comic",
"BookSection",
"ForumPost",
"GithubItem",
"GitCommit",
"Photo",
"MiscDoc",
# Observations
"ObservationContradiction",
"ReactionPattern",
"ObservationPattern",
"BeliefCluster",
"ConversationMetrics",
# Sources
"Book",
"ArticleFeed",
"EmailAccount",
]

View File

@ -0,0 +1,4 @@
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()

View File

@ -0,0 +1,171 @@
"""
Agent observation models for the epistemic sparring partner system.
"""
from sqlalchemy import (
ARRAY,
BigInteger,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
Text,
UUID,
func,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
class ObservationContradiction(Base):
"""
Tracks contradictions between observations.
Can be detected automatically or reported by agents.
"""
__tablename__ = "observation_contradiction"
id = Column(BigInteger, primary_key=True)
observation_1_id = Column(
BigInteger,
ForeignKey("agent_observation.id", ondelete="CASCADE"),
nullable=False,
)
observation_2_id = Column(
BigInteger,
ForeignKey("agent_observation.id", ondelete="CASCADE"),
nullable=False,
)
contradiction_type = Column(Text, nullable=False) # direct, implied, temporal
detected_at = Column(DateTime(timezone=True), server_default=func.now())
detection_method = Column(Text, nullable=False) # manual, automatic, agent-reported
resolution = Column(Text) # How it was resolved, if at all
observation_metadata = Column(JSONB)
# Relationships - use string references to avoid circular imports
observation_1 = relationship(
"AgentObservation",
foreign_keys=[observation_1_id],
back_populates="contradictions_as_first",
)
observation_2 = relationship(
"AgentObservation",
foreign_keys=[observation_2_id],
back_populates="contradictions_as_second",
)
__table_args__ = (
Index("obs_contra_obs1_idx", "observation_1_id"),
Index("obs_contra_obs2_idx", "observation_2_id"),
Index("obs_contra_type_idx", "contradiction_type"),
Index("obs_contra_method_idx", "detection_method"),
)
class ReactionPattern(Base):
"""
Tracks patterns in how the user reacts to certain types of observations or challenges.
"""
__tablename__ = "reaction_pattern"
id = Column(BigInteger, primary_key=True)
trigger_type = Column(
Text, nullable=False
) # What kind of observation triggers this
reaction_type = Column(Text, nullable=False) # How user typically responds
frequency = Column(Numeric(3, 2), nullable=False) # How often this pattern appears
first_observed = Column(DateTime(timezone=True), server_default=func.now())
last_observed = Column(DateTime(timezone=True), server_default=func.now())
example_observations = Column(
ARRAY(BigInteger)
) # IDs of observations showing this pattern
reaction_metadata = Column(JSONB)
__table_args__ = (
Index("reaction_trigger_idx", "trigger_type"),
Index("reaction_type_idx", "reaction_type"),
Index("reaction_frequency_idx", "frequency"),
)
class ObservationPattern(Base):
"""
Higher-level patterns detected across multiple observations.
"""
__tablename__ = "observation_pattern"
id = Column(BigInteger, primary_key=True)
pattern_type = Column(Text, nullable=False) # behavioral, cognitive, emotional
description = Column(Text, nullable=False)
supporting_observations = Column(
ARRAY(BigInteger), nullable=False
) # Observation IDs
exceptions = Column(ARRAY(BigInteger)) # Observations that don't fit
confidence = Column(Numeric(3, 2), nullable=False, default=0.7)
validity_start = Column(DateTime(timezone=True))
validity_end = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
pattern_metadata = Column(JSONB)
__table_args__ = (
Index("obs_pattern_type_idx", "pattern_type"),
Index("obs_pattern_confidence_idx", "confidence"),
Index("obs_pattern_validity_idx", "validity_start", "validity_end"),
)
class BeliefCluster(Base):
"""
Groups of related beliefs that support or depend on each other.
"""
__tablename__ = "belief_cluster"
id = Column(BigInteger, primary_key=True)
cluster_name = Column(Text, nullable=False)
core_beliefs = Column(ARRAY(Text), nullable=False)
peripheral_beliefs = Column(ARRAY(Text))
internal_consistency = Column(Numeric(3, 2)) # How well beliefs align
supporting_observations = Column(ARRAY(BigInteger)) # Observation IDs
created_at = Column(DateTime(timezone=True), server_default=func.now())
last_updated = Column(DateTime(timezone=True), server_default=func.now())
cluster_metadata = Column(JSONB)
__table_args__ = (
Index("belief_cluster_name_idx", "cluster_name"),
Index("belief_cluster_consistency_idx", "internal_consistency"),
)
class ConversationMetrics(Base):
"""
Tracks the effectiveness and depth of conversations.
"""
__tablename__ = "conversation_metrics"
id = Column(BigInteger, primary_key=True)
session_id = Column(UUID(as_uuid=True), nullable=False)
depth_score = Column(Numeric(3, 2)) # How deep the conversation went
breakthrough_count = Column(Integer, default=0)
challenge_acceptance = Column(Numeric(3, 2)) # How well challenges were received
new_insights = Column(Integer, default=0)
user_engagement = Column(Numeric(3, 2)) # Inferred engagement level
duration_minutes = Column(Integer)
observation_count = Column(Integer, default=0)
contradiction_count = Column(Integer, default=0)
created_at = Column(DateTime(timezone=True), server_default=func.now())
metrics_metadata = Column(JSONB)
__table_args__ = (
Index("conv_metrics_session_idx", "session_id", unique=True),
Index("conv_metrics_depth_idx", "depth_score"),
Index("conv_metrics_breakthrough_idx", "breakthrough_count"),
)

View File

@ -0,0 +1,287 @@
"""
Database models for the knowledge base system.
"""
import pathlib
import re
from typing import Any, Sequence, cast
import uuid
from PIL import Image
from sqlalchemy import (
ARRAY,
UUID,
BigInteger,
CheckConstraint,
Column,
DateTime,
ForeignKey,
Index,
Integer,
String,
Text,
event,
func,
)
from sqlalchemy.dialects.postgresql import BYTEA
from sqlalchemy.orm import Session, relationship
from memory.common import settings
import memory.common.extract as extract
import memory.common.collections as collections
import memory.common.chunker as chunker
import memory.common.summarizer as summarizer
from memory.common.db.models.base import Base
@event.listens_for(Session, "before_flush")
def handle_duplicate_sha256(session, flush_context, instances):
"""
Event listener that efficiently checks for duplicate sha256 values before flush
and removes items with duplicate sha256 from the session.
Uses a single query to identify all duplicates rather than querying for each item.
"""
# Find all SourceItem objects being added
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
if not new_items:
return
items = {}
for item in new_items:
try:
if (sha256 := item.sha256) is None:
continue
if sha256 in items:
session.expunge(item)
continue
items[sha256] = item
except (AttributeError, TypeError):
continue
if not new_items:
return
# Query database for existing items with these sha256 values in a single query
existing_sha256s = set(
row[0]
for row in session.query(SourceItem.sha256).filter(
SourceItem.sha256.in_(items.keys())
)
)
# Remove objects with duplicate sha256 values from the session
for sha256 in existing_sha256s:
if sha256 in items:
session.expunge(items[sha256])
def clean_filename(filename: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
for i, image in enumerate(images):
if not image.filename: # type: ignore
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
image.save(filename)
image.filename = str(filename) # type: ignore
return [image.filename for image in images] # type: ignore
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
return [chunk] + [
i
for i in images
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
]
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
final = {}
for m in metadata:
if tags := set(m.pop("tags", [])):
final["tags"] = tags | final.get("tags", set())
final |= m
return final
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
if not content.strip():
return []
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
summary, tags = summarizer.summarize(content)
full_text: extract.DataChunk = extract.DataChunk(
data=[content.strip(), *images], metadata={"tags": tags}
)
chunks: list[extract.DataChunk] = [full_text]
tokens = chunker.approx_token_count(content)
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks += [
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
for c in chunker.chunk_text(content)
]
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
return [c for c in chunks if c.data]
class Chunk(Base):
"""Stores content chunks with their vector embeddings."""
__tablename__ = "chunk"
# The ID is also used as the vector ID in the vector database
id = Column(
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
)
source_id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
)
file_paths = Column(
ARRAY(Text), nullable=True
) # Path to content if stored as a file
content = Column(Text) # Direct content storage
embedding_model = Column(Text)
collection_name = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now())
checked_at = Column(DateTime(timezone=True), server_default=func.now())
vector: list[float] = []
item_metadata: dict[str, Any] = {}
images: list[Image.Image] = []
# One of file_path or content must be populated
__table_args__ = (
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
Index("chunk_source_idx", "source_id"),
)
@property
def chunks(self) -> list[extract.MulitmodalChunk]:
chunks: list[extract.MulitmodalChunk] = []
if cast(str | None, self.content):
chunks = [cast(str, self.content)]
if self.images:
chunks += self.images
elif cast(Sequence[str] | None, self.file_paths):
chunks += [
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
]
return chunks
@property
def data(self) -> list[bytes | str | Image.Image]:
if self.file_paths is None:
return [cast(str, self.content)]
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
files = [path for path in paths if path.exists()]
items = []
for file_path in files:
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
if file_path.exists():
items.append(Image.open(file_path))
elif file_path.suffix == ".bin":
items.append(file_path.read_bytes())
else:
items.append(file_path.read_text())
return items
class SourceItem(Base):
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
__tablename__ = "source_item"
__allow_unmapped__ = True
id = Column(BigInteger, primary_key=True)
modality = Column(Text, nullable=False)
sha256 = Column(BYTEA, nullable=False, unique=True)
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
size = Column(Integer)
mime_type = Column(Text)
# Content is stored in the database if it's small enough and text
content = Column(Text)
# Otherwise the content is stored on disk
filename = Column(Text, nullable=True)
# Chunks relationship
embed_status = Column(Text, nullable=False, server_default="RAW")
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
# Discriminator column for SQLAlchemy inheritance
type = Column(String(50))
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
# Add table-level constraint and indexes
__table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
Index("source_modality_idx", "modality"),
Index("source_status_idx", "embed_status"),
Index("source_tags_idx", "tags", postgresql_using="gin"),
Index("source_filename_idx", "filename"),
)
@property
def vector_ids(self):
"""Get vector IDs from associated chunks."""
return [chunk.id for chunk in self.chunks]
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
chunks: list[extract.DataChunk] = []
content = cast(str | None, self.content)
if content:
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
summary, tags = summarizer.summarize(content)
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
mime_type = cast(str | None, self.mime_type)
if mime_type and mime_type.startswith("image/"):
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
return chunks
def _make_chunk(
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
) -> Chunk:
chunk_id = str(uuid.uuid4())
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
images = [c for c in data.data if isinstance(c, Image.Image)]
image_names = image_filenames(chunk_id, images)
chunk = Chunk(
id=chunk_id,
source=self,
content=text or None,
images=images,
file_paths=image_names,
collection_name=data.collection_name or cast(str, self.modality),
embedding_model=collections.collection_model(cast(str, self.modality)),
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
)
return chunk
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
return [self._make_chunk(data) for data in self._chunk_contents()]
def as_payload(self) -> dict:
return {
"source_id": self.id,
"tags": self.tags,
"size": self.size,
}
@property
def display_contents(self) -> str | None:
return cast(str | None, self.content) or cast(str | None, self.filename)

View File

@ -3,18 +3,14 @@ Database models for the knowledge base system.
"""
import pathlib
import re
import textwrap
from datetime import datetime
from typing import Any, Sequence, cast
import uuid
from PIL import Image
from sqlalchemy import (
ARRAY,
UUID,
BigInteger,
Boolean,
CheckConstraint,
Column,
DateTime,
@ -22,273 +18,23 @@ from sqlalchemy import (
Index,
Integer,
Numeric,
String,
Text,
event,
func,
)
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR, UUID
from sqlalchemy.orm import relationship
from memory.common import settings
import memory.common.extract as extract
import memory.common.collections as collections
import memory.common.chunker as chunker
import memory.common.summarizer as summarizer
Base = declarative_base()
@event.listens_for(Session, "before_flush")
def handle_duplicate_sha256(session, flush_context, instances):
"""
Event listener that efficiently checks for duplicate sha256 values before flush
and removes items with duplicate sha256 from the session.
Uses a single query to identify all duplicates rather than querying for each item.
"""
# Find all SourceItem objects being added
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
if not new_items:
return
items = {}
for item in new_items:
try:
if (sha256 := item.sha256) is None:
continue
if sha256 in items:
session.expunge(item)
continue
items[sha256] = item
except (AttributeError, TypeError):
continue
if not new_items:
return
# Query database for existing items with these sha256 values in a single query
existing_sha256s = set(
row[0]
for row in session.query(SourceItem.sha256).filter(
SourceItem.sha256.in_(items.keys())
)
)
# Remove objects with duplicate sha256 values from the session
for sha256 in existing_sha256s:
if sha256 in items:
session.expunge(items[sha256])
def clean_filename(filename: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
for i, image in enumerate(images):
if not image.filename: # type: ignore
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
image.save(filename)
image.filename = str(filename) # type: ignore
return [image.filename for image in images] # type: ignore
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
return [chunk] + [
i
for i in images
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
]
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
final = {}
for m in metadata:
if tags := set(m.pop("tags", [])):
final["tags"] = tags | final.get("tags", set())
final |= m
return final
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
if not content.strip():
return []
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
summary, tags = summarizer.summarize(content)
full_text: extract.DataChunk = extract.DataChunk(
data=[content.strip(), *images], metadata={"tags": tags}
)
chunks: list[extract.DataChunk] = [full_text]
tokens = chunker.approx_token_count(content)
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks += [
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
for c in chunker.chunk_text(content)
]
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
return [c for c in chunks if c.data]
class Chunk(Base):
"""Stores content chunks with their vector embeddings."""
__tablename__ = "chunk"
# The ID is also used as the vector ID in the vector database
id = Column(
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
)
source_id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
)
file_paths = Column(
ARRAY(Text), nullable=True
) # Path to content if stored as a file
content = Column(Text) # Direct content storage
embedding_model = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now())
checked_at = Column(DateTime(timezone=True), server_default=func.now())
vector: list[float] = []
item_metadata: dict[str, Any] = {}
images: list[Image.Image] = []
# One of file_path or content must be populated
__table_args__ = (
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
Index("chunk_source_idx", "source_id"),
)
@property
def chunks(self) -> list[extract.MulitmodalChunk]:
chunks: list[extract.MulitmodalChunk] = []
if cast(str | None, self.content):
chunks = [cast(str, self.content)]
if self.images:
chunks += self.images
elif cast(Sequence[str] | None, self.file_paths):
chunks += [
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
]
return chunks
@property
def data(self) -> list[bytes | str | Image.Image]:
if self.file_paths is None:
return [cast(str, self.content)]
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
files = [path for path in paths if path.exists()]
items = []
for file_path in files:
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
if file_path.exists():
items.append(Image.open(file_path))
elif file_path.suffix == ".bin":
items.append(file_path.read_bytes())
else:
items.append(file_path.read_text())
return items
class SourceItem(Base):
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
__tablename__ = "source_item"
__allow_unmapped__ = True
id = Column(BigInteger, primary_key=True)
modality = Column(Text, nullable=False)
sha256 = Column(BYTEA, nullable=False, unique=True)
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
size = Column(Integer)
mime_type = Column(Text)
# Content is stored in the database if it's small enough and text
content = Column(Text)
# Otherwise the content is stored on disk
filename = Column(Text, nullable=True)
# Chunks relationship
embed_status = Column(Text, nullable=False, server_default="RAW")
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
# Discriminator column for SQLAlchemy inheritance
type = Column(String(50))
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
# Add table-level constraint and indexes
__table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
Index("source_modality_idx", "modality"),
Index("source_status_idx", "embed_status"),
Index("source_tags_idx", "tags", postgresql_using="gin"),
Index("source_filename_idx", "filename"),
)
@property
def vector_ids(self):
"""Get vector IDs from associated chunks."""
return [chunk.id for chunk in self.chunks]
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
chunks: list[extract.DataChunk] = []
content = cast(str | None, self.content)
if content:
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
summary, tags = summarizer.summarize(content)
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
mime_type = cast(str | None, self.mime_type)
if mime_type and mime_type.startswith("image/"):
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
return chunks
def _make_chunk(
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
) -> Chunk:
chunk_id = str(uuid.uuid4())
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
images = [c for c in data.data if isinstance(c, Image.Image)]
image_names = image_filenames(chunk_id, images)
chunk = Chunk(
id=chunk_id,
source=self,
content=text or None,
images=images,
file_paths=image_names,
embedding_model=collections.collection_model(cast(str, self.modality)),
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
)
return chunk
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
return [self._make_chunk(data) for data in self._chunk_contents()]
def as_payload(self) -> dict:
return {
"source_id": self.id,
"tags": self.tags,
"size": self.size,
}
@property
def display_contents(self) -> str | None:
return cast(str | None, self.content) or cast(str | None, self.filename)
from memory.common.db.models.source_item import (
SourceItem,
Chunk,
clean_filename,
merge_metadata,
chunk_mixed,
)
class MailMessage(SourceItem):
@ -528,51 +274,6 @@ class Comic(SourceItem):
return [extract.DataChunk(data=[image, description])]
class Book(Base):
"""Book-level metadata table"""
__tablename__ = "book"
id = Column(BigInteger, primary_key=True)
isbn = Column(Text, unique=True)
title = Column(Text, nullable=False)
author = Column(Text)
publisher = Column(Text)
published = Column(DateTime(timezone=True))
language = Column(Text)
edition = Column(Text)
series = Column(Text)
series_number = Column(Integer)
total_pages = Column(Integer)
file_path = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
# Metadata from ebook parser
book_metadata = Column(JSONB, name="metadata")
created_at = Column(DateTime(timezone=True), server_default=func.now())
__table_args__ = (
Index("book_isbn_idx", "isbn"),
Index("book_author_idx", "author"),
Index("book_title_idx", "title"),
)
def as_payload(self) -> dict:
return {
**super().as_payload(),
"isbn": self.isbn,
"title": self.title,
"author": self.author,
"publisher": self.publisher,
"published": self.published,
"language": self.language,
"edition": self.edition,
"series": self.series,
"series_number": self.series_number,
} | (cast(dict, self.book_metadata) or {})
class BookSection(SourceItem):
"""Individual sections/chapters of books"""
@ -797,58 +498,73 @@ class GithubItem(SourceItem):
)
class ArticleFeed(Base):
__tablename__ = "article_feeds"
class AgentObservation(SourceItem):
"""
Records observations made by AI agents about the user.
This is the primary data model for the epistemic sparring partner.
"""
id = Column(BigInteger, primary_key=True)
url = Column(Text, nullable=False, unique=True)
title = Column(Text)
description = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
check_interval = Column(
Integer, nullable=False, server_default="60", doc="Minutes between checks"
__tablename__ = "agent_observation"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
last_checked_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
session_id = Column(
UUID(as_uuid=True)
) # Groups observations from same conversation
observation_type = Column(
Text, nullable=False
) # belief, preference, pattern, contradiction, behavior
subject = Column(Text, nullable=False) # What/who the observation is about
confidence = Column(Numeric(3, 2), nullable=False, default=0.8) # 0.0-1.0
evidence = Column(JSONB) # Supporting context, quotes, etc.
agent_model = Column(Text, nullable=False) # Which AI model made this observation
# Relationships
contradictions_as_first = relationship(
"ObservationContradiction",
foreign_keys="ObservationContradiction.observation_1_id",
back_populates="observation_1",
cascade="all, delete-orphan",
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
contradictions_as_second = relationship(
"ObservationContradiction",
foreign_keys="ObservationContradiction.observation_2_id",
back_populates="observation_2",
cascade="all, delete-orphan",
)
# Add indexes
__mapper_args__ = {
"polymorphic_identity": "agent_observation",
}
__table_args__ = (
Index("article_feeds_active_idx", "active", "last_checked_at"),
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
Index("agent_obs_session_idx", "session_id"),
Index("agent_obs_type_idx", "observation_type"),
Index("agent_obs_subject_idx", "subject"),
Index("agent_obs_confidence_idx", "confidence"),
Index("agent_obs_model_idx", "agent_model"),
)
def __init__(self, **kwargs):
if not kwargs.get("modality"):
kwargs["modality"] = "observation"
super().__init__(**kwargs)
class EmailAccount(Base):
__tablename__ = "email_accounts"
def as_payload(self) -> dict:
payload = {
**super().as_payload(),
"observation_type": self.observation_type,
"subject": self.subject,
"confidence": float(cast(Any, self.confidence)),
"evidence": self.evidence,
"agent_model": self.agent_model,
}
if self.session_id is not None:
payload["session_id"] = str(self.session_id)
return payload
id = Column(BigInteger, primary_key=True)
name = Column(Text, nullable=False)
email_address = Column(Text, nullable=False, unique=True)
imap_server = Column(Text, nullable=False)
imap_port = Column(Integer, nullable=False, server_default="993")
username = Column(Text, nullable=False)
password = Column(Text, nullable=False)
use_ssl = Column(Boolean, nullable=False, server_default="true")
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
last_sync_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
# Add indexes
__table_args__ = (
Index("email_accounts_address_idx", "email_address", unique=True),
Index("email_accounts_active_idx", "active", "last_sync_at"),
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
)
@property
def all_contradictions(self):
"""Get all contradictions involving this observation."""
return self.contradictions_as_first + self.contradictions_as_second

View File

@ -0,0 +1,122 @@
"""
Database models for the knowledge base system.
"""
from typing import cast
from sqlalchemy import (
ARRAY,
BigInteger,
Boolean,
Column,
DateTime,
Index,
Integer,
Text,
func,
)
from sqlalchemy.dialects.postgresql import JSONB
from memory.common.db.models.base import Base
class Book(Base):
"""Book-level metadata table"""
__tablename__ = "book"
id = Column(BigInteger, primary_key=True)
isbn = Column(Text, unique=True)
title = Column(Text, nullable=False)
author = Column(Text)
publisher = Column(Text)
published = Column(DateTime(timezone=True))
language = Column(Text)
edition = Column(Text)
series = Column(Text)
series_number = Column(Integer)
total_pages = Column(Integer)
file_path = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
# Metadata from ebook parser
book_metadata = Column(JSONB, name="metadata")
created_at = Column(DateTime(timezone=True), server_default=func.now())
__table_args__ = (
Index("book_isbn_idx", "isbn"),
Index("book_author_idx", "author"),
Index("book_title_idx", "title"),
)
def as_payload(self) -> dict:
return {
**super().as_payload(),
"isbn": self.isbn,
"title": self.title,
"author": self.author,
"publisher": self.publisher,
"published": self.published,
"language": self.language,
"edition": self.edition,
"series": self.series,
"series_number": self.series_number,
} | (cast(dict, self.book_metadata) or {})
class ArticleFeed(Base):
__tablename__ = "article_feeds"
id = Column(BigInteger, primary_key=True)
url = Column(Text, nullable=False, unique=True)
title = Column(Text)
description = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
check_interval = Column(
Integer, nullable=False, server_default="60", doc="Minutes between checks"
)
last_checked_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
# Add indexes
__table_args__ = (
Index("article_feeds_active_idx", "active", "last_checked_at"),
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
)
class EmailAccount(Base):
__tablename__ = "email_accounts"
id = Column(BigInteger, primary_key=True)
name = Column(Text, nullable=False)
email_address = Column(Text, nullable=False, unique=True)
imap_server = Column(Text, nullable=False)
imap_port = Column(Integer, nullable=False, server_default="993")
username = Column(Text, nullable=False)
password = Column(Text, nullable=False)
use_ssl = Column(Boolean, nullable=False, server_default="true")
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
last_sync_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
# Add indexes
__table_args__ = (
Index("email_accounts_address_idx", "email_address", unique=True),
Index("email_accounts_active_idx", "active", "last_sync_at"),
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
)

View File

@ -21,6 +21,7 @@ class DataChunk:
data: Sequence[MulitmodalChunk]
metadata: dict[str, Any] = field(default_factory=dict)
mime_type: str = "text/plain"
collection_name: str | None = None
@contextmanager

View File

@ -6,13 +6,14 @@ the complete workflow: existence checking, content hashing, embedding generation
vector storage, and result tracking.
"""
from collections import defaultdict
import hashlib
import traceback
import logging
from typing import Any, Callable, Iterable, Sequence, cast
from memory.common import embedding, qdrant
from memory.common.db.models import SourceItem
from memory.common.db.models import SourceItem, Chunk
logger = logging.getLogger(__name__)
@ -103,7 +104,19 @@ def embed_source_item(source_item: SourceItem) -> int:
return 0
def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
def by_collection(chunks: Sequence[Chunk]) -> dict[str, dict[str, Any]]:
collections: dict[str, dict[str, list[Any]]] = defaultdict(
lambda: defaultdict(list)
)
for chunk in chunks:
collection = collections[cast(str, chunk.collection_name)]
collection["ids"].append(chunk.id)
collection["vectors"].append(chunk.vector)
collection["payloads"].append(chunk.item_metadata)
return collections
def push_to_qdrant(source_items: Sequence[SourceItem]):
"""
Push embeddings to Qdrant vector database.
@ -135,17 +148,16 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
return
try:
vector_ids = [str(chunk.id) for chunk in all_chunks]
vectors = [chunk.vector for chunk in all_chunks]
payloads = [chunk.item_metadata for chunk in all_chunks]
qdrant.upsert_vectors(
client=qdrant.get_qdrant_client(),
collection_name=collection_name,
ids=vector_ids,
vectors=vectors,
payloads=payloads,
)
client = qdrant.get_qdrant_client()
collections = by_collection(all_chunks)
for collection_name, collection in collections.items():
qdrant.upsert_vectors(
client=client,
collection_name=collection_name,
ids=collection["ids"],
vectors=collection["vectors"],
payloads=collection["payloads"],
)
for item in items_to_process:
item.embed_status = "STORED" # type: ignore
@ -222,7 +234,7 @@ def process_content_item(item: SourceItem, session) -> dict[str, Any]:
return create_task_result(item, status, content_length=getattr(item, "size", 0))
try:
push_to_qdrant([item], cast(str, item.modality))
push_to_qdrant([item])
status = "processed"
item.embed_status = "STORED" # type: ignore
logger.info(

View File

@ -185,7 +185,7 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
f"Embedded section: {section.section_title} - {section.content[:100]}"
)
logger.info("Pushing to Qdrant")
push_to_qdrant(all_sections, "book")
push_to_qdrant(all_sections)
logger.info("Committing session")
session.commit()

View File

@ -1,4 +1,3 @@
from memory.common.db.models import SourceItem
from sqlalchemy.orm import Session
from unittest.mock import patch, Mock
from typing import cast
@ -6,17 +5,20 @@ import pytest
from PIL import Image
from datetime import datetime
from memory.common import settings, chunker, extract
from memory.common.db.models import (
from memory.common.db.models.sources import Book
from memory.common.db.models.source_items import (
Chunk,
clean_filename,
image_filenames,
add_pics,
MailMessage,
EmailAttachment,
BookSection,
BlogPost,
Book,
)
from memory.common.db.models.source_item import (
SourceItem,
image_filenames,
add_pics,
merge_metadata,
clean_filename,
)

View File

@ -18,6 +18,7 @@ from memory.workers.tasks.content_processing import (
process_content_item,
push_to_qdrant,
safe_task_execution,
by_collection,
)
@ -71,6 +72,7 @@ def mock_chunk():
chunk.id = "00000000-0000-0000-0000-000000000001"
chunk.vector = [0.1] * 1024
chunk.item_metadata = {"source_id": 1, "tags": ["test"]}
chunk.collection_name = "mail"
return chunk
@ -242,17 +244,19 @@ def test_push_to_qdrant_success(qdrant):
mock_chunk1.id = "00000000-0000-0000-0000-000000000001"
mock_chunk1.vector = [0.1] * 1024
mock_chunk1.item_metadata = {"source_id": 1, "tags": ["test"]}
mock_chunk1.collection_name = "mail"
mock_chunk2 = MagicMock()
mock_chunk2.id = "00000000-0000-0000-0000-000000000002"
mock_chunk2.vector = [0.2] * 1024
mock_chunk2.item_metadata = {"source_id": 2, "tags": ["test"]}
mock_chunk2.collection_name = "mail"
# Assign chunks directly (bypassing SQLAlchemy relationship)
item1.chunks = [mock_chunk1]
item2.chunks = [mock_chunk2]
push_to_qdrant([item1, item2], "mail")
push_to_qdrant([item1, item2])
assert str(item1.embed_status) == "STORED"
assert str(item2.embed_status) == "STORED"
@ -294,6 +298,7 @@ def test_push_to_qdrant_no_processing(
mock_chunk.id = f"00000000-0000-0000-0000-00000000000{suffix}"
mock_chunk.vector = [0.1] * 1024
mock_chunk.item_metadata = {"source_id": int(suffix), "tags": ["test"]}
mock_chunk.collection_name = "mail"
item.chunks = [mock_chunk]
else:
item.chunks = []
@ -302,7 +307,7 @@ def test_push_to_qdrant_no_processing(
item1 = create_item("1", item1_status, item1_has_chunks)
item2 = create_item("2", item2_status, item2_has_chunks)
push_to_qdrant([item1, item2], "mail")
push_to_qdrant([item1, item2])
assert str(item1.embed_status) == expected_item1_status
assert str(item2.embed_status) == expected_item2_status
@ -317,7 +322,7 @@ def test_push_to_qdrant_exception(sample_mail_message, mock_chunk):
side_effect=Exception("Qdrant error"),
):
with pytest.raises(Exception, match="Qdrant error"):
push_to_qdrant([sample_mail_message], "mail")
push_to_qdrant([sample_mail_message])
assert str(sample_mail_message.embed_status) == "FAILED"
@ -418,6 +423,196 @@ def test_create_task_result_no_title():
assert result["chunks_count"] == 0
def test_by_collection_empty_chunks():
result = by_collection([])
assert result == {}
def test_by_collection_single_chunk():
chunk = Chunk(
id="00000000-0000-0000-0000-000000000001",
content="test content",
embedding_model="test-model",
vector=[0.1, 0.2, 0.3],
item_metadata={"source_id": 1, "tags": ["test"]},
collection_name="test_collection",
)
result = by_collection([chunk])
assert len(result) == 1
assert "test_collection" in result
assert result["test_collection"]["ids"] == ["00000000-0000-0000-0000-000000000001"]
assert result["test_collection"]["vectors"] == [[0.1, 0.2, 0.3]]
assert result["test_collection"]["payloads"] == [{"source_id": 1, "tags": ["test"]}]
def test_by_collection_multiple_chunks_same_collection():
chunks = [
Chunk(
id="00000000-0000-0000-0000-000000000001",
content="test content 1",
embedding_model="test-model",
vector=[0.1, 0.2],
item_metadata={"source_id": 1},
collection_name="collection_a",
),
Chunk(
id="00000000-0000-0000-0000-000000000002",
content="test content 2",
embedding_model="test-model",
vector=[0.3, 0.4],
item_metadata={"source_id": 2},
collection_name="collection_a",
),
]
result = by_collection(chunks)
assert len(result) == 1
assert "collection_a" in result
assert result["collection_a"]["ids"] == [
"00000000-0000-0000-0000-000000000001",
"00000000-0000-0000-0000-000000000002",
]
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.3, 0.4]]
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 2}]
def test_by_collection_multiple_chunks_different_collections():
chunks = [
Chunk(
id="00000000-0000-0000-0000-000000000001",
content="test content 1",
embedding_model="test-model",
vector=[0.1, 0.2],
item_metadata={"source_id": 1},
collection_name="collection_a",
),
Chunk(
id="00000000-0000-0000-0000-000000000002",
content="test content 2",
embedding_model="test-model",
vector=[0.3, 0.4],
item_metadata={"source_id": 2},
collection_name="collection_b",
),
Chunk(
id="00000000-0000-0000-0000-000000000003",
content="test content 3",
embedding_model="test-model",
vector=[0.5, 0.6],
item_metadata={"source_id": 3},
collection_name="collection_a",
),
]
result = by_collection(chunks)
assert len(result) == 2
assert "collection_a" in result
assert "collection_b" in result
# Check collection_a
assert result["collection_a"]["ids"] == [
"00000000-0000-0000-0000-000000000001",
"00000000-0000-0000-0000-000000000003",
]
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.5, 0.6]]
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 3}]
# Check collection_b
assert result["collection_b"]["ids"] == ["00000000-0000-0000-0000-000000000002"]
assert result["collection_b"]["vectors"] == [[0.3, 0.4]]
assert result["collection_b"]["payloads"] == [{"source_id": 2}]
@pytest.mark.parametrize(
"collection_names,expected_collections",
[
(["col1", "col1", "col1"], 1),
(["col1", "col2", "col3"], 3),
(["col1", "col2", "col1", "col2"], 2),
(["single"], 1),
],
)
def test_by_collection_various_groupings(collection_names, expected_collections):
chunks = [
Chunk(
id=f"00000000-0000-0000-0000-00000000000{i}",
content=f"test content {i}",
embedding_model="test-model",
vector=[float(i)],
item_metadata={"index": i},
collection_name=collection_name,
)
for i, collection_name in enumerate(collection_names, 1)
]
result = by_collection(chunks)
assert len(result) == expected_collections
# Verify all chunks are accounted for
total_chunks = sum(len(coll["ids"]) for coll in result.values())
assert total_chunks == len(chunks)
def test_by_collection_with_none_values():
chunks = [
Chunk(
id="00000000-0000-0000-0000-000000000001",
content="test content",
embedding_model="test-model",
vector=None, # None vector
item_metadata=None, # None metadata
collection_name="test_collection",
),
Chunk(
id="00000000-0000-0000-0000-000000000002",
content="test content 2",
embedding_model="test-model",
vector=[0.1, 0.2],
item_metadata={"key": "value"},
collection_name="test_collection",
),
]
result = by_collection(chunks)
assert len(result) == 1
assert "test_collection" in result
assert result["test_collection"]["ids"] == [
"00000000-0000-0000-0000-000000000001",
"00000000-0000-0000-0000-000000000002",
]
assert result["test_collection"]["vectors"] == [None, [0.1, 0.2]]
assert result["test_collection"]["payloads"] == [None, {"key": "value"}]
def test_by_collection_preserves_order():
chunks = []
for i in range(5):
chunks.append(
Chunk(
id=f"00000000-0000-0000-0000-00000000000{i}",
content=f"test content {i}",
embedding_model="test-model",
vector=[float(i)],
item_metadata={"order": i},
collection_name="ordered_collection",
)
)
result = by_collection(chunks)
assert len(result) == 1
assert result["ordered_collection"]["ids"] == [
f"00000000-0000-0000-0000-00000000000{i}" for i in range(5)
]
assert result["ordered_collection"]["vectors"] == [[float(i)] for i in range(5)]
assert result["ordered_collection"]["payloads"] == [{"order": i} for i in range(5)]
@pytest.mark.parametrize(
"embedding_return,qdrant_error,expected_status,expected_embed_status",
[
@ -459,6 +654,7 @@ def test_process_content_item(
embedding_model="test-model",
vector=[0.1] * 1024,
item_metadata={"source_id": 1, "tags": ["test"]},
collection_name="mail",
)
mock_chunks = [real_chunk]
else: # empty

View File

@ -2,6 +2,7 @@
import uuid
from datetime import datetime, timedelta
from unittest.mock import patch, call
from typing import cast
import pytest
from PIL import Image
@ -350,6 +351,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
id=chunk_id,
source=item,
content=f"Test chunk content {i}",
collection_name=item.modality,
embedding_model="test-model",
)
for i, chunk_id in enumerate(chunk_ids)
@ -358,7 +360,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
db_session.commit()
# Add vectors to Qdrant
modality = "mail" if item_type == "MailMessage" else "blog"
modality = cast(str, item.modality)
qd.ensure_collection_exists(qdrant, modality, 1024)
qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids))
@ -375,6 +377,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
id=str(uuid.uuid4()),
content="New chunk content 1",
embedding_model="test-model",
collection_name=modality,
vector=[0.1] * 1024,
item_metadata={"source_id": item.id, "tags": ["test"]},
),
@ -382,6 +385,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
id=str(uuid.uuid4()),
content="New chunk content 2",
embedding_model="test-model",
collection_name=modality,
vector=[0.2] * 1024,
item_metadata={"source_id": item.id, "tags": ["test"]},
),
@ -449,6 +453,7 @@ def test_reingest_item_no_chunks(db_session, qdrant):
id=str(uuid.uuid4()),
content="New chunk content",
embedding_model="test-model",
collection_name=item.modality,
vector=[0.1] * 1024,
item_metadata={"source_id": item.id, "tags": ["test"]},
),
@ -538,6 +543,7 @@ def test_reingest_empty_source_items_success(db_session, item_type):
source=item_with_chunks,
content="Test chunk content",
embedding_model="test-model",
collection_name=item_with_chunks.modality,
)
db_session.add(chunk)
db_session.commit()