From 0d2eefd83d775b85296046ea02dd42dabca32d03 Mon Sep 17 00:00:00 2001 From: Maytown Date: Mon, 12 May 2025 12:47:55 +0200 Subject: [PATCH] Refactoring: Adjusted to newly added rag_config column --- backend/open_webui/routers/retrieval.py | 244 ++++++++++++++++-------- 1 file changed, 161 insertions(+), 83 deletions(-) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index ad45d7285..78d207d2a 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -213,9 +213,9 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF Otherwise, return the embedding configuration stored in the database. """ knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Return the embedding configuration from the database - rag_config = knowledge_base.data.get("rag_config", {}) + rag_config = knowledge_base.rag_config return { "status": True, "embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE), @@ -256,9 +256,9 @@ async def get_reranking_config(request: Request, collectionForm: CollectionNameF Otherwise, return the reranking configuration stored in the database. """ knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Return the reranking configuration from the database - rag_config = knowledge_base.data.get("rag_config", {}) + rag_config = knowledge_base.rag_config return { "status": True, "reranking_model": rag_config.get("reranking_model", request.app.state.config.RAG_RERANKING_MODEL), @@ -287,77 +287,137 @@ class EmbeddingModelUpdateForm(BaseModel): embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 + collection_name: Optional[str] = None @router.post("/embedding/update") async def update_embedding_config( request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): - # TODO Update for individual rag config - log.info( - f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" - ) + """ + Update the embedding model configuration. + If DEFAULT_RAG_SETTINGS is True, update the global configuration. + Otherwise, update the RAG configuration in the database for the user's knowledge base. + """ try: - request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model + knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + # Update the RAG configuration in the database + rag_config = knowledge_base.rag_config + log.info( + f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}" + ) + # Update embedding-related fields + rag_config["embedding_engine"] = form_data.embedding_engine + rag_config["embedding_model"] = form_data.embedding_model + rag_config["embedding_batch_size"] = form_data.embedding_batch_size - if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config is not None: - request.app.state.config.RAG_OPENAI_API_BASE_URL = ( - form_data.openai_config.url - ) - request.app.state.config.RAG_OPENAI_API_KEY = ( - form_data.openai_config.key - ) + rag_config["openai_config"] = { + "url": form_data.openai_config.url, + "key": form_data.openai_config.key, + } if form_data.ollama_config is not None: - request.app.state.config.RAG_OLLAMA_BASE_URL = ( - form_data.ollama_config.url + rag_config["ollama_config"] = { + "url": form_data.ollama_config.url, + "key": form_data.ollama_config.key, + } + # Update the embedding function + if not request.app.state.ef.get("embedding_model"): + request.app.state.ef[rag_config["embedding_model"]] = get_ef( + rag_config["embedding_engine"], + rag_config["embedding_model"], ) - request.app.state.config.RAG_OLLAMA_API_KEY = ( - form_data.ollama_config.key + + request.app.state.EMBEDDING_FUNCTION["embedding_model"] = get_embedding_function( + rag_config["embedding_engine"], + rag_config["embedding_model"], + request.app.state.ef[rag_config["embedding_model"]], + ( + rag_config["openai_config"]["url"] + if rag_config["embedding_engine"] == "openai" + else rag_config["ollama_config"]["url"] + ), + ( + rag_config["openai_config"]["key"] + if rag_config["embedding_engine"] == "openai" + else rag_config["ollama_config"]["key"] + ), + rag_config["embedding_batch_size"] ) - - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( - form_data.embedding_batch_size + + # Save the updated configuration to the database + Knowledges.update_knowledge_data_by_id( + id=form_data.collection_name, data={"rag_config": rag_config} ) - request.app.state.ef = get_ef( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - ) + return { + "status": True, + "embedding_engine": rag_config["embedding_engine"], + "embedding_model": rag_config["embedding_model"], + "embedding_batch_size": rag_config["embedding_batch_size"], + "openai_config": rag_config.get("openai_config", {}), + "ollama_config": rag_config.get("ollama_config", {}), + "message": "Embedding configuration updated in the database.", + } + else: + # Update the global configuration + log.info( + f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + ) + request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - request.app.state.EMBEDDING_FUNCTION = get_embedding_function( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.ef, - ( - request.app.state.config.RAG_OPENAI_API_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_BASE_URL - ), - ( - request.app.state.config.RAG_OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_API_KEY - ), - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - ) + if form_data.openai_config is not None: + request.app.state.config.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url + request.app.state.config.RAG_OPENAI_API_KEY = form_data.openai_config.key - return { - "status": True, - "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "openai_config": { - "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, - "key": request.app.state.config.RAG_OPENAI_API_KEY, - }, - "ollama_config": { - "url": request.app.state.config.RAG_OLLAMA_BASE_URL, - "key": request.app.state.config.RAG_OLLAMA_API_KEY, - }, - } + if form_data.ollama_config is not None: + request.app.state.config.RAG_OLLAMA_BASE_URL = form_data.ollama_config.url + request.app.state.config.RAG_OLLAMA_API_KEY = form_data.ollama_config.key + + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size + + # Update the embedding function + if not request.app.state.ef.get(form_data.embedding_model): + request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + ) + + request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL], + ( + request.app.state.config.RAG_OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_BASE_URL + ), + ( + request.app.state.config.RAG_OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_API_KEY + ), + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + ) + + return { + "status": True, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "openai_config": { + "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, + "key": request.app.state.config.RAG_OPENAI_API_KEY, + }, + "ollama_config": { + "url": request.app.state.config.RAG_OLLAMA_BASE_URL, + "key": request.app.state.config.RAG_OLLAMA_API_KEY, + }, + "message": "Embedding configuration updated globally.", + } except Exception as e: log.exception(f"Problem updating embedding model: {e}") raise HTTPException( @@ -381,14 +441,27 @@ async def update_reranking_config( """ try: knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name) - # TODO UPdate reranking accoridngly - if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Update the RAG configuration in the database - rag_config = knowledge_base.data.get("rag_config", {}) + rag_config = knowledge_base.rag_config + log.info( + f"Updating reranking model: {rag_config.get('embedding_model')} to {form_data.embedding_model}" + ) rag_config["reranking_model"] = form_data.reranking_model Knowledges.update_knowledge_data_by_id( id=knowledge_base.id, data={"rag_config": rag_config} ) + try: + if not request.app.state.rf.get(rag_config["reranking_model"]): + request.app.state.rf[rag_config["reranking_model"]] = get_rf( + rag_config["reranking_model"], + True, + ) + + except Exception as e: + log.error(f"Error loading reranking model: {e}") + rag_config["ENABLE_RAG_HYBRID_SEARCH"] = False + return { "status": True, "reranking_model": rag_config["reranking_model"], @@ -402,10 +475,13 @@ async def update_reranking_config( request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model try: - request.app.state.rf = get_rf( - request.app.state.config.RAG_RERANKING_MODEL, - True, - ) + if request.app.state.rf.get(form_data.reranking_model): + request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = request.app.state.rf[form_data.reranking_model] + else: + request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf( + request.app.state.config.RAG_RERANKING_MODEL, + True, + ) except Exception as e: log.error(f"Error loading reranking model: {e}") request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False @@ -431,9 +507,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u Otherwise, return the RAG configuration stored in the database. """ knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Return the RAG configuration from the database - rag_config = knowledge_base.data.get("rag_config", {}) + rag_config = knowledge_base.rag_config web_config = rag_config.get("web", {}) return { "status": True, @@ -700,9 +776,9 @@ async def update_rag_config( """ knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Update the RAG configuration in the database - rag_config = knowledge_base.data.get("rag_config", {}) + rag_config = knowledge_base.rag_config # Update only the provided fields in the rag_config for field, value in form_data.model_dump(exclude_unset=True).items(): @@ -710,7 +786,9 @@ async def update_rag_config( rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)} else: rag_config[field] = value - + if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True): + request.app.state.rf[rag_config["reranking_model"]] = None + Knowledges.update_knowledge_data_by_id( id=knowledge_base.id, data={"rag_config": rag_config} ) @@ -748,7 +826,7 @@ async def update_rag_config( ) # Free up memory if hybrid search is disabled if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: - request.app.state.rf = None + request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = None request.app.state.config.TOP_K_RERANKER = ( form_data.TOP_K_RERANKER @@ -1052,15 +1130,15 @@ def save_docs_to_vector_db( log.info( f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" ) - + rag_config = {} # Retrieve the knowledge base using the collection_name if knowledge_id: knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id) # Retrieve the RAG configuration - if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): - rag_config = knowledge_base.data.get("rag_config", {}) - print("RAG CONFIG: ", rag_config) + if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + # Use knowledge-base-specific or default configurations text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER) chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE) @@ -1072,7 +1150,7 @@ def save_docs_to_vector_db( openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY) ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL) ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY) - + # Check if entries with the same hash (metadata.hash) already exist if metadata and "hash" in metadata: result = VECTOR_DB_CLIENT.query( @@ -1156,7 +1234,7 @@ def save_docs_to_vector_db( embedding_function = get_embedding_function( embedding_engine, embedding_model, - request.app.state.ef, + request.app.state.ef.get(embedding_model, request.app.state.config.RAG_EMBEDDING_MODEL), ( openai_api_base_url if embedding_engine == "openai" @@ -1224,16 +1302,16 @@ def process_file( knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name) # Retrieve the RAG configuration - if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): - rag_config = knowledge_base.data.get("rag_config", {}) + if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db elif form_data.knowledge_id: knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) # Retrieve the RAG configuration - if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): - rag_config = knowledge_base.data.get("rag_config", {}) + if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config # Use knowledge-base-specific or default configurations content_extraction_engine = rag_config.get( @@ -1906,7 +1984,7 @@ def query_doc_handler( query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, + reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL], k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, r=( @@ -1957,7 +2035,7 @@ def query_collection_handler( query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, + reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL], k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, r=(