mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
Add observations model
This commit is contained in:
parent
e505f9b53c
commit
004bd39987
@ -12,6 +12,32 @@ from alembic import context
|
||||
from memory.common import settings
|
||||
from memory.common.db.models import Base
|
||||
|
||||
# Import all models to ensure they're registered with Base.metadata
|
||||
from memory.common.db.models import (
|
||||
SourceItem,
|
||||
Chunk,
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
ChatMessage,
|
||||
BlogPost,
|
||||
Comic,
|
||||
BookSection,
|
||||
ForumPost,
|
||||
GithubItem,
|
||||
GitCommit,
|
||||
Photo,
|
||||
MiscDoc,
|
||||
AgentObservation,
|
||||
ObservationContradiction,
|
||||
ReactionPattern,
|
||||
ObservationPattern,
|
||||
BeliefCluster,
|
||||
ConversationMetrics,
|
||||
Book,
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
)
|
||||
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
279
db/migrations/versions/20250531_154947_add_observation_models.py
Normal file
279
db/migrations/versions/20250531_154947_add_observation_models.py
Normal file
@ -0,0 +1,279 @@
|
||||
"""Add observation models
|
||||
|
||||
Revision ID: 6554eb260176
|
||||
Revises: 2524646f56f6
|
||||
Create Date: 2025-05-31 15:49:47.579256
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6554eb260176"
|
||||
down_revision: Union[str, None] = "2524646f56f6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"belief_cluster",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("cluster_name", sa.Text(), nullable=False),
|
||||
sa.Column("core_beliefs", sa.ARRAY(sa.Text()), nullable=False),
|
||||
sa.Column("peripheral_beliefs", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"internal_consistency", sa.Numeric(precision=3, scale=2), nullable=True
|
||||
),
|
||||
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"last_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"cluster_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"belief_cluster_consistency_idx",
|
||||
"belief_cluster",
|
||||
["internal_consistency"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"belief_cluster_name_idx", "belief_cluster", ["cluster_name"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"conversation_metrics",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("session_id", sa.UUID(), nullable=False),
|
||||
sa.Column("depth_score", sa.Numeric(precision=3, scale=2), nullable=True),
|
||||
sa.Column("breakthrough_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"challenge_acceptance", sa.Numeric(precision=3, scale=2), nullable=True
|
||||
),
|
||||
sa.Column("new_insights", sa.Integer(), nullable=True),
|
||||
sa.Column("user_engagement", sa.Numeric(precision=3, scale=2), nullable=True),
|
||||
sa.Column("duration_minutes", sa.Integer(), nullable=True),
|
||||
sa.Column("observation_count", sa.Integer(), nullable=True),
|
||||
sa.Column("contradiction_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"metrics_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"conv_metrics_breakthrough_idx",
|
||||
"conversation_metrics",
|
||||
["breakthrough_count"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"conv_metrics_depth_idx", "conversation_metrics", ["depth_score"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"conv_metrics_session_idx", "conversation_metrics", ["session_id"], unique=True
|
||||
)
|
||||
op.create_table(
|
||||
"observation_pattern",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("pattern_type", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=False),
|
||||
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=False),
|
||||
sa.Column("exceptions", sa.ARRAY(sa.BigInteger()), nullable=True),
|
||||
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||
sa.Column("validity_start", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("validity_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"pattern_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"obs_pattern_confidence_idx",
|
||||
"observation_pattern",
|
||||
["confidence"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"obs_pattern_type_idx", "observation_pattern", ["pattern_type"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"obs_pattern_validity_idx",
|
||||
"observation_pattern",
|
||||
["validity_start", "validity_end"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"reaction_pattern",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("trigger_type", sa.Text(), nullable=False),
|
||||
sa.Column("reaction_type", sa.Text(), nullable=False),
|
||||
sa.Column("frequency", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||
sa.Column(
|
||||
"first_observed",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"last_observed",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("example_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
|
||||
sa.Column(
|
||||
"reaction_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"reaction_frequency_idx", "reaction_pattern", ["frequency"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"reaction_trigger_idx", "reaction_pattern", ["trigger_type"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"reaction_type_idx", "reaction_pattern", ["reaction_type"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"agent_observation",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("session_id", sa.UUID(), nullable=True),
|
||||
sa.Column("observation_type", sa.Text(), nullable=False),
|
||||
sa.Column("subject", sa.Text(), nullable=False),
|
||||
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||
sa.Column("evidence", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("agent_model", sa.Text(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_confidence_idx", "agent_observation", ["confidence"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_model_idx", "agent_observation", ["agent_model"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_session_idx", "agent_observation", ["session_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_subject_idx", "agent_observation", ["subject"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_type_idx", "agent_observation", ["observation_type"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"observation_contradiction",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("observation_1_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("observation_2_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("contradiction_type", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"detected_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("detection_method", sa.Text(), nullable=False),
|
||||
sa.Column("resolution", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"observation_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["observation_1_id"], ["agent_observation.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["observation_2_id"], ["agent_observation.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"obs_contra_method_idx",
|
||||
"observation_contradiction",
|
||||
["detection_method"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"obs_contra_obs1_idx",
|
||||
"observation_contradiction",
|
||||
["observation_1_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"obs_contra_obs2_idx",
|
||||
"observation_contradiction",
|
||||
["observation_2_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"obs_contra_type_idx",
|
||||
"observation_contradiction",
|
||||
["contradiction_type"],
|
||||
unique=False,
|
||||
)
|
||||
op.add_column("chunk", sa.Column("collection_name", sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chunk", "collection_name")
|
||||
op.drop_index("obs_contra_type_idx", table_name="observation_contradiction")
|
||||
op.drop_index("obs_contra_obs2_idx", table_name="observation_contradiction")
|
||||
op.drop_index("obs_contra_obs1_idx", table_name="observation_contradiction")
|
||||
op.drop_index("obs_contra_method_idx", table_name="observation_contradiction")
|
||||
op.drop_table("observation_contradiction")
|
||||
op.drop_index("agent_obs_type_idx", table_name="agent_observation")
|
||||
op.drop_index("agent_obs_subject_idx", table_name="agent_observation")
|
||||
op.drop_index("agent_obs_session_idx", table_name="agent_observation")
|
||||
op.drop_index("agent_obs_model_idx", table_name="agent_observation")
|
||||
op.drop_index("agent_obs_confidence_idx", table_name="agent_observation")
|
||||
op.drop_table("agent_observation")
|
||||
op.drop_index("reaction_type_idx", table_name="reaction_pattern")
|
||||
op.drop_index("reaction_trigger_idx", table_name="reaction_pattern")
|
||||
op.drop_index("reaction_frequency_idx", table_name="reaction_pattern")
|
||||
op.drop_table("reaction_pattern")
|
||||
op.drop_index("obs_pattern_validity_idx", table_name="observation_pattern")
|
||||
op.drop_index("obs_pattern_type_idx", table_name="observation_pattern")
|
||||
op.drop_index("obs_pattern_confidence_idx", table_name="observation_pattern")
|
||||
op.drop_table("observation_pattern")
|
||||
op.drop_index("conv_metrics_session_idx", table_name="conversation_metrics")
|
||||
op.drop_index("conv_metrics_depth_idx", table_name="conversation_metrics")
|
||||
op.drop_index("conv_metrics_breakthrough_idx", table_name="conversation_metrics")
|
||||
op.drop_table("conversation_metrics")
|
||||
op.drop_index("belief_cluster_name_idx", table_name="belief_cluster")
|
||||
op.drop_index("belief_cluster_consistency_idx", table_name="belief_cluster")
|
||||
op.drop_table("belief_cluster")
|
@ -2,7 +2,7 @@
|
||||
Database utilities package.
|
||||
"""
|
||||
|
||||
from memory.common.db.models import Base
|
||||
from memory.common.db.models.base import Base
|
||||
from memory.common.db.connection import (
|
||||
get_engine,
|
||||
get_session_factory,
|
||||
|
61
src/memory/common/db/models/__init__.py
Normal file
61
src/memory/common/db/models/__init__.py
Normal file
@ -0,0 +1,61 @@
|
||||
from memory.common.db.models.base import Base
|
||||
from memory.common.db.models.source_item import (
|
||||
Chunk,
|
||||
SourceItem,
|
||||
clean_filename,
|
||||
)
|
||||
from memory.common.db.models.source_items import (
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
AgentObservation,
|
||||
ChatMessage,
|
||||
BlogPost,
|
||||
Comic,
|
||||
BookSection,
|
||||
ForumPost,
|
||||
GithubItem,
|
||||
GitCommit,
|
||||
Photo,
|
||||
MiscDoc,
|
||||
)
|
||||
from memory.common.db.models.observations import (
|
||||
ObservationContradiction,
|
||||
ReactionPattern,
|
||||
ObservationPattern,
|
||||
BeliefCluster,
|
||||
ConversationMetrics,
|
||||
)
|
||||
from memory.common.db.models.sources import (
|
||||
Book,
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"Chunk",
|
||||
"clean_filename",
|
||||
"SourceItem",
|
||||
"MailMessage",
|
||||
"EmailAttachment",
|
||||
"AgentObservation",
|
||||
"ChatMessage",
|
||||
"BlogPost",
|
||||
"Comic",
|
||||
"BookSection",
|
||||
"ForumPost",
|
||||
"GithubItem",
|
||||
"GitCommit",
|
||||
"Photo",
|
||||
"MiscDoc",
|
||||
# Observations
|
||||
"ObservationContradiction",
|
||||
"ReactionPattern",
|
||||
"ObservationPattern",
|
||||
"BeliefCluster",
|
||||
"ConversationMetrics",
|
||||
# Sources
|
||||
"Book",
|
||||
"ArticleFeed",
|
||||
"EmailAccount",
|
||||
]
|
4
src/memory/common/db/models/base.py
Normal file
4
src/memory/common/db/models/base.py
Normal file
@ -0,0 +1,4 @@
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
|
||||
Base = declarative_base()
|
171
src/memory/common/db/models/observations.py
Normal file
171
src/memory/common/db/models/observations.py
Normal file
@ -0,0 +1,171 @@
|
||||
"""
|
||||
Agent observation models for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
Text,
|
||||
UUID,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class ObservationContradiction(Base):
|
||||
"""
|
||||
Tracks contradictions between observations.
|
||||
Can be detected automatically or reported by agents.
|
||||
"""
|
||||
|
||||
__tablename__ = "observation_contradiction"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
observation_1_id = Column(
|
||||
BigInteger,
|
||||
ForeignKey("agent_observation.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
observation_2_id = Column(
|
||||
BigInteger,
|
||||
ForeignKey("agent_observation.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
contradiction_type = Column(Text, nullable=False) # direct, implied, temporal
|
||||
detected_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
detection_method = Column(Text, nullable=False) # manual, automatic, agent-reported
|
||||
resolution = Column(Text) # How it was resolved, if at all
|
||||
observation_metadata = Column(JSONB)
|
||||
|
||||
# Relationships - use string references to avoid circular imports
|
||||
observation_1 = relationship(
|
||||
"AgentObservation",
|
||||
foreign_keys=[observation_1_id],
|
||||
back_populates="contradictions_as_first",
|
||||
)
|
||||
observation_2 = relationship(
|
||||
"AgentObservation",
|
||||
foreign_keys=[observation_2_id],
|
||||
back_populates="contradictions_as_second",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("obs_contra_obs1_idx", "observation_1_id"),
|
||||
Index("obs_contra_obs2_idx", "observation_2_id"),
|
||||
Index("obs_contra_type_idx", "contradiction_type"),
|
||||
Index("obs_contra_method_idx", "detection_method"),
|
||||
)
|
||||
|
||||
|
||||
class ReactionPattern(Base):
|
||||
"""
|
||||
Tracks patterns in how the user reacts to certain types of observations or challenges.
|
||||
"""
|
||||
|
||||
__tablename__ = "reaction_pattern"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
trigger_type = Column(
|
||||
Text, nullable=False
|
||||
) # What kind of observation triggers this
|
||||
reaction_type = Column(Text, nullable=False) # How user typically responds
|
||||
frequency = Column(Numeric(3, 2), nullable=False) # How often this pattern appears
|
||||
first_observed = Column(DateTime(timezone=True), server_default=func.now())
|
||||
last_observed = Column(DateTime(timezone=True), server_default=func.now())
|
||||
example_observations = Column(
|
||||
ARRAY(BigInteger)
|
||||
) # IDs of observations showing this pattern
|
||||
reaction_metadata = Column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
Index("reaction_trigger_idx", "trigger_type"),
|
||||
Index("reaction_type_idx", "reaction_type"),
|
||||
Index("reaction_frequency_idx", "frequency"),
|
||||
)
|
||||
|
||||
|
||||
class ObservationPattern(Base):
|
||||
"""
|
||||
Higher-level patterns detected across multiple observations.
|
||||
"""
|
||||
|
||||
__tablename__ = "observation_pattern"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
pattern_type = Column(Text, nullable=False) # behavioral, cognitive, emotional
|
||||
description = Column(Text, nullable=False)
|
||||
supporting_observations = Column(
|
||||
ARRAY(BigInteger), nullable=False
|
||||
) # Observation IDs
|
||||
exceptions = Column(ARRAY(BigInteger)) # Observations that don't fit
|
||||
confidence = Column(Numeric(3, 2), nullable=False, default=0.7)
|
||||
validity_start = Column(DateTime(timezone=True))
|
||||
validity_end = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
pattern_metadata = Column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
Index("obs_pattern_type_idx", "pattern_type"),
|
||||
Index("obs_pattern_confidence_idx", "confidence"),
|
||||
Index("obs_pattern_validity_idx", "validity_start", "validity_end"),
|
||||
)
|
||||
|
||||
|
||||
class BeliefCluster(Base):
|
||||
"""
|
||||
Groups of related beliefs that support or depend on each other.
|
||||
"""
|
||||
|
||||
__tablename__ = "belief_cluster"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
cluster_name = Column(Text, nullable=False)
|
||||
core_beliefs = Column(ARRAY(Text), nullable=False)
|
||||
peripheral_beliefs = Column(ARRAY(Text))
|
||||
internal_consistency = Column(Numeric(3, 2)) # How well beliefs align
|
||||
supporting_observations = Column(ARRAY(BigInteger)) # Observation IDs
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
last_updated = Column(DateTime(timezone=True), server_default=func.now())
|
||||
cluster_metadata = Column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
Index("belief_cluster_name_idx", "cluster_name"),
|
||||
Index("belief_cluster_consistency_idx", "internal_consistency"),
|
||||
)
|
||||
|
||||
|
||||
class ConversationMetrics(Base):
|
||||
"""
|
||||
Tracks the effectiveness and depth of conversations.
|
||||
"""
|
||||
|
||||
__tablename__ = "conversation_metrics"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
session_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
depth_score = Column(Numeric(3, 2)) # How deep the conversation went
|
||||
breakthrough_count = Column(Integer, default=0)
|
||||
challenge_acceptance = Column(Numeric(3, 2)) # How well challenges were received
|
||||
new_insights = Column(Integer, default=0)
|
||||
user_engagement = Column(Numeric(3, 2)) # Inferred engagement level
|
||||
duration_minutes = Column(Integer)
|
||||
observation_count = Column(Integer, default=0)
|
||||
contradiction_count = Column(Integer, default=0)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
metrics_metadata = Column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
Index("conv_metrics_session_idx", "session_id", unique=True),
|
||||
Index("conv_metrics_depth_idx", "depth_score"),
|
||||
Index("conv_metrics_breakthrough_idx", "breakthrough_count"),
|
||||
)
|
287
src/memory/common/db/models/source_item.py
Normal file
287
src/memory/common/db/models/source_item.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""
|
||||
Database models for the knowledge base system.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Any, Sequence, cast
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
UUID,
|
||||
BigInteger,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
event,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import BYTEA
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
from memory.common import settings
|
||||
import memory.common.extract as extract
|
||||
import memory.common.collections as collections
|
||||
import memory.common.chunker as chunker
|
||||
import memory.common.summarizer as summarizer
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def handle_duplicate_sha256(session, flush_context, instances):
|
||||
"""
|
||||
Event listener that efficiently checks for duplicate sha256 values before flush
|
||||
and removes items with duplicate sha256 from the session.
|
||||
|
||||
Uses a single query to identify all duplicates rather than querying for each item.
|
||||
"""
|
||||
# Find all SourceItem objects being added
|
||||
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
items = {}
|
||||
for item in new_items:
|
||||
try:
|
||||
if (sha256 := item.sha256) is None:
|
||||
continue
|
||||
|
||||
if sha256 in items:
|
||||
session.expunge(item)
|
||||
continue
|
||||
|
||||
items[sha256] = item
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
# Query database for existing items with these sha256 values in a single query
|
||||
existing_sha256s = set(
|
||||
row[0]
|
||||
for row in session.query(SourceItem.sha256).filter(
|
||||
SourceItem.sha256.in_(items.keys())
|
||||
)
|
||||
)
|
||||
|
||||
# Remove objects with duplicate sha256 values from the session
|
||||
for sha256 in existing_sha256s:
|
||||
if sha256 in items:
|
||||
session.expunge(items[sha256])
|
||||
|
||||
|
||||
def clean_filename(filename: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
|
||||
|
||||
|
||||
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
|
||||
for i, image in enumerate(images):
|
||||
if not image.filename: # type: ignore
|
||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
|
||||
image.save(filename)
|
||||
image.filename = str(filename) # type: ignore
|
||||
|
||||
return [image.filename for image in images] # type: ignore
|
||||
|
||||
|
||||
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
|
||||
return [chunk] + [
|
||||
i
|
||||
for i in images
|
||||
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
|
||||
]
|
||||
|
||||
|
||||
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
final = {}
|
||||
for m in metadata:
|
||||
if tags := set(m.pop("tags", [])):
|
||||
final["tags"] = tags | final.get("tags", set())
|
||||
final |= m
|
||||
return final
|
||||
|
||||
|
||||
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
|
||||
if not content.strip():
|
||||
return []
|
||||
|
||||
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
||||
|
||||
summary, tags = summarizer.summarize(content)
|
||||
full_text: extract.DataChunk = extract.DataChunk(
|
||||
data=[content.strip(), *images], metadata={"tags": tags}
|
||||
)
|
||||
|
||||
chunks: list[extract.DataChunk] = [full_text]
|
||||
tokens = chunker.approx_token_count(content)
|
||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
chunks += [
|
||||
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
|
||||
for c in chunker.chunk_text(content)
|
||||
]
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
return [c for c in chunks if c.data]
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
"""Stores content chunks with their vector embeddings."""
|
||||
|
||||
__tablename__ = "chunk"
|
||||
|
||||
# The ID is also used as the vector ID in the vector database
|
||||
id = Column(
|
||||
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
|
||||
)
|
||||
source_id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
file_paths = Column(
|
||||
ARRAY(Text), nullable=True
|
||||
) # Path to content if stored as a file
|
||||
content = Column(Text) # Direct content storage
|
||||
embedding_model = Column(Text)
|
||||
collection_name = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
vector: list[float] = []
|
||||
item_metadata: dict[str, Any] = {}
|
||||
images: list[Image.Image] = []
|
||||
|
||||
# One of file_path or content must be populated
|
||||
__table_args__ = (
|
||||
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
|
||||
Index("chunk_source_idx", "source_id"),
|
||||
)
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[extract.MulitmodalChunk]:
|
||||
chunks: list[extract.MulitmodalChunk] = []
|
||||
if cast(str | None, self.content):
|
||||
chunks = [cast(str, self.content)]
|
||||
if self.images:
|
||||
chunks += self.images
|
||||
elif cast(Sequence[str] | None, self.file_paths):
|
||||
chunks += [
|
||||
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
|
||||
]
|
||||
return chunks
|
||||
|
||||
@property
|
||||
def data(self) -> list[bytes | str | Image.Image]:
|
||||
if self.file_paths is None:
|
||||
return [cast(str, self.content)]
|
||||
|
||||
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
|
||||
files = [path for path in paths if path.exists()]
|
||||
|
||||
items = []
|
||||
for file_path in files:
|
||||
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
||||
if file_path.exists():
|
||||
items.append(Image.open(file_path))
|
||||
elif file_path.suffix == ".bin":
|
||||
items.append(file_path.read_bytes())
|
||||
else:
|
||||
items.append(file_path.read_text())
|
||||
return items
|
||||
|
||||
|
||||
class SourceItem(Base):
|
||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||
|
||||
__tablename__ = "source_item"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
modality = Column(Text, nullable=False)
|
||||
sha256 = Column(BYTEA, nullable=False, unique=True)
|
||||
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
size = Column(Integer)
|
||||
mime_type = Column(Text)
|
||||
|
||||
# Content is stored in the database if it's small enough and text
|
||||
content = Column(Text)
|
||||
# Otherwise the content is stored on disk
|
||||
filename = Column(Text, nullable=True)
|
||||
|
||||
# Chunks relationship
|
||||
embed_status = Column(Text, nullable=False, server_default="RAW")
|
||||
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
||||
|
||||
# Discriminator column for SQLAlchemy inheritance
|
||||
type = Column(String(50))
|
||||
|
||||
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
|
||||
|
||||
# Add table-level constraint and indexes
|
||||
__table_args__ = (
|
||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||
Index("source_modality_idx", "modality"),
|
||||
Index("source_status_idx", "embed_status"),
|
||||
Index("source_tags_idx", "tags", postgresql_using="gin"),
|
||||
Index("source_filename_idx", "filename"),
|
||||
)
|
||||
|
||||
@property
|
||||
def vector_ids(self):
|
||||
"""Get vector IDs from associated chunks."""
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
chunks: list[extract.DataChunk] = []
|
||||
content = cast(str | None, self.content)
|
||||
if content:
|
||||
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
|
||||
|
||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
summary, tags = summarizer.summarize(content)
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
mime_type = cast(str | None, self.mime_type)
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
|
||||
return chunks
|
||||
|
||||
def _make_chunk(
|
||||
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
|
||||
) -> Chunk:
|
||||
chunk_id = str(uuid.uuid4())
|
||||
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
|
||||
images = [c for c in data.data if isinstance(c, Image.Image)]
|
||||
image_names = image_filenames(chunk_id, images)
|
||||
|
||||
chunk = Chunk(
|
||||
id=chunk_id,
|
||||
source=self,
|
||||
content=text or None,
|
||||
images=images,
|
||||
file_paths=image_names,
|
||||
collection_name=data.collection_name or cast(str, self.modality),
|
||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
||||
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
||||
)
|
||||
return chunk
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
return [self._make_chunk(data) for data in self._chunk_contents()]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
}
|
||||
|
||||
@property
|
||||
def display_contents(self) -> str | None:
|
||||
return cast(str | None, self.content) or cast(str | None, self.filename)
|
@ -3,18 +3,14 @@ Database models for the knowledge base system.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, Sequence, cast
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
UUID,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
@ -22,273 +18,23 @@ from sqlalchemy import (
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
event,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common import settings
|
||||
import memory.common.extract as extract
|
||||
import memory.common.collections as collections
|
||||
import memory.common.chunker as chunker
|
||||
import memory.common.summarizer as summarizer
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def handle_duplicate_sha256(session, flush_context, instances):
|
||||
"""
|
||||
Event listener that efficiently checks for duplicate sha256 values before flush
|
||||
and removes items with duplicate sha256 from the session.
|
||||
|
||||
Uses a single query to identify all duplicates rather than querying for each item.
|
||||
"""
|
||||
# Find all SourceItem objects being added
|
||||
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
items = {}
|
||||
for item in new_items:
|
||||
try:
|
||||
if (sha256 := item.sha256) is None:
|
||||
continue
|
||||
|
||||
if sha256 in items:
|
||||
session.expunge(item)
|
||||
continue
|
||||
|
||||
items[sha256] = item
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
# Query database for existing items with these sha256 values in a single query
|
||||
existing_sha256s = set(
|
||||
row[0]
|
||||
for row in session.query(SourceItem.sha256).filter(
|
||||
SourceItem.sha256.in_(items.keys())
|
||||
)
|
||||
)
|
||||
|
||||
# Remove objects with duplicate sha256 values from the session
|
||||
for sha256 in existing_sha256s:
|
||||
if sha256 in items:
|
||||
session.expunge(items[sha256])
|
||||
|
||||
|
||||
def clean_filename(filename: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
|
||||
|
||||
|
||||
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
|
||||
for i, image in enumerate(images):
|
||||
if not image.filename: # type: ignore
|
||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
|
||||
image.save(filename)
|
||||
image.filename = str(filename) # type: ignore
|
||||
|
||||
return [image.filename for image in images] # type: ignore
|
||||
|
||||
|
||||
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
|
||||
return [chunk] + [
|
||||
i
|
||||
for i in images
|
||||
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
|
||||
]
|
||||
|
||||
|
||||
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
final = {}
|
||||
for m in metadata:
|
||||
if tags := set(m.pop("tags", [])):
|
||||
final["tags"] = tags | final.get("tags", set())
|
||||
final |= m
|
||||
return final
|
||||
|
||||
|
||||
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
|
||||
if not content.strip():
|
||||
return []
|
||||
|
||||
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
||||
|
||||
summary, tags = summarizer.summarize(content)
|
||||
full_text: extract.DataChunk = extract.DataChunk(
|
||||
data=[content.strip(), *images], metadata={"tags": tags}
|
||||
)
|
||||
|
||||
chunks: list[extract.DataChunk] = [full_text]
|
||||
tokens = chunker.approx_token_count(content)
|
||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
chunks += [
|
||||
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
|
||||
for c in chunker.chunk_text(content)
|
||||
]
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
return [c for c in chunks if c.data]
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
"""Stores content chunks with their vector embeddings."""
|
||||
|
||||
__tablename__ = "chunk"
|
||||
|
||||
# The ID is also used as the vector ID in the vector database
|
||||
id = Column(
|
||||
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
|
||||
)
|
||||
source_id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
file_paths = Column(
|
||||
ARRAY(Text), nullable=True
|
||||
) # Path to content if stored as a file
|
||||
content = Column(Text) # Direct content storage
|
||||
embedding_model = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
vector: list[float] = []
|
||||
item_metadata: dict[str, Any] = {}
|
||||
images: list[Image.Image] = []
|
||||
|
||||
# One of file_path or content must be populated
|
||||
__table_args__ = (
|
||||
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
|
||||
Index("chunk_source_idx", "source_id"),
|
||||
)
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[extract.MulitmodalChunk]:
|
||||
chunks: list[extract.MulitmodalChunk] = []
|
||||
if cast(str | None, self.content):
|
||||
chunks = [cast(str, self.content)]
|
||||
if self.images:
|
||||
chunks += self.images
|
||||
elif cast(Sequence[str] | None, self.file_paths):
|
||||
chunks += [
|
||||
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
|
||||
]
|
||||
return chunks
|
||||
|
||||
@property
|
||||
def data(self) -> list[bytes | str | Image.Image]:
|
||||
if self.file_paths is None:
|
||||
return [cast(str, self.content)]
|
||||
|
||||
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
|
||||
files = [path for path in paths if path.exists()]
|
||||
|
||||
items = []
|
||||
for file_path in files:
|
||||
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
||||
if file_path.exists():
|
||||
items.append(Image.open(file_path))
|
||||
elif file_path.suffix == ".bin":
|
||||
items.append(file_path.read_bytes())
|
||||
else:
|
||||
items.append(file_path.read_text())
|
||||
return items
|
||||
|
||||
|
||||
class SourceItem(Base):
|
||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||
|
||||
__tablename__ = "source_item"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
modality = Column(Text, nullable=False)
|
||||
sha256 = Column(BYTEA, nullable=False, unique=True)
|
||||
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
size = Column(Integer)
|
||||
mime_type = Column(Text)
|
||||
|
||||
# Content is stored in the database if it's small enough and text
|
||||
content = Column(Text)
|
||||
# Otherwise the content is stored on disk
|
||||
filename = Column(Text, nullable=True)
|
||||
|
||||
# Chunks relationship
|
||||
embed_status = Column(Text, nullable=False, server_default="RAW")
|
||||
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
||||
|
||||
# Discriminator column for SQLAlchemy inheritance
|
||||
type = Column(String(50))
|
||||
|
||||
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
|
||||
|
||||
# Add table-level constraint and indexes
|
||||
__table_args__ = (
|
||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||
Index("source_modality_idx", "modality"),
|
||||
Index("source_status_idx", "embed_status"),
|
||||
Index("source_tags_idx", "tags", postgresql_using="gin"),
|
||||
Index("source_filename_idx", "filename"),
|
||||
)
|
||||
|
||||
@property
|
||||
def vector_ids(self):
|
||||
"""Get vector IDs from associated chunks."""
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
chunks: list[extract.DataChunk] = []
|
||||
content = cast(str | None, self.content)
|
||||
if content:
|
||||
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
|
||||
|
||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
summary, tags = summarizer.summarize(content)
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
mime_type = cast(str | None, self.mime_type)
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
|
||||
return chunks
|
||||
|
||||
def _make_chunk(
|
||||
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
|
||||
) -> Chunk:
|
||||
chunk_id = str(uuid.uuid4())
|
||||
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
|
||||
images = [c for c in data.data if isinstance(c, Image.Image)]
|
||||
image_names = image_filenames(chunk_id, images)
|
||||
|
||||
chunk = Chunk(
|
||||
id=chunk_id,
|
||||
source=self,
|
||||
content=text or None,
|
||||
images=images,
|
||||
file_paths=image_names,
|
||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
||||
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
||||
)
|
||||
return chunk
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
return [self._make_chunk(data) for data in self._chunk_contents()]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
}
|
||||
|
||||
@property
|
||||
def display_contents(self) -> str | None:
|
||||
return cast(str | None, self.content) or cast(str | None, self.filename)
|
||||
from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
Chunk,
|
||||
clean_filename,
|
||||
merge_metadata,
|
||||
chunk_mixed,
|
||||
)
|
||||
|
||||
|
||||
class MailMessage(SourceItem):
|
||||
@ -528,51 +274,6 @@ class Comic(SourceItem):
|
||||
return [extract.DataChunk(data=[image, description])]
|
||||
|
||||
|
||||
class Book(Base):
|
||||
"""Book-level metadata table"""
|
||||
|
||||
__tablename__ = "book"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
isbn = Column(Text, unique=True)
|
||||
title = Column(Text, nullable=False)
|
||||
author = Column(Text)
|
||||
publisher = Column(Text)
|
||||
published = Column(DateTime(timezone=True))
|
||||
language = Column(Text)
|
||||
edition = Column(Text)
|
||||
series = Column(Text)
|
||||
series_number = Column(Integer)
|
||||
total_pages = Column(Integer)
|
||||
file_path = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
# Metadata from ebook parser
|
||||
book_metadata = Column(JSONB, name="metadata")
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("book_isbn_idx", "isbn"),
|
||||
Index("book_author_idx", "author"),
|
||||
Index("book_title_idx", "title"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
**super().as_payload(),
|
||||
"isbn": self.isbn,
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
"publisher": self.publisher,
|
||||
"published": self.published,
|
||||
"language": self.language,
|
||||
"edition": self.edition,
|
||||
"series": self.series,
|
||||
"series_number": self.series_number,
|
||||
} | (cast(dict, self.book_metadata) or {})
|
||||
|
||||
|
||||
class BookSection(SourceItem):
|
||||
"""Individual sections/chapters of books"""
|
||||
|
||||
@ -797,58 +498,73 @@ class GithubItem(SourceItem):
|
||||
)
|
||||
|
||||
|
||||
class ArticleFeed(Base):
|
||||
__tablename__ = "article_feeds"
|
||||
class AgentObservation(SourceItem):
|
||||
"""
|
||||
Records observations made by AI agents about the user.
|
||||
This is the primary data model for the epistemic sparring partner.
|
||||
"""
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
url = Column(Text, nullable=False, unique=True)
|
||||
title = Column(Text)
|
||||
description = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
check_interval = Column(
|
||||
Integer, nullable=False, server_default="60", doc="Minutes between checks"
|
||||
__tablename__ = "agent_observation"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
session_id = Column(
|
||||
UUID(as_uuid=True)
|
||||
) # Groups observations from same conversation
|
||||
observation_type = Column(
|
||||
Text, nullable=False
|
||||
) # belief, preference, pattern, contradiction, behavior
|
||||
subject = Column(Text, nullable=False) # What/who the observation is about
|
||||
confidence = Column(Numeric(3, 2), nullable=False, default=0.8) # 0.0-1.0
|
||||
evidence = Column(JSONB) # Supporting context, quotes, etc.
|
||||
agent_model = Column(Text, nullable=False) # Which AI model made this observation
|
||||
|
||||
# Relationships
|
||||
contradictions_as_first = relationship(
|
||||
"ObservationContradiction",
|
||||
foreign_keys="ObservationContradiction.observation_1_id",
|
||||
back_populates="observation_1",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
contradictions_as_second = relationship(
|
||||
"ObservationContradiction",
|
||||
foreign_keys="ObservationContradiction.observation_2_id",
|
||||
back_populates="observation_2",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "agent_observation",
|
||||
}
|
||||
|
||||
__table_args__ = (
|
||||
Index("article_feeds_active_idx", "active", "last_checked_at"),
|
||||
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
|
||||
Index("agent_obs_session_idx", "session_id"),
|
||||
Index("agent_obs_type_idx", "observation_type"),
|
||||
Index("agent_obs_subject_idx", "subject"),
|
||||
Index("agent_obs_confidence_idx", "confidence"),
|
||||
Index("agent_obs_model_idx", "agent_model"),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not kwargs.get("modality"):
|
||||
kwargs["modality"] = "observation"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
class EmailAccount(Base):
|
||||
__tablename__ = "email_accounts"
|
||||
def as_payload(self) -> dict:
|
||||
payload = {
|
||||
**super().as_payload(),
|
||||
"observation_type": self.observation_type,
|
||||
"subject": self.subject,
|
||||
"confidence": float(cast(Any, self.confidence)),
|
||||
"evidence": self.evidence,
|
||||
"agent_model": self.agent_model,
|
||||
}
|
||||
if self.session_id is not None:
|
||||
payload["session_id"] = str(self.session_id)
|
||||
return payload
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
name = Column(Text, nullable=False)
|
||||
email_address = Column(Text, nullable=False, unique=True)
|
||||
imap_server = Column(Text, nullable=False)
|
||||
imap_port = Column(Integer, nullable=False, server_default="993")
|
||||
username = Column(Text, nullable=False)
|
||||
password = Column(Text, nullable=False)
|
||||
use_ssl = Column(Boolean, nullable=False, server_default="true")
|
||||
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
last_sync_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("email_accounts_address_idx", "email_address", unique=True),
|
||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||
)
|
||||
@property
|
||||
def all_contradictions(self):
|
||||
"""Get all contradictions involving this observation."""
|
||||
return self.contradictions_as_first + self.contradictions_as_second
|
122
src/memory/common/db/models/sources.py
Normal file
122
src/memory/common/db/models/sources.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""
|
||||
Database models for the knowledge base system.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Index,
|
||||
Integer,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class Book(Base):
|
||||
"""Book-level metadata table"""
|
||||
|
||||
__tablename__ = "book"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
isbn = Column(Text, unique=True)
|
||||
title = Column(Text, nullable=False)
|
||||
author = Column(Text)
|
||||
publisher = Column(Text)
|
||||
published = Column(DateTime(timezone=True))
|
||||
language = Column(Text)
|
||||
edition = Column(Text)
|
||||
series = Column(Text)
|
||||
series_number = Column(Integer)
|
||||
total_pages = Column(Integer)
|
||||
file_path = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
# Metadata from ebook parser
|
||||
book_metadata = Column(JSONB, name="metadata")
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("book_isbn_idx", "isbn"),
|
||||
Index("book_author_idx", "author"),
|
||||
Index("book_title_idx", "title"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
**super().as_payload(),
|
||||
"isbn": self.isbn,
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
"publisher": self.publisher,
|
||||
"published": self.published,
|
||||
"language": self.language,
|
||||
"edition": self.edition,
|
||||
"series": self.series,
|
||||
"series_number": self.series_number,
|
||||
} | (cast(dict, self.book_metadata) or {})
|
||||
|
||||
|
||||
class ArticleFeed(Base):
|
||||
__tablename__ = "article_feeds"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
url = Column(Text, nullable=False, unique=True)
|
||||
title = Column(Text)
|
||||
description = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
check_interval = Column(
|
||||
Integer, nullable=False, server_default="60", doc="Minutes between checks"
|
||||
)
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("article_feeds_active_idx", "active", "last_checked_at"),
|
||||
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
|
||||
class EmailAccount(Base):
|
||||
__tablename__ = "email_accounts"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
name = Column(Text, nullable=False)
|
||||
email_address = Column(Text, nullable=False, unique=True)
|
||||
imap_server = Column(Text, nullable=False)
|
||||
imap_port = Column(Integer, nullable=False, server_default="993")
|
||||
username = Column(Text, nullable=False)
|
||||
password = Column(Text, nullable=False)
|
||||
use_ssl = Column(Boolean, nullable=False, server_default="true")
|
||||
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
last_sync_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("email_accounts_address_idx", "email_address", unique=True),
|
||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||
)
|
@ -21,6 +21,7 @@ class DataChunk:
|
||||
data: Sequence[MulitmodalChunk]
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
mime_type: str = "text/plain"
|
||||
collection_name: str | None = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -6,13 +6,14 @@ the complete workflow: existence checking, content hashing, embedding generation
|
||||
vector storage, and result tracking.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
import hashlib
|
||||
import traceback
|
||||
import logging
|
||||
from typing import Any, Callable, Iterable, Sequence, cast
|
||||
|
||||
from memory.common import embedding, qdrant
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.common.db.models import SourceItem, Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -103,7 +104,19 @@ def embed_source_item(source_item: SourceItem) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
||||
def by_collection(chunks: Sequence[Chunk]) -> dict[str, dict[str, Any]]:
|
||||
collections: dict[str, dict[str, list[Any]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
for chunk in chunks:
|
||||
collection = collections[cast(str, chunk.collection_name)]
|
||||
collection["ids"].append(chunk.id)
|
||||
collection["vectors"].append(chunk.vector)
|
||||
collection["payloads"].append(chunk.item_metadata)
|
||||
return collections
|
||||
|
||||
|
||||
def push_to_qdrant(source_items: Sequence[SourceItem]):
|
||||
"""
|
||||
Push embeddings to Qdrant vector database.
|
||||
|
||||
@ -135,16 +148,15 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
||||
return
|
||||
|
||||
try:
|
||||
vector_ids = [str(chunk.id) for chunk in all_chunks]
|
||||
vectors = [chunk.vector for chunk in all_chunks]
|
||||
payloads = [chunk.item_metadata for chunk in all_chunks]
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
collections = by_collection(all_chunks)
|
||||
for collection_name, collection in collections.items():
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant.get_qdrant_client(),
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
ids=vector_ids,
|
||||
vectors=vectors,
|
||||
payloads=payloads,
|
||||
ids=collection["ids"],
|
||||
vectors=collection["vectors"],
|
||||
payloads=collection["payloads"],
|
||||
)
|
||||
|
||||
for item in items_to_process:
|
||||
@ -222,7 +234,7 @@ def process_content_item(item: SourceItem, session) -> dict[str, Any]:
|
||||
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
||||
|
||||
try:
|
||||
push_to_qdrant([item], cast(str, item.modality))
|
||||
push_to_qdrant([item])
|
||||
status = "processed"
|
||||
item.embed_status = "STORED" # type: ignore
|
||||
logger.info(
|
||||
|
@ -185,7 +185,7 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
||||
f"Embedded section: {section.section_title} - {section.content[:100]}"
|
||||
)
|
||||
logger.info("Pushing to Qdrant")
|
||||
push_to_qdrant(all_sections, "book")
|
||||
push_to_qdrant(all_sections)
|
||||
logger.info("Committing session")
|
||||
|
||||
session.commit()
|
||||
|
@ -1,4 +1,3 @@
|
||||
from memory.common.db.models import SourceItem
|
||||
from sqlalchemy.orm import Session
|
||||
from unittest.mock import patch, Mock
|
||||
from typing import cast
|
||||
@ -6,17 +5,20 @@ import pytest
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
from memory.common import settings, chunker, extract
|
||||
from memory.common.db.models import (
|
||||
from memory.common.db.models.sources import Book
|
||||
from memory.common.db.models.source_items import (
|
||||
Chunk,
|
||||
clean_filename,
|
||||
image_filenames,
|
||||
add_pics,
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
BookSection,
|
||||
BlogPost,
|
||||
Book,
|
||||
)
|
||||
from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
image_filenames,
|
||||
add_pics,
|
||||
merge_metadata,
|
||||
clean_filename,
|
||||
)
|
||||
|
||||
|
||||
|
@ -18,6 +18,7 @@ from memory.workers.tasks.content_processing import (
|
||||
process_content_item,
|
||||
push_to_qdrant,
|
||||
safe_task_execution,
|
||||
by_collection,
|
||||
)
|
||||
|
||||
|
||||
@ -71,6 +72,7 @@ def mock_chunk():
|
||||
chunk.id = "00000000-0000-0000-0000-000000000001"
|
||||
chunk.vector = [0.1] * 1024
|
||||
chunk.item_metadata = {"source_id": 1, "tags": ["test"]}
|
||||
chunk.collection_name = "mail"
|
||||
return chunk
|
||||
|
||||
|
||||
@ -242,17 +244,19 @@ def test_push_to_qdrant_success(qdrant):
|
||||
mock_chunk1.id = "00000000-0000-0000-0000-000000000001"
|
||||
mock_chunk1.vector = [0.1] * 1024
|
||||
mock_chunk1.item_metadata = {"source_id": 1, "tags": ["test"]}
|
||||
mock_chunk1.collection_name = "mail"
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.id = "00000000-0000-0000-0000-000000000002"
|
||||
mock_chunk2.vector = [0.2] * 1024
|
||||
mock_chunk2.item_metadata = {"source_id": 2, "tags": ["test"]}
|
||||
mock_chunk2.collection_name = "mail"
|
||||
|
||||
# Assign chunks directly (bypassing SQLAlchemy relationship)
|
||||
item1.chunks = [mock_chunk1]
|
||||
item2.chunks = [mock_chunk2]
|
||||
|
||||
push_to_qdrant([item1, item2], "mail")
|
||||
push_to_qdrant([item1, item2])
|
||||
|
||||
assert str(item1.embed_status) == "STORED"
|
||||
assert str(item2.embed_status) == "STORED"
|
||||
@ -294,6 +298,7 @@ def test_push_to_qdrant_no_processing(
|
||||
mock_chunk.id = f"00000000-0000-0000-0000-00000000000{suffix}"
|
||||
mock_chunk.vector = [0.1] * 1024
|
||||
mock_chunk.item_metadata = {"source_id": int(suffix), "tags": ["test"]}
|
||||
mock_chunk.collection_name = "mail"
|
||||
item.chunks = [mock_chunk]
|
||||
else:
|
||||
item.chunks = []
|
||||
@ -302,7 +307,7 @@ def test_push_to_qdrant_no_processing(
|
||||
item1 = create_item("1", item1_status, item1_has_chunks)
|
||||
item2 = create_item("2", item2_status, item2_has_chunks)
|
||||
|
||||
push_to_qdrant([item1, item2], "mail")
|
||||
push_to_qdrant([item1, item2])
|
||||
|
||||
assert str(item1.embed_status) == expected_item1_status
|
||||
assert str(item2.embed_status) == expected_item2_status
|
||||
@ -317,7 +322,7 @@ def test_push_to_qdrant_exception(sample_mail_message, mock_chunk):
|
||||
side_effect=Exception("Qdrant error"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Qdrant error"):
|
||||
push_to_qdrant([sample_mail_message], "mail")
|
||||
push_to_qdrant([sample_mail_message])
|
||||
|
||||
assert str(sample_mail_message.embed_status) == "FAILED"
|
||||
|
||||
@ -418,6 +423,196 @@ def test_create_task_result_no_title():
|
||||
assert result["chunks_count"] == 0
|
||||
|
||||
|
||||
def test_by_collection_empty_chunks():
|
||||
result = by_collection([])
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_by_collection_single_chunk():
|
||||
chunk = Chunk(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
content="test content",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
item_metadata={"source_id": 1, "tags": ["test"]},
|
||||
collection_name="test_collection",
|
||||
)
|
||||
|
||||
result = by_collection([chunk])
|
||||
|
||||
assert len(result) == 1
|
||||
assert "test_collection" in result
|
||||
assert result["test_collection"]["ids"] == ["00000000-0000-0000-0000-000000000001"]
|
||||
assert result["test_collection"]["vectors"] == [[0.1, 0.2, 0.3]]
|
||||
assert result["test_collection"]["payloads"] == [{"source_id": 1, "tags": ["test"]}]
|
||||
|
||||
|
||||
def test_by_collection_multiple_chunks_same_collection():
|
||||
chunks = [
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
content="test content 1",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1, 0.2],
|
||||
item_metadata={"source_id": 1},
|
||||
collection_name="collection_a",
|
||||
),
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000002",
|
||||
content="test content 2",
|
||||
embedding_model="test-model",
|
||||
vector=[0.3, 0.4],
|
||||
item_metadata={"source_id": 2},
|
||||
collection_name="collection_a",
|
||||
),
|
||||
]
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "collection_a" in result
|
||||
assert result["collection_a"]["ids"] == [
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
"00000000-0000-0000-0000-000000000002",
|
||||
]
|
||||
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.3, 0.4]]
|
||||
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 2}]
|
||||
|
||||
|
||||
def test_by_collection_multiple_chunks_different_collections():
|
||||
chunks = [
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
content="test content 1",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1, 0.2],
|
||||
item_metadata={"source_id": 1},
|
||||
collection_name="collection_a",
|
||||
),
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000002",
|
||||
content="test content 2",
|
||||
embedding_model="test-model",
|
||||
vector=[0.3, 0.4],
|
||||
item_metadata={"source_id": 2},
|
||||
collection_name="collection_b",
|
||||
),
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000003",
|
||||
content="test content 3",
|
||||
embedding_model="test-model",
|
||||
vector=[0.5, 0.6],
|
||||
item_metadata={"source_id": 3},
|
||||
collection_name="collection_a",
|
||||
),
|
||||
]
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "collection_a" in result
|
||||
assert "collection_b" in result
|
||||
|
||||
# Check collection_a
|
||||
assert result["collection_a"]["ids"] == [
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
"00000000-0000-0000-0000-000000000003",
|
||||
]
|
||||
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.5, 0.6]]
|
||||
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 3}]
|
||||
|
||||
# Check collection_b
|
||||
assert result["collection_b"]["ids"] == ["00000000-0000-0000-0000-000000000002"]
|
||||
assert result["collection_b"]["vectors"] == [[0.3, 0.4]]
|
||||
assert result["collection_b"]["payloads"] == [{"source_id": 2}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"collection_names,expected_collections",
|
||||
[
|
||||
(["col1", "col1", "col1"], 1),
|
||||
(["col1", "col2", "col3"], 3),
|
||||
(["col1", "col2", "col1", "col2"], 2),
|
||||
(["single"], 1),
|
||||
],
|
||||
)
|
||||
def test_by_collection_various_groupings(collection_names, expected_collections):
|
||||
chunks = [
|
||||
Chunk(
|
||||
id=f"00000000-0000-0000-0000-00000000000{i}",
|
||||
content=f"test content {i}",
|
||||
embedding_model="test-model",
|
||||
vector=[float(i)],
|
||||
item_metadata={"index": i},
|
||||
collection_name=collection_name,
|
||||
)
|
||||
for i, collection_name in enumerate(collection_names, 1)
|
||||
]
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == expected_collections
|
||||
# Verify all chunks are accounted for
|
||||
total_chunks = sum(len(coll["ids"]) for coll in result.values())
|
||||
assert total_chunks == len(chunks)
|
||||
|
||||
|
||||
def test_by_collection_with_none_values():
|
||||
chunks = [
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
content="test content",
|
||||
embedding_model="test-model",
|
||||
vector=None, # None vector
|
||||
item_metadata=None, # None metadata
|
||||
collection_name="test_collection",
|
||||
),
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000002",
|
||||
content="test content 2",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1, 0.2],
|
||||
item_metadata={"key": "value"},
|
||||
collection_name="test_collection",
|
||||
),
|
||||
]
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "test_collection" in result
|
||||
assert result["test_collection"]["ids"] == [
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
"00000000-0000-0000-0000-000000000002",
|
||||
]
|
||||
assert result["test_collection"]["vectors"] == [None, [0.1, 0.2]]
|
||||
assert result["test_collection"]["payloads"] == [None, {"key": "value"}]
|
||||
|
||||
|
||||
def test_by_collection_preserves_order():
|
||||
chunks = []
|
||||
for i in range(5):
|
||||
chunks.append(
|
||||
Chunk(
|
||||
id=f"00000000-0000-0000-0000-00000000000{i}",
|
||||
content=f"test content {i}",
|
||||
embedding_model="test-model",
|
||||
vector=[float(i)],
|
||||
item_metadata={"order": i},
|
||||
collection_name="ordered_collection",
|
||||
)
|
||||
)
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result["ordered_collection"]["ids"] == [
|
||||
f"00000000-0000-0000-0000-00000000000{i}" for i in range(5)
|
||||
]
|
||||
assert result["ordered_collection"]["vectors"] == [[float(i)] for i in range(5)]
|
||||
assert result["ordered_collection"]["payloads"] == [{"order": i} for i in range(5)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"embedding_return,qdrant_error,expected_status,expected_embed_status",
|
||||
[
|
||||
@ -459,6 +654,7 @@ def test_process_content_item(
|
||||
embedding_model="test-model",
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"source_id": 1, "tags": ["test"]},
|
||||
collection_name="mail",
|
||||
)
|
||||
mock_chunks = [real_chunk]
|
||||
else: # empty
|
||||
|
@ -2,6 +2,7 @@
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch, call
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
@ -350,6 +351,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
||||
id=chunk_id,
|
||||
source=item,
|
||||
content=f"Test chunk content {i}",
|
||||
collection_name=item.modality,
|
||||
embedding_model="test-model",
|
||||
)
|
||||
for i, chunk_id in enumerate(chunk_ids)
|
||||
@ -358,7 +360,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
||||
db_session.commit()
|
||||
|
||||
# Add vectors to Qdrant
|
||||
modality = "mail" if item_type == "MailMessage" else "blog"
|
||||
modality = cast(str, item.modality)
|
||||
qd.ensure_collection_exists(qdrant, modality, 1024)
|
||||
qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids))
|
||||
|
||||
@ -375,6 +377,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
||||
id=str(uuid.uuid4()),
|
||||
content="New chunk content 1",
|
||||
embedding_model="test-model",
|
||||
collection_name=modality,
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||
),
|
||||
@ -382,6 +385,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
||||
id=str(uuid.uuid4()),
|
||||
content="New chunk content 2",
|
||||
embedding_model="test-model",
|
||||
collection_name=modality,
|
||||
vector=[0.2] * 1024,
|
||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||
),
|
||||
@ -449,6 +453,7 @@ def test_reingest_item_no_chunks(db_session, qdrant):
|
||||
id=str(uuid.uuid4()),
|
||||
content="New chunk content",
|
||||
embedding_model="test-model",
|
||||
collection_name=item.modality,
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||
),
|
||||
@ -538,6 +543,7 @@ def test_reingest_empty_source_items_success(db_session, item_type):
|
||||
source=item_with_chunks,
|
||||
content="Test chunk content",
|
||||
embedding_model="test-model",
|
||||
collection_name=item_with_chunks.modality,
|
||||
)
|
||||
db_session.add(chunk)
|
||||
db_session.commit()
|
||||
|
Loading…
x
Reference in New Issue
Block a user