From 4c19aaaa640bf9dc1e1ae8e8a10425792c07b828 Mon Sep 17 00:00:00 2001 From: weberm1 Date: Fri, 6 Jun 2025 12:05:19 +0200 Subject: [PATCH] Fix: Added compatibility of azure openai for individual rag config - fixed query doc handler and query collection handler to handle individual rag embedding functions --- backend/open_webui/routers/retrieval.py | 221 ++++++++++++------------ 1 file changed, 107 insertions(+), 114 deletions(-) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 189654e66..883444bcf 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -332,16 +332,26 @@ async def update_embedding_config( rag_config["embedding_model"] = form_data.embedding_model rag_config["embedding_batch_size"] = form_data.embedding_batch_size - - rag_config["openai_config"] = { - "url": form_data.openai_config.url, - "key": form_data.openai_config.key, - } - - rag_config["ollama_config"] = { - "url": form_data.ollama_config.url, - "key": form_data.ollama_config.key, - } + # Update OpenAI, Ollama, and Azure OpenAI configurations if provided + if form_data.openai_config is not None: + rag_config["openai_config"] = { + "url": form_data.openai_config.url, + "key": form_data.openai_config.key, + } + + if form_data.ollama_config is not None: + rag_config["ollama_config"] = { + "url": form_data.ollama_config.url, + "key": form_data.ollama_config.key, + } + + if form_data.azure_openai_config is not None: + rag_config["azure_openai_config"] = { + "url": form_data.azure_openai_config.url, + "key": form_data.azure_openai_config.key, + "version": form_data.azure_openai_config.version, + } + # Update the embedding function if not rag_config["embedding_model"] in request.app.state.ef: request.app.state.ef[rag_config["embedding_model"]] = get_ef( @@ -363,10 +373,20 @@ async def update_embedding_config( if rag_config["embedding_engine"] == "openai" else rag_config["ollama_config"]["key"] ), - rag_config["embedding_batch_size"] + rag_config["embedding_batch_size"], + azure_api_version=( + rag_config["azure_openai_config"]["version"] + if rag_config["embedding_engine"] == "azure_openai" + else None + ) ) # add model to state for reloading on startup - request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"]) + if rag_config["embedding_engine"] == "azure_openai": + request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append( + {rag_config["embedding_model"]: rag_config.get("azure_openai_config", {}).get("version")} + ) + else: + request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"]) request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() # add model to state for selectable reranking models if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]: @@ -387,9 +407,9 @@ async def update_embedding_config( "embedding_batch_size": rag_config["embedding_batch_size"], "openai_config": rag_config.get("openai_config", {}), "ollama_config": rag_config.get("ollama_config", {}), + "azure_openai_config": rag_config.get("azure_openai_config", {}), "DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"], "LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"], - "message": "Embedding configuration updated in the database.", } else: # Update the global configuration @@ -417,18 +437,21 @@ async def update_embedding_config( gc.collect() torch.cuda.empty_cache() - if request.app.state.config.RAG_EMBEDDING_ENGINE in [ - "ollama", - "openai", - "azure_openai", - ]: - if form_data.openai_config is not None: - request.app.state.config.RAG_OPENAI_API_BASE_URL = ( - form_data.openai_config.url - ) - request.app.state.config.RAG_OPENAI_API_KEY = ( - form_data.openai_config.key - ) + request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model + + if request.app.state.config.RAG_EMBEDDING_ENGINE in [ + "ollama", + "openai", + "azure_openai", + ]: + if form_data.openai_config is not None: + request.app.state.config.RAG_OPENAI_API_BASE_URL = ( + form_data.openai_config.url + ) + request.app.state.config.RAG_OPENAI_API_KEY = ( + form_data.openai_config.key + ) if form_data.ollama_config is not None: request.app.state.config.RAG_OLLAMA_BASE_URL = ( @@ -438,64 +461,20 @@ async def update_embedding_config( form_data.ollama_config.key ) - if form_data.azure_openai_config is not None: - request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( - form_data.azure_openai_config.url - ) - request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( - form_data.azure_openai_config.key - ) - request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( - form_data.azure_openai_config.version - ) + if form_data.azure_openai_config is not None: + request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( + form_data.azure_openai_config.url + ) + request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( + form_data.azure_openai_config.key + ) + request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( + form_data.azure_openai_config.version + ) - if form_data.azure_openai_config is not None: - request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( - form_data.azure_openai_config.url - ) - request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( - form_data.azure_openai_config.key - ) - request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( - form_data.azure_openai_config.version - ) - - if form_data.azure_openai_config is not None: - request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( - form_data.azure_openai_config.url - ) - request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( - form_data.azure_openai_config.key - ) - request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( - form_data.azure_openai_config.version - ) - - if form_data.azure_openai_config is not None: - request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( - form_data.azure_openai_config.url - ) - request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( - form_data.azure_openai_config.key - ) - request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( - form_data.azure_openai_config.version - ) - - if form_data.azure_openai_config is not None: - request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( - form_data.azure_openai_config.url - ) - request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( - form_data.azure_openai_config.key - ) - request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( - form_data.azure_openai_config.version - ) - - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( - form_data.embedding_batch_size - ) + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( + form_data.embedding_batch_size + ) # Update the embedding function if not form_data.embedding_model in request.app.state.ef: @@ -534,7 +513,12 @@ async def update_embedding_config( ), ) # add model to state for reloading on startup - request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL) + if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai": + request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append( + {request.app.state.config.RAG_EMBEDDING_MODEL: request.app.state.config.RAG_AZURE_OPENAI_API_VERSION} + ) + else: + request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL) request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() # add model to state for selectable embedding models if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]: @@ -1541,9 +1525,9 @@ def save_docs_to_vector_db( log.info(f"adding to collection {collection_name}") embedding_function = get_embedding_function( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.ef, + embedding_engine, + embedding_model, + request.app.state.ef[embedding_model], ( openai_api_base_url if embedding_engine == "openai" @@ -1554,9 +1538,9 @@ def save_docs_to_vector_db( ) ), ( - request.app.state.config.RAG_OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_API_KEY + openai_api_key + if embedding_engine == "openai" + else ollama_api_key ), embedding_batch_size, azure_api_version=( @@ -2370,7 +2354,24 @@ def query_doc_handler( user=Depends(get_verified_user), ): try: - if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: + # Try to get individual rag config for this collection + rag_config = {} + knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + + # Use config from rag_config if present, else fallback to global config + enable_hybrid = rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH) + embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL) + reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL) + top_k = form_data.k if form_data.k else rag_config.get("TOP_K", request.app.state.config.TOP_K) + top_k_reranker = form_data.k_reranker if form_data.k_reranker else rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER) + relevance_threshold = form_data.r if form_data.r else rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD) + hybrid_bm25_weight = getattr(form_data, "hybrid_bm25_weight", None) + if hybrid_bm25_weight is None: + hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT) + + if enable_hybrid: collection_results = {} collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( collection_name=form_data.collection_name @@ -2379,32 +2380,23 @@ def query_doc_handler( collection_name=form_data.collection_name, collection_result=collection_results[form_data.collection_name], query=form_data.query, - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model]( query, prefix=prefix, user=user ), - k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL], - k_reranker=form_data.k_reranker - or request.app.state.config.TOP_K_RERANKER, - r=( - form_data.r - if form_data.r - else request.app.state.config.RELEVANCE_THRESHOLD - ), - hybrid_bm25_weight=( - form_data.hybrid_bm25_weight - if form_data.hybrid_bm25_weight - else request.app.state.config.HYBRID_BM25_WEIGHT - ), + k=top_k, + reranking_function=request.app.state.rf[reranking_model], + k_reranker=top_k_reranker, + r=relevance_threshold, + hybrid_bm25_weight=hybrid_bm25_weight, user=user, ) else: return query_doc( collection_name=form_data.collection_name, - query_embedding=request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( + query_embedding=request.app.state.EMBEDDING_FUNCTION[embedding_model]( form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user ), - k=form_data.k if form_data.k else request.app.state.config.TOP_K, + k=top_k, user=user, ) except Exception as e: @@ -2436,11 +2428,10 @@ def query_collection_handler( return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( - query, prefix=prefix, user=user - ), + user=user, + ef=request.app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL], + reranking_function=request.app.state.rf, k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, r=( @@ -2453,14 +2444,16 @@ def query_collection_handler( if form_data.hybrid_bm25_weight else request.app.state.config.HYBRID_BM25_WEIGHT ), + embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL, + reranking_model=request.app.state.config.RAG_RERANKING_MODEL, ) else: return query_collection( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( - query, prefix=prefix, user=user - ), + user=user, + ef=request.app.state.EMBEDDING_FUNCTION, + embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL, k=form_data.k if form_data.k else request.app.state.config.TOP_K, )