diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 767d5cce8..5ed47baaa 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -16,6 +16,8 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.models.users import UserModel from open_webui.models.files import Files +from open_webui.retrieval.vector.main import GetResult + from open_webui.env import ( SRC_LOG_LEVELS, OFFLINE_MODE, @@ -98,7 +100,7 @@ def get_doc(collection_name: str, user: UserModel = None): def query_doc_with_hybrid_search( collection_name: str, - collection_data, + collection_result: GetResult, query: str, embedding_function, k: int, @@ -108,8 +110,8 @@ def query_doc_with_hybrid_search( ) -> dict: try: bm25_retriever = BM25Retriever.from_texts( - texts=collection_data.documents[0], - metadatas=collection_data.metadatas[0], + texts=collection_result.documents[0], + metadatas=collection_result.metadatas[0], ) bm25_retriever.k = k @@ -135,9 +137,9 @@ def query_doc_with_hybrid_search( result = compression_retriever.invoke(query) - distances = [d.metadata.get("score") for d in collection_data] - documents = [d.page_content for d in collection_data] - metadatas = [d.metadata for d in collection_data] + distances = [d.metadata.get("score") for d in result] + documents = [d.page_content for d in result] + metadatas = [d.metadata for d in result] # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker if k < k_reranker: @@ -146,7 +148,8 @@ def query_doc_with_hybrid_search( ) sorted_items = sorted_items[:k] distances, documents, metadatas = map(list, zip(*sorted_items)) - collection_data = { + + result = { "distances": [distances], "documents": [documents], "metadatas": [metadatas], @@ -154,9 +157,9 @@ def query_doc_with_hybrid_search( log.info( "query_doc_with_hybrid_search:result " - + f'{collection_data["metadatas"]} {collection_data["distances"]}' + + f'{result["metadatas"]} {result["distances"]}' ) - return collection_data + return result except Exception as e: raise e @@ -279,20 +282,22 @@ def query_collection_with_hybrid_search( error = False # Fetch collection data once per collection sequentially # Avoid fetching the same data multiple times later - collection_data = {} + collection_results = {} for collection_name in collection_names: try: - collection_data[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name) + collection_results[collection_name] = VECTOR_DB_CLIENT.get( + collection_name=collection_name + ) except Exception as e: log.exception(f"Failed to fetch collection {collection_name}: {e}") - collection_data[collection_name] = None + collection_results[collection_name] = None for collection_name in collection_names: try: for query in queries: result = query_doc_with_hybrid_search( collection_name=collection_name, - collection_data=collection_data[collection_name], + collection_result=collection_results[collection_name], query=query, embedding_function=embedding_function, k=k,