mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 23:24:43 +02:00
parralem search of collecitons
This commit is contained in:
parent
55809f3980
commit
5e836337e2
@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import qdrant_client
|
||||
@ -54,7 +55,7 @@ def annotated_chunk(
|
||||
)
|
||||
|
||||
|
||||
def query_chunks(
|
||||
async def query_chunks(
|
||||
client: qdrant_client.QdrantClient,
|
||||
upload_data: list[extract.DataChunk],
|
||||
allowed_modalities: set[str],
|
||||
@ -73,22 +74,46 @@ def query_chunks(
|
||||
|
||||
vectors = embedder(chunks, input_type="query")
|
||||
|
||||
return {
|
||||
collection: [
|
||||
r
|
||||
for vector in vectors
|
||||
for r in qdrant.search_vectors(
|
||||
# Create all search tasks to run in parallel
|
||||
search_tasks = []
|
||||
task_metadata = [] # Keep track of which collection and vector each task corresponds to
|
||||
|
||||
for collection in allowed_modalities:
|
||||
for vector in vectors:
|
||||
task = asyncio.to_thread(
|
||||
qdrant.search_vectors,
|
||||
client=client,
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
filter_params=filters,
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
for collection in allowed_modalities
|
||||
search_tasks.append(task)
|
||||
task_metadata.append((collection, vector))
|
||||
|
||||
# Run all searches in parallel
|
||||
if not search_tasks:
|
||||
return {}
|
||||
|
||||
search_results = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||
|
||||
# Group results by collection
|
||||
results_by_collection: dict[str, list[qdrant_models.ScoredPoint]] = {
|
||||
collection: [] for collection in allowed_modalities
|
||||
}
|
||||
|
||||
for (collection, _), result in zip(task_metadata, search_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Search failed for collection {collection}: {result}")
|
||||
continue
|
||||
|
||||
# Filter by min_score and add to collection results
|
||||
result_list = cast(list[qdrant_models.ScoredPoint], result)
|
||||
filtered_results = [r for r in result_list if r.score >= min_score]
|
||||
results_by_collection[collection].extend(filtered_results)
|
||||
|
||||
return results_by_collection
|
||||
|
||||
|
||||
def merge_range_filter(
|
||||
filters: list[dict[str, Any]], key: str, val: Any
|
||||
@ -174,7 +199,7 @@ async def search_embeddings(
|
||||
search_filters = merge_filters(search_filters, key, val)
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
results = query_chunks(
|
||||
results = await query_chunks(
|
||||
client,
|
||||
data,
|
||||
modalities,
|
||||
|
Loading…
x
Reference in New Issue
Block a user