diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 17f1438da..1d94e58fe 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -1,9 +1,8 @@ import logging import os -import uuid +import heapq from typing import Optional, Union -import asyncio import requests from huggingface_hub import snapshot_download @@ -34,8 +33,6 @@ class VectorSearchRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, - *, - run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: result = VECTOR_DB_CLIENT.search( collection_name=self.collection_name, @@ -47,15 +44,12 @@ class VectorSearchRetriever(BaseRetriever): metadatas = result.metadatas[0] documents = result.documents[0] - results = [] - for idx in range(len(ids)): - results.append( - Document( - metadata=metadatas[idx], - page_content=documents[idx], - ) - ) - return results + return [ + Document( + metadata=metadatas[idx], + page_content=documents[idx], + ) for idx in range(len(ids)) + ] def query_doc( @@ -64,16 +58,14 @@ def query_doc( k: int, ): try: - result = VECTOR_DB_CLIENT.search( + if result := VECTOR_DB_CLIENT.search( collection_name=collection_name, vectors=[query_embedding], limit=k, - ) - - if result: + ): log.info(f"query_doc:result {result.ids} {result.metadatas}") - return result + return result except Exception as e: print(e) raise e @@ -135,44 +127,38 @@ def query_doc_with_hybrid_search( def merge_and_sort_query_results( query_results: list[dict], k: int, reverse: bool = False ) -> list[dict]: - # Initialize lists to store combined data - combined_distances = [] - combined_documents = [] - combined_metadatas = [] - - for data in query_results: - combined_distances.extend(data["distances"][0]) - combined_documents.extend(data["documents"][0]) - combined_metadatas.extend(data["metadatas"][0]) - - # Create a list of tuples (distance, document, metadata) - combined = list(zip(combined_distances, combined_documents, combined_metadatas)) - - # Sort the list based on distances - combined.sort(key=lambda x: x[0], reverse=reverse) - - # We don't have anything :-( - if not combined: - sorted_distances = [] - sorted_documents = [] - sorted_metadatas = [] + if not query_results: + return { + "distances": [[]], + "documents": [[]], + "metadatas": [[]], + } + + combined = ( + (data.get("distances", [float('inf')])[0], + data.get("documents", [None])[0], + data.get("metadatas", [{}])[0]) + for data in query_results + ) + + if reverse: + top_k = heapq.nlargest(k, combined, key=lambda x: x[0]) else: - # Unzip the sorted list - sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) - - # Slicing the lists to include only k elements - sorted_distances = list(sorted_distances)[:k] - sorted_documents = list(sorted_documents)[:k] - sorted_metadatas = list(sorted_metadatas)[:k] - - # Create the output dictionary - result = { - "distances": [sorted_distances], - "documents": [sorted_documents], - "metadatas": [sorted_metadatas], - } - - return result + top_k = heapq.nsmallest(k, combined, key=lambda x: x[0]) + + if not top_k: + return { + "distances": [[]], + "documents": [[]], + "metadatas": [[]], + } + else: + sorted_distances, sorted_documents, sorted_metadatas = zip(*top_k) + return { + "distances": [sorted_distances], + "documents": [sorted_documents], + "metadatas": [sorted_metadatas], + } def query_collection( @@ -185,19 +171,18 @@ def query_collection( for query in queries: query_embedding = embedding_function(query) for collection_name in collection_names: - if collection_name: - try: - result = query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ) - if result is not None: - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") - else: - pass + if not collection_name: + continue + + try: + if result := query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ): + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") return merge_and_sort_query_results(results, k=k) @@ -213,8 +198,8 @@ def query_collection_with_hybrid_search( results = [] error = False for collection_name in collection_names: - try: - for query in queries: + for query in queries: + try: result = query_doc_with_hybrid_search( collection_name=collection_name, query=query, @@ -224,11 +209,11 @@ def query_collection_with_hybrid_search( r=r, ) results.append(result) - except Exception as e: - log.exception( - "Error when querying the collection with " f"hybrid_search: {e}" - ) - error = True + except Exception as e: + log.exception( + "Error when querying the collection with " f"hybrid_search: {e}" + ) + error = True if error: raise Exception( @@ -259,10 +244,10 @@ def get_embedding_function( def generate_multiple(query, func): if isinstance(query, list): - embeddings = [] - for i in range(0, len(query), embedding_batch_size): - embeddings.extend(func(query[i : i + embedding_batch_size])) - return embeddings + return [ + func(query[i : i + embedding_batch_size]) + for i in range(0, len(query), embedding_batch_size) + ] else: return func(query) @@ -433,25 +418,26 @@ def generate_openai_batch_embeddings( def generate_ollama_batch_embeddings( model: str, texts: list[str], url: str, key: str = "" ) -> Optional[list[list[float]]]: + r = requests.post( + f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + }, + json={"input": texts, "model": model}, + ) try: - r = requests.post( - f"{url}/api/embed", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", - }, - json={"input": texts, "model": model}, - ) r.raise_for_status() - data = r.json() - - if "embeddings" in data: - return data["embeddings"] - else: - raise "Something went wrong :/" except Exception as e: print(e) return None + + data = r.json() + + if 'embeddings' not in data: + raise "Something went wrong :/" + + return data['embeddings'] def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):