diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index d91d39111..c2bed79ad 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -212,6 +212,10 @@ from open_webui.config import ( YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_PROXY_URL, DEFAULT_RAG_SETTINGS, + DOWNLOADED_EMBEDDING_MODELS, + DOWNLOADED_RERANKING_MODELS, + LOADED_EMBEDDING_MODELS, + LOADED_RERANKING_MODELS, # Retrieval (Web Search) ENABLE_WEB_SEARCH, WEB_SEARCH_ENGINE, @@ -708,6 +712,10 @@ app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY app.state.config.DEFAULT_RAG_SETTINGS = DEFAULT_RAG_SETTINGS +app.state.config.DOWNLOADED_EMBEDDING_MODELS = DOWNLOADED_EMBEDDING_MODELS +app.state.config.DOWNLOADED_RERANKING_MODELS = DOWNLOADED_RERANKING_MODELS +app.state.config.LOADED_EMBEDDING_MODELS = LOADED_EMBEDDING_MODELS +app.state.config.LOADED_RERANKING_MODELS = LOADED_RERANKING_MODELS app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT @@ -723,37 +731,41 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None try: - app.state.ef[app.state.config.RAG_EMBEDDING_MODEL] = get_ef( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, - ) - - app.state.rf[app.state.config.RAG_RERANKING_MODEL] = get_rf( - app.state.config.RAG_RERANKING_MODEL, - RAG_RERANKING_MODEL_AUTO_UPDATE, - ) + for engine, model_list in app.state.config.LOADED_EMBEDDING_MODELS.items(): + for model in model_list: + app.state.ef[model] = get_ef( + engine, + model, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + ) + app.state.EMBEDDING_FUNCTION[model] = get_embedding_function( + engine, + model, + app.state.ef[model], + ( + app.state.config.RAG_OPENAI_API_BASE_URL + if engine == "openai" + else app.state.config.RAG_OLLAMA_BASE_URL + ), + ( + app.state.config.RAG_OPENAI_API_KEY + if engine == "openai" + else app.state.config.RAG_OLLAMA_API_KEY + ), + app.state.config.RAG_EMBEDDING_BATCH_SIZE, + ) + + for model in app.state.config.LOADED_RERANKING_MODELS: + app.state.ef[model] = get_ef( + model, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + ) + except Exception as e: log.error(f"Error updating models: {e}") pass -app.state.EMBEDDING_FUNCTION[app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.ef, - ( - app.state.config.RAG_OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.RAG_OLLAMA_BASE_URL - ), - ( - app.state.config.RAG_OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.RAG_OLLAMA_API_KEY - ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, -) ######################################## #