mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 05:14:43 +02:00
muliple dimemnsions for confidence values
This commit is contained in:
parent
a40e0b50fa
commit
e5da3714de
79
db/migrations/versions/20250603_115642_add_confidences.py
Normal file
79
db/migrations/versions/20250603_115642_add_confidences.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""Add confidences
|
||||
|
||||
Revision ID: 152f8b4b52e8
|
||||
Revises: ba301527a2eb
|
||||
Create Date: 2025-06-03 11:56:42.302327
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "152f8b4b52e8"
|
||||
down_revision: Union[str, None] = "ba301527a2eb"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"confidence_score",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_item_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("confidence_type", sa.Text(), nullable=False),
|
||||
sa.Column("score", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||
sa.CheckConstraint("score >= 0.0 AND score <= 1.0", name="score_range_check"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_item_id"], ["source_item.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"source_item_id", "confidence_type", name="unique_source_confidence_type"
|
||||
),
|
||||
)
|
||||
op.create_index("confidence_score_idx", "confidence_score", ["score"], unique=False)
|
||||
op.create_index(
|
||||
"confidence_source_idx", "confidence_score", ["source_item_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"confidence_type_idx", "confidence_score", ["confidence_type"], unique=False
|
||||
)
|
||||
op.drop_index("agent_obs_confidence_idx", table_name="agent_observation")
|
||||
op.drop_column("agent_observation", "confidence")
|
||||
op.drop_index("note_confidence_idx", table_name="notes")
|
||||
op.drop_column("notes", "confidence")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"notes",
|
||||
sa.Column(
|
||||
"confidence",
|
||||
sa.NUMERIC(precision=3, scale=2),
|
||||
server_default=sa.text("0.5"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.create_index("note_confidence_idx", "notes", ["confidence"], unique=False)
|
||||
op.add_column(
|
||||
"agent_observation",
|
||||
sa.Column(
|
||||
"confidence",
|
||||
sa.NUMERIC(precision=3, scale=2),
|
||||
server_default=sa.text("0.5"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_confidence_idx", "agent_observation", ["confidence"], unique=False
|
||||
)
|
||||
op.drop_index("confidence_type_idx", table_name="confidence_score")
|
||||
op.drop_index("confidence_source_idx", table_name="confidence_score")
|
||||
op.drop_index("confidence_score_idx", table_name="confidence_score")
|
||||
op.drop_table("confidence_score")
|
@ -654,11 +654,9 @@ async def create_note(
|
||||
"""
|
||||
if filename:
|
||||
path = pathlib.Path(filename)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(settings.NOTES_STORAGE_DIR)
|
||||
else:
|
||||
if not path.is_absolute():
|
||||
path = pathlib.Path(settings.NOTES_STORAGE_DIR) / path
|
||||
filename = path.as_posix()
|
||||
filename = path.relative_to(settings.NOTES_STORAGE_DIR).as_posix()
|
||||
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
|
@ -2,6 +2,7 @@ from memory.common.db.models.base import Base
|
||||
from memory.common.db.models.source_item import (
|
||||
Chunk,
|
||||
SourceItem,
|
||||
ConfidenceScore,
|
||||
clean_filename,
|
||||
)
|
||||
from memory.common.db.models.source_items import (
|
||||
@ -37,6 +38,7 @@ __all__ = [
|
||||
"Chunk",
|
||||
"clean_filename",
|
||||
"SourceItem",
|
||||
"ConfidenceScore",
|
||||
"MailMessage",
|
||||
"EmailAttachment",
|
||||
"AgentObservation",
|
||||
|
@ -22,9 +22,11 @@ from sqlalchemy import (
|
||||
Text,
|
||||
event,
|
||||
func,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import BYTEA
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
from sqlalchemy.types import Numeric
|
||||
|
||||
from memory.common import settings
|
||||
import memory.common.extract as extract
|
||||
@ -191,6 +193,41 @@ class Chunk(Base):
|
||||
return items
|
||||
|
||||
|
||||
class ConfidenceScore(Base):
|
||||
"""
|
||||
Stores structured confidence scores for source items.
|
||||
Provides detailed confidence dimensions instead of a single score.
|
||||
"""
|
||||
|
||||
__tablename__ = "confidence_score"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_item_id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
confidence_type = Column(
|
||||
Text, nullable=False
|
||||
) # e.g., "observation_accuracy", "interpretation", "predictive_value"
|
||||
score = Column(Numeric(3, 2), nullable=False) # 0.0-1.0
|
||||
|
||||
# Relationship back to source item
|
||||
source_item = relationship("SourceItem", back_populates="confidence_scores")
|
||||
|
||||
__table_args__ = (
|
||||
Index("confidence_source_idx", "source_item_id"),
|
||||
Index("confidence_type_idx", "confidence_type"),
|
||||
Index("confidence_score_idx", "score"),
|
||||
CheckConstraint("score >= 0.0 AND score <= 1.0", name="score_range_check"),
|
||||
# Ensure each source_item can only have one score per confidence_type
|
||||
UniqueConstraint(
|
||||
"source_item_id", "confidence_type", name="unique_source_confidence_type"
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<ConfidenceScore(type={self.confidence_type}, score={self.score})>"
|
||||
|
||||
|
||||
class SourceItem(Base):
|
||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||
|
||||
@ -216,6 +253,11 @@ class SourceItem(Base):
|
||||
embed_status = Column(Text, nullable=False, server_default="RAW")
|
||||
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
||||
|
||||
# Confidence scores relationship
|
||||
confidence_scores = relationship(
|
||||
"ConfidenceScore", back_populates="source_item", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Discriminator column for SQLAlchemy inheritance
|
||||
type = Column(String(50))
|
||||
|
||||
@ -235,6 +277,35 @@ class SourceItem(Base):
|
||||
"""Get vector IDs from associated chunks."""
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
@property
|
||||
def confidence_dict(self) -> dict[str, float]:
|
||||
return {
|
||||
score.confidence_type: float(score.score)
|
||||
for score in self.confidence_scores
|
||||
}
|
||||
|
||||
def update_confidences(self, confidence_updates: dict[str, float]) -> None:
|
||||
"""
|
||||
Update confidence scores for this source item.
|
||||
Merges new scores with existing ones, overwriting duplicates.
|
||||
|
||||
Args:
|
||||
confidence_updates: Dict mapping confidence_type to score (0.0-1.0)
|
||||
"""
|
||||
if not confidence_updates:
|
||||
return
|
||||
|
||||
current = {s.confidence_type: s for s in self.confidence_scores}
|
||||
|
||||
for confidence_type, score in confidence_updates.items():
|
||||
if current_score := current.get(confidence_type):
|
||||
current_score.score = score
|
||||
else:
|
||||
new_score = ConfidenceScore(
|
||||
source_item_id=self.id, confidence_type=confidence_type, score=score
|
||||
)
|
||||
self.confidence_scores.append(new_score)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
content = cast(str | None, self.content)
|
||||
if content:
|
||||
|
@ -505,7 +505,6 @@ class Note(SourceItem):
|
||||
)
|
||||
note_type = Column(Text, nullable=True)
|
||||
subject = Column(Text, nullable=True)
|
||||
confidence = Column(Numeric(3, 2), nullable=False, default=0.5) # 0.0-1.0
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "note",
|
||||
@ -514,7 +513,6 @@ class Note(SourceItem):
|
||||
__table_args__ = (
|
||||
Index("note_type_idx", "note_type"),
|
||||
Index("note_subject_idx", "subject"),
|
||||
Index("note_confidence_idx", "confidence"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
@ -522,7 +520,7 @@ class Note(SourceItem):
|
||||
**super().as_payload(),
|
||||
"note_type": self.note_type,
|
||||
"subject": self.subject,
|
||||
"confidence": float(cast(Any, self.confidence)),
|
||||
"confidence": self.confidence_dict,
|
||||
}
|
||||
|
||||
@property
|
||||
@ -531,7 +529,7 @@ class Note(SourceItem):
|
||||
"subject": self.subject,
|
||||
"content": self.content,
|
||||
"note_type": self.note_type,
|
||||
"confidence": self.confidence,
|
||||
"confidence": self.confidence_dict,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
@ -573,7 +571,6 @@ class AgentObservation(SourceItem):
|
||||
Text, nullable=False
|
||||
) # belief, preference, pattern, contradiction, behavior
|
||||
subject = Column(Text, nullable=False) # What/who the observation is about
|
||||
confidence = Column(Numeric(3, 2), nullable=False, default=0.8) # 0.0-1.0
|
||||
evidence = Column(JSONB) # Supporting context, quotes, etc.
|
||||
agent_model = Column(Text, nullable=False) # Which AI model made this observation
|
||||
|
||||
@ -599,7 +596,6 @@ class AgentObservation(SourceItem):
|
||||
Index("agent_obs_session_idx", "session_id"),
|
||||
Index("agent_obs_type_idx", "observation_type"),
|
||||
Index("agent_obs_subject_idx", "subject"),
|
||||
Index("agent_obs_confidence_idx", "confidence"),
|
||||
Index("agent_obs_model_idx", "agent_model"),
|
||||
)
|
||||
|
||||
@ -613,7 +609,7 @@ class AgentObservation(SourceItem):
|
||||
**super().as_payload(),
|
||||
"observation_type": self.observation_type,
|
||||
"subject": self.subject,
|
||||
"confidence": float(cast(Any, self.confidence)),
|
||||
"confidence": self.confidence_dict,
|
||||
"evidence": self.evidence,
|
||||
"agent_model": self.agent_model,
|
||||
}
|
||||
@ -633,7 +629,7 @@ class AgentObservation(SourceItem):
|
||||
"content": self.content,
|
||||
"observation_type": self.observation_type,
|
||||
"evidence": self.evidence,
|
||||
"confidence": self.confidence,
|
||||
"confidence": self.confidence_dict,
|
||||
"agent_model": self.agent_model,
|
||||
"tags": self.tags,
|
||||
}
|
||||
@ -664,7 +660,6 @@ class AgentObservation(SourceItem):
|
||||
temporal_text = observation.generate_temporal_text(
|
||||
cast(str, self.subject),
|
||||
cast(str, self.content),
|
||||
cast(float, self.confidence),
|
||||
cast(datetime, self.inserted_at),
|
||||
)
|
||||
if temporal_text:
|
||||
|
@ -31,7 +31,6 @@ def generate_semantic_text(
|
||||
def generate_temporal_text(
|
||||
subject: str,
|
||||
content: str,
|
||||
confidence: float,
|
||||
created_at: datetime,
|
||||
) -> str:
|
||||
"""Generate text with temporal context for time-pattern search."""
|
||||
@ -55,8 +54,6 @@ def generate_temporal_text(
|
||||
f"Subject: {subject}",
|
||||
f"Observation: {content}",
|
||||
]
|
||||
if confidence is not None:
|
||||
parts.append(f"Confidence: {confidence}")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Note
|
||||
from memory.common.celery_app import app, SYNC_NOTE, SYNC_NOTES
|
||||
@ -23,7 +22,7 @@ def sync_note(
|
||||
content: str,
|
||||
filename: str | None = None,
|
||||
note_type: str | None = None,
|
||||
confidence: float | None = None,
|
||||
confidences: dict[str, float] = {},
|
||||
tags: list[str] = [],
|
||||
):
|
||||
logger.info(f"Syncing note {subject}")
|
||||
@ -32,6 +31,8 @@ def sync_note(
|
||||
|
||||
if filename:
|
||||
filename = filename.lstrip("/")
|
||||
if not filename.endswith(".md"):
|
||||
filename = f"{filename}.md"
|
||||
|
||||
with make_session() as session:
|
||||
existing_note = check_content_exists(session, Note, sha256=sha256)
|
||||
@ -45,7 +46,6 @@ def sync_note(
|
||||
note = Note(
|
||||
modality="note",
|
||||
mime_type="text/markdown",
|
||||
confidence=confidence or 0.5,
|
||||
)
|
||||
else:
|
||||
logger.info("Editing preexisting note")
|
||||
@ -58,11 +58,10 @@ def sync_note(
|
||||
|
||||
if note_type:
|
||||
note.note_type = note_type # type: ignore
|
||||
if confidence:
|
||||
note.confidence = confidence # type: ignore
|
||||
if tags:
|
||||
note.tags = tags # type: ignore
|
||||
|
||||
note.update_confidences(confidences)
|
||||
note.save_to_file()
|
||||
return process_content_item(note, session)
|
||||
|
||||
|
@ -21,7 +21,7 @@ def sync_observation(
|
||||
content: str,
|
||||
observation_type: str,
|
||||
evidence: dict | None = None,
|
||||
confidence: float = 0.5,
|
||||
confidences: dict[str, float] = {},
|
||||
session_id: str | None = None,
|
||||
agent_model: str = "unknown",
|
||||
tags: list[str] = [],
|
||||
@ -33,7 +33,6 @@ def sync_observation(
|
||||
content=content,
|
||||
subject=subject,
|
||||
observation_type=observation_type,
|
||||
confidence=confidence,
|
||||
evidence=evidence,
|
||||
tags=tags or [],
|
||||
session_id=session_id,
|
||||
@ -43,6 +42,7 @@ def sync_observation(
|
||||
sha256=sha256,
|
||||
modality="observation",
|
||||
)
|
||||
observation.update_confidences(confidences)
|
||||
|
||||
with make_session() as session:
|
||||
existing_observation = check_content_exists(
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -583,7 +583,6 @@ def test_agent_observation_embeddings(mock_voyage_client):
|
||||
tags=["bla"],
|
||||
observation_type="belief",
|
||||
subject="humans",
|
||||
confidence=0.8,
|
||||
evidence={
|
||||
"quote": "All humans are mortal.",
|
||||
"source": "https://en.wikipedia.org/wiki/Human",
|
||||
@ -591,6 +590,7 @@ def test_agent_observation_embeddings(mock_voyage_client):
|
||||
agent_model="gpt-4o",
|
||||
inserted_at=datetime(2025, 1, 1, 12, 0, 0),
|
||||
)
|
||||
item.update_confidences({"observation_accuracy": 0.8})
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
expected = [
|
||||
@ -600,7 +600,7 @@ def test_agent_observation_embeddings(mock_voyage_client):
|
||||
metadata | {"embedding_type": "semantic"},
|
||||
),
|
||||
(
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die. | Confidence: 0.8",
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die.",
|
||||
[],
|
||||
metadata | {"embedding_type": "temporal"},
|
||||
),
|
||||
@ -625,7 +625,7 @@ def test_agent_observation_embeddings(mock_voyage_client):
|
||||
assert mock_voyage_client.embed.call_args == call(
|
||||
[
|
||||
"Subject: humans | Type: belief | Observation: The user thinks that all men must die. | Quote: All humans are mortal.",
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die. | Confidence: 0.8",
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die.",
|
||||
"The user thinks that all men must die.",
|
||||
"All humans are mortal.",
|
||||
],
|
||||
|
@ -499,7 +499,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
||||
"size": None,
|
||||
"observation_type": "preference",
|
||||
"subject": "programming preferences",
|
||||
"confidence": 0.9,
|
||||
"confidence": {"observation_accuracy": 0.9},
|
||||
"evidence": {
|
||||
"quote": "I really like Python",
|
||||
"context": "discussion about languages",
|
||||
@ -513,7 +513,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
||||
"size": None,
|
||||
"observation_type": "preference",
|
||||
"subject": "programming preferences",
|
||||
"confidence": 0.9,
|
||||
"confidence": {"observation_accuracy": 0.9},
|
||||
"evidence": {
|
||||
"quote": "I really like Python",
|
||||
"context": "discussion about languages",
|
||||
@ -531,7 +531,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
||||
"size": None,
|
||||
"observation_type": "preference",
|
||||
"subject": "programming preferences",
|
||||
"confidence": 0.9,
|
||||
"confidence": {"observation_accuracy": 0.9},
|
||||
"evidence": {
|
||||
"quote": "I really like Python",
|
||||
"context": "discussion about languages",
|
||||
@ -546,7 +546,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
||||
"size": None,
|
||||
"observation_type": "preference",
|
||||
"subject": "programming preferences",
|
||||
"confidence": 0.9,
|
||||
"confidence": {"observation_accuracy": 0.9},
|
||||
"evidence": {
|
||||
"quote": "I really like Python",
|
||||
"context": "discussion about languages",
|
||||
@ -565,7 +565,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
||||
"size": None,
|
||||
"observation_type": "preference",
|
||||
"subject": "programming preferences",
|
||||
"confidence": 0.9,
|
||||
"confidence": {"observation_accuracy": 0.9},
|
||||
"evidence": {
|
||||
"quote": "I really like Python",
|
||||
"context": "discussion about languages",
|
||||
@ -580,7 +580,7 @@ def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chun
|
||||
"size": None,
|
||||
"observation_type": "preference",
|
||||
"subject": "programming preferences",
|
||||
"confidence": 0.9,
|
||||
"confidence": {"observation_accuracy": 0.9},
|
||||
"evidence": {
|
||||
"quote": "I really like Python",
|
||||
"context": "discussion about languages",
|
||||
@ -603,7 +603,6 @@ def test_agent_observation_data_chunks(
|
||||
content="User prefers Python over JavaScript",
|
||||
subject="programming preferences",
|
||||
observation_type="preference",
|
||||
confidence=0.9,
|
||||
evidence={
|
||||
"quote": "I really like Python",
|
||||
"context": "discussion about languages",
|
||||
@ -612,6 +611,7 @@ def test_agent_observation_data_chunks(
|
||||
session_id=session_id,
|
||||
tags=observation_tags,
|
||||
)
|
||||
observation.update_confidences({"observation_accuracy": 0.9})
|
||||
# Set inserted_at using object.__setattr__ to bypass SQLAlchemy restrictions
|
||||
object.__setattr__(observation, "inserted_at", datetime(2023, 1, 1, 12, 0, 0))
|
||||
|
||||
@ -634,7 +634,7 @@ def test_agent_observation_data_chunks(
|
||||
assert cast(str, semantic_chunk.collection_name) == "semantic"
|
||||
|
||||
temporal_chunk = result[1]
|
||||
expected_temporal_text = "Time: 12:00 on Sunday (afternoon) | Subject: programming preferences | Observation: User prefers Python over JavaScript | Confidence: 0.9"
|
||||
expected_temporal_text = "Time: 12:00 on Sunday (afternoon) | Subject: programming preferences | Observation: User prefers Python over JavaScript"
|
||||
assert temporal_chunk.data == [expected_temporal_text]
|
||||
|
||||
# Add session_id to expected metadata and remove tags if empty
|
||||
@ -654,11 +654,11 @@ def test_agent_observation_data_chunks_with_none_values():
|
||||
content="Content",
|
||||
subject="subject",
|
||||
observation_type="belief",
|
||||
confidence=0.7,
|
||||
evidence=None,
|
||||
agent_model="gpt-4",
|
||||
session_id=None,
|
||||
)
|
||||
observation.update_confidences({"observation_accuracy": 0.7})
|
||||
object.__setattr__(observation, "inserted_at", datetime(2023, 2, 15, 9, 30, 0))
|
||||
|
||||
result = observation.data_chunks()
|
||||
@ -671,7 +671,7 @@ def test_agent_observation_data_chunks_with_none_values():
|
||||
assert [i.data for i in result] == [
|
||||
["Subject: subject | Type: belief | Observation: Content"],
|
||||
[
|
||||
"Time: 09:30 on Wednesday (morning) | Subject: subject | Observation: Content | Confidence: 0.7"
|
||||
"Time: 09:30 on Wednesday (morning) | Subject: subject | Observation: Content"
|
||||
],
|
||||
["Content"],
|
||||
]
|
||||
@ -684,11 +684,11 @@ def test_agent_observation_data_chunks_merge_metadata_behavior():
|
||||
content="test",
|
||||
subject="test",
|
||||
observation_type="test",
|
||||
confidence=0.8,
|
||||
evidence={},
|
||||
agent_model="test",
|
||||
tags=["base_tag"], # Set base tags so they appear in both chunks
|
||||
)
|
||||
observation.update_confidences({"observation_accuracy": 0.9})
|
||||
object.__setattr__(observation, "inserted_at", datetime.now())
|
||||
|
||||
# Test that metadata merging preserves original values and adds new ones
|
||||
@ -723,11 +723,10 @@ def test_note_data_chunks(subject, content, expected):
|
||||
content=content,
|
||||
subject=subject,
|
||||
note_type="quicky",
|
||||
confidence=0.9,
|
||||
size=123,
|
||||
tags=["bla"],
|
||||
)
|
||||
|
||||
note.update_confidences({"observation_accuracy": 0.9})
|
||||
chunks = note.data_chunks()
|
||||
assert [chunk.content for chunk in chunks] == expected
|
||||
for chunk in chunks:
|
||||
@ -736,7 +735,7 @@ def test_note_data_chunks(subject, content, expected):
|
||||
if cast(str, chunk.content) == "test summary":
|
||||
tags |= {"tag1", "tag2"}
|
||||
assert chunk.item_metadata == {
|
||||
"confidence": 0.9,
|
||||
"confidence": {"observation_accuracy": 0.9},
|
||||
"note_type": "quicky",
|
||||
"size": 123,
|
||||
"source_id": None,
|
||||
|
@ -123,11 +123,10 @@ def test_generate_temporal_text_time_periods(hour: int, expected_period: str):
|
||||
result = generate_temporal_text(
|
||||
subject="test_subject",
|
||||
content="test_content",
|
||||
confidence=0.8,
|
||||
created_at=test_date,
|
||||
)
|
||||
time_str = test_date.strftime("%H:%M")
|
||||
expected = f"Time: {time_str} on Monday ({expected_period}) | Subject: test_subject | Observation: test_content | Confidence: 0.8"
|
||||
expected = f"Time: {time_str} on Monday ({expected_period}) | Subject: test_subject | Observation: test_content"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@ -146,7 +145,7 @@ def test_generate_temporal_text_time_periods(hour: int, expected_period: str):
|
||||
def test_generate_temporal_text_days_of_week(weekday: int, day_name: str):
|
||||
test_date = datetime(2024, 1, 15 + weekday, 10, 30)
|
||||
result = generate_temporal_text(
|
||||
subject="subject", content="content", confidence=0.5, created_at=test_date
|
||||
subject="subject", content="content", created_at=test_date
|
||||
)
|
||||
assert f"on {day_name}" in result
|
||||
|
||||
@ -157,10 +156,8 @@ def test_generate_temporal_text_confidence_values(confidence: float):
|
||||
result = generate_temporal_text(
|
||||
subject="subject",
|
||||
content="content",
|
||||
confidence=confidence,
|
||||
created_at=test_date,
|
||||
)
|
||||
assert f"Confidence: {confidence}" in result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -180,7 +177,7 @@ def test_generate_temporal_text_boundary_cases(
|
||||
test_date: datetime, expected_period: str
|
||||
):
|
||||
result = generate_temporal_text(
|
||||
subject="subject", content="content", confidence=0.8, created_at=test_date
|
||||
subject="subject", content="content", created_at=test_date
|
||||
)
|
||||
assert f"({expected_period})" in result
|
||||
|
||||
@ -190,22 +187,16 @@ def test_generate_temporal_text_complete_format():
|
||||
result = generate_temporal_text(
|
||||
subject="Important observation",
|
||||
content="User showed strong preference for X",
|
||||
confidence=0.95,
|
||||
created_at=test_date,
|
||||
)
|
||||
expected = "Time: 14:45 on Friday (afternoon) | Subject: Important observation | Observation: User showed strong preference for X | Confidence: 0.95"
|
||||
expected = "Time: 14:45 on Friday (afternoon) | Subject: Important observation | Observation: User showed strong preference for X"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_generate_temporal_text_empty_strings():
|
||||
test_date = datetime(2024, 1, 15, 10, 30)
|
||||
result = generate_temporal_text(
|
||||
subject="", content="", confidence=0.0, created_at=test_date
|
||||
)
|
||||
assert (
|
||||
result
|
||||
== "Time: 10:30 on Monday (morning) | Subject: | Observation: | Confidence: 0.0"
|
||||
)
|
||||
result = generate_temporal_text(subject="", content="", created_at=test_date)
|
||||
assert result == "Time: 10:30 on Monday (morning) | Subject: | Observation:"
|
||||
|
||||
|
||||
def test_generate_temporal_text_special_characters():
|
||||
@ -213,8 +204,7 @@ def test_generate_temporal_text_special_characters():
|
||||
result = generate_temporal_text(
|
||||
subject="Subject with | pipe",
|
||||
content="Content with | pipe and @#$ symbols",
|
||||
confidence=0.75,
|
||||
created_at=test_date,
|
||||
)
|
||||
expected = "Time: 15:20 on Monday (afternoon) | Subject: Subject with | pipe | Observation: Content with | pipe and @#$ symbols | Confidence: 0.75"
|
||||
expected = "Time: 15:20 on Monday (afternoon) | Subject: Subject with | pipe | Observation: Content with | pipe and @#$ symbols"
|
||||
assert result == expected
|
||||
|
@ -16,7 +16,7 @@ def mock_note_data():
|
||||
"content": "This is test note content with enough text to be processed and embedded.",
|
||||
"filename": "test_note.md",
|
||||
"note_type": "observation",
|
||||
"confidence": 0.8,
|
||||
"confidences": {"observation_accuracy": 0.8},
|
||||
"tags": ["test", "note"],
|
||||
}
|
||||
|
||||
@ -90,7 +90,7 @@ def test_sync_note_success(mock_note_data, db_session, qdrant):
|
||||
assert note.modality == "note"
|
||||
assert note.mime_type == "text/markdown"
|
||||
assert note.note_type == "observation"
|
||||
assert float(note.confidence) == 0.8 # Convert Decimal to float for comparison
|
||||
assert note.confidence_dict == {"observation_accuracy": 0.8}
|
||||
assert note.filename is not None
|
||||
assert note.tags == ["test", "note"]
|
||||
|
||||
@ -114,7 +114,7 @@ def test_sync_note_minimal_data(mock_minimal_note, db_session, qdrant):
|
||||
assert note.subject == "Minimal Note"
|
||||
assert note.content == "Minimal content"
|
||||
assert note.note_type is None
|
||||
assert float(note.confidence) == 0.5 # Default value, convert Decimal to float
|
||||
assert note.confidence_dict == {}
|
||||
assert note.tags == [] # Default empty list
|
||||
assert note.filename is not None and "Minimal Note.md" in note.filename
|
||||
|
||||
@ -205,6 +205,9 @@ def test_sync_note_edit(mock_note_data, db_session):
|
||||
embed_status="RAW",
|
||||
filename="test_note.md",
|
||||
)
|
||||
existing_note.update_confidences(
|
||||
{"observation_accuracy": 0.2, "predictive_value": 0.3}
|
||||
)
|
||||
db_session.add(existing_note)
|
||||
db_session.commit()
|
||||
|
||||
@ -225,6 +228,10 @@ def test_sync_note_edit(mock_note_data, db_session):
|
||||
assert len(db_session.query(Note).all()) == 1
|
||||
db_session.refresh(existing_note)
|
||||
assert existing_note.content == "bla bla bla" # type: ignore
|
||||
assert existing_note.confidence_dict == {
|
||||
"observation_accuracy": 0.8,
|
||||
"predictive_value": 0.3,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -242,14 +249,14 @@ def test_sync_note_parameters(note_type, confidence, tags, db_session, qdrant):
|
||||
subject=f"Test Note {note_type}",
|
||||
content="Test content for parameter testing",
|
||||
note_type=note_type,
|
||||
confidence=confidence,
|
||||
confidences={"observation_accuracy": confidence},
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
note = db_session.query(Note).filter_by(subject=f"Test Note {note_type}").first()
|
||||
assert note is not None
|
||||
assert note.note_type == note_type
|
||||
assert float(note.confidence) == confidence # Convert Decimal to float
|
||||
assert note.confidence_dict == {"observation_accuracy": confidence}
|
||||
assert note.tags == tags
|
||||
|
||||
# Updated to match actual return format
|
||||
|
Loading…
x
Reference in New Issue
Block a user