From fd0170c179ae01dc36056efdca1f46e885d286a4 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 30 Dec 2024 16:55:29 -0800 Subject: [PATCH] revert --- backend/open_webui/retrieval/utils.py | 172 ++++++++++++++------------ 1 file changed, 93 insertions(+), 79 deletions(-) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 7e8771bd6..c95367e6c 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -1,8 +1,9 @@ import logging import os -import heapq +import uuid from typing import Optional, Union +import asyncio import requests from huggingface_hub import snapshot_download @@ -33,6 +34,8 @@ class VectorSearchRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, + *, + run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: result = VECTOR_DB_CLIENT.search( collection_name=self.collection_name, @@ -44,12 +47,15 @@ class VectorSearchRetriever(BaseRetriever): metadatas = result.metadatas[0] documents = result.documents[0] - return [ - Document( - metadata=metadatas[idx], - page_content=documents[idx], - ) for idx in range(len(ids)) - ] + results = [] + for idx in range(len(ids)): + results.append( + Document( + metadata=metadatas[idx], + page_content=documents[idx], + ) + ) + return results def query_doc( @@ -58,14 +64,16 @@ def query_doc( k: int, ): try: - if result := VECTOR_DB_CLIENT.search( + result = VECTOR_DB_CLIENT.search( collection_name=collection_name, vectors=[query_embedding], limit=k, - ): + ) + + if result: log.info(f"query_doc:result {result.ids} {result.metadatas}") - return result + return result except Exception as e: print(e) raise e @@ -127,38 +135,44 @@ def query_doc_with_hybrid_search( def merge_and_sort_query_results( query_results: list[dict], k: int, reverse: bool = False ) -> list[dict]: - if not query_results: - return { - "distances": [[]], - "documents": [[]], - "metadatas": [[]], - } - - combined = ( - (data.get("distances", [float('inf')])[0], - data.get("documents", [None])[0], - data.get("metadatas", [{}])[0]) - for data in query_results - ) - - if reverse: - top_k = heapq.nlargest(k, combined, key=lambda x: x[0]) + # Initialize lists to store combined data + combined_distances = [] + combined_documents = [] + combined_metadatas = [] + + for data in query_results: + combined_distances.extend(data["distances"][0]) + combined_documents.extend(data["documents"][0]) + combined_metadatas.extend(data["metadatas"][0]) + + # Create a list of tuples (distance, document, metadata) + combined = list(zip(combined_distances, combined_documents, combined_metadatas)) + + # Sort the list based on distances + combined.sort(key=lambda x: x[0], reverse=reverse) + + # We don't have anything :-( + if not combined: + sorted_distances = [] + sorted_documents = [] + sorted_metadatas = [] else: - top_k = heapq.nsmallest(k, combined, key=lambda x: x[0]) - - if not top_k: - return { - "distances": [[]], - "documents": [[]], - "metadatas": [[]], - } - else: - sorted_distances, sorted_documents, sorted_metadatas = zip(*top_k) - return { - "distances": [sorted_distances], - "documents": [sorted_documents], - "metadatas": [sorted_metadatas], - } + # Unzip the sorted list + sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) + + # Slicing the lists to include only k elements + sorted_distances = list(sorted_distances)[:k] + sorted_documents = list(sorted_documents)[:k] + sorted_metadatas = list(sorted_metadatas)[:k] + + # Create the output dictionary + result = { + "distances": [sorted_distances], + "documents": [sorted_documents], + "metadatas": [sorted_metadatas], + } + + return result def query_collection( @@ -171,18 +185,19 @@ def query_collection( for query in queries: query_embedding = embedding_function(query) for collection_name in collection_names: - if not collection_name: - continue - - try: - if result := query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ): - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") + if collection_name: + try: + result = query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ) + if result is not None: + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + else: + pass return merge_and_sort_query_results(results, k=k) @@ -198,8 +213,8 @@ def query_collection_with_hybrid_search( results = [] error = False for collection_name in collection_names: - for query in queries: - try: + try: + for query in queries: result = query_doc_with_hybrid_search( collection_name=collection_name, query=query, @@ -209,11 +224,11 @@ def query_collection_with_hybrid_search( r=r, ) results.append(result) - except Exception as e: - log.exception( - "Error when querying the collection with " f"hybrid_search: {e}" - ) - error = True + except Exception as e: + log.exception( + "Error when querying the collection with " f"hybrid_search: {e}" + ) + error = True if error: raise Exception( @@ -244,10 +259,10 @@ def get_embedding_function( def generate_multiple(query, func): if isinstance(query, list): - return [ - func(query[i : i + embedding_batch_size]) - for i in range(0, len(query), embedding_batch_size) - ] + embeddings = [] + for i in range(0, len(query), embedding_batch_size): + embeddings.extend(func(query[i : i + embedding_batch_size])) + return embeddings else: return func(query) @@ -421,26 +436,25 @@ def generate_openai_batch_embeddings( def generate_ollama_batch_embeddings( model: str, texts: list[str], url: str, key: str = "" ) -> Optional[list[list[float]]]: - r = requests.post( - f"{url}/api/embed", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", - }, - json={"input": texts, "model": model}, - ) try: + r = requests.post( + f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + }, + json={"input": texts, "model": model}, + ) r.raise_for_status() + data = r.json() + + if "embeddings" in data: + return data["embeddings"] + else: + raise "Something went wrong :/" except Exception as e: print(e) return None - - data = r.json() - - if 'embeddings' not in data: - raise "Something went wrong :/" - - return data['embeddings'] def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):