Fix: Added cleanup options for individual rag settings; added reload options for rag settings; adjusted permissions that users can use individual rag settings

This commit is contained in:
weberm1 2025-05-21 12:25:01 +02:00
parent 126908cbcd
commit d1ed648680

View File

@ -298,7 +298,7 @@ class EmbeddingModelUpdateForm(BaseModel):
@router.post("/embedding/update")
async def update_embedding_config(
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_verified_user)
):
"""
Update the embedding model configuration.
@ -313,6 +313,17 @@ async def update_embedding_config(
log.info(
f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
)
# Check if model is in use elsewhere, otherwise free up memory
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('embedding_model'), model_type="embedding_model", id=form_data.knowledge_id)
if not in_use and not request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] == rag_config.get("embedding_model"):
del request.app.state.ef[rag_config.get("embedding_model")]
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
# Update embedding-related fields
rag_config["embedding_engine"] = form_data.embedding_engine
rag_config["embedding_model"] = form_data.embedding_model
@ -329,7 +340,7 @@ async def update_embedding_config(
"key": form_data.ollama_config.key,
}
# Update the embedding function
if not request.app.state.ef.get("embedding_model"):
if not rag_config["embedding_model"] in request.app.state.ef:
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
rag_config["embedding_engine"],
rag_config["embedding_model"],
@ -351,6 +362,15 @@ async def update_embedding_config(
),
rag_config["embedding_batch_size"]
)
# add model to state for reloading on startup
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
# add model to state for selectable reranking models
if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]:
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
rag_config["DOWNLOADED_EMBEDDING_MODELS"] = request.app.state.config.DOWNLOADED_EMBEDDING_MODELS
rag_config["LOADED_EMBEDDING_MODELS"] = request.app.state.config.LOADED_EMBEDDING_MODELS
# Save the updated configuration to the database
Knowledges.update_rag_config_by_id(
@ -364,6 +384,8 @@ async def update_embedding_config(
"embedding_batch_size": rag_config["embedding_batch_size"],
"openai_config": rag_config.get("openai_config", {}),
"ollama_config": rag_config.get("ollama_config", {}),
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
"message": "Embedding configuration updated in the database.",
}
else:
@ -371,6 +393,16 @@ async def update_embedding_config(
log.info(
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
# Check if model is in use elsewhere, otherwise free up memory
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_EMBEDDING_MODEL, model_type="embedding_model")
if not in_use:
del request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL]
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
@ -385,7 +417,7 @@ async def update_embedding_config(
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
# Update the embedding function
if not request.app.state.ef.get(form_data.embedding_model):
if not form_data.embedding_model in request.app.state.ef:
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
@ -407,6 +439,13 @@ async def update_embedding_config(
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
# add model to state for reloading on startup
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
# add model to state for selectable reranking models
if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]:
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
return {
"status": True,
@ -421,6 +460,8 @@ async def update_embedding_config(
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
"message": "Embedding configuration updated globally.",
}
except Exception as e:
@ -437,7 +478,7 @@ class RerankingModelUpdateForm(BaseModel):
@router.post("/reranking/update")
async def update_reranking_config(
request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_verified_user)
):
"""
Update the reranking model configuration.
@ -455,16 +496,28 @@ async def update_reranking_config(
log.info(
f"Updating reranking model: {rag_config.get('reranking_model')} to {form_data.reranking_model}"
)
rag_config["reranking_model"] = form_data.reranking_model if form_data.reranking_model else None
Knowledges.update_rag_config_by_id(
id=form_data.knowledge_id, rag_config=rag_config
)
try:
if not request.app.state.rf.get(rag_config["reranking_model"]):
if not rag_config["reranking_model"] in request.app.state.rf:
request.app.state.rf[rag_config["reranking_model"]] = get_rf(
rag_config["reranking_model"],
True,
)
# add model to state for reloading on startup
request.app.state.config.LOADED_RERANKING_MODELS.append(rag_config["reranking_model"])
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
# add model to state for selectable reranking models
if rag_config["reranking_model"] not in request.app.state.DOWNLOADED_RERANKING_MODELS:
request.app.state.config.DOWNLOADED_RERANKING_MODELS.append(rag_config["reranking_model"])
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS
rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS
except Exception as e:
log.error(f"Error loading reranking model: {e}")
@ -473,6 +526,8 @@ async def update_reranking_config(
return {
"status": True,
"reranking_model": rag_config["reranking_model"],
"LOADED_RERANKING_MODELS": rag_config["LOADED_RERANKING_MODELS"],
"DOWNLOADED_RERANKING_MODELS": rag_config["DOWNLOADED_RERANKING_MODELS"],
"message": "Reranking model updated in the database.",
}
else:
@ -483,13 +538,20 @@ async def update_reranking_config(
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
try:
if request.app.state.rf.get(form_data.reranking_model):
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = request.app.state.rf[form_data.reranking_model]
else:
if not form_data.reranking_model in request.app.state.rf:
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf(
request.app.state.config.RAG_RERANKING_MODEL,
True,
)
# add model to state for reloading on startup
request.app.state.config.LOADED_RERANKING_MODELS.append(request.app.state.config.RAG_RERANKING_MODEL)
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
# add model to state for selectable reranking models
if not request.app.state.config.RAG_RERANKING_MODEL in request.app.state.config.DOWNLOADED_RERANKING_MODELS:
request.app.state.config.DOWNLOADED_RERANKING_MODELS.append(request.app.state.config.RAG_RERANKING_MODEL)
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
@ -497,6 +559,8 @@ async def update_reranking_config(
return {
"status": True,
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
"LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS,
"DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS,
"message": "Reranking model updated globally.",
}
except Exception as e:
@ -508,7 +572,7 @@ async def update_reranking_config(
@router.post("/config")
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_admin_user)):
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_verified_user)):
"""
Retrieve the full RAG configuration.
If DEFAULT_RAG_SETTINGS is True, return the default settings.
@ -600,7 +664,11 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
"YOUTUBE_LOADER_PROXY_URL": web_config.get("YOUTUBE_LOADER_PROXY_URL", request.app.state.config.YOUTUBE_LOADER_PROXY_URL),
"YOUTUBE_LOADER_TRANSLATION": web_config.get("YOUTUBE_LOADER_TRANSLATION", request.app.state.config.YOUTUBE_LOADER_TRANSLATION),
},
"DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS)
"DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS),
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
"DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS,
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
"LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS,
}
else:
# Return default RAG settings
@ -683,7 +751,11 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
},
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS,
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
"DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS,
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
"LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS,
}
@ -784,7 +856,7 @@ class ConfigFormWrapper(BaseModel):
@router.post("/config/update")
async def update_rag_config(
request: Request, wrapper: ConfigFormWrapper, user=Depends(get_admin_user)
request: Request, wrapper: ConfigFormWrapper, user=Depends(get_verified_user)
):
"""
Update the RAG configuration.
@ -804,9 +876,18 @@ async def update_rag_config(
rag_config["web"] = {**rag_config.get("web", {}), **value}
else:
rag_config[field] = value
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True):
# Free up memory if hybrid search is disabled and model is not in use elswhere
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('reranking_model'), model_type="reranking_model", id=form_data.knowledge_id)
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True) and not in_use:
if rag_config.get("reranking_model"):
request.app.state.rf[rag_config["reranking_model"]] = None
del request.app.state.rf[rag_config["reranking_model"]]
request.app.state.LOADED_RERANKING_MODELS.remove(rag_config["reranking_model"])
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
Knowledges.update_rag_config_by_id(
id=knowledge_base.id, rag_config=rag_config
@ -843,9 +924,18 @@ async def update_rag_config(
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
)
# Free up memory if hybrid search is disabled
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = None
# Free up memory if hybrid search is disabled and model is not in use elswhere
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_RERANKING_MODEL, model_type="reranking_model")
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not in_use and request.app.state.config.RAG_RERANKING_MODEL:
del request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL]
request.app.state.config.LOADED_RERANKING_MODELS.remove(request.app.state.config.RAG_RERANKING_MODEL)
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
request.app.state.config.TOP_K_RERANKER = (
form_data.TOP_K_RERANKER