diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a5aee4bb8..6bf1908c5 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -224,6 +224,11 @@ from open_webui.config import ( PDF_EXTRACT_IMAGES, YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_PROXY_URL, + DEFAULT_RAG_SETTINGS, + DOWNLOADED_EMBEDDING_MODELS, + DOWNLOADED_RERANKING_MODELS, + LOADED_EMBEDDING_MODELS, + LOADED_RERANKING_MODELS, # Retrieval (Web Search) ENABLE_WEB_SEARCH, WEB_SEARCH_ENGINE, @@ -741,6 +746,11 @@ app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = EXTERNAL_WEB_SEARCH_API_KEY app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY +app.state.config.DEFAULT_RAG_SETTINGS = DEFAULT_RAG_SETTINGS +app.state.config.DOWNLOADED_EMBEDDING_MODELS = DOWNLOADED_EMBEDDING_MODELS +app.state.config.DOWNLOADED_RERANKING_MODELS = DOWNLOADED_RERANKING_MODELS +app.state.config.LOADED_EMBEDDING_MODELS = LOADED_EMBEDDING_MODELS +app.state.config.LOADED_RERANKING_MODELS = LOADED_RERANKING_MODELS app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT @@ -748,49 +758,53 @@ app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH -app.state.EMBEDDING_FUNCTION = None -app.state.ef = None -app.state.rf = None +app.state.EMBEDDING_FUNCTION = {} +app.state.ef = {} +app.state.rf = {} app.state.YOUTUBE_LOADER_TRANSLATION = None try: - app.state.ef = get_ef( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, - ) + # Load all embedding models that are currently in use + for engine, model_list in app.state.config.LOADED_EMBEDDING_MODELS.items(): + for model in model_list: + app.state.ef[model] = get_ef( + engine, + model, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + ) + app.state.EMBEDDING_FUNCTION[model] = get_embedding_function( + engine, + model, + app.state.ef[model], + ( + app.state.config.RAG_OPENAI_API_BASE_URL + if engine == "openai" + else app.state.config.RAG_OLLAMA_BASE_URL + ), + ( + app.state.config.RAG_OPENAI_API_KEY + if engine == "openai" + else app.state.config.RAG_OLLAMA_API_KEY + ), + app.state.config.RAG_EMBEDDING_BATCH_SIZE, + ) - app.state.rf = get_rf( - app.state.config.RAG_RERANKING_ENGINE, - app.state.config.RAG_RERANKING_MODEL, - app.state.config.RAG_EXTERNAL_RERANKER_URL, - app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - RAG_RERANKING_MODEL_AUTO_UPDATE, - ) + # Load all reranking models that are currently in use + for engine, model_list in app.state.config.LOADED_RERANKING_MODELS.items(): + for model in model_list: + app.state.rf[model["RAG_RERANKING_MODEL"]] = get_rf( + engine, + model["RAG_RERANKING_MODEL"], + model["RAG_EXTERNAL_RERANKER_URL"], + model["RAG_EXTERNAL_RERANKER_API_KEY"], + ) except Exception as e: log.error(f"Error updating models: {e}") pass -app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.ef, - ( - app.state.config.RAG_OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.RAG_OLLAMA_BASE_URL - ), - ( - app.state.config.RAG_OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.RAG_OLLAMA_API_KEY - ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, -) - ######################################## # # CODE EXECUTION