diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 67889deec..adb4f24ec 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -371,7 +371,7 @@ class RerankingModelUpdateForm(BaseModel): @router.post("/reranking/update") async def update_reranking_config( - request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: RerankingModelUpdateForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user) ): """ Update the reranking model configuration. @@ -379,7 +379,7 @@ async def update_reranking_config( Otherwise, update the RAG configuration in the database for the user's knowledge base. """ try: - knowledge_base = Knowledges.get_knowledge_base_by_user_id(user.id) + knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True): # Update the RAG configuration in the database @@ -1027,6 +1027,7 @@ def save_docs_to_vector_db( split: bool = True, add: bool = False, user=None, + knowledge_id: Optional[str] = None, ) -> bool: def _get_docs_info(docs: list[Document]) -> str: docs_info = set() @@ -1049,9 +1050,9 @@ def save_docs_to_vector_db( ) # Retrieve the knowledge base using the collection_name - knowledge_base = Knowledges.get_knowledge_base_by_collection_name(collection_name) + knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id) if not knowledge_base: - raise ValueError(f"Knowledge base not found for collection: {collection_name}") + raise ValueError(f"Knowledge base not found for collection: {knowledge_base}") # Retrieve the RAG configuration rag_config = {} @@ -1204,6 +1205,7 @@ class ProcessFileForm(BaseModel): def process_file( request: Request, form_data: ProcessFileForm, + knowledge_id: Optional[str] = None, user=Depends(get_verified_user), ): try: @@ -1215,9 +1217,9 @@ def process_file( collection_name = f"file-{file.id}" # Retrieve the knowledge base using the collection name - knowledge_base = Knowledges.get_knowledge_base_by_collection_name(collection_name) + knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id) if not knowledge_base: - raise ValueError(f"Knowledge base not found for collection: {collection_name}") + raise ValueError(f"Knowledge base not found for collection: {knowledge_base}") # Retrieve the RAG configuration rag_config = {} @@ -1372,6 +1374,7 @@ def process_file( }, add=(True if form_data.collection_name else False), user=user, + knowledge_id=knowledge_id, ) if result: @@ -1423,6 +1426,7 @@ def process_text( request: Request, form_data: ProcessTextForm, user=Depends(get_verified_user), + knowledge_id: Optional[str] = None, ): collection_name = form_data.collection_name if collection_name is None: @@ -1437,7 +1441,7 @@ def process_text( text_content = form_data.content log.debug(f"text_content: {text_content}") - result = save_docs_to_vector_db(request, docs, collection_name, user=user) + result = save_docs_to_vector_db(request, docs, collection_name, user=user, knowledge_id=knowledge_id) if result: return { "status": True,