mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
Fix search + proper integration tests
This commit is contained in:
parent
b10a1fb130
commit
29b8ce6860
@ -3,4 +3,5 @@ uvicorn==0.29.0
|
||||
python-jose==3.3.0
|
||||
python-multipart==0.0.9
|
||||
sqladmin
|
||||
mcp==1.9.2
|
||||
mcp==1.9.2
|
||||
bm25s[full]==0.2.13
|
@ -6,5 +6,5 @@ dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.18.1
|
||||
|
||||
bm25s[full]==0.2.13
|
||||
# Pin the httpx version, as newer versions break the anthropic client
|
||||
httpx==0.27.0
|
@ -38,6 +38,7 @@ from memory.workers.tasks.maintenance import (
|
||||
CLEAN_COLLECTION,
|
||||
REINGEST_CHUNK,
|
||||
REINGEST_EMPTY_SOURCE_ITEMS,
|
||||
REINGEST_ALL_EMPTY_SOURCE_ITEMS,
|
||||
REINGEST_ITEM,
|
||||
REINGEST_MISSING_CHUNKS,
|
||||
UPDATE_METADATA_FOR_ITEM,
|
||||
@ -67,6 +68,7 @@ TASK_MAPPINGS = {
|
||||
"reingest_chunk": REINGEST_CHUNK,
|
||||
"reingest_item": REINGEST_ITEM,
|
||||
"reingest_empty_source_items": REINGEST_EMPTY_SOURCE_ITEMS,
|
||||
"reingest_all_empty_source_items": REINGEST_ALL_EMPTY_SOURCE_ITEMS,
|
||||
"update_metadata_for_item": UPDATE_METADATA_FOR_ITEM,
|
||||
"update_metadata_for_source_items": UPDATE_METADATA_FOR_SOURCE_ITEMS,
|
||||
},
|
||||
@ -316,6 +318,13 @@ def maintenance_reingest_empty_source_items(ctx, item_type):
|
||||
execute_task(ctx, "maintenance", "reingest_empty_source_items", item_type=item_type)
|
||||
|
||||
|
||||
@maintenance.command("reingest-all-empty-source-items")
|
||||
@click.pass_context
|
||||
def maintenance_reingest_all_empty_source_items(ctx):
|
||||
"""Reingest all empty source items."""
|
||||
execute_task(ctx, "maintenance", "reingest_all_empty_source_items")
|
||||
|
||||
|
||||
@maintenance.command("reingest-chunk")
|
||||
@click.option("--chunk-id", required=True, help="Chunk ID to reingest")
|
||||
@click.pass_context
|
||||
|
@ -15,14 +15,60 @@ from memory.common.db.connection import make_session
|
||||
|
||||
from memory.common import extract
|
||||
from memory.common.db.models import AgentObservation
|
||||
from memory.api.search import search, SearchFilters
|
||||
from memory.api.search.search import search, SearchFilters
|
||||
from memory.common.formatters import observation
|
||||
from memory.workers.tasks.content_processing import process_content_item
|
||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create MCP server instance
|
||||
mcp = FastMCP("memory", stateless=True)
|
||||
mcp = FastMCP("memory", stateless_http=True)
|
||||
|
||||
|
||||
def filter_observation_source_ids(
|
||||
tags: list[str] | None = None, observation_types: list[str] | None = None
|
||||
):
|
||||
if not tags and not observation_types:
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
items_query = session.query(AgentObservation.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if observation_types:
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.observation_type.in_(observation_types)
|
||||
)
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
|
||||
return source_ids
|
||||
|
||||
|
||||
def filter_source_ids(
|
||||
modalities: set[str],
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
if not tags:
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
items_query = session.query(SourceItem.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if modalities:
|
||||
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
|
||||
return source_ids
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@ -48,20 +94,6 @@ async def get_all_tags() -> list[str]:
|
||||
- Projects: "project:website-redesign"
|
||||
- Contexts: "context:work", "context:late-night"
|
||||
- Domains: "domain:finance"
|
||||
|
||||
Example:
|
||||
# Get all tags to ensure consistency
|
||||
tags = await get_all_tags()
|
||||
# Returns: ["ai-safety", "context:work", "functional-programming",
|
||||
# "machine-learning", "project:thesis", ...]
|
||||
|
||||
# Use to check if a topic has been discussed before
|
||||
if "quantum-computing" in tags:
|
||||
# Search for related observations
|
||||
observations = await search_observations(
|
||||
query="quantum computing",
|
||||
tags=["quantum-computing"]
|
||||
)
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
@ -93,26 +125,6 @@ async def get_all_subjects() -> list[str]:
|
||||
- "ai_beliefs", "ai_safety_beliefs"
|
||||
- "learning_preferences"
|
||||
- "communication_style"
|
||||
|
||||
Example:
|
||||
# Get all subjects to ensure consistency
|
||||
subjects = await get_all_subjects()
|
||||
# Returns: ["ai_safety_beliefs", "architecture_preferences",
|
||||
# "programming_philosophy", "work_schedule", ...]
|
||||
|
||||
# Use to check what we know about the user
|
||||
if "programming_style" in subjects:
|
||||
# Get all programming-related observations
|
||||
observations = await search_observations(
|
||||
query="programming",
|
||||
subject="programming_style"
|
||||
)
|
||||
|
||||
Best practices:
|
||||
- Always check existing subjects before creating new ones
|
||||
- Use snake_case for consistency
|
||||
- Be specific but not too granular
|
||||
- Group related observations under same subject
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
@ -146,20 +158,6 @@ async def get_all_observation_types() -> list[str]:
|
||||
|
||||
Returns:
|
||||
List of observation types that have actually been used in the system.
|
||||
|
||||
Example:
|
||||
# Check what types of observations exist
|
||||
types = await get_all_observation_types()
|
||||
# Returns: ["behavior", "belief", "contradiction", "preference"]
|
||||
|
||||
# Use to analyze observation distribution
|
||||
for obs_type in types:
|
||||
observations = await search_observations(
|
||||
query="",
|
||||
observation_types=[obs_type],
|
||||
limit=100
|
||||
)
|
||||
print(f"{obs_type}: {len(observations)} observations")
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
@ -173,7 +171,11 @@ async def get_all_observation_types() -> list[str]:
|
||||
|
||||
@mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
||||
query: str,
|
||||
previews: bool = False,
|
||||
modalities: set[str] = set(),
|
||||
tags: list[str] = [],
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search through the user's stored knowledge and content.
|
||||
@ -283,14 +285,22 @@ async def search_knowledge_base(
|
||||
"""
|
||||
logger.info(f"MCP search for: {query}")
|
||||
|
||||
if not modalities:
|
||||
modalities = set(ALL_COLLECTIONS.keys())
|
||||
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
||||
|
||||
upload_data = extract.extract_text(query)
|
||||
results = await search(
|
||||
upload_data,
|
||||
previews=previews,
|
||||
modalities=modalities,
|
||||
limit=limit,
|
||||
min_text_score=0.3,
|
||||
min_multimodal_score=0.3,
|
||||
min_text_score=0.4,
|
||||
min_multimodal_score=0.25,
|
||||
filters=SearchFilters(
|
||||
tags=tags,
|
||||
source_ids=filter_source_ids(tags=tags, modalities=modalities),
|
||||
),
|
||||
)
|
||||
|
||||
# Convert SearchResult objects to dictionaries for MCP
|
||||
@ -456,12 +466,13 @@ async def observe(
|
||||
mime_type="text/plain",
|
||||
sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(),
|
||||
inserted_at=datetime.now(timezone.utc),
|
||||
modality="observation",
|
||||
)
|
||||
try:
|
||||
with make_session() as session:
|
||||
process_content_item(observation, session)
|
||||
|
||||
if not observation.id:
|
||||
if not cast(int | None, observation.id):
|
||||
raise ValueError("Observation not created")
|
||||
|
||||
logger.info(
|
||||
@ -600,24 +611,6 @@ async def search_observations(
|
||||
- Higher confidence observations are more reliable
|
||||
- Recent observations may override older ones on same topic
|
||||
"""
|
||||
source_ids = None
|
||||
if tags or observation_types:
|
||||
with make_session() as session:
|
||||
items_query = session.query(AgentObservation.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if observation_types:
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.observation_type.in_(observation_types)
|
||||
)
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
if not source_ids:
|
||||
return []
|
||||
|
||||
semantic_text = observation.generate_semantic_text(
|
||||
subject=subject or "",
|
||||
observation_type="".join(observation_types or []),
|
||||
@ -637,18 +630,24 @@ async def search_observations(
|
||||
extract.DataChunk(data=[temporal]),
|
||||
],
|
||||
previews=True,
|
||||
modalities=["semantic", "temporal"],
|
||||
modalities={"semantic", "temporal"},
|
||||
limit=limit,
|
||||
min_text_score=0.8,
|
||||
filters=SearchFilters(
|
||||
subject=subject,
|
||||
confidence=min_confidence,
|
||||
tags=tags,
|
||||
observation_types=observation_types,
|
||||
source_ids=source_ids,
|
||||
source_ids=filter_observation_source_ids(tags=tags),
|
||||
),
|
||||
timeout=2,
|
||||
)
|
||||
|
||||
return [
|
||||
cast(dict, cast(dict, result.model_dump()).get("content")) for result in results
|
||||
{
|
||||
"content": r.content,
|
||||
"tags": r.tags,
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
"metadata": r.metadata,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
@ -3,6 +3,7 @@ FastAPI application for the knowledge base.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import pathlib
|
||||
import logging
|
||||
from typing import Annotated, Optional
|
||||
@ -105,12 +106,16 @@ def get_file_by_path(path: str):
|
||||
return FileResponse(path=file_path, filename=file_path.name)
|
||||
|
||||
|
||||
def main():
|
||||
def main(reload: bool = False):
|
||||
"""Run the FastAPI server in debug mode with auto-reloading."""
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"memory.api.app:app", host="0.0.0.0", port=8000, reload=True, log_level="debug"
|
||||
"memory.api.app:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=reload,
|
||||
log_level="debug",
|
||||
)
|
||||
|
||||
|
||||
@ -118,4 +123,4 @@ if __name__ == "__main__":
|
||||
from memory.common.qdrant import setup_qdrant
|
||||
|
||||
setup_qdrant()
|
||||
main()
|
||||
main(os.getenv("RELOAD", "false") == "true")
|
||||
|
@ -1,382 +0,0 @@
|
||||
"""
|
||||
Search endpoints for the knowledge base API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from hashlib import sha256
|
||||
import io
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional, TypedDict, NotRequired
|
||||
|
||||
import bm25s
|
||||
import Stemmer
|
||||
import qdrant_client
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
|
||||
from memory.common import embedding, extract, qdrant, settings
|
||||
from memory.common.collections import (
|
||||
ALL_COLLECTIONS,
|
||||
MULTIMODAL_COLLECTIONS,
|
||||
TEXT_COLLECTIONS,
|
||||
)
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotatedChunk(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
metadata: dict
|
||||
preview: Optional[str | None] = None
|
||||
|
||||
|
||||
class SourceData(BaseModel):
|
||||
"""Holds source item data to avoid SQLAlchemy session issues"""
|
||||
|
||||
id: int
|
||||
size: int | None
|
||||
mime_type: str | None
|
||||
filename: str | None
|
||||
content: str | dict | None
|
||||
content_length: int
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
collection: str
|
||||
results: list[dict]
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
id: int
|
||||
size: int
|
||||
mime_type: str
|
||||
chunks: list[AnnotatedChunk]
|
||||
content: Optional[str | dict] = None
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
class SearchFilters(TypedDict):
|
||||
subject: NotRequired[str | None]
|
||||
confidence: NotRequired[float]
|
||||
tags: NotRequired[list[str] | None]
|
||||
observation_types: NotRequired[list[str] | None]
|
||||
source_ids: NotRequired[list[int] | None]
|
||||
|
||||
|
||||
async def with_timeout(
|
||||
call, timeout: int = 2
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
Run a function with a timeout.
|
||||
|
||||
Args:
|
||||
call: The function to run
|
||||
timeout: The timeout in seconds
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(call, timeout=timeout)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Search timed out after {timeout}s")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def annotated_chunk(
|
||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||
) -> tuple[SourceData, AnnotatedChunk]:
|
||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||
if not previews and not isinstance(item, str):
|
||||
return None
|
||||
if not previews and isinstance(item, str):
|
||||
return item[:100]
|
||||
|
||||
if isinstance(item, Image.Image):
|
||||
buffer = io.BytesIO()
|
||||
format = item.format or "PNG"
|
||||
item.save(buffer, format=format)
|
||||
mime_type = f"image/{format.lower()}"
|
||||
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
||||
elif isinstance(item, bytes):
|
||||
return base64.b64encode(item).decode("utf-8")
|
||||
elif isinstance(item, str):
|
||||
return item
|
||||
else:
|
||||
raise ValueError(f"Unsupported item type: {type(item)}")
|
||||
|
||||
metadata = search_result.payload or {}
|
||||
metadata = {
|
||||
k: v
|
||||
for k, v in metadata.items()
|
||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||
}
|
||||
|
||||
# Prefetch all needed source data while in session
|
||||
source = chunk.source
|
||||
source_data = SourceData(
|
||||
id=source.id,
|
||||
size=source.size,
|
||||
mime_type=source.mime_type,
|
||||
filename=source.filename,
|
||||
content=source.display_contents,
|
||||
content_length=len(source.content) if source.content else 0,
|
||||
)
|
||||
|
||||
return source_data, AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=search_result.score,
|
||||
metadata=metadata,
|
||||
preview=serialize_item(chunk.data[0]) if chunk.data else None,
|
||||
)
|
||||
|
||||
|
||||
def group_chunks(chunks: list[tuple[SourceData, AnnotatedChunk]]) -> list[SearchResult]:
|
||||
items = defaultdict(list)
|
||||
source_lookup = {}
|
||||
|
||||
for source, chunk in chunks:
|
||||
items[source.id].append(chunk)
|
||||
source_lookup[source.id] = source
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
id=source.id,
|
||||
size=source.size or source.content_length,
|
||||
mime_type=source.mime_type or "text/plain",
|
||||
filename=source.filename
|
||||
and source.filename.replace(
|
||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||
),
|
||||
content=source.content,
|
||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||
)
|
||||
for source_id, chunks in items.items()
|
||||
for source in [source_lookup[source_id]]
|
||||
]
|
||||
|
||||
|
||||
def query_chunks(
|
||||
client: qdrant_client.QdrantClient,
|
||||
upload_data: list[extract.DataChunk],
|
||||
allowed_modalities: set[str],
|
||||
embedder: Callable,
|
||||
min_score: float = 0.0,
|
||||
limit: int = 10,
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||
if not upload_data or not allowed_modalities:
|
||||
return {}
|
||||
|
||||
chunks = [chunk for chunk in upload_data if chunk.data]
|
||||
if not chunks:
|
||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||
return {}
|
||||
|
||||
logger.error(f"Embedding {len(chunks)} chunks for {allowed_modalities}")
|
||||
for c in chunks:
|
||||
logger.error(f"Chunk: {c.data}")
|
||||
vectors = embedder([c.data for c in chunks], input_type="query")
|
||||
|
||||
return {
|
||||
collection: [
|
||||
r
|
||||
for vector in vectors
|
||||
for r in qdrant.search_vectors(
|
||||
client=client,
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
filter_params=filters,
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
for collection in allowed_modalities
|
||||
}
|
||||
|
||||
|
||||
async def search_bm25(
|
||||
query: str,
|
||||
modalities: list[str],
|
||||
limit: int = 10,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
with make_session() as db:
|
||||
items_query = db.query(Chunk.id, Chunk.content).filter(
|
||||
Chunk.collection_name.in_(modalities)
|
||||
)
|
||||
if source_ids := filters.get("source_ids"):
|
||||
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
||||
items = items_query.all()
|
||||
item_ids = {
|
||||
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
|
||||
for item in items
|
||||
}
|
||||
corpus = [item.content.lower().strip() for item in items]
|
||||
|
||||
stemmer = Stemmer.Stemmer("english")
|
||||
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
|
||||
retriever = bm25s.BM25()
|
||||
retriever.index(corpus_tokens)
|
||||
|
||||
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
|
||||
results, scores = retriever.retrieve(
|
||||
query_tokens, k=min(limit, len(corpus)), corpus=corpus
|
||||
)
|
||||
|
||||
item_scores = {
|
||||
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
|
||||
for doc, score in zip(results[0], scores[0])
|
||||
}
|
||||
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
# Prefetch all needed source data while in session
|
||||
source = chunk.source
|
||||
source_data = SourceData(
|
||||
id=source.id,
|
||||
size=source.size,
|
||||
mime_type=source.mime_type,
|
||||
filename=source.filename,
|
||||
content=source.display_contents,
|
||||
content_length=len(source.content) if source.content else 0,
|
||||
)
|
||||
|
||||
annotated = AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=item_scores[chunk.id],
|
||||
metadata=source.as_payload(),
|
||||
preview=None,
|
||||
)
|
||||
results.append((source_data, annotated))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_embeddings(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
min_score: float = 0.3,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
multimodal: bool = False,
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
|
||||
Parameters:
|
||||
- data: List of data to search in (e.g., text, images, files)
|
||||
- previews: Whether to include previews in the search results
|
||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||
- limit: Maximum number of results
|
||||
- min_score: Minimum score to include in the search results
|
||||
- filters: Filters to apply to the search results
|
||||
- multimodal: Whether to search in multimodal collections
|
||||
"""
|
||||
query_filters = {
|
||||
"must": [
|
||||
{"key": "confidence", "range": {"gte": filters.get("confidence", 0.5)}},
|
||||
],
|
||||
}
|
||||
if tags := filters.get("tags"):
|
||||
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
|
||||
if observation_types := filters.get("observation_types"):
|
||||
query_filters["must"] += [
|
||||
{"key": "observation_type", "match": {"any": observation_types}}
|
||||
]
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
modalities,
|
||||
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||
min_score=min_score,
|
||||
limit=limit,
|
||||
filters=query_filters,
|
||||
)
|
||||
search_results = {k: results.get(k, []) for k in modalities}
|
||||
|
||||
found_chunks = {
|
||||
str(r.id): r for results in search_results.values() for r in results
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
return [
|
||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
|
||||
async def search(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: list[str] = [],
|
||||
limit: int = 10,
|
||||
min_text_score: float = 0.3,
|
||||
min_multimodal_score: float = 0.3,
|
||||
filters: SearchFilters = {},
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
|
||||
Parameters:
|
||||
- query: Optional text search query
|
||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||
- files: Optional files to include in the search context
|
||||
- limit: Maximum number of results per modality
|
||||
|
||||
Returns:
|
||||
- List of search results sorted by score
|
||||
"""
|
||||
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
||||
|
||||
text_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
limit,
|
||||
min_text_score,
|
||||
filters,
|
||||
multimodal=False,
|
||||
)
|
||||
)
|
||||
multimodal_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||
limit,
|
||||
min_multimodal_score,
|
||||
filters,
|
||||
multimodal=True,
|
||||
)
|
||||
)
|
||||
bm25_results = with_timeout(
|
||||
search_bm25(
|
||||
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
|
||||
modalities,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
text_embeddings_results,
|
||||
multimodal_embeddings_results,
|
||||
bm25_results,
|
||||
return_exceptions=False,
|
||||
)
|
||||
|
||||
results = group_chunks([c for r in results for c in r])
|
||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
4
src/memory/api/search/__init__.py
Normal file
4
src/memory/api/search/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .search import search
|
||||
from .utils import SearchResult, SearchFilters
|
||||
|
||||
__all__ = ["search", "SearchResult", "SearchFilters"]
|
68
src/memory/api/search/bm25.py
Normal file
68
src/memory/api/search/bm25.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""
|
||||
Search endpoints for the knowledge base API.
|
||||
"""
|
||||
|
||||
from hashlib import sha256
|
||||
import logging
|
||||
|
||||
import bm25s
|
||||
import Stemmer
|
||||
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def search_bm25(
|
||||
query: str,
|
||||
modalities: set[str],
|
||||
limit: int = 10,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
with make_session() as db:
|
||||
items_query = db.query(Chunk.id, Chunk.content).filter(
|
||||
Chunk.collection_name.in_(modalities)
|
||||
)
|
||||
if source_ids := filters.get("source_ids"):
|
||||
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
||||
items = items_query.all()
|
||||
item_ids = {
|
||||
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
|
||||
for item in items
|
||||
}
|
||||
corpus = [item.content.lower().strip() for item in items]
|
||||
|
||||
stemmer = Stemmer.Stemmer("english")
|
||||
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
|
||||
retriever = bm25s.BM25()
|
||||
retriever.index(corpus_tokens)
|
||||
|
||||
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
|
||||
results, scores = retriever.retrieve(
|
||||
query_tokens, k=min(limit, len(corpus)), corpus=corpus
|
||||
)
|
||||
|
||||
item_scores = {
|
||||
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
|
||||
for doc, score in zip(results[0], scores[0])
|
||||
}
|
||||
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
# Prefetch all needed source data while in session
|
||||
source_data = SourceData.from_chunk(chunk)
|
||||
|
||||
annotated = AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=item_scores[chunk.id],
|
||||
metadata=chunk.source.as_payload(),
|
||||
preview=None,
|
||||
search_method="bm25",
|
||||
)
|
||||
results.append((source_data, annotated))
|
||||
|
||||
return results
|
144
src/memory/api/search/embeddings.py
Normal file
144
src/memory/api/search/embeddings.py
Normal file
@ -0,0 +1,144 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import qdrant_client
|
||||
from PIL import Image
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
|
||||
from memory.common import embedding, extract, qdrant
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk
|
||||
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def annotated_chunk(
|
||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||
) -> tuple[SourceData, AnnotatedChunk]:
|
||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||
if not previews and not isinstance(item, str):
|
||||
return None
|
||||
if not previews and isinstance(item, str):
|
||||
return item[:100]
|
||||
|
||||
if isinstance(item, Image.Image):
|
||||
buffer = io.BytesIO()
|
||||
format = item.format or "PNG"
|
||||
item.save(buffer, format=format)
|
||||
mime_type = f"image/{format.lower()}"
|
||||
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
||||
elif isinstance(item, bytes):
|
||||
return base64.b64encode(item).decode("utf-8")
|
||||
elif isinstance(item, str):
|
||||
return item
|
||||
else:
|
||||
raise ValueError(f"Unsupported item type: {type(item)}")
|
||||
|
||||
metadata = search_result.payload or {}
|
||||
metadata = {
|
||||
k: v
|
||||
for k, v in metadata.items()
|
||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||
}
|
||||
|
||||
# Prefetch all needed source data while in session
|
||||
return SourceData.from_chunk(chunk), AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=search_result.score,
|
||||
metadata=metadata,
|
||||
preview=serialize_item(chunk.data[0]) if chunk.data else None,
|
||||
search_method="embeddings",
|
||||
)
|
||||
|
||||
|
||||
def query_chunks(
|
||||
client: qdrant_client.QdrantClient,
|
||||
upload_data: list[extract.DataChunk],
|
||||
allowed_modalities: set[str],
|
||||
embedder: Callable,
|
||||
min_score: float = 0.3,
|
||||
limit: int = 10,
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||
if not upload_data or not allowed_modalities:
|
||||
return {}
|
||||
|
||||
chunks = [chunk for chunk in upload_data if chunk.data]
|
||||
if not chunks:
|
||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||
return {}
|
||||
|
||||
vectors = embedder(chunks, input_type="query")
|
||||
|
||||
return {
|
||||
collection: [
|
||||
r
|
||||
for vector in vectors
|
||||
for r in qdrant.search_vectors(
|
||||
client=client,
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
filter_params=filters,
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
for collection in allowed_modalities
|
||||
}
|
||||
|
||||
|
||||
async def search_embeddings(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
min_score: float = 0.3,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
multimodal: bool = False,
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
|
||||
Parameters:
|
||||
- data: List of data to search in (e.g., text, images, files)
|
||||
- previews: Whether to include previews in the search results
|
||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||
- limit: Maximum number of results
|
||||
- min_score: Minimum score to include in the search results
|
||||
- filters: Filters to apply to the search results
|
||||
- multimodal: Whether to search in multimodal collections
|
||||
"""
|
||||
query_filters = {}
|
||||
if confidence := filters.get("confidence"):
|
||||
query_filters["must"] += [{"key": "confidence", "range": {"gte": confidence}}]
|
||||
if tags := filters.get("tags"):
|
||||
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
|
||||
if observation_types := filters.get("observation_types"):
|
||||
query_filters["must"] += [
|
||||
{"key": "observation_type", "match": {"any": observation_types}}
|
||||
]
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
modalities,
|
||||
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||
min_score=min_score,
|
||||
limit=limit,
|
||||
filters=query_filters,
|
||||
)
|
||||
search_results = {k: results.get(k, []) for k in modalities}
|
||||
|
||||
found_chunks = {
|
||||
str(r.id): r for results in search_results.values() for r in results
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
return [
|
||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||
for chunk in chunks
|
||||
]
|
94
src/memory/api/search/search.py
Normal file
94
src/memory/api/search/search.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""
|
||||
Search endpoints for the knowledge base API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from memory.api.search.embeddings import search_embeddings
|
||||
from memory.api.search.bm25 import search_bm25
|
||||
from memory.api.search.utils import SearchFilters, SearchResult
|
||||
|
||||
from memory.api.search.utils import group_chunks, with_timeout
|
||||
from memory.common import extract
|
||||
from memory.common.collections import (
|
||||
ALL_COLLECTIONS,
|
||||
MULTIMODAL_COLLECTIONS,
|
||||
TEXT_COLLECTIONS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def search(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
min_text_score: float = 0.4,
|
||||
min_multimodal_score: float = 0.25,
|
||||
filters: SearchFilters = {},
|
||||
timeout: int = 2,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
|
||||
Parameters:
|
||||
- query: Optional text search query
|
||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||
- files: Optional files to include in the search context
|
||||
- limit: Maximum number of results per modality
|
||||
|
||||
Returns:
|
||||
- List of search results sorted by score
|
||||
"""
|
||||
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
|
||||
|
||||
text_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
limit,
|
||||
min_text_score,
|
||||
filters,
|
||||
multimodal=False,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
multimodal_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||
limit,
|
||||
min_multimodal_score,
|
||||
filters,
|
||||
multimodal=True,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
bm25_results = with_timeout(
|
||||
search_bm25(
|
||||
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
|
||||
modalities,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
text_embeddings_results,
|
||||
multimodal_embeddings_results,
|
||||
bm25_results,
|
||||
return_exceptions=False,
|
||||
)
|
||||
text_results, multi_results, bm25_results = results
|
||||
all_results = text_results + multi_results
|
||||
if len(all_results) < limit:
|
||||
all_results += bm25_results
|
||||
|
||||
results = group_chunks(all_results, previews or False)
|
||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
134
src/memory/api/search/utils.py
Normal file
134
src/memory/api/search/utils.py
Normal file
@ -0,0 +1,134 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Optional, TypedDict, NotRequired
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.models import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotatedChunk(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
metadata: dict
|
||||
preview: Optional[str | None] = None
|
||||
search_method: str | None = None
|
||||
|
||||
|
||||
class SourceData(BaseModel):
|
||||
"""Holds source item data to avoid SQLAlchemy session issues"""
|
||||
|
||||
id: int
|
||||
size: int | None
|
||||
mime_type: str | None
|
||||
filename: str | None
|
||||
content_length: int
|
||||
contents: dict | None
|
||||
created_at: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def from_chunk(chunk: Chunk) -> "SourceData":
|
||||
source = chunk.source
|
||||
display_contents = source.display_contents or {}
|
||||
return SourceData(
|
||||
id=source.id,
|
||||
size=source.size,
|
||||
mime_type=source.mime_type,
|
||||
filename=source.filename,
|
||||
content_length=len(source.content) if source.content else 0,
|
||||
contents=display_contents,
|
||||
created_at=source.inserted_at,
|
||||
)
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
collection: str
|
||||
results: list[dict]
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
id: int
|
||||
size: int
|
||||
mime_type: str
|
||||
chunks: list[AnnotatedChunk]
|
||||
content: Optional[str | dict] = None
|
||||
filename: Optional[str] = None
|
||||
tags: list[str] | None = None
|
||||
metadata: dict | None = None
|
||||
created_at: datetime | None = None
|
||||
|
||||
|
||||
class SearchFilters(TypedDict):
|
||||
subject: NotRequired[str | None]
|
||||
confidence: NotRequired[float]
|
||||
tags: NotRequired[list[str] | None]
|
||||
observation_types: NotRequired[list[str] | None]
|
||||
source_ids: NotRequired[list[int] | None]
|
||||
|
||||
|
||||
async def with_timeout(
|
||||
call, timeout: int = 2
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
Run a function with a timeout.
|
||||
|
||||
Args:
|
||||
call: The function to run
|
||||
timeout: The timeout in seconds
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(call, timeout=timeout)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Search timed out after {timeout}s")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def group_chunks(
|
||||
chunks: list[tuple[SourceData, AnnotatedChunk]], preview: bool = False
|
||||
) -> list[SearchResult]:
|
||||
items = defaultdict(list)
|
||||
source_lookup = {}
|
||||
|
||||
for source, chunk in chunks:
|
||||
items[source.id].append(chunk)
|
||||
source_lookup[source.id] = source
|
||||
|
||||
def get_content(text: str | dict | None) -> str | dict | None:
|
||||
if preview or not text or not isinstance(text, str) or len(text) < 250:
|
||||
return text
|
||||
|
||||
return text[:250] + "..."
|
||||
|
||||
def make_result(source: SourceData, chunks: list[AnnotatedChunk]) -> SearchResult:
|
||||
contents = source.contents or {}
|
||||
tags = contents.pop("tags", [])
|
||||
content = contents.pop("content", None)
|
||||
|
||||
return SearchResult(
|
||||
id=source.id,
|
||||
size=source.size or source.content_length,
|
||||
mime_type=source.mime_type or "text/plain",
|
||||
filename=source.filename
|
||||
and source.filename.replace(
|
||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||
),
|
||||
content=get_content(content),
|
||||
tags=tags,
|
||||
metadata=contents,
|
||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||
created_at=source.created_at,
|
||||
)
|
||||
|
||||
return [
|
||||
make_result(source, chunks)
|
||||
for source_id, chunks in items.items()
|
||||
for source in [source_lookup[source_id]]
|
||||
]
|
@ -102,6 +102,7 @@ TEXT_COLLECTIONS = {
|
||||
MULTIMODAL_COLLECTIONS = {
|
||||
coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal")
|
||||
}
|
||||
OBSERVATION_COLLECTIONS = {"semantic", "temporal"}
|
||||
|
||||
TYPES = {
|
||||
"doc": ["application/pdf", "application/docx", "application/msword"],
|
||||
|
@ -84,7 +84,7 @@ def clean_filename(filename: str) -> str:
|
||||
|
||||
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
|
||||
for i, image in enumerate(images):
|
||||
if not image.filename: # type: ignore
|
||||
if not getattr(image, "filename", None): # type: ignore
|
||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
|
||||
image.save(filename)
|
||||
image.filename = str(filename) # type: ignore
|
||||
@ -100,16 +100,6 @@ def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalCh
|
||||
]
|
||||
|
||||
|
||||
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
final = {}
|
||||
for m in metadata:
|
||||
data = m.copy()
|
||||
if tags := set(data.pop("tags", [])):
|
||||
final["tags"] = tags | final.get("tags", set())
|
||||
final |= data
|
||||
return final
|
||||
|
||||
|
||||
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
|
||||
if not content.strip():
|
||||
return []
|
||||
@ -241,14 +231,11 @@ class SourceItem(Base):
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
chunks: list[extract.DataChunk] = []
|
||||
content = cast(str | None, self.content)
|
||||
if content:
|
||||
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
|
||||
|
||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
summary, tags = summarizer.summarize(content)
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
chunks = extract.extract_text(content)
|
||||
else:
|
||||
chunks = []
|
||||
|
||||
mime_type = cast(str | None, self.mime_type)
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
@ -272,12 +259,14 @@ class SourceItem(Base):
|
||||
file_paths=image_names,
|
||||
collection_name=modality,
|
||||
embedding_model=collections.collection_model(modality, text, images),
|
||||
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
||||
item_metadata=extract.merge_metadata(
|
||||
self.as_payload(), data.metadata, metadata
|
||||
),
|
||||
)
|
||||
return chunk
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
return [self._make_chunk(data) for data in self._chunk_contents()]
|
||||
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
@ -291,4 +280,7 @@ class SourceItem(Base):
|
||||
return {
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
"content": self.content,
|
||||
"filename": self.filename,
|
||||
"mime_type": self.mime_type,
|
||||
}
|
||||
|
@ -33,7 +33,6 @@ from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
Chunk,
|
||||
clean_filename,
|
||||
merge_metadata,
|
||||
chunk_mixed,
|
||||
)
|
||||
|
||||
@ -326,27 +325,24 @@ class BookSection(SourceItem):
|
||||
}
|
||||
return {k: v for k, v in vals.items() if v}
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
content = cast(str, self.content.strip())
|
||||
if not content:
|
||||
return []
|
||||
|
||||
if len([p for p in self.pages if p.strip()]) == 1:
|
||||
return [
|
||||
self._make_chunk(
|
||||
extract.DataChunk(data=[content]), metadata | {"type": "page"}
|
||||
)
|
||||
]
|
||||
chunks = extract.extract_text(content, metadata={"type": "page"})
|
||||
if len(chunks) > 1:
|
||||
chunks[-1].metadata["type"] = "summary"
|
||||
return chunks
|
||||
|
||||
summary, tags = summarizer.summarize(content)
|
||||
return [
|
||||
self._make_chunk(
|
||||
extract.DataChunk(data=[content]),
|
||||
merge_metadata(metadata, {"type": "section", "tags": tags}),
|
||||
extract.DataChunk(
|
||||
data=[content], metadata={"type": "section", "tags": tags}
|
||||
),
|
||||
self._make_chunk(
|
||||
extract.DataChunk(data=[summary]),
|
||||
merge_metadata(metadata, {"type": "summary", "tags": tags}),
|
||||
extract.DataChunk(
|
||||
data=[summary], metadata={"type": "summary", "tags": tags}
|
||||
),
|
||||
]
|
||||
|
||||
@ -596,7 +592,7 @@ class AgentObservation(SourceItem):
|
||||
)
|
||||
semantic_chunk = extract.DataChunk(
|
||||
data=[semantic_text],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||
metadata=extract.merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||
modality="semantic",
|
||||
)
|
||||
|
||||
@ -609,7 +605,7 @@ class AgentObservation(SourceItem):
|
||||
)
|
||||
temporal_chunk = extract.DataChunk(
|
||||
data=[temporal_text],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
|
||||
metadata=extract.merge_metadata(metadata, {"embedding_type": "temporal"}),
|
||||
modality="temporal",
|
||||
)
|
||||
|
||||
@ -617,14 +613,14 @@ class AgentObservation(SourceItem):
|
||||
self._make_chunk(
|
||||
extract.DataChunk(
|
||||
data=[i],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||
metadata=extract.merge_metadata(
|
||||
metadata, {"embedding_type": "semantic"}
|
||||
),
|
||||
modality="semantic",
|
||||
)
|
||||
)
|
||||
for i in [
|
||||
self.content,
|
||||
self.subject,
|
||||
self.observation_type,
|
||||
self.evidence.get("quote", ""),
|
||||
]
|
||||
if i
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Iterable, Literal, cast
|
||||
from typing import Literal, cast
|
||||
|
||||
import voyageai
|
||||
|
||||
@ -15,12 +15,22 @@ from memory.common.db.models import Chunk, SourceItem
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def as_string(
|
||||
chunk: extract.MulitmodalChunk | list[extract.MulitmodalChunk],
|
||||
) -> str:
|
||||
if isinstance(chunk, str):
|
||||
return chunk.strip()
|
||||
if isinstance(chunk, list):
|
||||
return "\n".join(as_string(i) for i in chunk).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def embed_chunks(
|
||||
chunks: list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
logger.debug(f"Embedding chunks: {model} - {str(chunks)[:100]} {len(chunks)}")
|
||||
logger.debug(f"Embedding chunks: {model} - {str(chunks)} {len(chunks)}")
|
||||
vo = voyageai.Client() # type: ignore
|
||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||
return vo.multimodal_embed(
|
||||
@ -29,17 +39,18 @@ def embed_chunks(
|
||||
input_type=input_type,
|
||||
).embeddings
|
||||
|
||||
texts = ["\n".join(i for i in c if isinstance(i, str)) for c in chunks]
|
||||
texts = [as_string(c) for c in chunks]
|
||||
logger.debug(f"Embedding texts: {texts}")
|
||||
return cast(
|
||||
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
|
||||
)
|
||||
|
||||
|
||||
def break_chunk(
|
||||
chunk: list[extract.MulitmodalChunk], chunk_size: int = DEFAULT_CHUNK_TOKENS
|
||||
chunk: extract.DataChunk, chunk_size: int = DEFAULT_CHUNK_TOKENS
|
||||
) -> list[extract.MulitmodalChunk]:
|
||||
result = []
|
||||
for c in chunk:
|
||||
for c in chunk.data:
|
||||
if isinstance(c, str):
|
||||
result += chunk_text(c, chunk_size, OVERLAP_TOKENS)
|
||||
else:
|
||||
@ -48,12 +59,12 @@ def break_chunk(
|
||||
|
||||
|
||||
def embed_text(
|
||||
chunks: list[list[extract.MulitmodalChunk]],
|
||||
chunks: list[extract.DataChunk],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||
) -> list[Vector]:
|
||||
chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks]
|
||||
chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks if chunk.data]
|
||||
if not any(chunked_chunks):
|
||||
return []
|
||||
|
||||
@ -61,12 +72,12 @@ def embed_text(
|
||||
|
||||
|
||||
def embed_mixed(
|
||||
items: list[list[extract.MulitmodalChunk]],
|
||||
items: list[extract.DataChunk],
|
||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||
) -> list[Vector]:
|
||||
chunked_chunks = [break_chunk(item, chunk_size) for item in items]
|
||||
chunked_chunks = [break_chunk(item, chunk_size) for item in items if item.data]
|
||||
return embed_chunks(chunked_chunks, model, input_type)
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ import tempfile
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Sequence, cast
|
||||
|
||||
from memory.common import chunker
|
||||
from memory.common import chunker, summarizer
|
||||
import pymupdf # PyMuPDF
|
||||
import pypandoc
|
||||
from PIL import Image
|
||||
@ -16,6 +16,16 @@ logger = logging.getLogger(__name__)
|
||||
MulitmodalChunk = Image.Image | str
|
||||
|
||||
|
||||
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
final = {}
|
||||
for m in metadata:
|
||||
data = m.copy()
|
||||
if tags := set(data.pop("tags", []) or []):
|
||||
final["tags"] = tags | final.get("tags", set())
|
||||
final |= data
|
||||
return final
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataChunk:
|
||||
data: Sequence[MulitmodalChunk]
|
||||
@ -109,7 +119,9 @@ def extract_image(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
|
||||
|
||||
def extract_text(
|
||||
content: bytes | str | pathlib.Path, chunk_size: int | None = None
|
||||
content: bytes | str | pathlib.Path,
|
||||
chunk_size: int | None = None,
|
||||
metadata: dict[str, Any] = {},
|
||||
) -> list[DataChunk]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
content = content.read_text()
|
||||
@ -117,8 +129,20 @@ def extract_text(
|
||||
content = content.decode("utf-8")
|
||||
|
||||
content = cast(str, content)
|
||||
chunks = chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
|
||||
return [DataChunk(data=[c], mime_type="text/plain") for c in chunks if c.strip()]
|
||||
chunks = [
|
||||
DataChunk(data=[c], modality="text", metadata=metadata)
|
||||
for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
|
||||
]
|
||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
summary, tags = summarizer.summarize(content)
|
||||
chunks.append(
|
||||
DataChunk(
|
||||
data=[summary],
|
||||
metadata=merge_metadata(metadata, {"tags": tags}),
|
||||
modality="text",
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
|
||||
|
||||
def extract_data_chunks(
|
||||
|
@ -3,7 +3,7 @@ from typing import Any, cast, Generator, Sequence
|
||||
|
||||
import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse, ApiException
|
||||
from memory.common import settings
|
||||
from memory.common.collections import ALL_COLLECTIONS, Collection, DistanceType, Vector
|
||||
|
||||
@ -193,14 +193,18 @@ def delete_points(
|
||||
collection_name: Name of the collection
|
||||
ids: List of vector IDs to delete
|
||||
"""
|
||||
client.delete(
|
||||
collection_name=collection_name,
|
||||
points_selector=qdrant_models.PointIdsList(
|
||||
points=ids, # type: ignore
|
||||
),
|
||||
)
|
||||
try:
|
||||
client.delete(
|
||||
collection_name=collection_name,
|
||||
points_selector=qdrant_models.PointIdsList(
|
||||
points=ids, # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
|
||||
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
|
||||
except (ApiException, UnexpectedResponse) as e:
|
||||
logger.error(f"Error deleting points from {collection_name}: {e}")
|
||||
raise IOError(f"Error deleting points from {collection_name}: {e}")
|
||||
|
||||
|
||||
def get_collection_info(
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from memory.common import settings, chunker
|
||||
@ -131,6 +132,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
summary = result.get("summary", "")
|
||||
tags = result.get("tags", [])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"Summarization failed: {e}")
|
||||
|
||||
tokens = chunker.approx_token_count(summary)
|
||||
|
@ -60,6 +60,7 @@ def section_processor(
|
||||
end_page=section.end_page,
|
||||
parent_section_id=None, # Will be set after flush
|
||||
content=content,
|
||||
filename=book.file_path,
|
||||
size=len(content),
|
||||
mime_type="text/plain",
|
||||
sha256=create_content_hash(
|
||||
|
@ -21,6 +21,7 @@ REINGEST_MISSING_CHUNKS = f"{MAINTENANCE_ROOT}.reingest_missing_chunks"
|
||||
REINGEST_CHUNK = f"{MAINTENANCE_ROOT}.reingest_chunk"
|
||||
REINGEST_ITEM = f"{MAINTENANCE_ROOT}.reingest_item"
|
||||
REINGEST_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_empty_source_items"
|
||||
REINGEST_ALL_EMPTY_SOURCE_ITEMS = f"{MAINTENANCE_ROOT}.reingest_all_empty_source_items"
|
||||
UPDATE_METADATA_FOR_SOURCE_ITEMS = (
|
||||
f"{MAINTENANCE_ROOT}.update_metadata_for_source_items"
|
||||
)
|
||||
@ -76,9 +77,9 @@ def reingest_chunk(chunk_id: str, collection: str):
|
||||
|
||||
data = chunk.data
|
||||
if collection in collections.MULTIMODAL_COLLECTIONS:
|
||||
vector = embedding.embed_mixed(data)[0]
|
||||
elif len(data) == 1 and isinstance(data[0], str):
|
||||
vector = embedding.embed_text([data[0]])[0]
|
||||
vector = embedding.embed_mixed([data])[0]
|
||||
elif collection in collections.TEXT_COLLECTIONS:
|
||||
vector = embedding.embed_text([data])[0]
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type for collection {collection}")
|
||||
|
||||
@ -123,7 +124,10 @@ def reingest_item(item_id: str, item_type: str):
|
||||
chunk_ids = [str(c.id) for c in item.chunks if c.id]
|
||||
if chunk_ids:
|
||||
client = qdrant.get_qdrant_client()
|
||||
qdrant.delete_points(client, item.modality, chunk_ids)
|
||||
try:
|
||||
qdrant.delete_points(client, item.modality, chunk_ids)
|
||||
except IOError as e:
|
||||
logger.error(f"Error deleting chunks for {item_id}: {e}")
|
||||
|
||||
for chunk in item.chunks:
|
||||
session.delete(chunk)
|
||||
@ -151,6 +155,13 @@ def reingest_empty_source_items(item_type: str):
|
||||
return {"status": "success", "items": len(item_ids)}
|
||||
|
||||
|
||||
@app.task(name=REINGEST_ALL_EMPTY_SOURCE_ITEMS)
|
||||
def reingest_all_empty_source_items():
|
||||
logger.info("Reingesting all empty source items")
|
||||
for item_type in SourceItem.registry._class_registry.keys():
|
||||
reingest_empty_source_items.delay(item_type) # type: ignore
|
||||
|
||||
|
||||
def check_batch(batch: Sequence[Chunk]) -> dict:
|
||||
client = qdrant.get_qdrant_client()
|
||||
by_collection = defaultdict(list)
|
||||
|
@ -234,8 +234,10 @@ def mock_voyage_client():
|
||||
def embeder(chunks, *args, **kwargs):
|
||||
return Mock(embeddings=[[0.1] * 1024] * len(chunks))
|
||||
|
||||
real_client = voyageai.Client
|
||||
with patch.object(voyageai, "Client", autospec=True) as mock_client:
|
||||
client = mock_client()
|
||||
client.real_client = real_client
|
||||
client.embed = embeder
|
||||
client.multimodal_embed = embeder
|
||||
yield client
|
||||
@ -251,7 +253,7 @@ def mock_openai_client():
|
||||
choices=[
|
||||
Mock(
|
||||
message=Mock(
|
||||
content='{"summary": "test", "tags": ["tag1", "tag2"]}'
|
||||
content='{"summary": "test summary", "tags": ["tag1", "tag2"]}'
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -267,7 +269,9 @@ def mock_anthropic_client():
|
||||
client.messages = Mock()
|
||||
client.messages.create = Mock(
|
||||
return_value=Mock(
|
||||
content=[Mock(text='{"summary": "test", "tags": ["tag1", "tag2"]}')]
|
||||
content=[
|
||||
Mock(text='{"summary": "test summary", "tags": ["tag1", "tag2"]}')
|
||||
]
|
||||
)
|
||||
)
|
||||
yield client
|
||||
|
BIN
tests/data/code_complexity.jpg
Normal file
BIN
tests/data/code_complexity.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 68 KiB |
249
tests/data/contents.py
Normal file
249
tests/data/contents.py
Normal file
@ -0,0 +1,249 @@
|
||||
import hashlib
|
||||
import pathlib
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify
|
||||
from PIL import Image
|
||||
|
||||
DATA_DIR = pathlib.Path(__file__).parent
|
||||
|
||||
SAMPLE_HTML = f"""
|
||||
<html>
|
||||
<body>
|
||||
<h1>The Evolution of Programming Languages</h1>
|
||||
|
||||
<p>Programming languages have undergone tremendous evolution since the early days of computing.
|
||||
From the machine code and assembly languages of the 1940s to the high-level, expressive languages
|
||||
we use today, each generation has built upon the lessons learned from its predecessors. Languages
|
||||
like FORTRAN and COBOL pioneered the concept of human-readable code, while later innovations like
|
||||
object-oriented programming in languages such as Smalltalk and C++ revolutionized how we structure
|
||||
and organize our programs.</p>
|
||||
|
||||
<img src="{DATA_DIR / "lang_timeline.png"}" alt="Timeline of programming language evolution" width="600" height="400">
|
||||
|
||||
<p>The rise of functional programming paradigms has brought mathematical rigor and immutability
|
||||
to the forefront of software development. Languages like Haskell, Lisp, and more recently Rust
|
||||
and Elm have demonstrated the power of pure functions and type systems in creating more reliable
|
||||
and maintainable code. These paradigms emphasize the elimination of side effects and the treatment
|
||||
of computation as the evaluation of mathematical functions.</p>
|
||||
patch
|
||||
<p>Modern development has also seen the emergence of domain-specific languages and the resurgence
|
||||
of interest in memory safety. The advent of languages like Python and JavaScript has democratized
|
||||
programming by lowering the barrier to entry, while systems languages like Rust have proven that
|
||||
performance and safety need not be mutually exclusive. The ongoing development of WebAssembly
|
||||
promises to bring high-performance computing to web browsers in ways previously unimaginable.</p>
|
||||
|
||||
<img src="{DATA_DIR / "code_complexity.jpg"}" alt="Visual representation of code complexity over time" width="500" height="300">
|
||||
|
||||
<p>Looking toward the future, we see emerging trends in quantum programming languages, AI-assisted
|
||||
code generation, and the continued evolution toward more expressive type systems. The challenge
|
||||
for tomorrow's language designers will be balancing expressiveness with simplicity, performance
|
||||
with safety, and innovation with backward compatibility. As computing continues to permeate every
|
||||
aspect of human life, the languages we use to command these machines will undoubtedly continue
|
||||
to evolve and shape the digital landscape.</p>
|
||||
|
||||
<p>The emergence of cloud computing and distributed systems has also driven new paradigms in
|
||||
language design. Languages like Go and Elixir have been specifically crafted to excel in
|
||||
concurrent and distributed environments, while the rise of microservices has renewed interest
|
||||
in polyglot programming approaches. These developments reflect a broader shift toward languages
|
||||
that are not just powerful tools for individual developers, but robust foundations for building
|
||||
scalable, resilient systems that can handle the demands of modern internet-scale applications.</p>
|
||||
|
||||
<p>Perhaps most intriguingly, the intersection of programming languages with artificial intelligence
|
||||
is opening entirely new frontiers. Differentiable programming languages are enabling new forms of
|
||||
machine learning research, while large language models are beginning to reshape how we think about
|
||||
code generation and developer tooling. As we stand on the brink of an era where AI systems may
|
||||
become active participants in the programming process itself, the very nature of what constitutes
|
||||
a programming language—and who or what programs in it—may be fundamentally transformed.</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
SECOND_PAGE = """
|
||||
<div>
|
||||
<h2>The Impact of Open Source on Language Development</h2>
|
||||
|
||||
<p>The open source movement has fundamentally transformed how programming languages are developed,
|
||||
distributed, and evolved. Unlike the proprietary languages of earlier decades, modern language
|
||||
development often occurs in public repositories where thousands of contributors can participate
|
||||
in the design process. Languages like Python, JavaScript, and Rust have benefited enormously
|
||||
from this collaborative approach, with their ecosystems growing rapidly through community-driven
|
||||
package managers and extensive third-party libraries.</p>
|
||||
|
||||
<p>This democratization of language development has led to faster innovation cycles and more
|
||||
responsive adaptation to developer needs. When a language feature proves problematic or a new
|
||||
paradigm emerges, open source languages can quickly incorporate changes through their community
|
||||
governance processes. The result has been an unprecedented period of language experimentation
|
||||
and refinement, where ideas can be tested, refined, and adopted across multiple language
|
||||
communities simultaneously.</p>
|
||||
|
||||
<p>Furthermore, the open source model has enabled the rise of domain-specific languages that
|
||||
might never have been commercially viable under traditional development models. From specialized
|
||||
query languages for databases to configuration management tools, the low barrier to entry for
|
||||
language creation has fostered an explosion of linguistic diversity in computing, each tool
|
||||
optimized for specific problem domains and user communities.</p>
|
||||
|
||||
<p>The collaborative nature of open source development has also revolutionized language tooling
|
||||
and developer experience. Modern languages benefit from rich ecosystems of editors, debuggers,
|
||||
profilers, and static analysis tools, all developed by passionate communities who understand
|
||||
the daily challenges faced by practitioners. This has created a virtuous cycle where better
|
||||
tooling attracts more developers, who in turn contribute improvements that make the language
|
||||
even more accessible and powerful.</p>
|
||||
|
||||
<p>Version control systems like Git have enabled unprecedented transparency in language evolution,
|
||||
allowing developers to trace the reasoning behind every design decision through detailed commit
|
||||
histories and issue discussions. This historical record serves not only as documentation but as
|
||||
a learning resource for future language designers, helping them understand the trade-offs and
|
||||
considerations that shaped successful language features.</p>
|
||||
|
||||
<p>The economic implications of open source language development cannot be overstated. By removing
|
||||
licensing barriers and vendor lock-in, open source languages have democratized access to powerful
|
||||
programming tools across the globe. This has enabled innovation in regions and sectors that might
|
||||
otherwise have been excluded from the software revolution, fostering a truly global community of
|
||||
software creators and problem solvers.</p>
|
||||
</div>
|
||||
"""
|
||||
|
||||
CHUNKS: list[str] = [
|
||||
"""The Evolution of Programming Languages
|
||||
======================================
|
||||
Programming languages have undergone tremendous evolution since the early days of computing.
|
||||
From the machine code and assembly languages of the 1940s to the high\\-level, expressive languages
|
||||
we use today, each generation has built upon the lessons learned from its predecessors. Languages
|
||||
like FORTRAN and COBOL pioneered the concept of human\\-readable code, while later innovations like
|
||||
object\\-oriented programming in languages such as Smalltalk and C\\+\\+ revolutionized how we structure
|
||||
and organize our programs.
|
||||

|
||||
The rise of functional programming paradigms has brought mathematical rigor and immutability
|
||||
to the forefront of software development. Languages like Haskell, Lisp, and more recently Rust
|
||||
and Elm have demonstrated the power of pure functions and type systems in creating more reliable
|
||||
and maintainable code. These paradigms emphasize the elimination of side effects and the treatment
|
||||
of computation as the evaluation of mathematical functions.
|
||||
Modern development has also seen the emergence of domain\\-specific languages and the resurgence
|
||||
of interest in memory safety. The advent of languages like Python and JavaScript has democratized
|
||||
programming by lowering the barrier to entry, while systems languages like Rust have proven that
|
||||
performance and safety need not be mutually exclusive. The ongoing development of WebAssembly
|
||||
promises to bring high\\-performance computing to web browsers in ways previously unimaginable.
|
||||

|
||||
Looking toward the future, we see emerging trends in quantum programming languages, AI\\-assisted
|
||||
code generation, and the continued evolution toward more expressive type systems. The challenge
|
||||
for tomorrow's language designers will be balancing expressiveness with simplicity, performance
|
||||
with safety, and innovation with backward compatibility. As computing continues to permeate every
|
||||
aspect of human life, the languages we use to command these machines will undoubtedly continue
|
||||
to evolve and shape the digital landscape.""",
|
||||
"""
|
||||
As computing continues to permeate every
|
||||
aspect of human life, the languages we use to command these machines will undoubtedly continue
|
||||
to evolve and shape the digital landscape.
|
||||
The emergence of cloud computing and distributed systems has also driven new paradigms in
|
||||
language design. Languages like Go and Elixir have been specifically crafted to excel in
|
||||
concurrent and distributed environments, while the rise of microservices has renewed interest
|
||||
in polyglot programming approaches. These developments reflect a broader shift toward languages
|
||||
that are not just powerful tools for individual developers, but robust foundations for building
|
||||
scalable, resilient systems that can handle the demands of modern internet\\-scale applications.
|
||||
Perhaps most intriguingly, the intersection of programming languages with artificial intelligence
|
||||
is opening entirely new frontiers. Differentiable programming languages are enabling new forms of
|
||||
machine learning research, while large language models are beginning to reshape how we think about
|
||||
code generation and developer tooling. As we stand on the brink of an era where AI systems may
|
||||
become active participants in the programming process itself, the very nature of what constitutes
|
||||
a programming language—and who or what programs in it—may be fundamentally transformed.""",
|
||||
]
|
||||
TWO_PAGE_CHUNKS: list[str] = [
|
||||
"""
|
||||
The Evolution of Programming Languages
|
||||
======================================
|
||||
Programming languages have undergone tremendous evolution since the early days of computing.
|
||||
From the machine code and assembly languages of the 1940s to the high\-level, expressive languages
|
||||
we use today, each generation has built upon the lessons learned from its predecessors. Languages
|
||||
like FORTRAN and COBOL pioneered the concept of human\-readable code, while later innovations like
|
||||
object\-oriented programming in languages such as Smalltalk and C\+\+ revolutionized how we structure
|
||||
and organize our programs.
|
||||

|
||||
The rise of functional programming paradigms has brought mathematical rigor and immutability
|
||||
to the forefront of software development. Languages like Haskell, Lisp, and more recently Rust
|
||||
and Elm have demonstrated the power of pure functions and type systems in creating more reliable
|
||||
and maintainable code. These paradigms emphasize the elimination of side effects and the treatment
|
||||
of computation as the evaluation of mathematical functions.
|
||||
Modern development has also seen the emergence of domain\-specific languages and the resurgence
|
||||
of interest in memory safety. The advent of languages like Python and JavaScript has democratized
|
||||
programming by lowering the barrier to entry, while systems languages like Rust have proven that
|
||||
performance and safety need not be mutually exclusive. The ongoing development of WebAssembly
|
||||
promises to bring high\-performance computing to web browsers in ways previously unimaginable.
|
||||

|
||||
Looking toward the future, we see emerging trends in quantum programming languages, AI\-assisted
|
||||
code generation, and the continued evolution toward more expressive type systems. The challenge
|
||||
for tomorrow's language designers will be balancing expressiveness with simplicity, performance
|
||||
with safety, and innovation with backward compatibility. As computing continues to permeate every
|
||||
aspect of human life, the languages we use to command these machines will undoubtedly continue
|
||||
to evolve and shape the digital landscape.
|
||||
""",
|
||||
"""
|
||||
As computing continues to permeate every
|
||||
aspect of human life, the languages we use to command these machines will undoubtedly continue
|
||||
to evolve and shape the digital landscape.
|
||||
The emergence of cloud computing and distributed systems has also driven new paradigms in
|
||||
language design. Languages like Go and Elixir have been specifically crafted to excel in
|
||||
concurrent and distributed environments, while the rise of microservices has renewed interest
|
||||
in polyglot programming approaches. These developments reflect a broader shift toward languages
|
||||
that are not just powerful tools for individual developers, but robust foundations for building
|
||||
scalable, resilient systems that can handle the demands of modern internet\-scale applications.
|
||||
Perhaps most intriguingly, the intersection of programming languages with artificial intelligence
|
||||
is opening entirely new frontiers. Differentiable programming languages are enabling new forms of
|
||||
machine learning research, while large language models are beginning to reshape how we think about
|
||||
code generation and developer tooling. As we stand on the brink of an era where AI systems may
|
||||
become active participants in the programming process itself, the very nature of what constitutes
|
||||
a programming language—and who or what programs in it—may be fundamentally transformed.
|
||||
The Impact of Open Source on Language Development
|
||||
-------------------------------------------------
|
||||
The open source movement has fundamentally transformed how programming languages are developed,
|
||||
distributed, and evolved. Unlike the proprietary languages of earlier decades, modern language
|
||||
development often occurs in public repositories where thousands of contributors can participate
|
||||
in the design process. Languages like Python, JavaScript, and Rust have benefited enormously
|
||||
from this collaborative approach, with their ecosystems growing rapidly through community\-driven
|
||||
package managers and extensive third\-party libraries.
|
||||
This democratization of language development has led to faster innovation cycles and more
|
||||
responsive adaptation to developer needs. When a language feature proves problematic or a new
|
||||
paradigm emerges, open source languages can quickly incorporate changes through their community
|
||||
governance processes. The result has been an unprecedented period of language experimentation
|
||||
and refinement, where ideas can be tested, refined, and adopted across multiple language
|
||||
communities simultaneously.""",
|
||||
"""
|
||||
The result has been an unprecedented period of language experimentation
|
||||
and refinement, where ideas can be tested, refined, and adopted across multiple language
|
||||
communities simultaneously.
|
||||
Furthermore, the open source model has enabled the rise of domain\-specific languages that
|
||||
might never have been commercially viable under traditional development models. From specialized
|
||||
query languages for databases to configuration management tools, the low barrier to entry for
|
||||
language creation has fostered an explosion of linguistic diversity in computing, each tool
|
||||
optimized for specific problem domains and user communities.
|
||||
The collaborative nature of open source development has also revolutionized language tooling
|
||||
and developer experience. Modern languages benefit from rich ecosystems of editors, debuggers,
|
||||
profilers, and static analysis tools, all developed by passionate communities who understand
|
||||
the daily challenges faced by practitioners. This has created a virtuous cycle where better
|
||||
tooling attracts more developers, who in turn contribute improvements that make the language
|
||||
even more accessible and powerful.
|
||||
Version control systems like Git have enabled unprecedented transparency in language evolution,
|
||||
allowing developers to trace the reasoning behind every design decision through detailed commit
|
||||
histories and issue discussions. This historical record serves not only as documentation but as
|
||||
a learning resource for future language designers, helping them understand the trade\-offs and
|
||||
considerations that shaped successful language features.
|
||||
The economic implications of open source language development cannot be overstated. By removing
|
||||
licensing barriers and vendor lock\-in, open source languages have democratized access to powerful
|
||||
programming tools across the globe. This has enabled innovation in regions and sectors that might
|
||||
otherwise have been excluded from the software revolution, fostering a truly global community of
|
||||
software creators and problem solvers.
|
||||
""",
|
||||
]
|
||||
|
||||
SAMPLE_MARKDOWN = markdownify(SAMPLE_HTML)
|
||||
SAMPLE_TEXT = BeautifulSoup(SAMPLE_HTML, "html.parser").get_text()
|
||||
SECOND_PAGE_MARKDOWN = markdownify(SECOND_PAGE)
|
||||
SECOND_PAGE_TEXT = BeautifulSoup(SECOND_PAGE, "html.parser").get_text()
|
||||
|
||||
|
||||
def image_hash(image: Image.Image) -> str:
|
||||
return hashlib.sha256(image.tobytes()).hexdigest()
|
||||
|
||||
|
||||
LANG_TIMELINE = Image.open(DATA_DIR / "lang_timeline.png")
|
||||
CODE_COMPLEXITY = Image.open(DATA_DIR / "code_complexity.jpg")
|
||||
LANG_TIMELINE_HASH = image_hash(LANG_TIMELINE)
|
||||
CODE_COMPLEXITY_HASH = image_hash(CODE_COMPLEXITY)
|
BIN
tests/data/lang_timeline.png
Normal file
BIN
tests/data/lang_timeline.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 309 KiB |
1118
tests/integration/test_real_queries.py
Normal file
1118
tests/integration/test_real_queries.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,23 +1,17 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import patch
|
||||
from typing import cast
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
from memory.common import settings, chunker, extract
|
||||
from memory.common.db.models.sources import Book
|
||||
from memory.common.db.models.source_items import (
|
||||
Chunk,
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
BookSection,
|
||||
BlogPost,
|
||||
)
|
||||
from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
image_filenames,
|
||||
add_pics,
|
||||
merge_metadata,
|
||||
clean_filename,
|
||||
)
|
||||
|
||||
@ -56,114 +50,6 @@ def test_clean_filename(input_filename, expected):
|
||||
assert clean_filename(input_filename) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dicts,expected",
|
||||
[
|
||||
# Empty input
|
||||
([], {}),
|
||||
# Single dict without tags
|
||||
([{"key": "value"}], {"key": "value"}),
|
||||
# Single dict with tags as list
|
||||
(
|
||||
[{"key": "value", "tags": ["tag1", "tag2"]}],
|
||||
{"key": "value", "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
# Single dict with tags as set
|
||||
(
|
||||
[{"key": "value", "tags": {"tag1", "tag2"}}],
|
||||
{"key": "value", "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
# Multiple dicts without tags
|
||||
(
|
||||
[{"key1": "value1"}, {"key2": "value2"}],
|
||||
{"key1": "value1", "key2": "value2"},
|
||||
),
|
||||
# Multiple dicts with non-overlapping tags
|
||||
(
|
||||
[
|
||||
{"key1": "value1", "tags": ["tag1"]},
|
||||
{"key2": "value2", "tags": ["tag2"]},
|
||||
],
|
||||
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
# Multiple dicts with overlapping tags
|
||||
(
|
||||
[
|
||||
{"key1": "value1", "tags": ["tag1", "tag2"]},
|
||||
{"key2": "value2", "tags": ["tag2", "tag3"]},
|
||||
],
|
||||
{"key1": "value1", "key2": "value2", "tags": {"tag1", "tag2", "tag3"}},
|
||||
),
|
||||
# Overlapping keys - later dict wins
|
||||
(
|
||||
[
|
||||
{"key": "value1", "other": "data1"},
|
||||
{"key": "value2", "another": "data2"},
|
||||
],
|
||||
{"key": "value2", "other": "data1", "another": "data2"},
|
||||
),
|
||||
# Mixed tags types (list and set)
|
||||
(
|
||||
[
|
||||
{"key1": "value1", "tags": ["tag1", "tag2"]},
|
||||
{"key2": "value2", "tags": {"tag3", "tag4"}},
|
||||
],
|
||||
{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
"tags": {"tag1", "tag2", "tag3", "tag4"},
|
||||
},
|
||||
),
|
||||
# Empty tags
|
||||
(
|
||||
[{"key": "value", "tags": []}, {"key2": "value2", "tags": []}],
|
||||
{"key": "value", "key2": "value2"},
|
||||
),
|
||||
# None values
|
||||
(
|
||||
[{"key1": None, "key2": "value"}, {"key3": None}],
|
||||
{"key1": None, "key2": "value", "key3": None},
|
||||
),
|
||||
# Complex nested structures
|
||||
(
|
||||
[
|
||||
{"nested": {"inner": "value1"}, "list": [1, 2, 3], "tags": ["tag1"]},
|
||||
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": ["tag2"]},
|
||||
],
|
||||
{"nested": {"inner": "value2"}, "list": [4, 5], "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
# Boolean and numeric values
|
||||
(
|
||||
[
|
||||
{"bool": True, "int": 42, "float": 3.14, "tags": ["numeric"]},
|
||||
{"bool": False, "int": 100},
|
||||
],
|
||||
{"bool": False, "int": 100, "float": 3.14, "tags": {"numeric"}},
|
||||
),
|
||||
# Three or more dicts
|
||||
(
|
||||
[
|
||||
{"a": 1, "tags": ["t1"]},
|
||||
{"b": 2, "tags": ["t2", "t3"]},
|
||||
{"c": 3, "a": 10, "tags": ["t3", "t4"]},
|
||||
],
|
||||
{"a": 10, "b": 2, "c": 3, "tags": {"t1", "t2", "t3", "t4"}},
|
||||
),
|
||||
# Dict with only tags
|
||||
([{"tags": ["tag1", "tag2"]}], {"tags": {"tag1", "tag2"}}),
|
||||
# Empty dicts
|
||||
([{}, {}], {}),
|
||||
# Mix of empty and non-empty dicts
|
||||
(
|
||||
[{}, {"key": "value", "tags": ["tag"]}, {}],
|
||||
{"key": "value", "tags": {"tag"}},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_merge_metadata(dicts, expected):
|
||||
assert merge_metadata(*dicts) == expected
|
||||
|
||||
|
||||
def test_image_filenames_with_existing_filenames(tmp_path):
|
||||
"""Test image_filenames when images already have filenames"""
|
||||
chunk_id = "test_chunk_123"
|
||||
|
626
tests/memory/common/db/models/test_source_item_embeddings.py
Normal file
626
tests/memory/common/db/models/test_source_item_embeddings.py
Normal file
@ -0,0 +1,626 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Sequence, cast
|
||||
from unittest.mock import ANY, Mock, call
|
||||
|
||||
import pymupdf # PyMuPDF
|
||||
import pytest
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.models.source_item import Chunk, SourceItem
|
||||
from memory.common.db.models.source_items import (
|
||||
AgentObservation,
|
||||
BlogPost,
|
||||
BookSection,
|
||||
Comic,
|
||||
EmailAttachment,
|
||||
ForumPost,
|
||||
MailMessage,
|
||||
)
|
||||
from memory.common.db.models.sources import Book
|
||||
from memory.common.embedding import embed_source_item
|
||||
from memory.common.extract import page_to_image
|
||||
from tests.data.contents import (
|
||||
CHUNKS,
|
||||
LANG_TIMELINE_HASH,
|
||||
SAMPLE_MARKDOWN,
|
||||
SAMPLE_TEXT,
|
||||
image_hash,
|
||||
)
|
||||
|
||||
|
||||
def compare_chunks(
|
||||
chunks: Sequence[Chunk],
|
||||
expected: Sequence[tuple[str | None, list[str], dict]],
|
||||
):
|
||||
data = [
|
||||
(c.content, [image_hash(i) for i in c.images], c.item_metadata) for c in chunks
|
||||
]
|
||||
assert data == expected
|
||||
|
||||
|
||||
def test_base_source_item_text_embeddings(mock_voyage_client):
|
||||
item = SourceItem(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
mime_type="text/html",
|
||||
modality="text",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
expected = [
|
||||
(CHUNKS[0].strip(), cast(list[str], []), metadata),
|
||||
(CHUNKS[1].strip(), cast(list[str], []), metadata),
|
||||
("test summary", [], metadata | {"tags": {"tag1", "tag2", "bla"}}),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 1
|
||||
assert not mock_voyage_client.multimodal_embed.call_count
|
||||
|
||||
assert mock_voyage_client.embed.call_args == call(
|
||||
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
|
||||
|
||||
def test_base_source_item_mixed_embeddings(mock_voyage_client):
|
||||
item = SourceItem(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
filename=DATA_DIR / "lang_timeline.png",
|
||||
mime_type="image/png",
|
||||
modality="photo",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
expected = [
|
||||
(CHUNKS[0].strip(), [], metadata),
|
||||
(CHUNKS[1].strip(), [], metadata),
|
||||
("test summary", [], metadata | {"tags": {"tag1", "tag2", "bla"}}),
|
||||
(None, [LANG_TIMELINE_HASH], {"size": 3465, "source_id": 1, "tags": {"bla"}}),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 1
|
||||
assert mock_voyage_client.multimodal_embed.call_count == 1
|
||||
|
||||
assert mock_voyage_client.embed.call_args == call(
|
||||
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
assert mock_voyage_client.multimodal_embed.call_args == call(
|
||||
[[ANY]],
|
||||
model=settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
assert [
|
||||
image_hash(i) for i in mock_voyage_client.multimodal_embed.call_args[0][0][0]
|
||||
] == [LANG_TIMELINE_HASH]
|
||||
|
||||
|
||||
def test_mail_message_embeddings(mock_voyage_client):
|
||||
item = MailMessage(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
mime_type="text/html",
|
||||
modality="text",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
message_id="123",
|
||||
subject="Test Subject",
|
||||
sender="test@example.com",
|
||||
recipients=["test@example.com"],
|
||||
folder="INBOX",
|
||||
sent_at=datetime(2025, 1, 1, 12, 0, 0),
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla", "test@example.com"}
|
||||
expected = [
|
||||
(CHUNKS[0].strip(), [], metadata),
|
||||
(CHUNKS[1].strip(), [], metadata),
|
||||
(
|
||||
"test summary",
|
||||
[],
|
||||
metadata | {"tags": {"tag1", "tag2", "bla", "test@example.com"}},
|
||||
),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 1
|
||||
assert not mock_voyage_client.multimodal_embed.call_count
|
||||
|
||||
assert mock_voyage_client.embed.call_args == call(
|
||||
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
|
||||
|
||||
def test_email_attachment_embeddings_text(mock_voyage_client):
|
||||
item = EmailAttachment(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
mime_type="text/html",
|
||||
modality="text",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
expected = [
|
||||
(CHUNKS[0].strip(), cast(list[str], []), metadata),
|
||||
(CHUNKS[1].strip(), cast(list[str], []), metadata),
|
||||
(
|
||||
"test summary",
|
||||
[],
|
||||
metadata | {"tags": {"tag1", "tag2", "bla"}},
|
||||
),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 1
|
||||
assert not mock_voyage_client.multimodal_embed.call_count
|
||||
|
||||
assert mock_voyage_client.embed.call_args == call(
|
||||
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
|
||||
|
||||
def test_email_attachment_embeddings_photo(mock_voyage_client):
|
||||
item = EmailAttachment(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
filename=DATA_DIR / "lang_timeline.png",
|
||||
mime_type="image/png",
|
||||
modality="photo",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
expected = [
|
||||
(None, [LANG_TIMELINE_HASH], metadata),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 0
|
||||
assert mock_voyage_client.multimodal_embed.call_count == 1
|
||||
|
||||
assert mock_voyage_client.multimodal_embed.call_args == call(
|
||||
[[ANY]],
|
||||
model=settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
assert [
|
||||
image_hash(i) for i in mock_voyage_client.multimodal_embed.call_args[0][0][0]
|
||||
] == [LANG_TIMELINE_HASH]
|
||||
|
||||
|
||||
def test_email_attachment_embeddings_pdf(mock_voyage_client):
|
||||
item = EmailAttachment(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
filename=DATA_DIR / "regulamin.pdf",
|
||||
mime_type="application/pdf",
|
||||
modality="doc",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
with pymupdf.open(item.filename) as pdf:
|
||||
expected = [
|
||||
(
|
||||
None,
|
||||
[image_hash(page_to_image(page))],
|
||||
metadata
|
||||
| {
|
||||
"page": page.number,
|
||||
"width": page.rect.width,
|
||||
"height": page.rect.height,
|
||||
},
|
||||
)
|
||||
for page in pdf.pages()
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 0
|
||||
assert mock_voyage_client.multimodal_embed.call_count == 1
|
||||
|
||||
assert mock_voyage_client.multimodal_embed.call_args == call(
|
||||
[[ANY], [ANY]],
|
||||
model=settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
assert [
|
||||
[image_hash(a) for a in i]
|
||||
for i in mock_voyage_client.multimodal_embed.call_args[0][0]
|
||||
] == [page for _, page, _ in expected]
|
||||
|
||||
|
||||
def test_email_attachment_embeddings_comic(mock_voyage_client):
|
||||
item = Comic(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
filename=DATA_DIR / "lang_timeline.png",
|
||||
mime_type="image/png",
|
||||
modality="comic",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
title="The Evolution of Programming Languages",
|
||||
author="John Doe",
|
||||
published=datetime(2025, 1, 1, 12, 0, 0),
|
||||
volume="1",
|
||||
issue="1",
|
||||
page=1,
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
expected = [
|
||||
(
|
||||
"The Evolution of Programming Languages by John Doe",
|
||||
[LANG_TIMELINE_HASH],
|
||||
metadata,
|
||||
),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 0
|
||||
assert mock_voyage_client.multimodal_embed.call_count == 1
|
||||
|
||||
assert mock_voyage_client.multimodal_embed.call_args == call(
|
||||
[["The Evolution of Programming Languages by John Doe", ANY]],
|
||||
model=settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
assert (
|
||||
image_hash(mock_voyage_client.multimodal_embed.call_args[0][0][0][1])
|
||||
== LANG_TIMELINE_HASH
|
||||
)
|
||||
|
||||
|
||||
def test_book_section_embeddings_single_page(mock_voyage_client):
|
||||
item = BookSection(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
mime_type="text/html",
|
||||
modality="text",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
book_id=1,
|
||||
section_title="The Evolution of Programming Languages",
|
||||
section_number=1,
|
||||
section_level=1,
|
||||
start_page=1,
|
||||
end_page=1,
|
||||
pages=[SAMPLE_TEXT],
|
||||
book=Book(
|
||||
id=1,
|
||||
title="Programming Languages",
|
||||
author="John Doe",
|
||||
published=datetime(2025, 1, 1, 12, 0, 0),
|
||||
),
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
expected = [
|
||||
(CHUNKS[0].strip(), cast(list[str], []), metadata | {"type": "page"}),
|
||||
(CHUNKS[1].strip(), cast(list[str], []), metadata | {"type": "page"}),
|
||||
(
|
||||
"test summary",
|
||||
[],
|
||||
metadata | {"tags": {"tag1", "tag2", "bla"}, "type": "summary"},
|
||||
),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 1
|
||||
assert not mock_voyage_client.multimodal_embed.call_count
|
||||
|
||||
assert mock_voyage_client.embed.call_args == call(
|
||||
[CHUNKS[0].strip(), CHUNKS[1].strip(), "test summary"],
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
|
||||
|
||||
def test_book_section_embeddings_multiple_pages(mock_voyage_client):
|
||||
item = BookSection(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN + "\n\n" + SECOND_PAGE,
|
||||
mime_type="text/html",
|
||||
modality="text",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
book_id=1,
|
||||
section_title="The Evolution of Programming Languages",
|
||||
section_number=1,
|
||||
section_level=1,
|
||||
start_page=1,
|
||||
end_page=2,
|
||||
pages=[SAMPLE_TEXT, SECOND_PAGE_TEXT],
|
||||
book=Book(
|
||||
id=1,
|
||||
title="Programming Languages",
|
||||
author="John Doe",
|
||||
published=datetime(2025, 1, 1, 12, 0, 0),
|
||||
),
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla", "tag1", "tag2"}
|
||||
expected = [
|
||||
(item.content.strip(), cast(list[str], []), metadata | {"type": "section"}),
|
||||
("test summary", [], metadata | {"type": "summary"}),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 1
|
||||
assert not mock_voyage_client.multimodal_embed.call_count
|
||||
|
||||
assert mock_voyage_client.embed.call_args == call(
|
||||
[item.content.strip(), "test summary"],
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"class_, modality",
|
||||
(
|
||||
(BlogPost, "blog"),
|
||||
(ForumPost, "forum"),
|
||||
),
|
||||
)
|
||||
def test_post_embeddings_single_page(mock_voyage_client, class_, modality):
|
||||
item = class_(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
mime_type="text/html",
|
||||
modality=modality,
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
images=[LANG_TIMELINE.filename, CODE_COMPLEXITY.filename], # type: ignore
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla", "tag1", "tag2"}
|
||||
expected = [
|
||||
(item.content.strip(), [LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH], metadata),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert not mock_voyage_client.embed.call_count
|
||||
assert mock_voyage_client.multimodal_embed.call_count == 1
|
||||
|
||||
assert mock_voyage_client.multimodal_embed.call_args == call(
|
||||
[[item.content.strip(), ANY, ANY]],
|
||||
model=settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
assert [
|
||||
image_hash(i)
|
||||
for i in mock_voyage_client.multimodal_embed.call_args[0][0][0][1:]
|
||||
] == [LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"class_, modality",
|
||||
(
|
||||
(BlogPost, "blog"),
|
||||
(ForumPost, "forum"),
|
||||
),
|
||||
)
|
||||
def test_post_embeddings_multi_page(mock_voyage_client, class_, modality):
|
||||
item = class_(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN + "\n\n" + SECOND_PAGE_MARKDOWN,
|
||||
mime_type="text/html",
|
||||
modality=modality,
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN + SECOND_PAGE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
images=[LANG_TIMELINE.filename, CODE_COMPLEXITY.filename], # type: ignore
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla", "tag1", "tag2"}
|
||||
|
||||
all_contents = (
|
||||
item.content.strip(),
|
||||
[LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH],
|
||||
metadata,
|
||||
)
|
||||
first_chunk = (
|
||||
TWO_PAGE_CHUNKS[0].strip(),
|
||||
[LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH],
|
||||
metadata,
|
||||
)
|
||||
second_chunk = (TWO_PAGE_CHUNKS[1].strip(), [], metadata)
|
||||
third_chunk = (TWO_PAGE_CHUNKS[2].strip(), [], metadata)
|
||||
summary = ("test summary", [], metadata)
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(
|
||||
item.data_chunks(),
|
||||
[all_contents, first_chunk, second_chunk, third_chunk, summary],
|
||||
)
|
||||
# embed_source_item first does text embedding, then mixed embedding
|
||||
# so the order of chunks is different than in data_chunks()
|
||||
compare_chunks(
|
||||
embed_source_item(item),
|
||||
[
|
||||
second_chunk,
|
||||
third_chunk,
|
||||
summary,
|
||||
all_contents,
|
||||
first_chunk,
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 1
|
||||
assert mock_voyage_client.multimodal_embed.call_count == 1
|
||||
|
||||
assert mock_voyage_client.embed.call_args == call(
|
||||
[
|
||||
TWO_PAGE_CHUNKS[1].strip(),
|
||||
TWO_PAGE_CHUNKS[2].strip(),
|
||||
"test summary",
|
||||
],
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
assert mock_voyage_client.multimodal_embed.call_args == call(
|
||||
[[item.content.strip(), ANY, ANY], [TWO_PAGE_CHUNKS[0].strip(), ANY, ANY]],
|
||||
model=settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
assert [
|
||||
image_hash(i)
|
||||
for i in mock_voyage_client.multimodal_embed.call_args[0][0][0][1:]
|
||||
] == [LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH]
|
||||
assert [
|
||||
image_hash(i)
|
||||
for i in mock_voyage_client.multimodal_embed.call_args[0][0][1][1:]
|
||||
] == [LANG_TIMELINE_HASH, CODE_COMPLEXITY_HASH]
|
||||
|
||||
|
||||
def test_agent_observation_embeddings(mock_voyage_client):
|
||||
item = AgentObservation(
|
||||
id=1,
|
||||
content="The user thinks that all men must die.",
|
||||
mime_type="text/html",
|
||||
modality="observation",
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
observation_type="belief",
|
||||
subject="humans",
|
||||
confidence=0.8,
|
||||
evidence={
|
||||
"quote": "All humans are mortal.",
|
||||
"source": "https://en.wikipedia.org/wiki/Human",
|
||||
},
|
||||
agent_model="gpt-4o",
|
||||
inserted_at=datetime(2025, 1, 1, 12, 0, 0),
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
expected = [
|
||||
(
|
||||
"Subject: humans | Type: belief | Observation: The user thinks that all men must die. | Quote: All humans are mortal.",
|
||||
[],
|
||||
metadata | {"embedding_type": "semantic"},
|
||||
),
|
||||
(
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die. | Confidence: 0.8",
|
||||
[],
|
||||
metadata | {"embedding_type": "temporal"},
|
||||
),
|
||||
(
|
||||
"The user thinks that all men must die.",
|
||||
[],
|
||||
metadata | {"embedding_type": "semantic"},
|
||||
),
|
||||
("All humans are mortal.", [], metadata | {"embedding_type": "semantic"}),
|
||||
]
|
||||
|
||||
mock_voyage_client.embed = Mock(return_value=Mock(embeddings=[[0.1] * 1024] * 3))
|
||||
mock_voyage_client.multimodal_embed = Mock(
|
||||
return_value=Mock(embeddings=[[0.1] * 1024] * 3)
|
||||
)
|
||||
compare_chunks(item.data_chunks(), expected)
|
||||
compare_chunks(embed_source_item(item), expected)
|
||||
|
||||
assert mock_voyage_client.embed.call_count == 1
|
||||
assert not mock_voyage_client.multimodal_embed.call_count
|
||||
|
||||
assert mock_voyage_client.embed.call_args == call(
|
||||
[
|
||||
"Subject: humans | Type: belief | Observation: The user thinks that all men must die. | Quote: All humans are mortal.",
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die. | Confidence: 0.8",
|
||||
"The user thinks that all men must die.",
|
||||
"All humans are mortal.",
|
||||
],
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
@ -14,7 +14,6 @@ from memory.common.db.models.source_items import (
|
||||
BlogPost,
|
||||
AgentObservation,
|
||||
)
|
||||
from memory.common.db.models.source_item import merge_metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -356,7 +355,8 @@ def test_book_section_data_chunks(pages, expected_chunks):
|
||||
|
||||
chunks = book_section.data_chunks()
|
||||
expected = [
|
||||
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks
|
||||
(c, extract.merge_metadata(book_section.as_payload(), m))
|
||||
for c, m in expected_chunks
|
||||
]
|
||||
assert [(c.content, c.item_metadata) for c in chunks] == expected
|
||||
for c in chunks:
|
||||
|
Loading…
x
Reference in New Issue
Block a user