diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index d77e4eded..bfa3f5baf 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -549,7 +549,7 @@ async def chat_image_generation_handler( async def chat_completion_files_handler( - request: Request, body: dict, user: UserModel + request: Request, body: dict, user: UserModel, model_knowledge ) -> tuple[dict, dict[str, list]]: sources = [] @@ -587,6 +587,20 @@ async def chat_completion_files_handler( queries = [get_last_user_message(body["messages"])] try: + # check if individual rag config is used + rag_config = {} + if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + rag_config = model_knowledge[0].get("rag_config") + + k=rag_config.get("TOP_K", request.app.state.config.TOP_K) + reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL) + reranking_function=request.app.state.rf[reranking_model] if reranking_model else None + k_reranker=rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER) + r=rag_config.get("RELEVANCE THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD) + hybrid_search=rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH) + full_context=rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT) + embedding_model = rag_config.get("RAG_EMBEDDING_MODEL", request.app.state.config.RAG_EMBEDDING_MODEL) + # Offload get_sources_from_files to a separate thread loop = asyncio.get_running_loop() with ThreadPoolExecutor() as executor: @@ -596,15 +610,15 @@ async def chat_completion_files_handler( request=request, files=files, queries=queries, - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model]( query, prefix=prefix, user=user ), - k=request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, - k_reranker=request.app.state.config.TOP_K_RERANKER, - r=request.app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - full_context=request.app.state.config.RAG_FULL_CONTEXT, + k=k, + reranking_function=reranking_function, + k_reranker=k_reranker, + r=r, + hybrid_search=hybrid_search, + full_context=full_context, ), ) except Exception as e: @@ -862,7 +876,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.exception(e) try: - form_data, flags = await chat_completion_files_handler(request, form_data, user) + form_data, flags = await chat_completion_files_handler(request, form_data, user, model_knowledge) sources.extend(flags.get("sources", [])) except Exception as e: log.exception(e) @@ -898,20 +912,24 @@ async def process_chat_payload(request, form_data, user, metadata, model): f"With a 0 relevancy threshold for RAG, the context cannot be empty" ) + # Adjusted RAG template step to use knowledge-base-specific configuration + rag_template_config = request.app.state.config.RAG_TEMPLATE + + if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + rag_template_config = model_knowledge[0].get("rag_config").get( + "RAG_TEMPLATE", request.app.state.config.RAG_TEMPLATE + ) + # Workaround for Ollama 2.0+ system prompt issue # TODO: replace with add_or_update_system_message if model.get("owned_by") == "ollama": form_data["messages"] = prepend_to_first_user_message_content( - rag_template( - request.app.state.config.RAG_TEMPLATE, context_string, prompt - ), + rag_template(rag_template_config, context_string, prompt), form_data["messages"], ) else: form_data["messages"] = add_or_update_system_message( - rag_template( - request.app.state.config.RAG_TEMPLATE, context_string, prompt - ), + rag_template(rag_template_config, context_string, prompt), form_data["messages"], )