mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Fix: adjusted to handle both default and individual rag settings
This commit is contained in:
@@ -190,6 +190,8 @@ class ProcessUrlForm(CollectionNameForm):
|
||||
class SearchForm(BaseModel):
|
||||
query: str
|
||||
|
||||
class CollectionForm(BaseModel):
|
||||
knowledge_id: Optional[str] = None
|
||||
|
||||
@router.get("/")
|
||||
async def get_status(request: Request):
|
||||
@@ -206,13 +208,15 @@ async def get_status(request: Request):
|
||||
|
||||
|
||||
@router.post("/embedding")
|
||||
async def get_embedding_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
|
||||
async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
|
||||
"""
|
||||
Retrieve the embedding configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, return the default embedding settings.
|
||||
Otherwise, return the embedding configuration stored in the database.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
|
||||
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the embedding configuration from the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
@@ -249,13 +253,15 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF
|
||||
|
||||
|
||||
@router.post("/reranking")
|
||||
async def get_reranking_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
|
||||
async def get_reranking_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
|
||||
"""
|
||||
Retrieve the reranking configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, return the default reranking settings.
|
||||
Otherwise, return the reranking configuration stored in the database.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
|
||||
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the reranking configuration from the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
@@ -287,7 +293,7 @@ class EmbeddingModelUpdateForm(BaseModel):
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
embedding_batch_size: Optional[int] = 1
|
||||
collection_name: Optional[str] = None
|
||||
knowledge_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/embedding/update")
|
||||
@@ -300,7 +306,7 @@ async def update_embedding_config(
|
||||
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
||||
"""
|
||||
try:
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
@@ -312,14 +318,13 @@ async def update_embedding_config(
|
||||
rag_config["embedding_model"] = form_data.embedding_model
|
||||
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
|
||||
|
||||
if form_data.openai_config is not None:
|
||||
rag_config["openai_config"] = {
|
||||
|
||||
rag_config["openai_config"] = {
|
||||
"url": form_data.openai_config.url,
|
||||
"key": form_data.openai_config.key,
|
||||
}
|
||||
|
||||
if form_data.ollama_config is not None:
|
||||
rag_config["ollama_config"] = {
|
||||
rag_config["ollama_config"] = {
|
||||
"url": form_data.ollama_config.url,
|
||||
"key": form_data.ollama_config.key,
|
||||
}
|
||||
@@ -348,8 +353,8 @@ async def update_embedding_config(
|
||||
)
|
||||
|
||||
# Save the updated configuration to the database
|
||||
Knowledges.update_knowledge_data_by_id(
|
||||
id=form_data.collection_name, data={"rag_config": rag_config}
|
||||
Knowledges.update_rag_config_by_id(
|
||||
id=form_data.knowledge_id, rag_config=rag_config
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -428,7 +433,7 @@ async def update_embedding_config(
|
||||
|
||||
class RerankingModelUpdateForm(BaseModel):
|
||||
reranking_model: str
|
||||
collection_name: Optional[str]
|
||||
knowledge_id: Optional[str] = None
|
||||
|
||||
@router.post("/reranking/update")
|
||||
async def update_reranking_config(
|
||||
@@ -440,16 +445,19 @@ async def update_reranking_config(
|
||||
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
||||
"""
|
||||
try:
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
|
||||
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
||||
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
|
||||
rag_config = knowledge_base.rag_config
|
||||
log.info(
|
||||
f"Updating reranking model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
|
||||
f"Updating reranking model: {rag_config.get('reranking_model')} to {form_data.reranking_model}"
|
||||
)
|
||||
rag_config["reranking_model"] = form_data.reranking_model
|
||||
Knowledges.update_knowledge_data_by_id(
|
||||
id=knowledge_base.id, data={"rag_config": rag_config}
|
||||
rag_config["reranking_model"] = form_data.reranking_model if form_data.reranking_model else None
|
||||
Knowledges.update_rag_config_by_id(
|
||||
id=form_data.knowledge_id, rag_config=rag_config
|
||||
)
|
||||
try:
|
||||
if not request.app.state.rf.get(rag_config["reranking_model"]):
|
||||
@@ -500,13 +508,15 @@ async def update_reranking_config(
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
async def get_rag_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_admin_user)):
|
||||
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_admin_user)):
|
||||
"""
|
||||
Retrieve the full RAG configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, return the default settings.
|
||||
Otherwise, return the RAG configuration stored in the database.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
|
||||
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the RAG configuration from the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
@@ -764,18 +774,26 @@ class ConfigForm(BaseModel):
|
||||
# Web search settings
|
||||
web: Optional[WebConfig] = None
|
||||
|
||||
# knowledge base ID
|
||||
knowledge_id: Optional[str] = None
|
||||
|
||||
|
||||
class ConfigFormWrapper(BaseModel):
|
||||
form_data: ConfigForm
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
async def update_rag_config(
|
||||
request: Request, form_data: ConfigForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user)
|
||||
request: Request, wrapper: ConfigFormWrapper, user=Depends(get_admin_user)
|
||||
):
|
||||
"""
|
||||
Update the RAG configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
|
||||
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
|
||||
form_data = wrapper.form_data
|
||||
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
@@ -783,14 +801,15 @@ async def update_rag_config(
|
||||
# Update only the provided fields in the rag_config
|
||||
for field, value in form_data.model_dump(exclude_unset=True).items():
|
||||
if field == "web" and value is not None:
|
||||
rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)}
|
||||
rag_config["web"] = {**rag_config.get("web", {}), **value}
|
||||
else:
|
||||
rag_config[field] = value
|
||||
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True):
|
||||
request.app.state.rf[rag_config["reranking_model"]] = None
|
||||
|
||||
Knowledges.update_knowledge_data_by_id(
|
||||
id=knowledge_base.id, data={"rag_config": rag_config}
|
||||
if rag_config.get("reranking_model"):
|
||||
request.app.state.rf[rag_config["reranking_model"]] = None
|
||||
|
||||
Knowledges.update_rag_config_by_id(
|
||||
id=knowledge_base.id, rag_config=rag_config
|
||||
)
|
||||
|
||||
return rag_config
|
||||
@@ -1090,6 +1109,7 @@ async def update_rag_config(
|
||||
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
},
|
||||
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user