diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index aa7b57dfc..767d5cce8 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -1,9 +1,7 @@ import logging import os -import uuid from typing import Optional, Union -import asyncio import requests import hashlib @@ -12,10 +10,8 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document - from open_webui.config import VECTOR_DB from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.utils.misc import get_last_user_message, calculate_sha256_string from open_webui.models.users import UserModel from open_webui.models.files import Files @@ -102,6 +98,7 @@ def get_doc(collection_name: str, user: UserModel = None): def query_doc_with_hybrid_search( collection_name: str, + collection_data, query: str, embedding_function, k: int, @@ -110,11 +107,9 @@ def query_doc_with_hybrid_search( r: float, ) -> dict: try: - result = VECTOR_DB_CLIENT.get(collection_name=collection_name) - bm25_retriever = BM25Retriever.from_texts( - texts=result.documents[0], - metadatas=result.metadatas[0], + texts=collection_data.documents[0], + metadatas=collection_data.metadatas[0], ) bm25_retriever.k = k @@ -140,9 +135,9 @@ def query_doc_with_hybrid_search( result = compression_retriever.invoke(query) - distances = [d.metadata.get("score") for d in result] - documents = [d.page_content for d in result] - metadatas = [d.metadata for d in result] + distances = [d.metadata.get("score") for d in collection_data] + documents = [d.page_content for d in collection_data] + metadatas = [d.metadata for d in collection_data] # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker if k < k_reranker: @@ -151,7 +146,7 @@ def query_doc_with_hybrid_search( ) sorted_items = sorted_items[:k] distances, documents, metadatas = map(list, zip(*sorted_items)) - result = { + collection_data = { "distances": [distances], "documents": [documents], "metadatas": [metadatas], @@ -159,9 +154,9 @@ def query_doc_with_hybrid_search( log.info( "query_doc_with_hybrid_search:result " - + f'{result["metadatas"]} {result["distances"]}' + + f'{collection_data["metadatas"]} {collection_data["distances"]}' ) - return result + return collection_data except Exception as e: raise e @@ -282,11 +277,22 @@ def query_collection_with_hybrid_search( ) -> dict: results = [] error = False + # Fetch collection data once per collection sequentially + # Avoid fetching the same data multiple times later + collection_data = {} + for collection_name in collection_names: + try: + collection_data[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name) + except Exception as e: + log.exception(f"Failed to fetch collection {collection_name}: {e}") + collection_data[collection_name] = None + for collection_name in collection_names: try: for query in queries: result = query_doc_with_hybrid_search( collection_name=collection_name, + collection_data=collection_data[collection_name], query=query, embedding_function=embedding_function, k=k,