parralem search of collecitons

This commit is contained in:
Daniel O'Connell 2025-06-10 16:08:36 +02:00
parent 55809f3980
commit 5e836337e2

View File

@ -1,6 +1,7 @@
import base64
import io
import logging
import asyncio
from typing import Any, Callable, Optional, cast
import qdrant_client
@ -54,7 +55,7 @@ def annotated_chunk(
)
def query_chunks(
async def query_chunks(
client: qdrant_client.QdrantClient,
upload_data: list[extract.DataChunk],
allowed_modalities: set[str],
@ -73,22 +74,46 @@ def query_chunks(
vectors = embedder(chunks, input_type="query")
return {
collection: [
r
for vector in vectors
for r in qdrant.search_vectors(
# Create all search tasks to run in parallel
search_tasks = []
task_metadata = [] # Keep track of which collection and vector each task corresponds to
for collection in allowed_modalities:
for vector in vectors:
task = asyncio.to_thread(
qdrant.search_vectors,
client=client,
collection_name=collection,
query_vector=vector,
limit=limit,
filter_params=filters,
)
if r.score >= min_score
]
for collection in allowed_modalities
search_tasks.append(task)
task_metadata.append((collection, vector))
# Run all searches in parallel
if not search_tasks:
return {}
search_results = await asyncio.gather(*search_tasks, return_exceptions=True)
# Group results by collection
results_by_collection: dict[str, list[qdrant_models.ScoredPoint]] = {
collection: [] for collection in allowed_modalities
}
for (collection, _), result in zip(task_metadata, search_results):
if isinstance(result, Exception):
logger.error(f"Search failed for collection {collection}: {result}")
continue
# Filter by min_score and add to collection results
result_list = cast(list[qdrant_models.ScoredPoint], result)
filtered_results = [r for r in result_list if r.score >= min_score]
results_by_collection[collection].extend(filtered_results)
return results_by_collection
def merge_range_filter(
filters: list[dict[str, Any]], key: str, val: Any
@ -174,7 +199,7 @@ async def search_embeddings(
search_filters = merge_filters(search_filters, key, val)
client = qdrant.get_qdrant_client()
results = query_chunks(
results = await query_chunks(
client,
data,
modalities,