make bm25_weight a regular parameter of query_doc.. / get_sources_from_files functions

This commit is contained in:
Jan Kessler 2025-05-20 11:21:14 +02:00
parent b5ddaf6417
commit 308d8ac04a
No known key found for this signature in database
GPG Key ID: FCF0DCB4ADFC53E7
3 changed files with 19 additions and 4 deletions

View File

@ -29,7 +29,6 @@ from open_webui.config import (
RAG_EMBEDDING_QUERY_PREFIX, RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX, RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME, RAG_EMBEDDING_PREFIX_FIELD_NAME,
RAG_BM25_WEIGHT,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -117,6 +116,7 @@ def query_doc_with_hybrid_search(
reranking_function, reranking_function,
k_reranker: int, k_reranker: int,
r: float, r: float,
bm25_weight: float,
) -> dict: ) -> dict:
try: try:
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}") log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
@ -132,18 +132,18 @@ def query_doc_with_hybrid_search(
top_k=k, top_k=k,
) )
if RAG_BM25_WEIGHT <= 0: if bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever( ensemble_retriever = EnsembleRetriever(
retrievers=[vector_search_retriever], weights=[1.] retrievers=[vector_search_retriever], weights=[1.]
) )
elif RAG_BM25_WEIGHT >= 1: elif bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever( ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever], weights=[1.] retrievers=[bm25_retriever], weights=[1.]
) )
else: else:
ensemble_retriever = EnsembleRetriever( ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever], retrievers=[bm25_retriever, vector_search_retriever],
weights=[RAG_BM25_WEIGHT, 1. - RAG_BM25_WEIGHT] weights=[bm25_weight, 1. - bm25_weight]
) )
compressor = RerankCompressor( compressor = RerankCompressor(
@ -325,6 +325,7 @@ def query_collection_with_hybrid_search(
reranking_function, reranking_function,
k_reranker: int, k_reranker: int,
r: float, r: float,
bm25_weight: float,
) -> dict: ) -> dict:
results = [] results = []
error = False error = False
@ -358,6 +359,7 @@ def query_collection_with_hybrid_search(
reranking_function=reranking_function, reranking_function=reranking_function,
k_reranker=k_reranker, k_reranker=k_reranker,
r=r, r=r,
bm25_weight=bm25_weight,
) )
return result, None return result, None
except Exception as e: except Exception as e:
@ -445,6 +447,7 @@ def get_sources_from_files(
reranking_function, reranking_function,
k_reranker, k_reranker,
r, r,
bm25_weight,
hybrid_search, hybrid_search,
full_context=False, full_context=False,
): ):
@ -562,6 +565,7 @@ def get_sources_from_files(
reranking_function=reranking_function, reranking_function=reranking_function,
k_reranker=k_reranker, k_reranker=k_reranker,
r=r, r=r,
bm25_weight=bm25_weight,
) )
except Exception as e: except Exception as e:
log.debug( log.debug(

View File

@ -1782,6 +1782,11 @@ def query_doc_handler(
if form_data.r if form_data.r
else request.app.state.config.RELEVANCE_THRESHOLD else request.app.state.config.RELEVANCE_THRESHOLD
), ),
bm25_weight=(
form_data.bm25_weight
if form_data.bm25_weight
else request.app.state.config.BM25_WEIGHT
),
user=user, user=user,
) )
else: else:
@ -1833,6 +1838,11 @@ def query_collection_handler(
if form_data.r if form_data.r
else request.app.state.config.RELEVANCE_THRESHOLD else request.app.state.config.RELEVANCE_THRESHOLD
), ),
bm25_weight=(
form_data.bm25_weight
if form_data.bm25_weight
else request.app.state.config.BM25_WEIGHT
),
) )
else: else:
return query_collection( return query_collection(

View File

@ -603,6 +603,7 @@ async def chat_completion_files_handler(
reranking_function=request.app.state.rf, reranking_function=request.app.state.rf,
k_reranker=request.app.state.config.TOP_K_RERANKER, k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD, r=request.app.state.config.RELEVANCE_THRESHOLD,
bm25_weight=request.app.state.config.BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT, full_context=request.app.state.config.RAG_FULL_CONTEXT,
), ),