mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 23:24:43 +02:00
parralem search of collecitons
This commit is contained in:
parent
55809f3980
commit
5e836337e2
@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
from typing import Any, Callable, Optional, cast
|
from typing import Any, Callable, Optional, cast
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
@ -54,7 +55,7 @@ def annotated_chunk(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def query_chunks(
|
async def query_chunks(
|
||||||
client: qdrant_client.QdrantClient,
|
client: qdrant_client.QdrantClient,
|
||||||
upload_data: list[extract.DataChunk],
|
upload_data: list[extract.DataChunk],
|
||||||
allowed_modalities: set[str],
|
allowed_modalities: set[str],
|
||||||
@ -73,22 +74,46 @@ def query_chunks(
|
|||||||
|
|
||||||
vectors = embedder(chunks, input_type="query")
|
vectors = embedder(chunks, input_type="query")
|
||||||
|
|
||||||
return {
|
# Create all search tasks to run in parallel
|
||||||
collection: [
|
search_tasks = []
|
||||||
r
|
task_metadata = [] # Keep track of which collection and vector each task corresponds to
|
||||||
for vector in vectors
|
|
||||||
for r in qdrant.search_vectors(
|
for collection in allowed_modalities:
|
||||||
|
for vector in vectors:
|
||||||
|
task = asyncio.to_thread(
|
||||||
|
qdrant.search_vectors,
|
||||||
client=client,
|
client=client,
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
query_vector=vector,
|
query_vector=vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
filter_params=filters,
|
filter_params=filters,
|
||||||
)
|
)
|
||||||
if r.score >= min_score
|
search_tasks.append(task)
|
||||||
]
|
task_metadata.append((collection, vector))
|
||||||
for collection in allowed_modalities
|
|
||||||
|
# Run all searches in parallel
|
||||||
|
if not search_tasks:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
search_results = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Group results by collection
|
||||||
|
results_by_collection: dict[str, list[qdrant_models.ScoredPoint]] = {
|
||||||
|
collection: [] for collection in allowed_modalities
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (collection, _), result in zip(task_metadata, search_results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"Search failed for collection {collection}: {result}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Filter by min_score and add to collection results
|
||||||
|
result_list = cast(list[qdrant_models.ScoredPoint], result)
|
||||||
|
filtered_results = [r for r in result_list if r.score >= min_score]
|
||||||
|
results_by_collection[collection].extend(filtered_results)
|
||||||
|
|
||||||
|
return results_by_collection
|
||||||
|
|
||||||
|
|
||||||
def merge_range_filter(
|
def merge_range_filter(
|
||||||
filters: list[dict[str, Any]], key: str, val: Any
|
filters: list[dict[str, Any]], key: str, val: Any
|
||||||
@ -174,7 +199,7 @@ async def search_embeddings(
|
|||||||
search_filters = merge_filters(search_filters, key, val)
|
search_filters = merge_filters(search_filters, key, val)
|
||||||
|
|
||||||
client = qdrant.get_qdrant_client()
|
client = qdrant.get_qdrant_client()
|
||||||
results = query_chunks(
|
results = await query_chunks(
|
||||||
client,
|
client,
|
||||||
data,
|
data,
|
||||||
modalities,
|
modalities,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user