mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
Fix: Added cleanup options for individual rag settings; added reload options for rag settings; adjusted permissions that users can use individual rag settings
This commit is contained in:
parent
126908cbcd
commit
d1ed648680
@ -298,7 +298,7 @@ class EmbeddingModelUpdateForm(BaseModel):
|
||||
|
||||
@router.post("/embedding/update")
|
||||
async def update_embedding_config(
|
||||
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_verified_user)
|
||||
):
|
||||
"""
|
||||
Update the embedding model configuration.
|
||||
@ -313,6 +313,17 @@ async def update_embedding_config(
|
||||
log.info(
|
||||
f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
|
||||
)
|
||||
|
||||
# Check if model is in use elsewhere, otherwise free up memory
|
||||
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('embedding_model'), model_type="embedding_model", id=form_data.knowledge_id)
|
||||
|
||||
if not in_use and not request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] == rag_config.get("embedding_model"):
|
||||
del request.app.state.ef[rag_config.get("embedding_model")]
|
||||
import gc
|
||||
import torch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Update embedding-related fields
|
||||
rag_config["embedding_engine"] = form_data.embedding_engine
|
||||
rag_config["embedding_model"] = form_data.embedding_model
|
||||
@ -329,7 +340,7 @@ async def update_embedding_config(
|
||||
"key": form_data.ollama_config.key,
|
||||
}
|
||||
# Update the embedding function
|
||||
if not request.app.state.ef.get("embedding_model"):
|
||||
if not rag_config["embedding_model"] in request.app.state.ef:
|
||||
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
|
||||
rag_config["embedding_engine"],
|
||||
rag_config["embedding_model"],
|
||||
@ -351,6 +362,15 @@ async def update_embedding_config(
|
||||
),
|
||||
rag_config["embedding_batch_size"]
|
||||
)
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||
# add model to state for selectable reranking models
|
||||
if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]:
|
||||
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
||||
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
|
||||
rag_config["DOWNLOADED_EMBEDDING_MODELS"] = request.app.state.config.DOWNLOADED_EMBEDDING_MODELS
|
||||
rag_config["LOADED_EMBEDDING_MODELS"] = request.app.state.config.LOADED_EMBEDDING_MODELS
|
||||
|
||||
# Save the updated configuration to the database
|
||||
Knowledges.update_rag_config_by_id(
|
||||
@ -364,6 +384,8 @@ async def update_embedding_config(
|
||||
"embedding_batch_size": rag_config["embedding_batch_size"],
|
||||
"openai_config": rag_config.get("openai_config", {}),
|
||||
"ollama_config": rag_config.get("ollama_config", {}),
|
||||
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
|
||||
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
|
||||
"message": "Embedding configuration updated in the database.",
|
||||
}
|
||||
else:
|
||||
@ -371,6 +393,16 @@ async def update_embedding_config(
|
||||
log.info(
|
||||
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||
)
|
||||
|
||||
# Check if model is in use elsewhere, otherwise free up memory
|
||||
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_EMBEDDING_MODEL, model_type="embedding_model")
|
||||
if not in_use:
|
||||
del request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL]
|
||||
import gc
|
||||
import torch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
||||
@ -385,7 +417,7 @@ async def update_embedding_config(
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
||||
|
||||
# Update the embedding function
|
||||
if not request.app.state.ef.get(form_data.embedding_model):
|
||||
if not form_data.embedding_model in request.app.state.ef:
|
||||
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
@ -407,6 +439,13 @@ async def update_embedding_config(
|
||||
),
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||
# add model to state for selectable reranking models
|
||||
if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]:
|
||||
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@ -421,6 +460,8 @@ async def update_embedding_config(
|
||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||
},
|
||||
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
|
||||
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
|
||||
"message": "Embedding configuration updated globally.",
|
||||
}
|
||||
except Exception as e:
|
||||
@ -437,7 +478,7 @@ class RerankingModelUpdateForm(BaseModel):
|
||||
|
||||
@router.post("/reranking/update")
|
||||
async def update_reranking_config(
|
||||
request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
|
||||
request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_verified_user)
|
||||
):
|
||||
"""
|
||||
Update the reranking model configuration.
|
||||
@ -455,16 +496,28 @@ async def update_reranking_config(
|
||||
log.info(
|
||||
f"Updating reranking model: {rag_config.get('reranking_model')} to {form_data.reranking_model}"
|
||||
)
|
||||
|
||||
rag_config["reranking_model"] = form_data.reranking_model if form_data.reranking_model else None
|
||||
Knowledges.update_rag_config_by_id(
|
||||
id=form_data.knowledge_id, rag_config=rag_config
|
||||
)
|
||||
try:
|
||||
if not request.app.state.rf.get(rag_config["reranking_model"]):
|
||||
if not rag_config["reranking_model"] in request.app.state.rf:
|
||||
request.app.state.rf[rag_config["reranking_model"]] = get_rf(
|
||||
rag_config["reranking_model"],
|
||||
True,
|
||||
)
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_RERANKING_MODELS.append(rag_config["reranking_model"])
|
||||
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
||||
# add model to state for selectable reranking models
|
||||
if rag_config["reranking_model"] not in request.app.state.DOWNLOADED_RERANKING_MODELS:
|
||||
request.app.state.config.DOWNLOADED_RERANKING_MODELS.append(rag_config["reranking_model"])
|
||||
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
|
||||
|
||||
|
||||
rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS
|
||||
rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error loading reranking model: {e}")
|
||||
@ -473,6 +526,8 @@ async def update_reranking_config(
|
||||
return {
|
||||
"status": True,
|
||||
"reranking_model": rag_config["reranking_model"],
|
||||
"LOADED_RERANKING_MODELS": rag_config["LOADED_RERANKING_MODELS"],
|
||||
"DOWNLOADED_RERANKING_MODELS": rag_config["DOWNLOADED_RERANKING_MODELS"],
|
||||
"message": "Reranking model updated in the database.",
|
||||
}
|
||||
else:
|
||||
@ -483,13 +538,20 @@ async def update_reranking_config(
|
||||
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
|
||||
|
||||
try:
|
||||
if request.app.state.rf.get(form_data.reranking_model):
|
||||
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = request.app.state.rf[form_data.reranking_model]
|
||||
else:
|
||||
if not form_data.reranking_model in request.app.state.rf:
|
||||
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf(
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
True,
|
||||
)
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_RERANKING_MODELS.append(request.app.state.config.RAG_RERANKING_MODEL)
|
||||
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
||||
|
||||
# add model to state for selectable reranking models
|
||||
if not request.app.state.config.RAG_RERANKING_MODEL in request.app.state.config.DOWNLOADED_RERANKING_MODELS:
|
||||
request.app.state.config.DOWNLOADED_RERANKING_MODELS.append(request.app.state.config.RAG_RERANKING_MODEL)
|
||||
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error loading reranking model: {e}")
|
||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||
@ -497,6 +559,8 @@ async def update_reranking_config(
|
||||
return {
|
||||
"status": True,
|
||||
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
||||
"LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS,
|
||||
"DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS,
|
||||
"message": "Reranking model updated globally.",
|
||||
}
|
||||
except Exception as e:
|
||||
@ -508,7 +572,7 @@ async def update_reranking_config(
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_admin_user)):
|
||||
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_verified_user)):
|
||||
"""
|
||||
Retrieve the full RAG configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, return the default settings.
|
||||
@ -600,7 +664,11 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
|
||||
"YOUTUBE_LOADER_PROXY_URL": web_config.get("YOUTUBE_LOADER_PROXY_URL", request.app.state.config.YOUTUBE_LOADER_PROXY_URL),
|
||||
"YOUTUBE_LOADER_TRANSLATION": web_config.get("YOUTUBE_LOADER_TRANSLATION", request.app.state.config.YOUTUBE_LOADER_TRANSLATION),
|
||||
},
|
||||
"DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS)
|
||||
"DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS),
|
||||
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
|
||||
"DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS,
|
||||
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
|
||||
"LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS,
|
||||
}
|
||||
else:
|
||||
# Return default RAG settings
|
||||
@ -683,7 +751,11 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
|
||||
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
},
|
||||
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS
|
||||
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS,
|
||||
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
|
||||
"DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS,
|
||||
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
|
||||
"LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS,
|
||||
}
|
||||
|
||||
|
||||
@ -784,7 +856,7 @@ class ConfigFormWrapper(BaseModel):
|
||||
|
||||
@router.post("/config/update")
|
||||
async def update_rag_config(
|
||||
request: Request, wrapper: ConfigFormWrapper, user=Depends(get_admin_user)
|
||||
request: Request, wrapper: ConfigFormWrapper, user=Depends(get_verified_user)
|
||||
):
|
||||
"""
|
||||
Update the RAG configuration.
|
||||
@ -804,9 +876,18 @@ async def update_rag_config(
|
||||
rag_config["web"] = {**rag_config.get("web", {}), **value}
|
||||
else:
|
||||
rag_config[field] = value
|
||||
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True):
|
||||
|
||||
# Free up memory if hybrid search is disabled and model is not in use elswhere
|
||||
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('reranking_model'), model_type="reranking_model", id=form_data.knowledge_id)
|
||||
|
||||
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True) and not in_use:
|
||||
if rag_config.get("reranking_model"):
|
||||
request.app.state.rf[rag_config["reranking_model"]] = None
|
||||
del request.app.state.rf[rag_config["reranking_model"]]
|
||||
request.app.state.LOADED_RERANKING_MODELS.remove(rag_config["reranking_model"])
|
||||
import gc
|
||||
import torch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Knowledges.update_rag_config_by_id(
|
||||
id=knowledge_base.id, rag_config=rag_config
|
||||
@ -843,9 +924,18 @@ async def update_rag_config(
|
||||
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
||||
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
)
|
||||
# Free up memory if hybrid search is disabled
|
||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = None
|
||||
|
||||
# Free up memory if hybrid search is disabled and model is not in use elswhere
|
||||
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_RERANKING_MODEL, model_type="reranking_model")
|
||||
|
||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not in_use and request.app.state.config.RAG_RERANKING_MODEL:
|
||||
del request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL]
|
||||
request.app.state.config.LOADED_RERANKING_MODELS.remove(request.app.state.config.RAG_RERANKING_MODEL)
|
||||
|
||||
import gc
|
||||
import torch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
request.app.state.config.TOP_K_RERANKER = (
|
||||
form_data.TOP_K_RERANKER
|
||||
|
Loading…
Reference in New Issue
Block a user