parralem search of collecitons

This commit is contained in:
Daniel O'Connell 2025-06-10 16:08:36 +02:00
parent 55809f3980
commit 5e836337e2

View File

@ -1,6 +1,7 @@
import base64 import base64
import io import io
import logging import logging
import asyncio
from typing import Any, Callable, Optional, cast from typing import Any, Callable, Optional, cast
import qdrant_client import qdrant_client
@ -54,7 +55,7 @@ def annotated_chunk(
) )
def query_chunks( async def query_chunks(
client: qdrant_client.QdrantClient, client: qdrant_client.QdrantClient,
upload_data: list[extract.DataChunk], upload_data: list[extract.DataChunk],
allowed_modalities: set[str], allowed_modalities: set[str],
@ -73,22 +74,46 @@ def query_chunks(
vectors = embedder(chunks, input_type="query") vectors = embedder(chunks, input_type="query")
return { # Create all search tasks to run in parallel
collection: [ search_tasks = []
r task_metadata = [] # Keep track of which collection and vector each task corresponds to
for vector in vectors
for r in qdrant.search_vectors( for collection in allowed_modalities:
for vector in vectors:
task = asyncio.to_thread(
qdrant.search_vectors,
client=client, client=client,
collection_name=collection, collection_name=collection,
query_vector=vector, query_vector=vector,
limit=limit, limit=limit,
filter_params=filters, filter_params=filters,
) )
if r.score >= min_score search_tasks.append(task)
] task_metadata.append((collection, vector))
for collection in allowed_modalities
# Run all searches in parallel
if not search_tasks:
return {}
search_results = await asyncio.gather(*search_tasks, return_exceptions=True)
# Group results by collection
results_by_collection: dict[str, list[qdrant_models.ScoredPoint]] = {
collection: [] for collection in allowed_modalities
} }
for (collection, _), result in zip(task_metadata, search_results):
if isinstance(result, Exception):
logger.error(f"Search failed for collection {collection}: {result}")
continue
# Filter by min_score and add to collection results
result_list = cast(list[qdrant_models.ScoredPoint], result)
filtered_results = [r for r in result_list if r.score >= min_score]
results_by_collection[collection].extend(filtered_results)
return results_by_collection
def merge_range_filter( def merge_range_filter(
filters: list[dict[str, Any]], key: str, val: Any filters: list[dict[str, Any]], key: str, val: Any
@ -174,7 +199,7 @@ async def search_embeddings(
search_filters = merge_filters(search_filters, key, val) search_filters = merge_filters(search_filters, key, val)
client = qdrant.get_qdrant_client() client = qdrant.get_qdrant_client()
results = query_chunks( results = await query_chunks(
client, client,
data, data,
modalities, modalities,