From d1ed648680c82fa817b3b4720d17cce9d5295cb6 Mon Sep 17 00:00:00 2001 From: weberm1 Date: Wed, 21 May 2025 12:25:01 +0200 Subject: [PATCH] Fix: Added cleanup options for individual rag settings; added reload options for rag settings; adjusted permissions that users can use individual rag settings --- backend/open_webui/routers/retrieval.py | 128 ++++++++++++++++++++---- 1 file changed, 109 insertions(+), 19 deletions(-) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 49f739ce5..6318ec2cb 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -298,7 +298,7 @@ class EmbeddingModelUpdateForm(BaseModel): @router.post("/embedding/update") async def update_embedding_config( - request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_verified_user) ): """ Update the embedding model configuration. @@ -313,6 +313,17 @@ async def update_embedding_config( log.info( f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}" ) + + # Check if model is in use elsewhere, otherwise free up memory + in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('embedding_model'), model_type="embedding_model", id=form_data.knowledge_id) + + if not in_use and not request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] == rag_config.get("embedding_model"): + del request.app.state.ef[rag_config.get("embedding_model")] + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + # Update embedding-related fields rag_config["embedding_engine"] = form_data.embedding_engine rag_config["embedding_model"] = form_data.embedding_model @@ -329,7 +340,7 @@ async def update_embedding_config( "key": form_data.ollama_config.key, } # Update the embedding function - if not request.app.state.ef.get("embedding_model"): + if not rag_config["embedding_model"] in request.app.state.ef: request.app.state.ef[rag_config["embedding_model"]] = get_ef( rag_config["embedding_engine"], rag_config["embedding_model"], @@ -351,7 +362,16 @@ async def update_embedding_config( ), rag_config["embedding_batch_size"] ) - + # add model to state for reloading on startup + request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"]) + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + # add model to state for selectable reranking models + if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]: + request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"]) + request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save() + rag_config["DOWNLOADED_EMBEDDING_MODELS"] = request.app.state.config.DOWNLOADED_EMBEDDING_MODELS + rag_config["LOADED_EMBEDDING_MODELS"] = request.app.state.config.LOADED_EMBEDDING_MODELS + # Save the updated configuration to the database Knowledges.update_rag_config_by_id( id=form_data.knowledge_id, rag_config=rag_config @@ -364,6 +384,8 @@ async def update_embedding_config( "embedding_batch_size": rag_config["embedding_batch_size"], "openai_config": rag_config.get("openai_config", {}), "ollama_config": rag_config.get("ollama_config", {}), + "DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"], + "LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"], "message": "Embedding configuration updated in the database.", } else: @@ -371,6 +393,16 @@ async def update_embedding_config( log.info( f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) + + # Check if model is in use elsewhere, otherwise free up memory + in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_EMBEDDING_MODEL, model_type="embedding_model") + if not in_use: + del request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model @@ -385,7 +417,7 @@ async def update_embedding_config( request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size # Update the embedding function - if not request.app.state.ef.get(form_data.embedding_model): + if not form_data.embedding_model in request.app.state.ef: request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, @@ -407,6 +439,13 @@ async def update_embedding_config( ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) + # add model to state for reloading on startup + request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL) + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + # add model to state for selectable reranking models + if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]: + request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL) + request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save() return { "status": True, @@ -421,6 +460,8 @@ async def update_embedding_config( "url": request.app.state.config.RAG_OLLAMA_BASE_URL, "key": request.app.state.config.RAG_OLLAMA_API_KEY, }, + "LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS, + "DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS, "message": "Embedding configuration updated globally.", } except Exception as e: @@ -437,7 +478,7 @@ class RerankingModelUpdateForm(BaseModel): @router.post("/reranking/update") async def update_reranking_config( - request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_verified_user) ): """ Update the reranking model configuration. @@ -455,16 +496,28 @@ async def update_reranking_config( log.info( f"Updating reranking model: {rag_config.get('reranking_model')} to {form_data.reranking_model}" ) + rag_config["reranking_model"] = form_data.reranking_model if form_data.reranking_model else None Knowledges.update_rag_config_by_id( id=form_data.knowledge_id, rag_config=rag_config ) try: - if not request.app.state.rf.get(rag_config["reranking_model"]): + if not rag_config["reranking_model"] in request.app.state.rf: request.app.state.rf[rag_config["reranking_model"]] = get_rf( rag_config["reranking_model"], True, ) + # add model to state for reloading on startup + request.app.state.config.LOADED_RERANKING_MODELS.append(rag_config["reranking_model"]) + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() + # add model to state for selectable reranking models + if rag_config["reranking_model"] not in request.app.state.DOWNLOADED_RERANKING_MODELS: + request.app.state.config.DOWNLOADED_RERANKING_MODELS.append(rag_config["reranking_model"]) + request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save() + + + rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS + rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS except Exception as e: log.error(f"Error loading reranking model: {e}") @@ -473,6 +526,8 @@ async def update_reranking_config( return { "status": True, "reranking_model": rag_config["reranking_model"], + "LOADED_RERANKING_MODELS": rag_config["LOADED_RERANKING_MODELS"], + "DOWNLOADED_RERANKING_MODELS": rag_config["DOWNLOADED_RERANKING_MODELS"], "message": "Reranking model updated in the database.", } else: @@ -483,13 +538,20 @@ async def update_reranking_config( request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model try: - if request.app.state.rf.get(form_data.reranking_model): - request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = request.app.state.rf[form_data.reranking_model] - else: + if not form_data.reranking_model in request.app.state.rf: request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf( request.app.state.config.RAG_RERANKING_MODEL, True, ) + # add model to state for reloading on startup + request.app.state.config.LOADED_RERANKING_MODELS.append(request.app.state.config.RAG_RERANKING_MODEL) + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() + + # add model to state for selectable reranking models + if not request.app.state.config.RAG_RERANKING_MODEL in request.app.state.config.DOWNLOADED_RERANKING_MODELS: + request.app.state.config.DOWNLOADED_RERANKING_MODELS.append(request.app.state.config.RAG_RERANKING_MODEL) + request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save() + except Exception as e: log.error(f"Error loading reranking model: {e}") request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False @@ -497,6 +559,8 @@ async def update_reranking_config( return { "status": True, "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, + "LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS, + "DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS, "message": "Reranking model updated globally.", } except Exception as e: @@ -508,7 +572,7 @@ async def update_reranking_config( @router.post("/config") -async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_admin_user)): +async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_verified_user)): """ Retrieve the full RAG configuration. If DEFAULT_RAG_SETTINGS is True, return the default settings. @@ -600,7 +664,11 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user= "YOUTUBE_LOADER_PROXY_URL": web_config.get("YOUTUBE_LOADER_PROXY_URL", request.app.state.config.YOUTUBE_LOADER_PROXY_URL), "YOUTUBE_LOADER_TRANSLATION": web_config.get("YOUTUBE_LOADER_TRANSLATION", request.app.state.config.YOUTUBE_LOADER_TRANSLATION), }, - "DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS) + "DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS), + "DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS, + "DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS, + "LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS, + "LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS, } else: # Return default RAG settings @@ -683,7 +751,11 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user= "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, }, - "DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS + "DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS, + "DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS, + "DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS, + "LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS, + "LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS, } @@ -784,7 +856,7 @@ class ConfigFormWrapper(BaseModel): @router.post("/config/update") async def update_rag_config( - request: Request, wrapper: ConfigFormWrapper, user=Depends(get_admin_user) + request: Request, wrapper: ConfigFormWrapper, user=Depends(get_verified_user) ): """ Update the RAG configuration. @@ -804,10 +876,19 @@ async def update_rag_config( rag_config["web"] = {**rag_config.get("web", {}), **value} else: rag_config[field] = value - if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True): - if rag_config.get("reranking_model"): - request.app.state.rf[rag_config["reranking_model"]] = None + # Free up memory if hybrid search is disabled and model is not in use elswhere + in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('reranking_model'), model_type="reranking_model", id=form_data.knowledge_id) + + if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True) and not in_use: + if rag_config.get("reranking_model"): + del request.app.state.rf[rag_config["reranking_model"]] + request.app.state.LOADED_RERANKING_MODELS.remove(rag_config["reranking_model"]) + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + Knowledges.update_rag_config_by_id( id=knowledge_base.id, rag_config=rag_config ) @@ -843,9 +924,18 @@ async def update_rag_config( if form_data.ENABLE_RAG_HYBRID_SEARCH is not None else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH ) - # Free up memory if hybrid search is disabled - if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: - request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = None + + # Free up memory if hybrid search is disabled and model is not in use elswhere + in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_RERANKING_MODEL, model_type="reranking_model") + + if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not in_use and request.app.state.config.RAG_RERANKING_MODEL: + del request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] + request.app.state.config.LOADED_RERANKING_MODELS.remove(request.app.state.config.RAG_RERANKING_MODEL) + + import gc + import torch + gc.collect() + torch.cuda.empty_cache() request.app.state.config.TOP_K_RERANKER = ( form_data.TOP_K_RERANKER