mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
muliple dimemnsions for confidence values
This commit is contained in:
parent
a40e0b50fa
commit
e5da3714de
79
db/migrations/versions/20250603_115642_add_confidences.py
Normal file
79
db/migrations/versions/20250603_115642_add_confidences.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Add confidences
|
||||||
|
|
||||||
|
Revision ID: 152f8b4b52e8
|
||||||
|
Revises: ba301527a2eb
|
||||||
|
Create Date: 2025-06-03 11:56:42.302327
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "152f8b4b52e8"
|
||||||
|
down_revision: Union[str, None] = "ba301527a2eb"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"confidence_score",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("source_item_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("confidence_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column("score", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||||
|
sa.CheckConstraint("score >= 0.0 AND score <= 1.0", name="score_range_check"),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["source_item_id"], ["source_item.id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"source_item_id", "confidence_type", name="unique_source_confidence_type"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("confidence_score_idx", "confidence_score", ["score"], unique=False)
|
||||||
|
op.create_index(
|
||||||
|
"confidence_source_idx", "confidence_score", ["source_item_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"confidence_type_idx", "confidence_score", ["confidence_type"], unique=False
|
||||||
|
)
|
||||||
|
op.drop_index("agent_obs_confidence_idx", table_name="agent_observation")
|
||||||
|
op.drop_column("agent_observation", "confidence")
|
||||||
|
op.drop_index("note_confidence_idx", table_name="notes")
|
||||||
|
op.drop_column("notes", "confidence")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"notes",
|
||||||
|
sa.Column(
|
||||||
|
"confidence",
|
||||||
|
sa.NUMERIC(precision=3, scale=2),
|
||||||
|
server_default=sa.text("0.5"),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("note_confidence_idx", "notes", ["confidence"], unique=False)
|
||||||
|
op.add_column(
|
||||||
|
"agent_observation",
|
||||||
|
sa.Column(
|
||||||
|
"confidence",
|
||||||
|
sa.NUMERIC(precision=3, scale=2),
|
||||||
|
server_default=sa.text("0.5"),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"agent_obs_confidence_idx", "agent_observation", ["confidence"], unique=False
|
||||||
|
)
|
||||||
|
op.drop_index("confidence_type_idx", table_name="confidence_score")
|
||||||
|
op.drop_index("confidence_source_idx", table_name="confidence_score")
|
||||||
|
op.drop_index("confidence_score_idx", table_name="confidence_score")
|
||||||
|
op.drop_table("confidence_score")
|
@ -654,11 +654,9 @@ async def create_note(
|
|||||||
"""
|
"""
|
||||||
if filename:
|
if filename:
|
||||||
path = pathlib.Path(filename)
|
path = pathlib.Path(filename)
|
||||||
if path.is_absolute():
|
if not path.is_absolute():
|
||||||
path = path.relative_to(settings.NOTES_STORAGE_DIR)
|
|
||||||
else:
|
|
||||||
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
|
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
|
||||||
filename = path.as_posix()
|
filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
task = celery_app.send_task(
|
task = celery_app.send_task(
|
||||||
|
@ -2,6 +2,7 @@ from memory.common.db.models.base import Base
|
|||||||
from memory.common.db.models.source_item import (
|
from memory.common.db.models.source_item import (
|
||||||
Chunk,
|
Chunk,
|
||||||
SourceItem,
|
SourceItem,
|
||||||
|
ConfidenceScore,
|
||||||
clean_filename,
|
clean_filename,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.source_items import (
|
from memory.common.db.models.source_items import (
|
||||||
@ -37,6 +38,7 @@ __all__ = [
|
|||||||
"Chunk",
|
"Chunk",
|
||||||
"clean_filename",
|
"clean_filename",
|
||||||
"SourceItem",
|
"SourceItem",
|
||||||
|
"ConfidenceScore",
|
||||||
"MailMessage",
|
"MailMessage",
|
||||||
"EmailAttachment",
|
"EmailAttachment",
|
||||||
"AgentObservation",
|
"AgentObservation",
|
||||||
|
@ -22,9 +22,11 @@ from sqlalchemy import (
|
|||||||
Text,
|
Text,
|
||||||
event,
|
event,
|
||||||
func,
|
func,
|
||||||
|
UniqueConstraint,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import BYTEA
|
from sqlalchemy.dialects.postgresql import BYTEA
|
||||||
from sqlalchemy.orm import Session, relationship
|
from sqlalchemy.orm import Session, relationship
|
||||||
|
from sqlalchemy.types import Numeric
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
import memory.common.extract as extract
|
import memory.common.extract as extract
|
||||||
@ -191,6 +193,41 @@ class Chunk(Base):
|
|||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
class ConfidenceScore(Base):
|
||||||
|
"""
|
||||||
|
Stores structured confidence scores for source items.
|
||||||
|
Provides detailed confidence dimensions instead of a single score.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "confidence_score"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
source_item_id = Column(
|
||||||
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
confidence_type = Column(
|
||||||
|
Text, nullable=False
|
||||||
|
) # e.g., "observation_accuracy", "interpretation", "predictive_value"
|
||||||
|
score = Column(Numeric(3, 2), nullable=False) # 0.0-1.0
|
||||||
|
|
||||||
|
# Relationship back to source item
|
||||||
|
source_item = relationship("SourceItem", back_populates="confidence_scores")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("confidence_source_idx", "source_item_id"),
|
||||||
|
Index("confidence_type_idx", "confidence_type"),
|
||||||
|
Index("confidence_score_idx", "score"),
|
||||||
|
CheckConstraint("score >= 0.0 AND score <= 1.0", name="score_range_check"),
|
||||||
|
# Ensure each source_item can only have one score per confidence_type
|
||||||
|
UniqueConstraint(
|
||||||
|
"source_item_id", "confidence_type", name="unique_source_confidence_type"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<ConfidenceScore(type={self.confidence_type}, score={self.score})>"
|
||||||
|
|
||||||
|
|
||||||
class SourceItem(Base):
|
class SourceItem(Base):
|
||||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||||
|
|
||||||
@ -216,6 +253,11 @@ class SourceItem(Base):
|
|||||||
embed_status = Column(Text, nullable=False, server_default="RAW")
|
embed_status = Column(Text, nullable=False, server_default="RAW")
|
||||||
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
# Confidence scores relationship
|
||||||
|
confidence_scores = relationship(
|
||||||
|
"ConfidenceScore", back_populates="source_item", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
# Discriminator column for SQLAlchemy inheritance
|
# Discriminator column for SQLAlchemy inheritance
|
||||||
type = Column(String(50))
|
type = Column(String(50))
|
||||||
|
|
||||||
@ -235,6 +277,35 @@ class SourceItem(Base):
|
|||||||
"""Get vector IDs from associated chunks."""
|
"""Get vector IDs from associated chunks."""
|
||||||
return [chunk.id for chunk in self.chunks]
|
return [chunk.id for chunk in self.chunks]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def confidence_dict(self) -> dict[str, float]:
|
||||||
|
return {
|
||||||
|
score.confidence_type: float(score.score)
|
||||||
|
for score in self.confidence_scores
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_confidences(self, confidence_updates: dict[str, float]) -> None:
|
||||||
|
"""
|
||||||
|
Update confidence scores for this source item.
|
||||||
|
Merges new scores with existing ones, overwriting duplicates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
confidence_updates: Dict mapping confidence_type to score (0.0-1.0)
|
||||||
|
"""
|
||||||
|
if not confidence_updates:
|
||||||
|
return
|
||||||
|
|
||||||
|
current = {s.confidence_type: s for s in self.confidence_scores}
|
||||||
|
|
||||||
|
for confidence_type, score in confidence_updates.items():
|
||||||
|
if current_score := current.get(confidence_type):
|
||||||
|
current_score.score = score
|
||||||
|
else:
|
||||||
|
new_score = ConfidenceScore(
|
||||||
|
source_item_id=self.id, confidence_type=confidence_type, score=score
|
||||||
|
)
|
||||||
|
self.confidence_scores.append(new_score)
|
||||||
|
|
||||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
content = cast(str | None, self.content)
|
content = cast(str | None, self.content)
|
||||||
if content:
|
if content:
|
||||||
|
@ -505,7 +505,6 @@ class Note(SourceItem):
|
|||||||
)
|
)
|
||||||
note_type = Column(Text, nullable=True)
|
note_type = Column(Text, nullable=True)
|
||||||
subject = Column(Text, nullable=True)
|
subject = Column(Text, nullable=True)
|
||||||
confidence = Column(Numeric(3, 2), nullable=False, default=0.5) # 0.0-1.0
|
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
"polymorphic_identity": "note",
|
"polymorphic_identity": "note",
|
||||||
@ -514,7 +513,6 @@ class Note(SourceItem):
|
|||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("note_type_idx", "note_type"),
|
Index("note_type_idx", "note_type"),
|
||||||
Index("note_subject_idx", "subject"),
|
Index("note_subject_idx", "subject"),
|
||||||
Index("note_confidence_idx", "confidence"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> dict:
|
||||||
@ -522,7 +520,7 @@ class Note(SourceItem):
|
|||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
"note_type": self.note_type,
|
"note_type": self.note_type,
|
||||||
"subject": self.subject,
|
"subject": self.subject,
|
||||||
"confidence": float(cast(Any, self.confidence)),
|
"confidence": self.confidence_dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -531,7 +529,7 @@ class Note(SourceItem):
|
|||||||
"subject": self.subject,
|
"subject": self.subject,
|
||||||
"content": self.content,
|
"content": self.content,
|
||||||
"note_type": self.note_type,
|
"note_type": self.note_type,
|
||||||
"confidence": self.confidence,
|
"confidence": self.confidence_dict,
|
||||||
"tags": self.tags,
|
"tags": self.tags,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -573,7 +571,6 @@ class AgentObservation(SourceItem):
|
|||||||
Text, nullable=False
|
Text, nullable=False
|
||||||
) # belief, preference, pattern, contradiction, behavior
|
) # belief, preference, pattern, contradiction, behavior
|
||||||
subject = Column(Text, nullable=False) # What/who the observation is about
|
subject = Column(Text, nullable=False) # What/who the observation is about
|
||||||
confidence = Column(Numeric(3, 2), nullable=False, default=0.8) # 0.0-1.0
|
|
||||||
evidence = Column(JSONB) # Supporting context, quotes, etc.
|
evidence = Column(JSONB) # Supporting context, quotes, etc.
|
||||||
agent_model = Column(Text, nullable=False) # Which AI model made this observation
|
agent_model = Column(Text, nullable=False) # Which AI model made this observation
|
||||||
|
|
||||||
@ -599,7 +596,6 @@ class AgentObservation(SourceItem):
|
|||||||
Index("agent_obs_session_idx", "session_id"),
|
Index("agent_obs_session_idx", "session_id"),
|
||||||
Index("agent_obs_type_idx", "observation_type"),
|
Index("agent_obs_type_idx", "observation_type"),
|
||||||
Index("agent_obs_subject_idx", "subject"),
|
Index("agent_obs_subject_idx", "subject"),
|
||||||
Index("agent_obs_confidence_idx", "confidence"),
|
|
||||||
Index("agent_obs_model_idx", "agent_model"),
|
Index("agent_obs_model_idx", "agent_model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -613,7 +609,7 @@ class AgentObservation(SourceItem):
|
|||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
"observation_type": self.observation_type,
|
"observation_type": self.observation_type,
|
||||||
"subject": self.subject,
|
"subject": self.subject,
|
||||||
"confidence": float(cast(Any, self.confidence)),
|
"confidence": self.confidence_dict,
|
||||||
"evidence": self.evidence,
|
"evidence": self.evidence,
|
||||||
"agent_model": self.agent_model,
|
"agent_model": self.agent_model,
|
||||||
}
|
}
|
||||||
@ -633,7 +629,7 @@ class AgentObservation(SourceItem):
|
|||||||
"content": self.content,
|
"content": self.content,
|
||||||
"observation_type": self.observation_type,
|
"observation_type": self.observation_type,
|
||||||
"evidence": self.evidence,
|
"evidence": self.evidence,
|
||||||
"confidence": self.confidence,
|
"confidence": self.confidence_dict,
|
||||||
"agent_model": self.agent_model,
|
"agent_model": self.agent_model,
|
||||||
"tags": self.tags,
|
"tags": self.tags,
|
||||||
}
|
}
|
||||||
@ -664,7 +660,6 @@ class AgentObservation(SourceItem):
|
|||||||
temporal_text = observation.generate_temporal_text(
|
temporal_text = observation.generate_temporal_text(
|
||||||
cast(str, self.subject),
|
cast(str, self.subject),
|
||||||
cast(str, self.content),
|
cast(str, self.content),
|
||||||
cast(float, self.confidence),
|
|
||||||
cast(datetime, self.inserted_at),
|
cast(datetime, self.inserted_at),
|
||||||
)
|
)
|
||||||
if temporal_text:
|
if temporal_text:
|
||||||
|
@ -31,7 +31,6 @@ def generate_semantic_text(
|
|||||||
def generate_temporal_text(
|
def generate_temporal_text(
|
||||||
subject: str,
|
subject: str,
|
||||||
content: str,
|
content: str,
|
||||||
confidence: float,
|
|
||||||
created_at: datetime,
|
created_at: datetime,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate text with temporal context for time-pattern search."""
|
"""Generate text with temporal context for time-pattern search."""
|
||||||
@ -55,8 +54,6 @@ def generate_temporal_text(
|
|||||||
f"Subject: {subject}",
|
f"Subject: {subject}",
|
||||||
f"Observation: {content}",
|
f"Observation: {content}",
|
||||||
]
|
]
|
||||||
if confidence is not None:
|
|
||||||
parts.append(f"Confidence: {confidence}")
|
|
||||||
|
|
||||||
return " | ".join(parts)
|
return " | ".join(parts)
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from memory.common import settings
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Note
|
from memory.common.db.models import Note
|
||||||
from memory.common.celery_app import app, SYNC_NOTE, SYNC_NOTES
|
from memory.common.celery_app import app, SYNC_NOTE, SYNC_NOTES
|
||||||
@ -23,7 +22,7 @@ def sync_note(
|
|||||||
content: str,
|
content: str,
|
||||||
filename: str | None = None,
|
filename: str | None = None,
|
||||||
note_type: str | None = None,
|
note_type: str | None = None,
|
||||||
confidence: float | None = None,
|
confidences: dict[str, float] = {},
|
||||||
tags: list[str] = [],
|
tags: list[str] = [],
|
||||||
):
|
):
|
||||||
logger.info(f"Syncing note {subject}")
|
logger.info(f"Syncing note {subject}")
|
||||||
@ -32,6 +31,8 @@ def sync_note(
|
|||||||
|
|
||||||
if filename:
|
if filename:
|
||||||
filename = filename.lstrip("/")
|
filename = filename.lstrip("/")
|
||||||
|
if not filename.endswith(".md"):
|
||||||
|
filename = f"{filename}.md"
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
existing_note = check_content_exists(session, Note, sha256=sha256)
|
existing_note = check_content_exists(session, Note, sha256=sha256)
|
||||||
@ -45,7 +46,6 @@ def sync_note(
|
|||||||
note = Note(
|
note = Note(
|
||||||
modality="note",
|
modality="note",
|
||||||
mime_type="text/markdown",
|
mime_type="text/markdown",
|
||||||
confidence=confidence or 0.5,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Editing preexisting note")
|
logger.info("Editing preexisting note")
|
||||||
@ -58,11 +58,10 @@ def sync_note(
|
|||||||
|
|
||||||
if note_type:
|
if note_type:
|
||||||
note.note_type = note_type # type: ignore
|
note.note_type = note_type # type: ignore
|
||||||
if confidence:
|
|
||||||
note.confidence = confidence # type: ignore
|
|
||||||
if tags:
|
if tags:
|
||||||
note.tags = tags # type: ignore
|
note.tags = tags # type: ignore
|
||||||
|
|
||||||
|
note.update_confidences(confidences)
|
||||||
note.save_to_file()
|
note.save_to_file()
|
||||||
return process_content_item(note, session)
|
return process_content_item(note, session)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def sync_observation(
|
|||||||
content: str,
|
content: str,
|
||||||
observation_type: str,
|
observation_type: str,
|
||||||
evidence: dict | None = None,
|
evidence: dict | None = None,
|
||||||
confidence: float = 0.5,
|
confidences: dict[str, float] = {},
|
||||||
session_id: str | None = None,
|
session_id: str | None = None,
|
||||||
agent_model: str = "unknown",
|
agent_model: str = "unknown",
|
||||||
tags: list[str] = [],
|
tags: list[str] = [],
|
||||||
@ -33,7 +33,6 @@ def sync_observation(
|
|||||||
content=content,
|
content=content,
|
||||||
subject=subject,
|
subject=subject,
|
||||||
observation_type=observation_type,
|
observation_type=observation_type,
|
||||||
confidence=confidence,
|
|
||||||
evidence=evidence,
|
evidence=evidence,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -43,6 +42,7 @@ def sync_observation(
|
|||||||
sha256=sha256,
|
sha256=sha256,
|
||||||
modality="observation",
|
modality="observation",
|
||||||
)
|
)
|
||||||
|
observation.update_confidences(confidences)
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
existing_observation = check_content_exists(
|
existing_observation = check_content_exists(
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -583,7 +583,6 @@ def test_agent_observation_embeddings(mock_voyage_client):
|
|||||||
tags=["bla"],
|
tags=["bla"],
|
||||||
observation_type="belief",
|
observation_type="belief",
|
||||||
subject="humans",
|
subject="humans",
|
||||||
confidence=0.8,
|
|
||||||
evidence={
|
evidence={
|
||||||
"quote": "All humans are mortal.",
|
"quote": "All humans are mortal.",
|
||||||
"source": "https://en.wikipedia.org/wiki/Human",
|
"source": "https://en.wikipedia.org/wiki/Human",
|
||||||
@ -591,6 +590,7 @@ def test_agent_observation_embeddings(mock_voyage_client):
|
|||||||
agent_model="gpt-4o",
|
agent_model="gpt-4o",
|
||||||
inserted_at=datetime(2025, 1, 1, 12, 0, 0),
|
inserted_at=datetime(2025, 1, 1, 12, 0, 0),
|
||||||
)
|
)
|
||||||
|
item.update_confidences({"observation_accuracy": 0.8})
|
||||||
metadata = item.as_payload()
|
metadata = item.as_payload()
|
||||||
metadata["tags"] = {"bla"}
|
metadata["tags"] = {"bla"}
|
||||||
expected = [
|
expected = [
|
||||||
@ -600,7 +600,7 @@ def test_agent_observation_embeddings(mock_voyage_client):
|
|||||||
metadata | {"embedding_type": "semantic"},
|
metadata | {"embedding_type": "semantic"},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die. | Confidence: 0.8",
|
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die.",
|
||||||
[],
|
[],
|
||||||
metadata | {"embedding_type": "temporal"},
|
metadata | {"embedding_type": "temporal"},
|
||||||
),
|
),
|
||||||
@ -625,7 +625,7 @@ def test_agent_observation_embeddings(mock_voyage_client):
|
|||||||
assert mock_voyage_client.embed.call_args == call(
|
assert mock_voyage_client.embed.call_args == call(
|
||||||
[
|
[
|
||||||
"Subject: humans | Type: belief | Observation: The user thinks that all men must die. | Quote: All humans are mortal.",
|
"Subject: humans | Type: belief | Observation: The user thinks that all men must die. | Quote: All humans are mortal.",
|
||||||
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die. | Confidence: 0.8",
|
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die.",
|
||||||
"The user thinks that all men must die.",
|
"The user thinks that all men must die.",
|
||||||
"All humans are mortal.",
|
"All humans are mortal.",
|
||||||
],
|
],
|
||||||
|
@ -499,7 +499,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
|||||||
"size": None,
|
"size": None,
|
||||||
"observation_type": "preference",
|
"observation_type": "preference",
|
||||||
"subject": "programming preferences",
|
"subject": "programming preferences",
|
||||||
"confidence": 0.9,
|
"confidence": {"observation_accuracy": 0.9},
|
||||||
"evidence": {
|
"evidence": {
|
||||||
"quote": "I really like Python",
|
"quote": "I really like Python",
|
||||||
"context": "discussion about languages",
|
"context": "discussion about languages",
|
||||||
@ -513,7 +513,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
|||||||
"size": None,
|
"size": None,
|
||||||
"observation_type": "preference",
|
"observation_type": "preference",
|
||||||
"subject": "programming preferences",
|
"subject": "programming preferences",
|
||||||
"confidence": 0.9,
|
"confidence": {"observation_accuracy": 0.9},
|
||||||
"evidence": {
|
"evidence": {
|
||||||
"quote": "I really like Python",
|
"quote": "I really like Python",
|
||||||
"context": "discussion about languages",
|
"context": "discussion about languages",
|
||||||
@ -531,7 +531,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
|||||||
"size": None,
|
"size": None,
|
||||||
"observation_type": "preference",
|
"observation_type": "preference",
|
||||||
"subject": "programming preferences",
|
"subject": "programming preferences",
|
||||||
"confidence": 0.9,
|
"confidence": {"observation_accuracy": 0.9},
|
||||||
"evidence": {
|
"evidence": {
|
||||||
"quote": "I really like Python",
|
"quote": "I really like Python",
|
||||||
"context": "discussion about languages",
|
"context": "discussion about languages",
|
||||||
@ -546,7 +546,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
|||||||
"size": None,
|
"size": None,
|
||||||
"observation_type": "preference",
|
"observation_type": "preference",
|
||||||
"subject": "programming preferences",
|
"subject": "programming preferences",
|
||||||
"confidence": 0.9,
|
"confidence": {"observation_accuracy": 0.9},
|
||||||
"evidence": {
|
"evidence": {
|
||||||
"quote": "I really like Python",
|
"quote": "I really like Python",
|
||||||
"context": "discussion about languages",
|
"context": "discussion about languages",
|
||||||
@ -565,7 +565,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
|||||||
"size": None,
|
"size": None,
|
||||||
"observation_type": "preference",
|
"observation_type": "preference",
|
||||||
"subject": "programming preferences",
|
"subject": "programming preferences",
|
||||||
"confidence": 0.9,
|
"confidence": {"observation_accuracy": 0.9},
|
||||||
"evidence": {
|
"evidence": {
|
||||||
"quote": "I really like Python",
|
"quote": "I really like Python",
|
||||||
"context": "discussion about languages",
|
"context": "discussion about languages",
|
||||||
@ -580,7 +580,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
|||||||
"size": None,
|
"size": None,
|
||||||
"observation_type": "preference",
|
"observation_type": "preference",
|
||||||
"subject": "programming preferences",
|
"subject": "programming preferences",
|
||||||
"confidence": 0.9,
|
"confidence": {"observation_accuracy": 0.9},
|
||||||
"evidence": {
|
"evidence": {
|
||||||
"quote": "I really like Python",
|
"quote": "I really like Python",
|
||||||
"context": "discussion about languages",
|
"context": "discussion about languages",
|
||||||
@ -603,7 +603,6 @@ def test_agent_observation_data_chunks(
|
|||||||
content="User prefers Python over JavaScript",
|
content="User prefers Python over JavaScript",
|
||||||
subject="programming preferences",
|
subject="programming preferences",
|
||||||
observation_type="preference",
|
observation_type="preference",
|
||||||
confidence=0.9,
|
|
||||||
evidence={
|
evidence={
|
||||||
"quote": "I really like Python",
|
"quote": "I really like Python",
|
||||||
"context": "discussion about languages",
|
"context": "discussion about languages",
|
||||||
@ -612,6 +611,7 @@ def test_agent_observation_data_chunks(
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
tags=observation_tags,
|
tags=observation_tags,
|
||||||
)
|
)
|
||||||
|
observation.update_confidences({"observation_accuracy": 0.9})
|
||||||
# Set inserted_at using object.__setattr__ to bypass SQLAlchemy restrictions
|
# Set inserted_at using object.__setattr__ to bypass SQLAlchemy restrictions
|
||||||
object.__setattr__(observation, "inserted_at", datetime(2023, 1, 1, 12, 0, 0))
|
object.__setattr__(observation, "inserted_at", datetime(2023, 1, 1, 12, 0, 0))
|
||||||
|
|
||||||
@ -634,7 +634,7 @@ def test_agent_observation_data_chunks(
|
|||||||
assert cast(str, semantic_chunk.collection_name) == "semantic"
|
assert cast(str, semantic_chunk.collection_name) == "semantic"
|
||||||
|
|
||||||
temporal_chunk = result[1]
|
temporal_chunk = result[1]
|
||||||
expected_temporal_text = "Time: 12:00 on Sunday (afternoon) | Subject: programming preferences | Observation: User prefers Python over JavaScript | Confidence: 0.9"
|
expected_temporal_text = "Time: 12:00 on Sunday (afternoon) | Subject: programming preferences | Observation: User prefers Python over JavaScript"
|
||||||
assert temporal_chunk.data == [expected_temporal_text]
|
assert temporal_chunk.data == [expected_temporal_text]
|
||||||
|
|
||||||
# Add session_id to expected metadata and remove tags if empty
|
# Add session_id to expected metadata and remove tags if empty
|
||||||
@ -654,11 +654,11 @@ def test_agent_observation_data_chunks_with_none_values():
|
|||||||
content="Content",
|
content="Content",
|
||||||
subject="subject",
|
subject="subject",
|
||||||
observation_type="belief",
|
observation_type="belief",
|
||||||
confidence=0.7,
|
|
||||||
evidence=None,
|
evidence=None,
|
||||||
agent_model="gpt-4",
|
agent_model="gpt-4",
|
||||||
session_id=None,
|
session_id=None,
|
||||||
)
|
)
|
||||||
|
observation.update_confidences({"observation_accuracy": 0.7})
|
||||||
object.__setattr__(observation, "inserted_at", datetime(2023, 2, 15, 9, 30, 0))
|
object.__setattr__(observation, "inserted_at", datetime(2023, 2, 15, 9, 30, 0))
|
||||||
|
|
||||||
result = observation.data_chunks()
|
result = observation.data_chunks()
|
||||||
@ -671,7 +671,7 @@ def test_agent_observation_data_chunks_with_none_values():
|
|||||||
assert [i.data for i in result] == [
|
assert [i.data for i in result] == [
|
||||||
["Subject: subject | Type: belief | Observation: Content"],
|
["Subject: subject | Type: belief | Observation: Content"],
|
||||||
[
|
[
|
||||||
"Time: 09:30 on Wednesday (morning) | Subject: subject | Observation: Content | Confidence: 0.7"
|
"Time: 09:30 on Wednesday (morning) | Subject: subject | Observation: Content"
|
||||||
],
|
],
|
||||||
["Content"],
|
["Content"],
|
||||||
]
|
]
|
||||||
@ -684,11 +684,11 @@ def test_agent_observation_data_chunks_merge_metadata_behavior():
|
|||||||
content="test",
|
content="test",
|
||||||
subject="test",
|
subject="test",
|
||||||
observation_type="test",
|
observation_type="test",
|
||||||
confidence=0.8,
|
|
||||||
evidence={},
|
evidence={},
|
||||||
agent_model="test",
|
agent_model="test",
|
||||||
tags=["base_tag"], # Set base tags so they appear in both chunks
|
tags=["base_tag"], # Set base tags so they appear in both chunks
|
||||||
)
|
)
|
||||||
|
observation.update_confidences({"observation_accuracy": 0.9})
|
||||||
object.__setattr__(observation, "inserted_at", datetime.now())
|
object.__setattr__(observation, "inserted_at", datetime.now())
|
||||||
|
|
||||||
# Test that metadata merging preserves original values and adds new ones
|
# Test that metadata merging preserves original values and adds new ones
|
||||||
@ -723,11 +723,10 @@ def test_note_data_chunks(subject, content, expected):
|
|||||||
content=content,
|
content=content,
|
||||||
subject=subject,
|
subject=subject,
|
||||||
note_type="quicky",
|
note_type="quicky",
|
||||||
confidence=0.9,
|
|
||||||
size=123,
|
size=123,
|
||||||
tags=["bla"],
|
tags=["bla"],
|
||||||
)
|
)
|
||||||
|
note.update_confidences({"observation_accuracy": 0.9})
|
||||||
chunks = note.data_chunks()
|
chunks = note.data_chunks()
|
||||||
assert [chunk.content for chunk in chunks] == expected
|
assert [chunk.content for chunk in chunks] == expected
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
@ -736,7 +735,7 @@ def test_note_data_chunks(subject, content, expected):
|
|||||||
if cast(str, chunk.content) == "test summary":
|
if cast(str, chunk.content) == "test summary":
|
||||||
tags |= {"tag1", "tag2"}
|
tags |= {"tag1", "tag2"}
|
||||||
assert chunk.item_metadata == {
|
assert chunk.item_metadata == {
|
||||||
"confidence": 0.9,
|
"confidence": {"observation_accuracy": 0.9},
|
||||||
"note_type": "quicky",
|
"note_type": "quicky",
|
||||||
"size": 123,
|
"size": 123,
|
||||||
"source_id": None,
|
"source_id": None,
|
||||||
|
@ -123,11 +123,10 @@ def test_generate_temporal_text_time_periods(hour: int, expected_period: str):
|
|||||||
result = generate_temporal_text(
|
result = generate_temporal_text(
|
||||||
subject="test_subject",
|
subject="test_subject",
|
||||||
content="test_content",
|
content="test_content",
|
||||||
confidence=0.8,
|
|
||||||
created_at=test_date,
|
created_at=test_date,
|
||||||
)
|
)
|
||||||
time_str = test_date.strftime("%H:%M")
|
time_str = test_date.strftime("%H:%M")
|
||||||
expected = f"Time: {time_str} on Monday ({expected_period}) | Subject: test_subject | Observation: test_content | Confidence: 0.8"
|
expected = f"Time: {time_str} on Monday ({expected_period}) | Subject: test_subject | Observation: test_content"
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
@ -146,7 +145,7 @@ def test_generate_temporal_text_time_periods(hour: int, expected_period: str):
|
|||||||
def test_generate_temporal_text_days_of_week(weekday: int, day_name: str):
|
def test_generate_temporal_text_days_of_week(weekday: int, day_name: str):
|
||||||
test_date = datetime(2024, 1, 15 + weekday, 10, 30)
|
test_date = datetime(2024, 1, 15 + weekday, 10, 30)
|
||||||
result = generate_temporal_text(
|
result = generate_temporal_text(
|
||||||
subject="subject", content="content", confidence=0.5, created_at=test_date
|
subject="subject", content="content", created_at=test_date
|
||||||
)
|
)
|
||||||
assert f"on {day_name}" in result
|
assert f"on {day_name}" in result
|
||||||
|
|
||||||
@ -157,10 +156,8 @@ def test_generate_temporal_text_confidence_values(confidence: float):
|
|||||||
result = generate_temporal_text(
|
result = generate_temporal_text(
|
||||||
subject="subject",
|
subject="subject",
|
||||||
content="content",
|
content="content",
|
||||||
confidence=confidence,
|
|
||||||
created_at=test_date,
|
created_at=test_date,
|
||||||
)
|
)
|
||||||
assert f"Confidence: {confidence}" in result
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -180,7 +177,7 @@ def test_generate_temporal_text_boundary_cases(
|
|||||||
test_date: datetime, expected_period: str
|
test_date: datetime, expected_period: str
|
||||||
):
|
):
|
||||||
result = generate_temporal_text(
|
result = generate_temporal_text(
|
||||||
subject="subject", content="content", confidence=0.8, created_at=test_date
|
subject="subject", content="content", created_at=test_date
|
||||||
)
|
)
|
||||||
assert f"({expected_period})" in result
|
assert f"({expected_period})" in result
|
||||||
|
|
||||||
@ -190,22 +187,16 @@ def test_generate_temporal_text_complete_format():
|
|||||||
result = generate_temporal_text(
|
result = generate_temporal_text(
|
||||||
subject="Important observation",
|
subject="Important observation",
|
||||||
content="User showed strong preference for X",
|
content="User showed strong preference for X",
|
||||||
confidence=0.95,
|
|
||||||
created_at=test_date,
|
created_at=test_date,
|
||||||
)
|
)
|
||||||
expected = "Time: 14:45 on Friday (afternoon) | Subject: Important observation | Observation: User showed strong preference for X | Confidence: 0.95"
|
expected = "Time: 14:45 on Friday (afternoon) | Subject: Important observation | Observation: User showed strong preference for X"
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
def test_generate_temporal_text_empty_strings():
|
def test_generate_temporal_text_empty_strings():
|
||||||
test_date = datetime(2024, 1, 15, 10, 30)
|
test_date = datetime(2024, 1, 15, 10, 30)
|
||||||
result = generate_temporal_text(
|
result = generate_temporal_text(subject="", content="", created_at=test_date)
|
||||||
subject="", content="", confidence=0.0, created_at=test_date
|
assert result == "Time: 10:30 on Monday (morning) | Subject: | Observation:"
|
||||||
)
|
|
||||||
assert (
|
|
||||||
result
|
|
||||||
== "Time: 10:30 on Monday (morning) | Subject: | Observation: | Confidence: 0.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_temporal_text_special_characters():
|
def test_generate_temporal_text_special_characters():
|
||||||
@ -213,8 +204,7 @@ def test_generate_temporal_text_special_characters():
|
|||||||
result = generate_temporal_text(
|
result = generate_temporal_text(
|
||||||
subject="Subject with | pipe",
|
subject="Subject with | pipe",
|
||||||
content="Content with | pipe and @#$ symbols",
|
content="Content with | pipe and @#$ symbols",
|
||||||
confidence=0.75,
|
|
||||||
created_at=test_date,
|
created_at=test_date,
|
||||||
)
|
)
|
||||||
expected = "Time: 15:20 on Monday (afternoon) | Subject: Subject with | pipe | Observation: Content with | pipe and @#$ symbols | Confidence: 0.75"
|
expected = "Time: 15:20 on Monday (afternoon) | Subject: Subject with | pipe | Observation: Content with | pipe and @#$ symbols"
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
@ -16,7 +16,7 @@ def mock_note_data():
|
|||||||
"content": "This is test note content with enough text to be processed and embedded.",
|
"content": "This is test note content with enough text to be processed and embedded.",
|
||||||
"filename": "test_note.md",
|
"filename": "test_note.md",
|
||||||
"note_type": "observation",
|
"note_type": "observation",
|
||||||
"confidence": 0.8,
|
"confidences": {"observation_accuracy": 0.8},
|
||||||
"tags": ["test", "note"],
|
"tags": ["test", "note"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ def test_sync_note_success(mock_note_data, db_session, qdrant):
|
|||||||
assert note.modality == "note"
|
assert note.modality == "note"
|
||||||
assert note.mime_type == "text/markdown"
|
assert note.mime_type == "text/markdown"
|
||||||
assert note.note_type == "observation"
|
assert note.note_type == "observation"
|
||||||
assert float(note.confidence) == 0.8 # Convert Decimal to float for comparison
|
assert note.confidence_dict == {"observation_accuracy": 0.8}
|
||||||
assert note.filename is not None
|
assert note.filename is not None
|
||||||
assert note.tags == ["test", "note"]
|
assert note.tags == ["test", "note"]
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ def test_sync_note_minimal_data(mock_minimal_note, db_session, qdrant):
|
|||||||
assert note.subject == "Minimal Note"
|
assert note.subject == "Minimal Note"
|
||||||
assert note.content == "Minimal content"
|
assert note.content == "Minimal content"
|
||||||
assert note.note_type is None
|
assert note.note_type is None
|
||||||
assert float(note.confidence) == 0.5 # Default value, convert Decimal to float
|
assert note.confidence_dict == {}
|
||||||
assert note.tags == [] # Default empty list
|
assert note.tags == [] # Default empty list
|
||||||
assert note.filename is not None and "Minimal Note.md" in note.filename
|
assert note.filename is not None and "Minimal Note.md" in note.filename
|
||||||
|
|
||||||
@ -205,6 +205,9 @@ def test_sync_note_edit(mock_note_data, db_session):
|
|||||||
embed_status="RAW",
|
embed_status="RAW",
|
||||||
filename="test_note.md",
|
filename="test_note.md",
|
||||||
)
|
)
|
||||||
|
existing_note.update_confidences(
|
||||||
|
{"observation_accuracy": 0.2, "predictive_value": 0.3}
|
||||||
|
)
|
||||||
db_session.add(existing_note)
|
db_session.add(existing_note)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
@ -225,6 +228,10 @@ def test_sync_note_edit(mock_note_data, db_session):
|
|||||||
assert len(db_session.query(Note).all()) == 1
|
assert len(db_session.query(Note).all()) == 1
|
||||||
db_session.refresh(existing_note)
|
db_session.refresh(existing_note)
|
||||||
assert existing_note.content == "bla bla bla" # type: ignore
|
assert existing_note.content == "bla bla bla" # type: ignore
|
||||||
|
assert existing_note.confidence_dict == {
|
||||||
|
"observation_accuracy": 0.8,
|
||||||
|
"predictive_value": 0.3,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -242,14 +249,14 @@ def test_sync_note_parameters(note_type, confidence, tags, db_session, qdrant):
|
|||||||
subject=f"Test Note {note_type}",
|
subject=f"Test Note {note_type}",
|
||||||
content="Test content for parameter testing",
|
content="Test content for parameter testing",
|
||||||
note_type=note_type,
|
note_type=note_type,
|
||||||
confidence=confidence,
|
confidences={"observation_accuracy": confidence},
|
||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
note = db_session.query(Note).filter_by(subject=f"Test Note {note_type}").first()
|
note = db_session.query(Note).filter_by(subject=f"Test Note {note_type}").first()
|
||||||
assert note is not None
|
assert note is not None
|
||||||
assert note.note_type == note_type
|
assert note.note_type == note_type
|
||||||
assert float(note.confidence) == confidence # Convert Decimal to float
|
assert note.confidence_dict == {"observation_accuracy": confidence}
|
||||||
assert note.tags == tags
|
assert note.tags == tags
|
||||||
|
|
||||||
# Updated to match actual return format
|
# Updated to match actual return format
|
||||||
|
Loading…
x
Reference in New Issue
Block a user