Fix search + proper integration tests

This commit is contained in:
Daniel O'Connell 2025-06-02 02:53:32 +02:00
parent b10a1fb130
commit 29b8ce6860
28 changed files with 2646 additions and 645 deletions

View File

@ -3,4 +3,5 @@ uvicorn==0.29.0
python-jose==3.3.0 python-jose==3.3.0
python-multipart==0.0.9 python-multipart==0.0.9
sqladmin sqladmin
mcp==1.9.2 mcp==1.9.2
bm25s[full]==0.2.13

View File

@ -6,5 +6,5 @@ dotenv==0.9.9
voyageai==0.3.2 voyageai==0.3.2
qdrant-client==1.9.0 qdrant-client==1.9.0
anthropic==0.18.1 anthropic==0.18.1
# Pin the httpx version, as newer versions break the anthropic client
bm25s[full]==0.2.13 httpx==0.27.0

View File

@ -38,6 +38,7 @@ from memory.workers.tasks.maintenance import (
CLEAN_COLLECTION, CLEAN_COLLECTION,
REINGEST_CHUNK, REINGEST_CHUNK,
REINGEST_EMPTY_SOURCE_ITEMS, REINGEST_EMPTY_SOURCE_ITEMS,
REINGEST_ALL_EMPTY_SOURCE_ITEMS,
REINGEST_ITEM, REINGEST_ITEM,
REINGEST_MISSING_CHUNKS, REINGEST_MISSING_CHUNKS,
UPDATE_METADATA_FOR_ITEM, UPDATE_METADATA_FOR_ITEM,
@ -67,6 +68,7 @@ TASK_MAPPINGS = {
"reingest_chunk": REINGEST_CHUNK, "reingest_chunk": REINGEST_CHUNK,
"reingest_item": REINGEST_ITEM, "reingest_item": REINGEST_ITEM,
"reingest_empty_source_items": REINGEST_EMPTY_SOURCE_ITEMS, "reingest_empty_source_items": REINGEST_EMPTY_SOURCE_ITEMS,
"reingest_all_empty_source_items": REINGEST_ALL_EMPTY_SOURCE_ITEMS,
"update_metadata_for_item": UPDATE_METADATA_FOR_ITEM, "update_metadata_for_item": UPDATE_METADATA_FOR_ITEM,
"update_metadata_for_source_items": UPDATE_METADATA_FOR_SOURCE_ITEMS, "update_metadata_for_source_items": UPDATE_METADATA_FOR_SOURCE_ITEMS,
}, },
@ -316,6 +318,13 @@ def maintenance_reingest_empty_source_items(ctx, item_type):
execute_task(ctx, "maintenance", "reingest_empty_source_items", item_type=item_type) execute_task(ctx, "maintenance", "reingest_empty_source_items", item_type=item_type)
@maintenance.command("reingest-all-empty-source-items")
@click.pass_context
def maintenance_reingest_all_empty_source_items(ctx):
"""Reingest all empty source items."""
execute_task(ctx, "maintenance", "reingest_all_empty_source_items")
@maintenance.command("reingest-chunk") @maintenance.command("reingest-chunk")
@click.option("--chunk-id", required=True, help="Chunk ID to reingest") @click.option("--chunk-id", required=True, help="Chunk ID to reingest")
@click.pass_context @click.pass_context

View File

@ -15,14 +15,60 @@ from memory.common.db.connection import make_session
from memory.common import extract from memory.common import extract
from memory.common.db.models import AgentObservation from memory.common.db.models import AgentObservation
from memory.api.search import search, SearchFilters from memory.api.search.search import search, SearchFilters
from memory.common.formatters import observation from memory.common.formatters import observation
from memory.workers.tasks.content_processing import process_content_item from memory.workers.tasks.content_processing import process_content_item
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Create MCP server instance # Create MCP server instance
mcp = FastMCP("memory", stateless=True) mcp = FastMCP("memory", stateless_http=True)
def filter_observation_source_ids(
tags: list[str] | None = None, observation_types: list[str] | None = None
):
if not tags and not observation_types:
return None
with make_session() as session:
items_query = session.query(AgentObservation.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if observation_types:
items_query = items_query.filter(
AgentObservation.observation_type.in_(observation_types)
)
source_ids = [item.id for item in items_query.all()]
return source_ids
def filter_source_ids(
modalities: set[str],
tags: list[str] | None = None,
):
if not tags:
return None
with make_session() as session:
items_query = session.query(SourceItem.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if modalities:
items_query = items_query.filter(SourceItem.modality.in_(modalities))
source_ids = [item.id for item in items_query.all()]
return source_ids
@mcp.tool() @mcp.tool()
@ -48,20 +94,6 @@ async def get_all_tags() -> list[str]:
- Projects: "project:website-redesign" - Projects: "project:website-redesign"
- Contexts: "context:work", "context:late-night" - Contexts: "context:work", "context:late-night"
- Domains: "domain:finance" - Domains: "domain:finance"
Example:
# Get all tags to ensure consistency
tags = await get_all_tags()
# Returns: ["ai-safety", "context:work", "functional-programming",
# "machine-learning", "project:thesis", ...]
# Use to check if a topic has been discussed before
if "quantum-computing" in tags:
# Search for related observations
observations = await search_observations(
query="quantum computing",
tags=["quantum-computing"]
)
""" """
with make_session() as session: with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct() tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
@ -93,26 +125,6 @@ async def get_all_subjects() -> list[str]:
- "ai_beliefs", "ai_safety_beliefs" - "ai_beliefs", "ai_safety_beliefs"
- "learning_preferences" - "learning_preferences"
- "communication_style" - "communication_style"
Example:
# Get all subjects to ensure consistency
subjects = await get_all_subjects()
# Returns: ["ai_safety_beliefs", "architecture_preferences",
# "programming_philosophy", "work_schedule", ...]
# Use to check what we know about the user
if "programming_style" in subjects:
# Get all programming-related observations
observations = await search_observations(
query="programming",
subject="programming_style"
)
Best practices:
- Always check existing subjects before creating new ones
- Use snake_case for consistency
- Be specific but not too granular
- Group related observations under same subject
""" """
with make_session() as session: with make_session() as session:
return sorted( return sorted(
@ -146,20 +158,6 @@ async def get_all_observation_types() -> list[str]:
Returns: Returns:
List of observation types that have actually been used in the system. List of observation types that have actually been used in the system.
Example:
# Check what types of observations exist
types = await get_all_observation_types()
# Returns: ["behavior", "belief", "contradiction", "preference"]
# Use to analyze observation distribution
for obs_type in types:
observations = await search_observations(
query="",
observation_types=[obs_type],
limit=100
)
print(f"{obs_type}: {len(observations)} observations")
""" """
with make_session() as session: with make_session() as session:
return sorted( return sorted(
@ -173,7 +171,11 @@ async def get_all_observation_types() -> list[str]:
@mcp.tool() @mcp.tool()
async def search_knowledge_base( async def search_knowledge_base(
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10 query: str,
previews: bool = False,
modalities: set[str] = set(),
tags: list[str] = [],
limit: int = 10,
) -> list[dict]: ) -> list[dict]:
""" """
Search through the user's stored knowledge and content. Search through the user's stored knowledge and content.
@ -283,14 +285,22 @@ async def search_knowledge_base(
""" """
logger.info(f"MCP search for: {query}") logger.info(f"MCP search for: {query}")
if not modalities:
modalities = set(ALL_COLLECTIONS.keys())
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
upload_data = extract.extract_text(query) upload_data = extract.extract_text(query)
results = await search( results = await search(
upload_data, upload_data,
previews=previews, previews=previews,
modalities=modalities, modalities=modalities,
limit=limit, limit=limit,
min_text_score=0.3, min_text_score=0.4,
min_multimodal_score=0.3, min_multimodal_score=0.25,
filters=SearchFilters(
tags=tags,
source_ids=filter_source_ids(tags=tags, modalities=modalities),
),
) )
# Convert SearchResult objects to dictionaries for MCP # Convert SearchResult objects to dictionaries for MCP
@ -456,12 +466,13 @@ async def observe(
mime_type="text/plain", mime_type="text/plain",
sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(), sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(),
inserted_at=datetime.now(timezone.utc), inserted_at=datetime.now(timezone.utc),
modality="observation",
) )
try: try:
with make_session() as session: with make_session() as session:
process_content_item(observation, session) process_content_item(observation, session)
if not observation.id: if not cast(int | None, observation.id):
raise ValueError("Observation not created") raise ValueError("Observation not created")
logger.info( logger.info(
@ -600,24 +611,6 @@ async def search_observations(
- Higher confidence observations are more reliable - Higher confidence observations are more reliable
- Recent observations may override older ones on same topic - Recent observations may override older ones on same topic
""" """
source_ids = None
if tags or observation_types:
with make_session() as session:
items_query = session.query(AgentObservation.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if observation_types:
items_query = items_query.filter(
AgentObservation.observation_type.in_(observation_types)
)
source_ids = [item.id for item in items_query.all()]
if not source_ids:
return []
semantic_text = observation.generate_semantic_text( semantic_text = observation.generate_semantic_text(
subject=subject or "", subject=subject or "",
observation_type="".join(observation_types or []), observation_type="".join(observation_types or []),
@ -637,18 +630,24 @@ async def search_observations(
extract.DataChunk(data=[temporal]), extract.DataChunk(data=[temporal]),
], ],
previews=True, previews=True,
modalities=["semantic", "temporal"], modalities={"semantic", "temporal"},
limit=limit, limit=limit,
min_text_score=0.8,
filters=SearchFilters( filters=SearchFilters(
subject=subject, subject=subject,
confidence=min_confidence, confidence=min_confidence,
tags=tags, tags=tags,
observation_types=observation_types, observation_types=observation_types,
source_ids=source_ids, source_ids=filter_observation_source_ids(tags=tags),
), ),
timeout=2,
) )
return [ return [
cast(dict, cast(dict, result.model_dump()).get("content")) for result in results {
"content": r.content,
"tags": r.tags,
"created_at": r.created_at.isoformat() if r.created_at else None,
"metadata": r.metadata,
}
for r in results
] ]

View File

@ -3,6 +3,7 @@ FastAPI application for the knowledge base.
""" """
import contextlib import contextlib
import os
import pathlib import pathlib
import logging import logging
from typing import Annotated, Optional from typing import Annotated, Optional
@ -105,12 +106,16 @@ def get_file_by_path(path: str):
return FileResponse(path=file_path, filename=file_path.name) return FileResponse(path=file_path, filename=file_path.name)
def main(): def main(reload: bool = False):
"""Run the FastAPI server in debug mode with auto-reloading.""" """Run the FastAPI server in debug mode with auto-reloading."""
import uvicorn import uvicorn
uvicorn.run( uvicorn.run(
"memory.api.app:app", host="0.0.0.0", port=8000, reload=True, log_level="debug" "memory.api.app:app",
host="0.0.0.0",
port=8000,
reload=reload,
log_level="debug",
) )
@ -118,4 +123,4 @@ if __name__ == "__main__":
from memory.common.qdrant import setup_qdrant from memory.common.qdrant import setup_qdrant
setup_qdrant() setup_qdrant()
main() main(os.getenv("RELOAD", "false") == "true")

View File

@ -1,382 +0,0 @@
"""
Search endpoints for the knowledge base API.
"""
import asyncio
import base64
from hashlib import sha256
import io
import logging
from collections import defaultdict
from typing import Any, Callable, Optional, TypedDict, NotRequired
import bm25s
import Stemmer
import qdrant_client
from PIL import Image
from pydantic import BaseModel
from qdrant_client.http import models as qdrant_models
from memory.common import embedding, extract, qdrant, settings
from memory.common.collections import (
ALL_COLLECTIONS,
MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS,
)
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
class AnnotatedChunk(BaseModel):
id: str
score: float
metadata: dict
preview: Optional[str | None] = None
class SourceData(BaseModel):
"""Holds source item data to avoid SQLAlchemy session issues"""
id: int
size: int | None
mime_type: str | None
filename: str | None
content: str | dict | None
content_length: int
class SearchResponse(BaseModel):
collection: str
results: list[dict]
class SearchResult(BaseModel):
id: int
size: int
mime_type: str
chunks: list[AnnotatedChunk]
content: Optional[str | dict] = None
filename: Optional[str] = None
class SearchFilters(TypedDict):
subject: NotRequired[str | None]
confidence: NotRequired[float]
tags: NotRequired[list[str] | None]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]
async def with_timeout(
call, timeout: int = 2
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Run a function with a timeout.
Args:
call: The function to run
timeout: The timeout in seconds
"""
try:
return await asyncio.wait_for(call, timeout=timeout)
except TimeoutError:
logger.warning(f"Search timed out after {timeout}s")
return []
except Exception as e:
logger.error(f"Search failed: {e}")
return []
def annotated_chunk(
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
) -> tuple[SourceData, AnnotatedChunk]:
def serialize_item(item: bytes | str | Image.Image) -> str | None:
if not previews and not isinstance(item, str):
return None
if not previews and isinstance(item, str):
return item[:100]
if isinstance(item, Image.Image):
buffer = io.BytesIO()
format = item.format or "PNG"
item.save(buffer, format=format)
mime_type = f"image/{format.lower()}"
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
elif isinstance(item, bytes):
return base64.b64encode(item).decode("utf-8")
elif isinstance(item, str):
return item
else:
raise ValueError(f"Unsupported item type: {type(item)}")
metadata = search_result.payload or {}
metadata = {
k: v
for k, v in metadata.items()
if k not in ["content", "filename", "size", "content_type", "tags"]
}
# Prefetch all needed source data while in session
source = chunk.source
source_data = SourceData(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename,
content=source.display_contents,
content_length=len(source.content) if source.content else 0,
)
return source_data, AnnotatedChunk(
id=str(chunk.id),
score=search_result.score,
metadata=metadata,
preview=serialize_item(chunk.data[0]) if chunk.data else None,
)
def group_chunks(chunks: list[tuple[SourceData, AnnotatedChunk]]) -> list[SearchResult]:
items = defaultdict(list)
source_lookup = {}
for source, chunk in chunks:
items[source.id].append(chunk)
source_lookup[source.id] = source
return [
SearchResult(
id=source.id,
size=source.size or source.content_length,
mime_type=source.mime_type or "text/plain",
filename=source.filename
and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
),
content=source.content,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
)
for source_id, chunks in items.items()
for source in [source_lookup[source_id]]
]
def query_chunks(
client: qdrant_client.QdrantClient,
upload_data: list[extract.DataChunk],
allowed_modalities: set[str],
embedder: Callable,
min_score: float = 0.0,
limit: int = 10,
filters: dict[str, Any] | None = None,
) -> dict[str, list[qdrant_models.ScoredPoint]]:
if not upload_data or not allowed_modalities:
return {}
chunks = [chunk for chunk in upload_data if chunk.data]
if not chunks:
logger.error(f"No chunks to embed for {allowed_modalities}")
return {}
logger.error(f"Embedding {len(chunks)} chunks for {allowed_modalities}")
for c in chunks:
logger.error(f"Chunk: {c.data}")
vectors = embedder([c.data for c in chunks], input_type="query")
return {
collection: [
r
for vector in vectors
for r in qdrant.search_vectors(
client=client,
collection_name=collection,
query_vector=vector,
limit=limit,
filter_params=filters,
)
if r.score >= min_score
]
for collection in allowed_modalities
}
async def search_bm25(
query: str,
modalities: list[str],
limit: int = 10,
filters: SearchFilters = SearchFilters(),
) -> list[tuple[SourceData, AnnotatedChunk]]:
with make_session() as db:
items_query = db.query(Chunk.id, Chunk.content).filter(
Chunk.collection_name.in_(modalities)
)
if source_ids := filters.get("source_ids"):
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
items = items_query.all()
item_ids = {
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
for item in items
}
corpus = [item.content.lower().strip() for item in items]
stemmer = Stemmer.Stemmer("english")
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
results, scores = retriever.retrieve(
query_tokens, k=min(limit, len(corpus)), corpus=corpus
)
item_scores = {
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
for doc, score in zip(results[0], scores[0])
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
results = []
for chunk in chunks:
# Prefetch all needed source data while in session
source = chunk.source
source_data = SourceData(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename,
content=source.display_contents,
content_length=len(source.content) if source.content else 0,
)
annotated = AnnotatedChunk(
id=str(chunk.id),
score=item_scores[chunk.id],
metadata=source.as_payload(),
preview=None,
)
results.append((source_data, annotated))
return results
async def search_embeddings(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_score: float = 0.3,
filters: SearchFilters = SearchFilters(),
multimodal: bool = False,
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Search across knowledge base using text query and optional files.
Parameters:
- data: List of data to search in (e.g., text, images, files)
- previews: Whether to include previews in the search results
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
- limit: Maximum number of results
- min_score: Minimum score to include in the search results
- filters: Filters to apply to the search results
- multimodal: Whether to search in multimodal collections
"""
query_filters = {
"must": [
{"key": "confidence", "range": {"gte": filters.get("confidence", 0.5)}},
],
}
if tags := filters.get("tags"):
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
if observation_types := filters.get("observation_types"):
query_filters["must"] += [
{"key": "observation_type", "match": {"any": observation_types}}
]
client = qdrant.get_qdrant_client()
results = query_chunks(
client,
data,
modalities,
embedding.embed_text if not multimodal else embedding.embed_mixed,
min_score=min_score,
limit=limit,
filters=query_filters,
)
search_results = {k: results.get(k, []) for k in modalities}
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
return [
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
async def search(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: list[str] = [],
limit: int = 10,
min_text_score: float = 0.3,
min_multimodal_score: float = 0.3,
filters: SearchFilters = {},
) -> list[SearchResult]:
"""
Search across knowledge base using text query and optional files.
Parameters:
- query: Optional text search query
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
- files: Optional files to include in the search context
- limit: Maximum number of results per modality
Returns:
- List of search results sorted by score
"""
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
text_embeddings_results = with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & TEXT_COLLECTIONS,
limit,
min_text_score,
filters,
multimodal=False,
)
)
multimodal_embeddings_results = with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & MULTIMODAL_COLLECTIONS,
limit,
min_multimodal_score,
filters,
multimodal=True,
)
)
bm25_results = with_timeout(
search_bm25(
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
modalities,
limit=limit,
filters=filters,
)
)
results = await asyncio.gather(
text_embeddings_results,
multimodal_embeddings_results,
bm25_results,
return_exceptions=False,
)
results = group_chunks([c for r in results for c in r])
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)

View File

@ -0,0 +1,4 @@
from .search import search
from .utils import SearchResult, SearchFilters
__all__ = ["search", "SearchResult", "SearchFilters"]

View File

@ -0,0 +1,68 @@
"""
Search endpoints for the knowledge base API.
"""
from hashlib import sha256
import logging
import bm25s
import Stemmer
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
async def search_bm25(
query: str,
modalities: set[str],
limit: int = 10,
filters: SearchFilters = SearchFilters(),
) -> list[tuple[SourceData, AnnotatedChunk]]:
with make_session() as db:
items_query = db.query(Chunk.id, Chunk.content).filter(
Chunk.collection_name.in_(modalities)
)
if source_ids := filters.get("source_ids"):
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
items = items_query.all()
item_ids = {
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
for item in items
}
corpus = [item.content.lower().strip() for item in items]
stemmer = Stemmer.Stemmer("english")
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
results, scores = retriever.retrieve(
query_tokens, k=min(limit, len(corpus)), corpus=corpus
)
item_scores = {
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
for doc, score in zip(results[0], scores[0])
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
results = []
for chunk in chunks:
# Prefetch all needed source data while in session
source_data = SourceData.from_chunk(chunk)
annotated = AnnotatedChunk(
id=str(chunk.id),
score=item_scores[chunk.id],
metadata=chunk.source.as_payload(),
preview=None,
search_method="bm25",
)
results.append((source_data, annotated))
return results

View File

@ -0,0 +1,144 @@
import base64
import io
import logging
from typing import Any, Callable, Optional
import qdrant_client
from PIL import Image
from qdrant_client.http import models as qdrant_models
from memory.common import embedding, extract, qdrant
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
logger = logging.getLogger(__name__)
def annotated_chunk(
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
) -> tuple[SourceData, AnnotatedChunk]:
def serialize_item(item: bytes | str | Image.Image) -> str | None:
if not previews and not isinstance(item, str):
return None
if not previews and isinstance(item, str):
return item[:100]
if isinstance(item, Image.Image):
buffer = io.BytesIO()
format = item.format or "PNG"
item.save(buffer, format=format)
mime_type = f"image/{format.lower()}"
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
elif isinstance(item, bytes):
return base64.b64encode(item).decode("utf-8")
elif isinstance(item, str):
return item
else:
raise ValueError(f"Unsupported item type: {type(item)}")
metadata = search_result.payload or {}
metadata = {
k: v
for k, v in metadata.items()
if k not in ["content", "filename", "size", "content_type", "tags"]
}
# Prefetch all needed source data while in session
return SourceData.from_chunk(chunk), AnnotatedChunk(
id=str(chunk.id),
score=search_result.score,
metadata=metadata,
preview=serialize_item(chunk.data[0]) if chunk.data else None,
search_method="embeddings",
)
def query_chunks(
client: qdrant_client.QdrantClient,
upload_data: list[extract.DataChunk],
allowed_modalities: set[str],
embedder: Callable,
min_score: float = 0.3,
limit: int = 10,
filters: dict[str, Any] | None = None,
) -> dict[str, list[qdrant_models.ScoredPoint]]:
if not upload_data or not allowed_modalities:
return {}
chunks = [chunk for chunk in upload_data if chunk.data]
if not chunks:
logger.error(f"No chunks to embed for {allowed_modalities}")
return {}
vectors = embedder(chunks, input_type="query")
return {
collection: [
r
for vector in vectors
for r in qdrant.search_vectors(
client=client,
collection_name=collection,
query_vector=vector,
limit=limit,
filter_params=filters,
)
if r.score >= min_score
]
for collection in allowed_modalities
}
async def search_embeddings(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_score: float = 0.3,
filters: SearchFilters = SearchFilters(),
multimodal: bool = False,
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Search across knowledge base using text query and optional files.
Parameters:
- data: List of data to search in (e.g., text, images, files)
- previews: Whether to include previews in the search results
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
- limit: Maximum number of results
- min_score: Minimum score to include in the search results
- filters: Filters to apply to the search results
- multimodal: Whether to search in multimodal collections
"""
query_filters = {}
if confidence := filters.get("confidence"):
query_filters["must"] += [{"key": "confidence", "range": {"gte": confidence}}]
if tags := filters.get("tags"):
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
if observation_types := filters.get("observation_types"):
query_filters["must"] += [
{"key": "observation_type", "match": {"any": observation_types}}
]
client = qdrant.get_qdrant_client()
results = query_chunks(
client,
data,
modalities,
embedding.embed_text if not multimodal else embedding.embed_mixed,
min_score=min_score,
limit=limit,
filters=query_filters,
)
search_results = {k: results.get(k, []) for k in modalities}
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
return [
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]

View File

@ -0,0 +1,94 @@
"""
Search endpoints for the knowledge base API.
"""
import asyncio
import logging
from typing import Optional
from memory.api.search.embeddings import search_embeddings
from memory.api.search.bm25 import search_bm25
from memory.api.search.utils import SearchFilters, SearchResult
from memory.api.search.utils import group_chunks, with_timeout
from memory.common import extract
from memory.common.collections import (
ALL_COLLECTIONS,
MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS,
)
logger = logging.getLogger(__name__)
async def search(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_text_score: float = 0.4,
min_multimodal_score: float = 0.25,
filters: SearchFilters = {},
timeout: int = 2,
) -> list[SearchResult]:
"""
Search across knowledge base using text query and optional files.
Parameters:
- query: Optional text search query
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
- files: Optional files to include in the search context
- limit: Maximum number of results per modality
Returns:
- List of search results sorted by score
"""
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
text_embeddings_results = with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & TEXT_COLLECTIONS,
limit,
min_text_score,
filters,
multimodal=False,
),
timeout,
)
multimodal_embeddings_results = with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & MULTIMODAL_COLLECTIONS,
limit,
min_multimodal_score,
filters,
multimodal=True,
),
timeout,
)
bm25_results = with_timeout(
search_bm25(
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
modalities,
limit=limit,
filters=filters,
),
timeout,
)
results = await asyncio.gather(
text_embeddings_results,
multimodal_embeddings_results,
bm25_results,
return_exceptions=False,
)
text_results, multi_results, bm25_results = results
all_results = text_results + multi_results
if len(all_results) < limit:
all_results += bm25_results
results = group_chunks(all_results, previews or False)
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)

View File

@ -0,0 +1,134 @@
import asyncio
from datetime import datetime
import logging
from collections import defaultdict
from typing import Optional, TypedDict, NotRequired
from pydantic import BaseModel
from memory.common import settings
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
class AnnotatedChunk(BaseModel):
id: str
score: float
metadata: dict
preview: Optional[str | None] = None
search_method: str | None = None
class SourceData(BaseModel):
"""Holds source item data to avoid SQLAlchemy session issues"""
id: int
size: int | None
mime_type: str | None
filename: str | None
content_length: int
contents: dict | None
created_at: datetime | None
@staticmethod
def from_chunk(chunk: Chunk) -> "SourceData":
source = chunk.source
display_contents = source.display_contents or {}
return SourceData(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename,
content_length=len(source.content) if source.content else 0,
contents=display_contents,
created_at=source.inserted_at,
)
class SearchResponse(BaseModel):
collection: str
results: list[dict]
class SearchResult(BaseModel):
id: int
size: int
mime_type: str
chunks: list[AnnotatedChunk]
content: Optional[str | dict] = None
filename: Optional[str] = None
tags: list[str] | None = None
metadata: dict | None = None
created_at: datetime | None = None
class SearchFilters(TypedDict):
subject: NotRequired[str | None]
confidence: NotRequired[float]
tags: NotRequired[list[str] | None]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]
async def with_timeout(
call, timeout: int = 2
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Run a function with a timeout.
Args:
call: The function to run
timeout: The timeout in seconds
"""
try:
return await asyncio.wait_for(call, timeout=timeout)
except TimeoutError:
logger.warning(f"Search timed out after {timeout}s")
return []
except Exception as e:
logger.error(f"Search failed: {e}")
return []
def group_chunks(
chunks: list[tuple[SourceData, AnnotatedChunk]], preview: bool = False
) -> list[SearchResult]:
items = defaultdict(list)
source_lookup = {}
for source, chunk in chunks:
items[source.id].append(chunk)
source_lookup[source.id] = source
def get_content(text: str | dict | None) -> str | dict | None:
if preview or not text or not isinstance(text, str) or len(text) < 250:
return text
return text[:250] + "..."
def make_result(source: SourceData, chunks: list[AnnotatedChunk]) -> SearchResult:
contents = source.contents or {}
tags = contents.pop("tags", [])
content = contents.pop("content", None)
return SearchResult(
id=source.id,
size=source.size or source.content_length,
mime_type=source.mime_type or "text/plain",
filename=source.filename
and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
),
content=get_content(content),
tags=tags,
metadata=contents,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
created_at=source.created_at,
)
return [
make_result(source, chunks)
for source_id, chunks in items.items()
for source in [source_lookup[source_id]]
]

View File

@ -102,6 +102,7 @@ TEXT_COLLECTIONS = {
MULTIMODAL_COLLECTIONS = { MULTIMODAL_COLLECTIONS = {
coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal") coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal")
} }
OBSERVATION_COLLECTIONS = {"semantic", "temporal"}
TYPES = { TYPES = {
"doc": ["application/pdf", "application/docx", "application/msword"], "doc": ["application/pdf", "application/docx", "application/msword"],

View File

@ -84,7 +84,7 @@ def clean_filename(filename: str) -> str:
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]: def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
for i, image in enumerate(images): for i, image in enumerate(images):
if not image.filename: # type: ignore if not getattr(image, "filename", None): # type: ignore
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
image.save(filename) image.save(filename)
image.filename = str(filename) # type: ignore image.filename = str(filename) # type: ignore
@ -100,16 +100,6 @@ def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalCh
] ]
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
final = {}
for m in metadata:
data = m.copy()
if tags := set(data.pop("tags", [])):
final["tags"] = tags | final.get("tags", set())
final |= data
return final
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]: def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
if not content.strip(): if not content.strip():
return [] return []
@ -241,14 +231,11 @@ class SourceItem(Base):
return [chunk.id for chunk in self.chunks] return [chunk.id for chunk in self.chunks]
def _chunk_contents(self) -> Sequence[extract.DataChunk]: def _chunk_contents(self) -> Sequence[extract.DataChunk]:
chunks: list[extract.DataChunk] = []
content = cast(str | None, self.content) content = cast(str | None, self.content)
if content: if content:
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)] chunks = extract.extract_text(content)
else:
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2: chunks = []
summary, tags = summarizer.summarize(content)
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
mime_type = cast(str | None, self.mime_type) mime_type = cast(str | None, self.mime_type)
if mime_type and mime_type.startswith("image/"): if mime_type and mime_type.startswith("image/"):
@ -272,12 +259,14 @@ class SourceItem(Base):
file_paths=image_names, file_paths=image_names,
collection_name=modality, collection_name=modality,
embedding_model=collections.collection_model(modality, text, images), embedding_model=collections.collection_model(modality, text, images),
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata), item_metadata=extract.merge_metadata(
self.as_payload(), data.metadata, metadata
),
) )
return chunk return chunk
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
return [self._make_chunk(data) for data in self._chunk_contents()] return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
def as_payload(self) -> dict: def as_payload(self) -> dict:
return { return {
@ -291,4 +280,7 @@ class SourceItem(Base):
return { return {
"tags": self.tags, "tags": self.tags,
"size": self.size, "size": self.size,
"content": self.content,
"filename": self.filename,
"mime_type": self.mime_type,
} }

View File

@ -33,7 +33,6 @@ from memory.common.db.models.source_item import (
SourceItem, SourceItem,
Chunk, Chunk,
clean_filename, clean_filename,
merge_metadata,
chunk_mixed, chunk_mixed,
) )
@ -326,27 +325,24 @@ class BookSection(SourceItem):
} }
return {k: v for k, v in vals.items() if v} return {k: v for k, v in vals.items() if v}
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: def _chunk_contents(self) -> Sequence[extract.DataChunk]:
content = cast(str, self.content.strip()) content = cast(str, self.content.strip())
if not content: if not content:
return [] return []
if len([p for p in self.pages if p.strip()]) == 1: if len([p for p in self.pages if p.strip()]) == 1:
return [ chunks = extract.extract_text(content, metadata={"type": "page"})
self._make_chunk( if len(chunks) > 1:
extract.DataChunk(data=[content]), metadata | {"type": "page"} chunks[-1].metadata["type"] = "summary"
) return chunks
]
summary, tags = summarizer.summarize(content) summary, tags = summarizer.summarize(content)
return [ return [
self._make_chunk( extract.DataChunk(
extract.DataChunk(data=[content]), data=[content], metadata={"type": "section", "tags": tags}
merge_metadata(metadata, {"type": "section", "tags": tags}),
), ),
self._make_chunk( extract.DataChunk(
extract.DataChunk(data=[summary]), data=[summary], metadata={"type": "summary", "tags": tags}
merge_metadata(metadata, {"type": "summary", "tags": tags}),
), ),
] ]
@ -596,7 +592,7 @@ class AgentObservation(SourceItem):
) )
semantic_chunk = extract.DataChunk( semantic_chunk = extract.DataChunk(
data=[semantic_text], data=[semantic_text],
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}), metadata=extract.merge_metadata(metadata, {"embedding_type": "semantic"}),
modality="semantic", modality="semantic",
) )
@ -609,7 +605,7 @@ class AgentObservation(SourceItem):
) )
temporal_chunk = extract.DataChunk( temporal_chunk = extract.DataChunk(
data=[temporal_text], data=[temporal_text],
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}), metadata=extract.merge_metadata(metadata, {"embedding_type": "temporal"}),
modality="temporal", modality="temporal",
) )
@ -617,14 +613,14 @@ class AgentObservation(SourceItem):
self._make_chunk( self._make_chunk(
extract.DataChunk( extract.DataChunk(
data=[i], data=[i],
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}), metadata=extract.merge_metadata(
metadata, {"embedding_type": "semantic"}
),
modality="semantic", modality="semantic",
) )
) )
for i in [ for i in [
self.content, self.content,
self.subject,
self.observation_type,
self.evidence.get("quote", ""), self.evidence.get("quote", ""),
] ]
if i if i

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Iterable, Literal, cast from typing import Literal, cast
import voyageai import voyageai
@ -15,12 +15,22 @@ from memory.common.db.models import Chunk, SourceItem
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def as_string(
chunk: extract.MulitmodalChunk | list[extract.MulitmodalChunk],
) -> str:
if isinstance(chunk, str):
return chunk.strip()
if isinstance(chunk, list):
return "\n".join(as_string(i) for i in chunk).strip()
return ""
def embed_chunks( def embed_chunks(
chunks: list[list[extract.MulitmodalChunk]], chunks: list[list[extract.MulitmodalChunk]],
model: str = settings.TEXT_EMBEDDING_MODEL, model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document", input_type: Literal["document", "query"] = "document",
) -> list[Vector]: ) -> list[Vector]:
logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]} {len(chunks)}") logger.debug(f"Embedding chunks: {model} - {str(chunks)} {len(chunks)}")
vo = voyageai.Client() # type: ignore vo = voyageai.Client() # type: ignore
if model == settings.MIXED_EMBEDDING_MODEL: if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed( return vo.multimodal_embed(
@ -29,17 +39,18 @@ def embed_chunks(
input_type=input_type, input_type=input_type,
).embeddings ).embeddings
texts = ["\n".join(i for i in c if isinstance(i, str)) for c in chunks] texts = [as_string(c) for c in chunks]
logger.debug(f"Embedding texts: {texts}")
return cast( return cast(
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
) )
def break_chunk( def break_chunk(
chunk: list[extract.MulitmodalChunk], chunk_size: int = DEFAULT_CHUNK_TOKENS chunk: extract.DataChunk, chunk_size: int = DEFAULT_CHUNK_TOKENS
) -> list[extract.MulitmodalChunk]: ) -> list[extract.MulitmodalChunk]:
result = [] result = []
for c in chunk: for c in chunk.data:
if isinstance(c, str): if isinstance(c, str):
result += chunk_text(c, chunk_size, OVERLAP_TOKENS) result += chunk_text(c, chunk_size, OVERLAP_TOKENS)
else: else:
@ -48,12 +59,12 @@ def break_chunk(
def embed_text( def embed_text(
chunks: list[list[extract.MulitmodalChunk]], chunks: list[extract.DataChunk],
model: str = settings.TEXT_EMBEDDING_MODEL, model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document", input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS, chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]: ) -> list[Vector]:
chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks] chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks if chunk.data]
if not any(chunked_chunks): if not any(chunked_chunks):
return [] return []
@ -61,12 +72,12 @@ def embed_text(
def embed_mixed( def embed_mixed(
items: list[list[extract.MulitmodalChunk]], items: list[extract.DataChunk],
model: str = settings.MIXED_EMBEDDING_MODEL, model: str = settings.MIXED_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document", input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS, chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]: ) -> list[Vector]:
chunked_chunks = [break_chunk(item, chunk_size) for item in items] chunked_chunks = [break_chunk(item, chunk_size) for item in items if item.data]
return embed_chunks(chunked_chunks, model, input_type) return embed_chunks(chunked_chunks, model, input_type)

View File

@ -6,7 +6,7 @@ import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Generator, Sequence, cast from typing import Any, Generator, Sequence, cast
from memory.common import chunker from memory.common import chunker, summarizer
import pymupdf # PyMuPDF import pymupdf # PyMuPDF
import pypandoc import pypandoc
from PIL import Image from PIL import Image
@ -16,6 +16,16 @@ logger = logging.getLogger(__name__)
MulitmodalChunk = Image.Image | str MulitmodalChunk = Image.Image | str
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
final = {}
for m in metadata:
data = m.copy()
if tags := set(data.pop("tags", []) or []):
final["tags"] = tags | final.get("tags", set())
final |= data
return final
@dataclass @dataclass
class DataChunk: class DataChunk:
data: Sequence[MulitmodalChunk] data: Sequence[MulitmodalChunk]
@ -109,7 +119,9 @@ def extract_image(content: bytes | str | pathlib.Path) -> list[DataChunk]:
def extract_text( def extract_text(
content: bytes | str | pathlib.Path, chunk_size: int | None = None content: bytes | str | pathlib.Path,
chunk_size: int | None = None,
metadata: dict[str, Any] = {},
) -> list[DataChunk]: ) -> list[DataChunk]:
if isinstance(content, pathlib.Path): if isinstance(content, pathlib.Path):
content = content.read_text() content = content.read_text()
@ -117,8 +129,20 @@ def extract_text(
content = content.decode("utf-8") content = content.decode("utf-8")
content = cast(str, content) content = cast(str, content)
chunks = chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS) chunks = [
return [DataChunk(data=[c], mime_type="text/plain") for c in chunks if c.strip()] DataChunk(data=[c], modality="text", metadata=metadata)
for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
]
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
summary, tags = summarizer.summarize(content)
chunks.append(
DataChunk(
data=[summary],
metadata=merge_metadata(metadata, {"tags": tags}),
modality="text",
)
)
return chunks
def extract_data_chunks( def extract_data_chunks(

View File

@ -3,7 +3,7 @@ from typing import Any, cast, Generator, Sequence
import qdrant_client import qdrant_client
from qdrant_client.http import models as qdrant_models from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse, ApiException
from memory.common import settings from memory.common import settings
from memory.common.collections import ALL_COLLECTIONS, Collection, DistanceType, Vector from memory.common.collections import ALL_COLLECTIONS, Collection, DistanceType, Vector
@ -193,14 +193,18 @@ def delete_points(
collection_name: Name of the collection collection_name: Name of the collection
ids: List of vector IDs to delete ids: List of vector IDs to delete
""" """
client.delete( try:
collection_name=collection_name, client.delete(
points_selector=qdrant_models.PointIdsList( collection_name=collection_name,
points=ids, # type: ignore points_selector=qdrant_models.PointIdsList(
), points=ids, # type: ignore
) ),
)
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}") logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
except (ApiException, UnexpectedResponse) as e:
logger.error(f"Error deleting points from {collection_name}: {e}")
raise IOError(f"Error deleting points from {collection_name}: {e}")
def get_collection_info( def get_collection_info(

View File

@ -1,5 +1,6 @@
import json import json
import logging import logging
import traceback
from typing import Any from typing import Any
from memory.common import settings, chunker from memory.common import settings, chunker
@ -131,6 +132,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
summary = result.get("summary", "") summary = result.get("summary", "")
tags = result.get("tags", []) tags = result.get("tags", [])
except Exception as e: except Exception as e:
traceback.print_exc()
logger.error(f"Summarization failed: {e}") logger.error(f"Summarization failed: {e}")
tokens = chunker.approx_token_count(summary) tokens = chunker.approx_token_count(summary)

View File

@ -60,6 +60,7 @@ def section_processor(
end_page=section.end_page, end_page=section.end_page,
parent_section_id=None, # Will be set after flush parent_section_id=None, # Will be set after flush
content=content, content=content,
filename=book.file_path,
size=len(content), size=len(content),
mime_type="text/plain", mime_type="text/plain",
sha256=create_content_hash( sha256=create_content_hash(

View File

@ -21,6 +21,7 @@ REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks"
REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk" REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk"
REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item" REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item"
REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items" REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items"
REINGEST_ALL_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_all_empty_source_items"
UPDATE_METADATA_FOR_SOURCE_ITEMS = ( UPDATE_METADATA_FOR_SOURCE_ITEMS = (
f"{MAINTENANCE_ROOT}.update_metadata_for_source_items" f"{MAINTENANCE_ROOT}.update_metadata_for_source_items"
) )
@ -76,9 +77,9 @@ def reingest_chunk(chunk_id: str, collection: str):
data = chunk.data data = chunk.data
if collection in collections.MULTIMODAL_COLLECTIONS: if collection in collections.MULTIMODAL_COLLECTIONS:
vector = embedding.embed_mixed(data)[0] vector = embedding.embed_mixed([data])[0]
elif len(data) == 1 and isinstance(data[0], str): elif collection in collections.TEXT_COLLECTIONS:
vector = embedding.embed_text([data[0]])[0] vector = embedding.embed_text([data])[0]
else: else:
raise ValueError(f"Unsupported data type for collection {collection}") raise ValueError(f"Unsupported data type for collection {collection}")
@ -123,7 +124,10 @@ def reingest_item(item_id: str, item_type: str):
chunk_ids = [str(c.id) for c in item.chunks if c.id] chunk_ids = [str(c.id) for c in item.chunks if c.id]
if chunk_ids: if chunk_ids:
client = qdrant.get_qdrant_client() client = qdrant.get_qdrant_client()
qdrant.delete_points(client, item.modality, chunk_ids) try:
qdrant.delete_points(client, item.modality, chunk_ids)
except IOError as e:
logger.error(f"Error deleting chunks for {item_id}: {e}")
for chunk in item.chunks: for chunk in item.chunks:
session.delete(chunk) session.delete(chunk)
@ -151,6 +155,13 @@ def reingest_empty_source_items(item_type: str):
return {"status": "success", "items": len(item_ids)} return {"status": "success", "items": len(item_ids)}
@app.task(name=REINGEST_ALL_EMPTY_SOURCE_ITEMS)
def reingest_all_empty_source_items():
logger.info("Reingesting all empty source items")
for item_type in SourceItem.registry._class_registry.keys():
reingest_empty_source_items.delay(item_type) # type: ignore
def check_batch(batch: Sequence[Chunk]) -> dict: def check_batch(batch: Sequence[Chunk]) -> dict:
client = qdrant.get_qdrant_client() client = qdrant.get_qdrant_client()
by_collection = defaultdict(list) by_collection = defaultdict(list)

View File

@ -234,8 +234,10 @@ def mock_voyage_client():
def embeder(chunks, *args, **kwargs): def embeder(chunks, *args, **kwargs):
return Mock(embeddings=[[0.1] * 1024] * len(chunks)) return Mock(embeddings=[[0.1] * 1024] * len(chunks))
real_client = voyageai.Client
with patch.object(voyageai, "Client", autospec=True) as mock_client: with patch.object(voyageai, "Client", autospec=True) as mock_client:
client = mock_client() client = mock_client()
client.real_client = real_client
client.embed = embeder client.embed = embeder
client.multimodal_embed = embeder client.multimodal_embed = embeder
yield client yield client
@ -251,7 +253,7 @@ def mock_openai_client():
choices=[ choices=[
Mock( Mock(
message=Mock( message=Mock(
content='{"summary": "test", "tags": ["tag1", "tag2"]}' content='{"summary": "test summary", "tags": ["tag1", "tag2"]}'
) )
) )
] ]
@ -267,7 +269,9 @@ def mock_anthropic_client():
client.messages = Mock() client.messages = Mock()
client.messages.create = Mock( client.messages.create = Mock(
return_value=Mock( return_value=Mock(
content=[Mock(text='{"summary": "test", "tags": ["tag1", "tag2"]}')] content=[
Mock(text='{"summary": "test summary", "tags": ["tag1", "tag2"]}')
]
) )
) )
yield client yield client

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

249
tests/data/contents.py Normal file
View File

@ -0,0 +1,249 @@
import hashlib
import pathlib
from bs4 import BeautifulSoup
from markdownify import markdownify
from PIL import Image
DATA_DIR = pathlib.Path(__file__).parent
SAMPLE_HTML = f"""
<html>
<body>
<h1>The Evolution of Programming Languages</h1>
<p>Programming languages have undergone tremendous evolution since the early days of computing.
From the machine code and assembly languages of the 1940s to the high-level, expressive languages
we use today, each generation has built upon the lessons learned from its predecessors. Languages
like FORTRAN and COBOL pioneered the concept of human-readable code, while later innovations like
object-oriented programming in languages such as Smalltalk and C++ revolutionized how we structure
and organize our programs.</p>
<img src="{DATA_DIR / "lang_timeline.png"}" alt="Timeline of programming language evolution" width="600" height="400">
<p>The rise of functional programming paradigms has brought mathematical rigor and immutability
to the forefront of software development. Languages like Haskell, Lisp, and more recently Rust
and Elm have demonstrated the power of pure functions and type systems in creating more reliable
and maintainable code. These paradigms emphasize the elimination of side effects and the treatment
of computation as the evaluation of mathematical functions.</p>
patch
<p>Modern development has also seen the emergence of domain-specific languages and the resurgence
of interest in memory safety. The advent of languages like Python and JavaScript has democratized
programming by lowering the barrier to entry, while systems languages like Rust have proven that
performance and safety need not be mutually exclusive. The ongoing development of WebAssembly
promises to bring high-performance computing to web browsers in ways previously unimaginable.</p>
<img src="{DATA_DIR / "code_complexity.jpg"}" alt="Visual representation of code complexity over time" width="500" height="300">
<p>Looking toward the future, we see emerging trends in quantum programming languages, AI-assisted
code generation, and the continued evolution toward more expressive type systems. The challenge
for tomorrow's language designers will be balancing expressiveness with simplicity, performance
with safety, and innovation with backward compatibility. As computing continues to permeate every
aspect of human life, the languages we use to command these machines will undoubtedly continue
to evolve and shape the digital landscape.</p>
<p>The emergence of cloud computing and distributed systems has also driven new paradigms in
language design. Languages like Go and Elixir have been specifically crafted to excel in
concurrent and distributed environments, while the rise of microservices has renewed interest
in polyglot programming approaches. These developments reflect a broader shift toward languages
that are not just powerful tools for individual developers, but robust foundations for building
scalable, resilient systems that can handle the demands of modern internet-scale applications.</p>
<p>Perhaps most intriguingly, the intersection of programming languages with artificial intelligence
is opening entirely new frontiers. Differentiable programming languages are enabling new forms of
machine learning research, while large language models are beginning to reshape how we think about
code generation and developer tooling. As we stand on the brink of an era where AI systems may
become active participants in the programming process itself, the very nature of what constitutes
a programming languageand who or what programs in itmay be fundamentally transformed.</p>
</body>
</html>
"""
SECOND_PAGE = """
<div>
<h2>The Impact of Open Source on Language Development</h2>
<p>The open source movement has fundamentally transformed how programming languages are developed,
distributed, and evolved. Unlike the proprietary languages of earlier decades, modern language
development often occurs in public repositories where thousands of contributors can participate
in the design process. Languages like Python, JavaScript, and Rust have benefited enormously
from this collaborative approach, with their ecosystems growing rapidly through community-driven
package managers and extensive third-party libraries.</p>
<p>This democratization of language development has led to faster innovation cycles and more
responsive adaptation to developer needs. When a language feature proves problematic or a new
paradigm emerges, open source languages can quickly incorporate changes through their community
governance processes. The result has been an unprecedented period of language experimentation
and refinement, where ideas can be tested, refined, and adopted across multiple language
communities simultaneously.</p>
<p>Furthermore, the open source model has enabled the rise of domain-specific languages that
might never have been commercially viable under traditional development models. From specialized
query languages for databases to configuration management tools, the low barrier to entry for
language creation has fostered an explosion of linguistic diversity in computing, each tool
optimized for specific problem domains and user communities.</p>
<p>The collaborative nature of open source development has also revolutionized language tooling
and developer experience. Modern languages benefit from rich ecosystems of editors, debuggers,
profilers, and static analysis tools, all developed by passionate communities who understand
the daily challenges faced by practitioners. This has created a virtuous cycle where better
tooling attracts more developers, who in turn contribute improvements that make the language
even more accessible and powerful.</p>
<p>Version control systems like Git have enabled unprecedented transparency in language evolution,
allowing developers to trace the reasoning behind every design decision through detailed commit
histories and issue discussions. This historical record serves not only as documentation but as
a learning resource for future language designers, helping them understand the trade-offs and
considerations that shaped successful language features.</p>
<p>The economic implications of open source language development cannot be overstated. By removing
licensing barriers and vendor lock-in, open source languages have democratized access to powerful
programming tools across the globe. This has enabled innovation in regions and sectors that might
otherwise have been excluded from the software revolution, fostering a truly global community of
software creators and problem solvers.</p>
</div>
"""
CHUNKS: list[str] = [
"""The Evolution of Programming Languages
======================================
Programming languages have undergone tremendous evolution since the early days of computing.
From the machine code and assembly languages of the 1940s to the high\\-level, expressive languages
we use today, each generation has built upon the lessons learned from its predecessors. Languages
like FORTRAN and COBOL pioneered the concept of human\\-readable code, while later innovations like
object\\-oriented programming in languages such as Smalltalk and C\\+\\+ revolutionized how we structure
and organize our programs.
![Timeline of programming language evolution](/Users/dan/code/memory/tests/data/lang_timeline.png)
The rise of functional programming paradigms has brought mathematical rigor and immutability
to the forefront of software development. Languages like Haskell, Lisp, and more recently Rust
and Elm have demonstrated the power of pure functions and type systems in creating more reliable
and maintainable code. These paradigms emphasize the elimination of side effects and the treatment
of computation as the evaluation of mathematical functions.
Modern development has also seen the emergence of domain\\-specific languages and the resurgence
of interest in memory safety. The advent of languages like Python and JavaScript has democratized
programming by lowering the barrier to entry, while systems languages like Rust have proven that
performance and safety need not be mutually exclusive. The ongoing development of WebAssembly
promises to bring high\\-performance computing to web browsers in ways previously unimaginable.
![Visual representation of code complexity over time](/Users/dan/code/memory/tests/data/code_complexity.jpg)
Looking toward the future, we see emerging trends in quantum programming languages, AI\\-assisted
code generation, and the continued evolution toward more expressive type systems. The challenge
for tomorrow's language designers will be balancing expressiveness with simplicity, performance
with safety, and innovation with backward compatibility. As computing continues to permeate every
aspect of human life, the languages we use to command these machines will undoubtedly continue
to evolve and shape the digital landscape.""",
"""
As computing continues to permeate every
aspect of human life, the languages we use to command these machines will undoubtedly continue
to evolve and shape the digital landscape.
The emergence of cloud computing and distributed systems has also driven new paradigms in
language design. Languages like Go and Elixir have been specifically crafted to excel in
concurrent and distributed environments, while the rise of microservices has renewed interest
in polyglot programming approaches. These developments reflect a broader shift toward languages
that are not just powerful tools for individual developers, but robust foundations for building
scalable, resilient systems that can handle the demands of modern internet\\-scale applications.
Perhaps most intriguingly, the intersection of programming languages with artificial intelligence
is opening entirely new frontiers. Differentiable programming languages are enabling new forms of
machine learning research, while large language models are beginning to reshape how we think about
code generation and developer tooling. As we stand on the brink of an era where AI systems may
become active participants in the programming process itself, the very nature of what constitutes
a programming languageand who or what programs in itmay be fundamentally transformed.""",
]
TWO_PAGE_CHUNKS: list[str] = [
"""
The Evolution of Programming Languages
======================================
Programming languages have undergone tremendous evolution since the early days of computing.
From the machine code and assembly languages of the 1940s to the high\-level, expressive languages
we use today, each generation has built upon the lessons learned from its predecessors. Languages
like FORTRAN and COBOL pioneered the concept of human\-readable code, while later innovations like
object\-oriented programming in languages such as Smalltalk and C\+\+ revolutionized how we structure
and organize our programs.
![Timeline of programming language evolution](/Users/dan/code/memory/tests/data/lang_timeline.png)
The rise of functional programming paradigms has brought mathematical rigor and immutability
to the forefront of software development. Languages like Haskell, Lisp, and more recently Rust
and Elm have demonstrated the power of pure functions and type systems in creating more reliable
and maintainable code. These paradigms emphasize the elimination of side effects and the treatment
of computation as the evaluation of mathematical functions.
Modern development has also seen the emergence of domain\-specific languages and the resurgence
of interest in memory safety. The advent of languages like Python and JavaScript has democratized
programming by lowering the barrier to entry, while systems languages like Rust have proven that
performance and safety need not be mutually exclusive. The ongoing development of WebAssembly
promises to bring high\-performance computing to web browsers in ways previously unimaginable.
![Visual representation of code complexity over time](/Users/dan/code/memory/tests/data/code_complexity.jpg)
Looking toward the future, we see emerging trends in quantum programming languages, AI\-assisted
code generation, and the continued evolution toward more expressive type systems. The challenge
for tomorrow's language designers will be balancing expressiveness with simplicity, performance
with safety, and innovation with backward compatibility. As computing continues to permeate every
aspect of human life, the languages we use to command these machines will undoubtedly continue
to evolve and shape the digital landscape.
""",
"""
As computing continues to permeate every
aspect of human life, the languages we use to command these machines will undoubtedly continue
to evolve and shape the digital landscape.
The emergence of cloud computing and distributed systems has also driven new paradigms in
language design. Languages like Go and Elixir have been specifically crafted to excel in
concurrent and distributed environments, while the rise of microservices has renewed interest
in polyglot programming approaches. These developments reflect a broader shift toward languages
that are not just powerful tools for individual developers, but robust foundations for building
scalable, resilient systems that can handle the demands of modern internet\-scale applications.
Perhaps most intriguingly, the intersection of programming languages with artificial intelligence
is opening entirely new frontiers. Differentiable programming languages are enabling new forms of
machine learning research, while large language models are beginning to reshape how we think about
code generation and developer tooling. As we stand on the brink of an era where AI systems may
become active participants in the programming process itself, the very nature of what constitutes
a programming languageand who or what programs in itmay be fundamentally transformed.
The Impact of Open Source on Language Development
-------------------------------------------------
The open source movement has fundamentally transformed how programming languages are developed,
distributed, and evolved. Unlike the proprietary languages of earlier decades, modern language
development often occurs in public repositories where thousands of contributors can participate
in the design process. Languages like Python, JavaScript, and Rust have benefited enormously
from this collaborative approach, with their ecosystems growing rapidly through community\-driven
package managers and extensive third\-party libraries.
This democratization of language development has led to faster innovation cycles and more
responsive adaptation to developer needs. When a language feature proves problematic or a new
paradigm emerges, open source languages can quickly incorporate changes through their community
governance processes. The result has been an unprecedented period of language experimentation
and refinement, where ideas can be tested, refined, and adopted across multiple language
communities simultaneously.""",
"""
The result has been an unprecedented period of language experimentation
and refinement, where ideas can be tested, refined, and adopted across multiple language
communities simultaneously.
Furthermore, the open source model has enabled the rise of domain\-specific languages that
might never have been commercially viable under traditional development models. From specialized
query languages for databases to configuration management tools, the low barrier to entry for
language creation has fostered an explosion of linguistic diversity in computing, each tool
optimized for specific problem domains and user communities.
The collaborative nature of open source development has also revolutionized language tooling
and developer experience. Modern languages benefit from rich ecosystems of editors, debuggers,
profilers, and static analysis tools, all developed by passionate communities who understand
the daily challenges faced by practitioners. This has created a virtuous cycle where better
tooling attracts more developers, who in turn contribute improvements that make the language
even more accessible and powerful.
Version control systems like Git have enabled unprecedented transparency in language evolution,
allowing developers to trace the reasoning behind every design decision through detailed commit
histories and issue discussions. This historical record serves not only as documentation but as
a learning resource for future language designers, helping them understand the trade\-offs and
considerations that shaped successful language features.
The economic implications of open source language development cannot be overstated. By removing
licensing barriers and vendor lock\-in, open source languages have democratized access to powerful
programming tools across the globe. This has enabled innovation in regions and sectors that might
otherwise have been excluded from the software revolution, fostering a truly global community of
software creators and problem solvers.
""",
]
SAMPLE_MARKDOWN = markdownify(SAMPLE_HTML)
SAMPLE_TEXT = BeautifulSoup(SAMPLE_HTML, "html.parser").get_text()
SECOND_PAGE_MARKDOWN = markdownify(SECOND_PAGE)
SECOND_PAGE_TEXT = BeautifulSoup(SECOND_PAGE, "html.parser").get_text()
def image_hash(image: Image.Image) -> str:
return hashlib.sha256(image.tobytes()).hexdigest()
LANG_TIMELINE = Image.open(DATA_DIR / "lang_timeline.png")
CODE_COMPLEXITY = Image.open(DATA_DIR / "code_complexity.jpg")
LANG_TIMELINE_HASH = image_hash(LANG_TIMELINE)
CODE_COMPLEXITY_HASH = image_hash(CODE_COMPLEXITY)

Binary file not shown.

After

Width:  |  Height:  |  Size: 309 KiB

File diff suppressed because it is too large Load Diff

View File

@ -1,23 +1,17 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from unittest.mock import patch, Mock from unittest.mock import patch
from typing import cast from typing import cast
import pytest import pytest
from PIL import Image from PIL import Image
from datetime import datetime
from memory.common import settings, chunker, extract from memory.common import settings, chunker, extract
from memory.common.db.models.sources import Book
from memory.common.db.models.source_items import ( from memory.common.db.models.source_items import (
Chunk, Chunk,
MailMessage, MailMessage,
EmailAttachment,
BookSection,
BlogPost,
) )
from memory.common.db.models.source_item import ( from memory.common.db.models.source_item import (
SourceItem, SourceItem,
image_filenames, image_filenames,
add_pics, add_pics,
merge_metadata,
clean_filename, clean_filename,
) )
@ -56,114 +50,6 @@ def test_clean_filename(input_filename, expected):
assert clean_filename(input_filename) == expected assert clean_filename(input_filename) == expected
@pytest.mark.parametrize(
"dicts,expected",
[
# Empty input
([], {}),
# Single dict without tags
([{"key": "value"}], {"key": "value"}),
# Single dict with tags as list
(
[{"key": "value", "tags": ["tag1", "tag2"]}],
{"key": "value", "tags": {"tag1", "tag2"}},
),
# Single dict with tags as set
(
[{"key": "value", "tags": {"tag1", "tag2"}}],
{"key": "value", "tags": {"tag1", "tag2"}},
),
# Multiple dicts without tags
(
[{"key1": "value1"}, {"key2": "value2"}],
{"key1": "value1", "key2": "value2"},
),
# Multiple dicts with non-overlapping tags
(
[
{"key1": "value1", "tags": ["tag1"]},
{"key2": "value2", "tags": ["tag2"]},
],
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2"}},
),
# Multiple dicts with overlapping tags
(
[
{"key1": "value1", "tags": ["tag1", "tag2"]},
{"key2": "value2", "tags": ["tag2", "tag3"]},
],
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2", "tag3"}},
),
# Overlapping keys - later dict wins
(
[
{"key": "value1", "other": "data1"},
{"key": "value2", "another": "data2"},
],
{"key": "value2", "other": "data1", "another": "data2"},
),
# Mixed tags types (list and set)
(
[
{"key1": "value1", "tags": ["tag1", "tag2"]},
{"key2": "value2", "tags": {"tag3", "tag4"}},
],
{
"key1": "value1",
"key2": "value2",
"tags": {"tag1", "tag2", "tag3", "tag4"},
},
),
# Empty tags
(
[{"key": "value", "tags": []}, {"key2": "value2", "tags": []}],
{"key": "value", "key2": "value2"},
),
# None values
(
[{"key1": None, "key2": "value"}, {"key3": None}],
{"key1": None, "key2": "value", "key3": None},
),
# Complex nested structures
(
[
{"nested": {"inner": "value1"}, "list": [1, 2, 3], "tags": ["tag1"]},
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": ["tag2"]},
],
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": {"tag1", "tag2"}},
),
# Boolean and numeric values
(
[
{"bool": True, "int": 42, "float": 3.14, "tags": ["numeric"]},
{"bool": False, "int": 100},
],
{"bool": False, "int": 100, "float": 3.14, "tags": {"numeric"}},
),
# Three or more dicts
(
[
{"a": 1, "tags": ["t1"]},
{"b": 2, "tags": ["t2", "t3"]},
{"c": 3, "a": 10, "tags": ["t3", "t4"]},
],
{"a": 10, "b": 2, "c": 3, "tags": {"t1", "t2", "t3", "t4"}},
),
# Dict with only tags
([{"tags": ["tag1", "tag2"]}], {"tags": {"tag1", "tag2"}}),
# Empty dicts
([{}, {}], {}),
# Mix of empty and non-empty dicts
(
[{}, {"key": "value", "tags": ["tag"]}, {}],
{"key": "value", "tags": {"tag"}},
),
],
)
def test_merge_metadata(dicts, expected):
assert merge_metadata(*dicts) == expected
def test_image_filenames_with_existing_filenames(tmp_path): def test_image_filenames_with_existing_filenames(tmp_path):
"""Test image_filenames when images already have filenames""" """Test image_filenames when images already have filenames"""
chunk_id = "test_chunk_123" chunk_id = "test_chunk_123"

View File

@ -0,0 +1,626 @@
import hashlib
from datetime import datetime
from typing import Sequence, cast
from unittest.mock import ANY, Mock, call
import pymupdf # PyMuPDF
import pytest
from memory.common import settings
from memory.common.db.models.source_item import Chunk, SourceItem
from memory.common.db.models.source_items import (
AgentObservation,
BlogPost,
BookSection,
Comic,
EmailAttachment,
ForumPost,
MailMessage,
)
from memory.common.db.models.sources import Book
from memory.common.embedding import embed_source_item
from memory.common.extract import page_to_image
from tests.data.contents import (
CHUNKS,
LANG_TIMELINE_HASH,
SAMPLE_MARKDOWN,
SAMPLE_TEXT,
image_hash,
)
def compare_chunks(
chunks: Sequence[Chunk],
expected: Sequence[tuple[str | None, list[str], dict]],
):
data = [
(c.content, [image_hash(i) for i in c.images], c.item_metadata) for c in chunks
]
assert data == expected
def test_base_source_item_text_embeddings(mock_voyage_client):
item = SourceItem(
id=1,
content=SAMPLE_MARKDOWN,
mime_type="text/html",
modality="text",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
)
metadata = item.as_payload()
metadata["tags"] = {"bla"}
expected = [
(CHUNKS[0].strip(), cast(list[str], []), metadata),
(CHUNKS[1].strip(), cast(list[str], []), metadata),
("test summary", [], metadata | {"tags": {"tag1", "tag2", "bla"}}),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 1
assert not mock_voyage_client.multimodal_embed.call_count
assert mock_voyage_client.embed.call_args == call(
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
model=settings.TEXT_EMBEDDING_MODEL,
input_type="document",
)
def test_base_source_item_mixed_embeddings(mock_voyage_client):
item = SourceItem(
id=1,
content=SAMPLE_MARKDOWN,
filename=DATA_DIR / "lang_timeline.png",
mime_type="image/png",
modality="photo",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
)
metadata = item.as_payload()
metadata["tags"] = {"bla"}
expected = [
(CHUNKS[0].strip(), [], metadata),
(CHUNKS[1].strip(), [], metadata),
("test summary", [], metadata | {"tags": {"tag1", "tag2", "bla"}}),
(None, [LANG_TIMELINE_HASH], {"size": 3465, "source_id": 1, "tags": {"bla"}}),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 1
assert mock_voyage_client.multimodal_embed.call_count == 1
assert mock_voyage_client.embed.call_args == call(
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
model=settings.TEXT_EMBEDDING_MODEL,
input_type="document",
)
assert mock_voyage_client.multimodal_embed.call_args == call(
[[ANY]],
model=settings.MIXED_EMBEDDING_MODEL,
input_type="document",
)
assert [
image_hash(i) for i in mock_voyage_client.multimodal_embed.call_args[0][0][0]
] == [LANG_TIMELINE_HASH]
def test_mail_message_embeddings(mock_voyage_client):
item = MailMessage(
id=1,
content=SAMPLE_MARKDOWN,
mime_type="text/html",
modality="text",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
message_id="123",
subject="Test Subject",
sender="test@example.com",
recipients=["test@example.com"],
folder="INBOX",
sent_at=datetime(2025, 1, 1, 12, 0, 0),
)
metadata = item.as_payload()
metadata["tags"] = {"bla", "test@example.com"}
expected = [
(CHUNKS[0].strip(), [], metadata),
(CHUNKS[1].strip(), [], metadata),
(
"test summary",
[],
metadata | {"tags": {"tag1", "tag2", "bla", "test@example.com"}},
),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 1
assert not mock_voyage_client.multimodal_embed.call_count
assert mock_voyage_client.embed.call_args == call(
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
model=settings.TEXT_EMBEDDING_MODEL,
input_type="document",
)
def test_email_attachment_embeddings_text(mock_voyage_client):
item = EmailAttachment(
id=1,
content=SAMPLE_MARKDOWN,
mime_type="text/html",
modality="text",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
)
metadata = item.as_payload()
metadata["tags"] = {"bla"}
expected = [
(CHUNKS[0].strip(), cast(list[str], []), metadata),
(CHUNKS[1].strip(), cast(list[str], []), metadata),
(
"test summary",
[],
metadata | {"tags": {"tag1", "tag2", "bla"}},
),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 1
assert not mock_voyage_client.multimodal_embed.call_count
assert mock_voyage_client.embed.call_args == call(
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
model=settings.TEXT_EMBEDDING_MODEL,
input_type="document",
)
def test_email_attachment_embeddings_photo(mock_voyage_client):
item = EmailAttachment(
id=1,
content=SAMPLE_MARKDOWN,
filename=DATA_DIR / "lang_timeline.png",
mime_type="image/png",
modality="photo",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
)
metadata = item.as_payload()
metadata["tags"] = {"bla"}
expected = [
(None, [LANG_TIMELINE_HASH], metadata),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 0
assert mock_voyage_client.multimodal_embed.call_count == 1
assert mock_voyage_client.multimodal_embed.call_args == call(
[[ANY]],
model=settings.MIXED_EMBEDDING_MODEL,
input_type="document",
)
assert [
image_hash(i) for i in mock_voyage_client.multimodal_embed.call_args[0][0][0]
] == [LANG_TIMELINE_HASH]
def test_email_attachment_embeddings_pdf(mock_voyage_client):
item = EmailAttachment(
id=1,
content=SAMPLE_MARKDOWN,
filename=DATA_DIR / "regulamin.pdf",
mime_type="application/pdf",
modality="doc",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
)
metadata = item.as_payload()
metadata["tags"] = {"bla"}
with pymupdf.open(item.filename) as pdf:
expected = [
(
None,
[image_hash(page_to_image(page))],
metadata
| {
"page": page.number,
"width": page.rect.width,
"height": page.rect.height,
},
)
for page in pdf.pages()
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 0
assert mock_voyage_client.multimodal_embed.call_count == 1
assert mock_voyage_client.multimodal_embed.call_args == call(
[[ANY], [ANY]],
model=settings.MIXED_EMBEDDING_MODEL,
input_type="document",
)
assert [
[image_hash(a) for a in i]
for i in mock_voyage_client.multimodal_embed.call_args[0][0]
] == [page for _, page, _ in expected]
def test_email_attachment_embeddings_comic(mock_voyage_client):
item = Comic(
id=1,
content=SAMPLE_MARKDOWN,
filename=DATA_DIR / "lang_timeline.png",
mime_type="image/png",
modality="comic",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
title="The Evolution of Programming Languages",
author="John Doe",
published=datetime(2025, 1, 1, 12, 0, 0),
volume="1",
issue="1",
page=1,
)
metadata = item.as_payload()
metadata["tags"] = {"bla"}
expected = [
(
"The Evolution of Programming Languages by John Doe",
[LANG_TIMELINE_HASH],
metadata,
),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 0
assert mock_voyage_client.multimodal_embed.call_count == 1
assert mock_voyage_client.multimodal_embed.call_args == call(
[["The Evolution of Programming Languages by John Doe", ANY]],
model=settings.MIXED_EMBEDDING_MODEL,
input_type="document",
)
assert (
image_hash(mock_voyage_client.multimodal_embed.call_args[0][0][0][1])
== LANG_TIMELINE_HASH
)
def test_book_section_embeddings_single_page(mock_voyage_client):
item = BookSection(
id=1,
content=SAMPLE_MARKDOWN,
mime_type="text/html",
modality="text",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
book_id=1,
section_title="The Evolution of Programming Languages",
section_number=1,
section_level=1,
start_page=1,
end_page=1,
pages=[SAMPLE_TEXT],
book=Book(
id=1,
title="Programming Languages",
author="John Doe",
published=datetime(2025, 1, 1, 12, 0, 0),
),
)
metadata = item.as_payload()
metadata["tags"] = {"bla"}
expected = [
(CHUNKS[0].strip(), cast(list[str], []), metadata | {"type": "page"}),
(CHUNKS[1].strip(), cast(list[str], []), metadata | {"type": "page"}),
(
"test summary",
[],
metadata | {"tags": {"tag1", "tag2", "bla"}, "type": "summary"},
),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 1
assert not mock_voyage_client.multimodal_embed.call_count
assert mock_voyage_client.embed.call_args == call(
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
model=settings.TEXT_EMBEDDING_MODEL,
input_type="document",
)
def test_book_section_embeddings_multiple_pages(mock_voyage_client):
item = BookSection(
id=1,
content=SAMPLE_MARKDOWN + "\n\n" + SECOND_PAGE,
mime_type="text/html",
modality="text",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
book_id=1,
section_title="The Evolution of Programming Languages",
section_number=1,
section_level=1,
start_page=1,
end_page=2,
pages=[SAMPLE_TEXT, SECOND_PAGE_TEXT],
book=Book(
id=1,
title="Programming Languages",
author="John Doe",
published=datetime(2025, 1, 1, 12, 0, 0),
),
)
metadata = item.as_payload()
metadata["tags"] = {"bla", "tag1", "tag2"}
expected = [
(item.content.strip(), cast(list[str], []), metadata | {"type": "section"}),
("test summary", [], metadata | {"type": "summary"}),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 1
assert not mock_voyage_client.multimodal_embed.call_count
assert mock_voyage_client.embed.call_args == call(
[item.content.strip(), "test summary"],
model=settings.TEXT_EMBEDDING_MODEL,
input_type="document",
)
@pytest.mark.parametrize(
"class_, modality",
(
(BlogPost, "blog"),
(ForumPost, "forum"),
),
)
def test_post_embeddings_single_page(mock_voyage_client, class_, modality):
item = class_(
id=1,
content=SAMPLE_MARKDOWN,
mime_type="text/html",
modality=modality,
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
images=[LANG_TIMELINE.filename, CODE_COMPLEXITY.filename], # type: ignore
)
metadata = item.as_payload()
metadata["tags"] = {"bla", "tag1", "tag2"}
expected = [
(item.content.strip(), [LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH], metadata),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert not mock_voyage_client.embed.call_count
assert mock_voyage_client.multimodal_embed.call_count == 1
assert mock_voyage_client.multimodal_embed.call_args == call(
[[item.content.strip(), ANY, ANY]],
model=settings.MIXED_EMBEDDING_MODEL,
input_type="document",
)
assert [
image_hash(i)
for i in mock_voyage_client.multimodal_embed.call_args[0][0][0][1:]
] == [LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH]
@pytest.mark.parametrize(
"class_, modality",
(
(BlogPost, "blog"),
(ForumPost, "forum"),
),
)
def test_post_embeddings_multi_page(mock_voyage_client, class_, modality):
item = class_(
id=1,
content=SAMPLE_MARKDOWN + "\n\n" + SECOND_PAGE_MARKDOWN,
mime_type="text/html",
modality=modality,
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN + SECOND_PAGE_MARKDOWN),
tags=["bla"],
images=[LANG_TIMELINE.filename, CODE_COMPLEXITY.filename], # type: ignore
)
metadata = item.as_payload()
metadata["tags"] = {"bla", "tag1", "tag2"}
all_contents = (
item.content.strip(),
[LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH],
metadata,
)
first_chunk = (
TWO_PAGE_CHUNKS[0].strip(),
[LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH],
metadata,
)
second_chunk = (TWO_PAGE_CHUNKS[1].strip(), [], metadata)
third_chunk = (TWO_PAGE_CHUNKS[2].strip(), [], metadata)
summary = ("test summary", [], metadata)
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(
item.data_chunks(),
[all_contents, first_chunk, second_chunk, third_chunk, summary],
)
# embed_source_item first does text embedding, then mixed embedding
# so the order of chunks is different than in data_chunks()
compare_chunks(
embed_source_item(item),
[
second_chunk,
third_chunk,
summary,
all_contents,
first_chunk,
],
)
assert mock_voyage_client.embed.call_count == 1
assert mock_voyage_client.multimodal_embed.call_count == 1
assert mock_voyage_client.embed.call_args == call(
[
TWO_PAGE_CHUNKS[1].strip(),
TWO_PAGE_CHUNKS[2].strip(),
"test summary",
],
model=settings.TEXT_EMBEDDING_MODEL,
input_type="document",
)
assert mock_voyage_client.multimodal_embed.call_args == call(
[[item.content.strip(), ANY, ANY], [TWO_PAGE_CHUNKS[0].strip(), ANY, ANY]],
model=settings.MIXED_EMBEDDING_MODEL,
input_type="document",
)
assert [
image_hash(i)
for i in mock_voyage_client.multimodal_embed.call_args[0][0][0][1:]
] == [LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH]
assert [
image_hash(i)
for i in mock_voyage_client.multimodal_embed.call_args[0][0][1][1:]
] == [LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH]
def test_agent_observation_embeddings(mock_voyage_client):
item = AgentObservation(
id=1,
content="The user thinks that all men must die.",
mime_type="text/html",
modality="observation",
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
size=len(SAMPLE_MARKDOWN),
tags=["bla"],
observation_type="belief",
subject="humans",
confidence=0.8,
evidence={
"quote": "All humans are mortal.",
"source": "https://en.wikipedia.org/wiki/Human",
},
agent_model="gpt-4o",
inserted_at=datetime(2025, 1, 1, 12, 0, 0),
)
metadata = item.as_payload()
metadata["tags"] = {"bla"}
expected = [
(
"Subject: humans | Type: belief | Observation: The user thinks that all men must die. | Quote: All humans are mortal.",
[],
metadata | {"embedding_type": "semantic"},
),
(
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die. | Confidence: 0.8",
[],
metadata | {"embedding_type": "temporal"},
),
(
"The user thinks that all men must die.",
[],
metadata | {"embedding_type": "semantic"},
),
("All humans are mortal.", [], metadata | {"embedding_type": "semantic"}),
]
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
mock_voyage_client.multimodal_embed = Mock(
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
)
compare_chunks(item.data_chunks(), expected)
compare_chunks(embed_source_item(item), expected)
assert mock_voyage_client.embed.call_count == 1
assert not mock_voyage_client.multimodal_embed.call_count
assert mock_voyage_client.embed.call_args == call(
[
"Subject: humans | Type: belief | Observation: The user thinks that all men must die. | Quote: All humans are mortal.",
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die. | Confidence: 0.8",
"The user thinks that all men must die.",
"All humans are mortal.",
],
model=settings.TEXT_EMBEDDING_MODEL,
input_type="document",
)

View File

@ -14,7 +14,6 @@ from memory.common.db.models.source_items import (
BlogPost, BlogPost,
AgentObservation, AgentObservation,
) )
from memory.common.db.models.source_item import merge_metadata
@pytest.fixture @pytest.fixture
@ -356,7 +355,8 @@ def test_book_section_data_chunks(pages, expected_chunks):
chunks = book_section.data_chunks() chunks = book_section.data_chunks()
expected = [ expected = [
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks (c, extract.merge_metadata(book_section.as_payload(), m))
for c, m in expected_chunks
] ]
assert [(c.content, c.item_metadata) for c in chunks] == expected assert [(c.content, c.item_metadata) for c in chunks] == expected
for c in chunks: for c in chunks: