diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index c083b9f67..70c3f4115 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -29,7 +29,6 @@ from open_webui.config import ( RAG_EMBEDDING_QUERY_PREFIX, RAG_EMBEDDING_CONTENT_PREFIX, RAG_EMBEDDING_PREFIX_FIELD_NAME, - RAG_BM25_WEIGHT, ) log = logging.getLogger(__name__) @@ -117,6 +116,7 @@ def query_doc_with_hybrid_search( reranking_function, k_reranker: int, r: float, + bm25_weight: float, ) -> dict: try: log.debug(f"query_doc_with_hybrid_search:doc {collection_name}") @@ -132,18 +132,18 @@ def query_doc_with_hybrid_search( top_k=k, ) - if RAG_BM25_WEIGHT <= 0: + if bm25_weight <= 0: ensemble_retriever = EnsembleRetriever( retrievers=[vector_search_retriever], weights=[1.] ) - elif RAG_BM25_WEIGHT >= 1: + elif bm25_weight >= 1: ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever], weights=[1.] ) else: ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever, vector_search_retriever], - weights=[RAG_BM25_WEIGHT, 1. - RAG_BM25_WEIGHT] + weights=[bm25_weight, 1. - bm25_weight] ) compressor = RerankCompressor( @@ -325,6 +325,7 @@ def query_collection_with_hybrid_search( reranking_function, k_reranker: int, r: float, + bm25_weight: float, ) -> dict: results = [] error = False @@ -358,6 +359,7 @@ def query_collection_with_hybrid_search( reranking_function=reranking_function, k_reranker=k_reranker, r=r, + bm25_weight=bm25_weight, ) return result, None except Exception as e: @@ -445,6 +447,7 @@ def get_sources_from_files( reranking_function, k_reranker, r, + bm25_weight, hybrid_search, full_context=False, ): @@ -562,6 +565,7 @@ def get_sources_from_files( reranking_function=reranking_function, k_reranker=k_reranker, r=r, + bm25_weight=bm25_weight, ) except Exception as e: log.debug( diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index cdd71196c..e31dba299 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1782,6 +1782,11 @@ def query_doc_handler( if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD ), + bm25_weight=( + form_data.bm25_weight + if form_data.bm25_weight + else request.app.state.config.BM25_WEIGHT + ), user=user, ) else: @@ -1833,6 +1838,11 @@ def query_collection_handler( if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD ), + bm25_weight=( + form_data.bm25_weight + if form_data.bm25_weight + else request.app.state.config.BM25_WEIGHT + ), ) else: return query_collection( diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index c9095f931..c0ce2f063 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -603,6 +603,7 @@ async def chat_completion_files_handler( reranking_function=request.app.state.rf, k_reranker=request.app.state.config.TOP_K_RERANKER, r=request.app.state.config.RELEVANCE_THRESHOLD, + bm25_weight=request.app.state.config.BM25_WEIGHT, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, full_context=request.app.state.config.RAG_FULL_CONTEXT, ),