From 5e836337e2af0054b72ecad6d59e4fa7463385f6 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Tue, 10 Jun 2025 16:08:36 +0200 Subject: [PATCH] parralem search of collecitons --- src/memory/api/search/embeddings.py | 45 ++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/src/memory/api/search/embeddings.py b/src/memory/api/search/embeddings.py index 6c6b40c..ea4682d 100644 --- a/src/memory/api/search/embeddings.py +++ b/src/memory/api/search/embeddings.py @@ -1,6 +1,7 @@ import base64 import io import logging +import asyncio from typing import Any, Callable, Optional, cast import qdrant_client @@ -54,7 +55,7 @@ def annotated_chunk( ) -def query_chunks( +async def query_chunks( client: qdrant_client.QdrantClient, upload_data: list[extract.DataChunk], allowed_modalities: set[str], @@ -73,22 +74,46 @@ def query_chunks( vectors = embedder(chunks, input_type="query") - return { - collection: [ - r - for vector in vectors - for r in qdrant.search_vectors( + # Create all search tasks to run in parallel + search_tasks = [] + task_metadata = [] # Keep track of which collection and vector each task corresponds to + + for collection in allowed_modalities: + for vector in vectors: + task = asyncio.to_thread( + qdrant.search_vectors, client=client, collection_name=collection, query_vector=vector, limit=limit, filter_params=filters, ) - if r.score >= min_score - ] - for collection in allowed_modalities + search_tasks.append(task) + task_metadata.append((collection, vector)) + + # Run all searches in parallel + if not search_tasks: + return {} + + search_results = await asyncio.gather(*search_tasks, return_exceptions=True) + + # Group results by collection + results_by_collection: dict[str, list[qdrant_models.ScoredPoint]] = { + collection: [] for collection in allowed_modalities } + for (collection, _), result in zip(task_metadata, search_results): + if isinstance(result, Exception): + logger.error(f"Search failed for collection {collection}: {result}") + continue + + # Filter by min_score and add to collection results + result_list = cast(list[qdrant_models.ScoredPoint], result) + filtered_results = [r for r in result_list if r.score >= min_score] + results_by_collection[collection].extend(filtered_results) + + return results_by_collection + def merge_range_filter( filters: list[dict[str, Any]], key: str, val: Any @@ -174,7 +199,7 @@ async def search_embeddings( search_filters = merge_filters(search_filters, key, val) client = qdrant.get_qdrant_client() - results = query_chunks( + results = await query_chunks( client, data, modalities,