From 925bfe840b46df424360f230a93e748289df0139 Mon Sep 17 00:00:00 2001 From: mikhail-khludnev Date: Tue, 18 Feb 2025 16:39:02 +0300 Subject: [PATCH] dedupe results from multiple queries --- backend/open_webui/retrieval/utils.py | 35 ++++++++++++++++----------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 437183369..e5ba55878 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -138,37 +138,44 @@ def query_doc_with_hybrid_search( def merge_and_sort_query_results( - query_results: list[dict], k: int, reverse: bool = False + query_results: list[dict], k: int, reverse: bool = False ) -> list[dict]: # Initialize lists to store combined data combined_distances = [] combined_documents = [] combined_metadatas = [] + combined_ids = [] for data in query_results: combined_distances.extend(data["distances"][0]) combined_documents.extend(data["documents"][0]) combined_metadatas.extend(data["metadatas"][0]) + # DISTINCT(chunk_id,file_id) - in case if id (chunk_ids) become ordinals + combined_ids.extend([id + meta["file_id"] for id, meta in zip(data["ids"][0], data["metadatas"][0])]) - # Create a list of tuples (distance, document, metadata) - combined = list(zip(combined_distances, combined_documents, combined_metadatas)) + # Create a list of tuples (distance, document, metadata, ids) + combined = list(zip(combined_distances, combined_documents, combined_metadatas, combined_ids)) # Sort the list based on distances combined.sort(key=lambda x: x[0], reverse=reverse) - # We don't have anything :-( - if not combined: - sorted_distances = [] - sorted_documents = [] - sorted_metadatas = [] - else: + sorted_distances = [] + sorted_documents = [] + sorted_metadatas = [] + # Otherwise we don't have anything :-( + if combined: # Unzip the sorted list - sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) - + all_distances, all_documents, all_metadatas, all_ids = zip(*combined) + seen_ids = set() # Slicing the lists to include only k elements - sorted_distances = list(sorted_distances)[:k] - sorted_documents = list(sorted_documents)[:k] - sorted_metadatas = list(sorted_metadatas)[:k] + for index, id in enumerate(all_ids): + if id not in seen_ids: + sorted_distances.append(all_distances[index]) + sorted_documents.append(all_documents[index]) + sorted_metadatas.append(all_metadatas[index]) + seen_ids.add(id) + if len(sorted_distances) >= k: + break # Create the output dictionary result = {