Fix: adjusted to handle both default and individual rag settings

This commit is contained in:
Maytown
2025-05-14 17:33:11 +02:00
parent 1ae3873c55
commit ba54452ab1
2 changed files with 107 additions and 41 deletions

View File

@@ -190,6 +190,8 @@ class ProcessUrlForm(CollectionNameForm):
class SearchForm(BaseModel):
query: str
class CollectionForm(BaseModel):
knowledge_id: Optional[str] = None
@router.get("/")
async def get_status(request: Request):
@@ -206,13 +208,15 @@ async def get_status(request: Request):
@router.post("/embedding")
async def get_embedding_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
"""
Retrieve the embedding configuration.
If DEFAULT_RAG_SETTINGS is True, return the default embedding settings.
Otherwise, return the embedding configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the embedding configuration from the database
rag_config = knowledge_base.rag_config
@@ -249,13 +253,15 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF
@router.post("/reranking")
async def get_reranking_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
async def get_reranking_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
"""
Retrieve the reranking configuration.
If DEFAULT_RAG_SETTINGS is True, return the default reranking settings.
Otherwise, return the reranking configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the reranking configuration from the database
rag_config = knowledge_base.rag_config
@@ -287,7 +293,7 @@ class EmbeddingModelUpdateForm(BaseModel):
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
collection_name: Optional[str] = None
knowledge_id: Optional[str] = None
@router.post("/embedding/update")
@@ -300,7 +306,7 @@ async def update_embedding_config(
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
try:
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
@@ -312,14 +318,13 @@ async def update_embedding_config(
rag_config["embedding_model"] = form_data.embedding_model
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
if form_data.openai_config is not None:
rag_config["openai_config"] = {
rag_config["openai_config"] = {
"url": form_data.openai_config.url,
"key": form_data.openai_config.key,
}
if form_data.ollama_config is not None:
rag_config["ollama_config"] = {
rag_config["ollama_config"] = {
"url": form_data.ollama_config.url,
"key": form_data.ollama_config.key,
}
@@ -348,8 +353,8 @@ async def update_embedding_config(
)
# Save the updated configuration to the database
Knowledges.update_knowledge_data_by_id(
id=form_data.collection_name, data={"rag_config": rag_config}
Knowledges.update_rag_config_by_id(
id=form_data.knowledge_id, rag_config=rag_config
)
return {
@@ -428,7 +433,7 @@ async def update_embedding_config(
class RerankingModelUpdateForm(BaseModel):
reranking_model: str
collection_name: Optional[str]
knowledge_id: Optional[str] = None
@router.post("/reranking/update")
async def update_reranking_config(
@@ -440,16 +445,19 @@ async def update_reranking_config(
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
try:
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
log.info(
f"Updating reranking model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
f"Updating reranking model: {rag_config.get('reranking_model')} to {form_data.reranking_model}"
)
rag_config["reranking_model"] = form_data.reranking_model
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config}
rag_config["reranking_model"] = form_data.reranking_model if form_data.reranking_model else None
Knowledges.update_rag_config_by_id(
id=form_data.knowledge_id, rag_config=rag_config
)
try:
if not request.app.state.rf.get(rag_config["reranking_model"]):
@@ -500,13 +508,15 @@ async def update_reranking_config(
@router.post("/config")
async def get_rag_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_admin_user)):
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_admin_user)):
"""
Retrieve the full RAG configuration.
If DEFAULT_RAG_SETTINGS is True, return the default settings.
Otherwise, return the RAG configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the RAG configuration from the database
rag_config = knowledge_base.rag_config
@@ -764,18 +774,26 @@ class ConfigForm(BaseModel):
# Web search settings
web: Optional[WebConfig] = None
# knowledge base ID
knowledge_id: Optional[str] = None
class ConfigFormWrapper(BaseModel):
form_data: ConfigForm
@router.post("/config/update")
async def update_rag_config(
request: Request, form_data: ConfigForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user)
request: Request, wrapper: ConfigFormWrapper, user=Depends(get_admin_user)
):
"""
Update the RAG configuration.
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
form_data = wrapper.form_data
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
@@ -783,14 +801,15 @@ async def update_rag_config(
# Update only the provided fields in the rag_config
for field, value in form_data.model_dump(exclude_unset=True).items():
if field == "web" and value is not None:
rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)}
rag_config["web"] = {**rag_config.get("web", {}), **value}
else:
rag_config[field] = value
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True):
request.app.state.rf[rag_config["reranking_model"]] = None
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config}
if rag_config.get("reranking_model"):
request.app.state.rf[rag_config["reranking_model"]] = None
Knowledges.update_rag_config_by_id(
id=knowledge_base.id, rag_config=rag_config
)
return rag_config
@@ -1090,6 +1109,7 @@ async def update_rag_config(
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
},
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS
}