mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 21:34:42 +02:00
search my new confidence scores
This commit is contained in:
parent
e5da3714de
commit
79567b19f2
@ -291,7 +291,7 @@ async def observe(
|
|||||||
content: str,
|
content: str,
|
||||||
subject: str,
|
subject: str,
|
||||||
observation_type: str = "general",
|
observation_type: str = "general",
|
||||||
confidence: float = 0.8,
|
confidences: dict[str, float] = {},
|
||||||
evidence: dict | None = None,
|
evidence: dict | None = None,
|
||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
session_id: str | None = None,
|
session_id: str | None = None,
|
||||||
@ -345,12 +345,16 @@ async def observe(
|
|||||||
- "contradiction": An inconsistency with previous observations
|
- "contradiction": An inconsistency with previous observations
|
||||||
- "general": Doesn't fit other categories
|
- "general": Doesn't fit other categories
|
||||||
|
|
||||||
confidence: How certain you are (0.0-1.0):
|
confidences: How certain you are (0.0-1.0) in a given aspect of the observation:
|
||||||
- 1.0: User explicitly stated this
|
- 1.0: User explicitly stated this
|
||||||
- 0.9: Strongly implied or demonstrated repeatedly
|
- 0.9: Strongly implied or demonstrated repeatedly
|
||||||
- 0.8: Inferred with high confidence (default)
|
- 0.8: Inferred with high confidence (default)
|
||||||
- 0.7: Probable but with some uncertainty
|
- 0.7: Probable but with some uncertainty
|
||||||
- 0.6 or below: Speculative, use sparingly
|
- 0.6 or below: Speculative, use sparingly
|
||||||
|
Provided as a mapping of <aspect>: <confidence>
|
||||||
|
Examples:
|
||||||
|
- {"observation_accuracy": 0.95}
|
||||||
|
- {"observation_accuracy": 0.8, "interpretation": 0.5}
|
||||||
|
|
||||||
evidence: Supporting context as a dict. Include relevant details:
|
evidence: Supporting context as a dict. Include relevant details:
|
||||||
- "quote": Exact words from the user
|
- "quote": Exact words from the user
|
||||||
@ -392,7 +396,7 @@ async def observe(
|
|||||||
"of all evil'. They prioritize code purity over performance.",
|
"of all evil'. They prioritize code purity over performance.",
|
||||||
subject="programming_philosophy",
|
subject="programming_philosophy",
|
||||||
observation_type="belief",
|
observation_type="belief",
|
||||||
confidence=0.95,
|
confidences={"observation_accuracy": 0.95},
|
||||||
evidence={
|
evidence={
|
||||||
"quote": "State is the root of all evil in programming",
|
"quote": "State is the root of all evil in programming",
|
||||||
"context": "Discussing why they chose Haskell for their project"
|
"context": "Discussing why they chose Haskell for their project"
|
||||||
@ -408,7 +412,7 @@ async def observe(
|
|||||||
"typically between 11pm and 3am, claiming better focus",
|
"typically between 11pm and 3am, claiming better focus",
|
||||||
subject="work_schedule",
|
subject="work_schedule",
|
||||||
observation_type="behavior",
|
observation_type="behavior",
|
||||||
confidence=0.85,
|
confidences={"observation_accuracy": 0.85},
|
||||||
evidence={
|
evidence={
|
||||||
"context": "Mentioned across multiple conversations over 2 weeks"
|
"context": "Mentioned across multiple conversations over 2 weeks"
|
||||||
},
|
},
|
||||||
@ -422,7 +426,7 @@ async def observe(
|
|||||||
"previously argued strongly for monoliths in similar contexts",
|
"previously argued strongly for monoliths in similar contexts",
|
||||||
subject="architecture_preferences",
|
subject="architecture_preferences",
|
||||||
observation_type="contradiction",
|
observation_type="contradiction",
|
||||||
confidence=0.9,
|
confidences={"observation_accuracy": 0.9},
|
||||||
evidence={
|
evidence={
|
||||||
"quote": "Microservices are definitely the way to go",
|
"quote": "Microservices are definitely the way to go",
|
||||||
"context": "Designing a new system similar to one from 3 months ago"
|
"context": "Designing a new system similar to one from 3 months ago"
|
||||||
@ -438,7 +442,7 @@ async def observe(
|
|||||||
"subject": subject,
|
"subject": subject,
|
||||||
"content": content,
|
"content": content,
|
||||||
"observation_type": observation_type,
|
"observation_type": observation_type,
|
||||||
"confidence": confidence,
|
"confidences": confidences,
|
||||||
"evidence": evidence,
|
"evidence": evidence,
|
||||||
"tags": tags,
|
"tags": tags,
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
@ -457,7 +461,7 @@ async def search_observations(
|
|||||||
subject: str = "",
|
subject: str = "",
|
||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
observation_types: list[str] | None = None,
|
observation_types: list[str] | None = None,
|
||||||
min_confidence: float = 0.5,
|
min_confidences: dict[str, float] = {},
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
@ -478,12 +482,6 @@ async def search_observations(
|
|||||||
- To build context about the user's expertise or interests
|
- To build context about the user's expertise or interests
|
||||||
- Whenever personalization would improve your response
|
- Whenever personalization would improve your response
|
||||||
|
|
||||||
How it works:
|
|
||||||
Uses hybrid search combining semantic similarity with keyword matching.
|
|
||||||
Searches across multiple embedding spaces (semantic meaning and temporal
|
|
||||||
context) to find relevant observations from different angles. This approach
|
|
||||||
ensures you find both conceptually related and specifically mentioned items.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Natural language description of what you're looking for. The search
|
query: Natural language description of what you're looking for. The search
|
||||||
matches both meaning and specific terms in observation content.
|
matches both meaning and specific terms in observation content.
|
||||||
@ -513,10 +511,9 @@ async def search_observations(
|
|||||||
- "general": Other observations
|
- "general": Other observations
|
||||||
Leave as None to search all types.
|
Leave as None to search all types.
|
||||||
|
|
||||||
min_confidence: Only return observations with confidence >= this value.
|
min_confidences: Only return observations with confidence >= this value, e.g.
|
||||||
- Use 0.8+ for high-confidence facts
|
{"observation_accuracy": 0.8, "interpretation": 0.5} facts where you were confident
|
||||||
- Use 0.5-0.7 to include inferred observations
|
that you observed the fact but are not necessarily sure about the interpretation.
|
||||||
- Default 0.5 includes most observations
|
|
||||||
Range: 0.0 to 1.0
|
Range: 0.0 to 1.0
|
||||||
|
|
||||||
limit: Maximum results to return (1-100). Default 10. Increase when you
|
limit: Maximum results to return (1-100). Default 10. Increase when you
|
||||||
@ -579,7 +576,6 @@ async def search_observations(
|
|||||||
temporal = observation.generate_temporal_text(
|
temporal = observation.generate_temporal_text(
|
||||||
subject=subject or "",
|
subject=subject or "",
|
||||||
content=query,
|
content=query,
|
||||||
confidence=0,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
results = await search(
|
results = await search(
|
||||||
@ -593,7 +589,7 @@ async def search_observations(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
filters=SearchFilters(
|
filters=SearchFilters(
|
||||||
subject=subject,
|
subject=subject,
|
||||||
confidence=min_confidence,
|
min_confidences=min_confidences,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
observation_types=observation_types,
|
observation_types=observation_types,
|
||||||
source_ids=filter_observation_source_ids(tags=tags),
|
source_ids=filter_observation_source_ids(tags=tags),
|
||||||
|
@ -188,6 +188,8 @@ class AgentObservationAdmin(ModelView, model=AgentObservation):
|
|||||||
"inserted_at",
|
"inserted_at",
|
||||||
]
|
]
|
||||||
column_searchable_list = ["subject", "observation_type"]
|
column_searchable_list = ["subject", "observation_type"]
|
||||||
|
column_default_sort = [("inserted_at", True)]
|
||||||
|
column_sortable_list = ["inserted_at"]
|
||||||
|
|
||||||
|
|
||||||
class NoteAdmin(ModelView, model=Note):
|
class NoteAdmin(ModelView, model=Note):
|
||||||
@ -201,6 +203,8 @@ class NoteAdmin(ModelView, model=Note):
|
|||||||
"inserted_at",
|
"inserted_at",
|
||||||
]
|
]
|
||||||
column_searchable_list = ["subject", "content"]
|
column_searchable_list = ["subject", "content"]
|
||||||
|
column_default_sort = [("inserted_at", True)]
|
||||||
|
column_sortable_list = ["inserted_at"]
|
||||||
|
|
||||||
|
|
||||||
def setup_admin(admin: Admin):
|
def setup_admin(admin: Admin):
|
||||||
|
@ -10,7 +10,7 @@ import Stemmer
|
|||||||
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
|
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
|
||||||
|
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Chunk
|
from memory.common.db.models import Chunk, ConfidenceScore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -25,9 +25,24 @@ async def search_bm25(
|
|||||||
items_query = db.query(Chunk.id, Chunk.content).filter(
|
items_query = db.query(Chunk.id, Chunk.content).filter(
|
||||||
Chunk.collection_name.in_(modalities)
|
Chunk.collection_name.in_(modalities)
|
||||||
)
|
)
|
||||||
|
|
||||||
if source_ids := filters.get("source_ids"):
|
if source_ids := filters.get("source_ids"):
|
||||||
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
||||||
|
|
||||||
|
# Add confidence filtering if specified
|
||||||
|
if min_confidences := filters.get("min_confidences"):
|
||||||
|
for confidence_type, min_score in min_confidences.items():
|
||||||
|
items_query = items_query.join(
|
||||||
|
ConfidenceScore,
|
||||||
|
(ConfidenceScore.source_item_id == Chunk.source_id)
|
||||||
|
& (ConfidenceScore.confidence_type == confidence_type)
|
||||||
|
& (ConfidenceScore.score >= min_score),
|
||||||
|
)
|
||||||
|
|
||||||
items = items_query.all()
|
items = items_query.all()
|
||||||
|
if not items:
|
||||||
|
return []
|
||||||
|
|
||||||
item_ids = {
|
item_ids = {
|
||||||
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
|
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
|
||||||
for item in items
|
for item in items
|
||||||
|
@ -111,15 +111,26 @@ async def search_embeddings(
|
|||||||
- filters: Filters to apply to the search results
|
- filters: Filters to apply to the search results
|
||||||
- multimodal: Whether to search in multimodal collections
|
- multimodal: Whether to search in multimodal collections
|
||||||
"""
|
"""
|
||||||
query_filters = {}
|
query_filters = {"must": []}
|
||||||
if confidence := filters.get("confidence"):
|
|
||||||
query_filters["must"] += [{"key": "confidence", "range": {"gte": confidence}}]
|
# Handle structured confidence filtering
|
||||||
if tags := filters.get("tags"):
|
if min_confidences := filters.get("min_confidences"):
|
||||||
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
|
confidence_filters = [
|
||||||
if observation_types := filters.get("observation_types"):
|
{
|
||||||
query_filters["must"] += [
|
"key": f"confidence.{confidence_type}",
|
||||||
{"key": "observation_type", "match": {"any": observation_types}}
|
"range": {"gte": min_confidence_score},
|
||||||
|
}
|
||||||
|
for confidence_type, min_confidence_score in min_confidences.items()
|
||||||
]
|
]
|
||||||
|
query_filters["must"].extend(confidence_filters)
|
||||||
|
|
||||||
|
if tags := filters.get("tags"):
|
||||||
|
query_filters["must"].append({"key": "tags", "match": {"any": tags}})
|
||||||
|
|
||||||
|
if observation_types := filters.get("observation_types"):
|
||||||
|
query_filters["must"].append(
|
||||||
|
{"key": "observation_type", "match": {"any": observation_types}}
|
||||||
|
)
|
||||||
|
|
||||||
client = qdrant.get_qdrant_client()
|
client = qdrant.get_qdrant_client()
|
||||||
results = query_chunks(
|
results = query_chunks(
|
||||||
@ -129,7 +140,7 @@ async def search_embeddings(
|
|||||||
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
filters=query_filters,
|
filters=query_filters if query_filters["must"] else None,
|
||||||
)
|
)
|
||||||
search_results = {k: results.get(k, []) for k in modalities}
|
search_results = {k: results.get(k, []) for k in modalities}
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ class SearchResult(BaseModel):
|
|||||||
|
|
||||||
class SearchFilters(TypedDict):
|
class SearchFilters(TypedDict):
|
||||||
subject: NotRequired[str | None]
|
subject: NotRequired[str | None]
|
||||||
confidence: NotRequired[float]
|
min_confidences: NotRequired[dict[str, float]]
|
||||||
tags: NotRequired[list[str] | None]
|
tags: NotRequired[list[str] | None]
|
||||||
observation_types: NotRequired[list[str] | None]
|
observation_types: NotRequired[list[str] | None]
|
||||||
source_ids: NotRequired[list[int] | None]
|
source_ids: NotRequired[list[int] | None]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user