diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8184b467b..5261d440f 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -62,8 +62,12 @@ from open_webui.routers import ( users, utils, ) - from open_webui.retrieval.utils import get_sources_from_files +from open_webui.routers.retrieval import ( + get_embedding_function, + update_embedding_model, + update_reranking_model, +) from open_webui.socket.main import ( @@ -73,15 +77,16 @@ from open_webui.socket.main import ( get_event_emitter, ) - from open_webui.internal.db import Session -from backend.open_webui.routers.webui import ( +from open_webui.routers.webui import ( app as webui_app, generate_function_chat_completion, get_all_models as get_open_webui_models, ) + + from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.models.users import UserModel, Users @@ -523,6 +528,34 @@ app.state.sentence_transformer_rf = None app.state.YOUTUBE_LOADER_TRANSLATION = None +app.state.EMBEDDING_FUNCTION = get_embedding_function( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + app.state.sentence_transformer_ef, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), + app.state.config.RAG_EMBEDDING_BATCH_SIZE, +) + +update_embedding_model( + app.state.config.RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, +) + +update_reranking_model( + app.state.config.RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, +) + + ######################################## # # IMAGES diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 7e0dc6018..5cd7209a8 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -155,17 +155,6 @@ def update_reranking_model( request.app.state.sentence_transformer_rf = None -update_embedding_model( - request.app.state.config.RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, -) - -update_reranking_model( - request.app.state.config.RAG_RERANKING_MODEL, - RAG_RERANKING_MODEL_AUTO_UPDATE, -) - - ########################################## # # API routes @@ -176,24 +165,6 @@ update_reranking_model( router = APIRouter() -request.app.state.EMBEDDING_FUNCTION = get_embedding_function( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.sentence_transformer_ef, - ( - request.app.state.config.OPENAI_API_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.OLLAMA_BASE_URL - ), - ( - request.app.state.config.OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.OLLAMA_API_KEY - ), - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, -) - - class CollectionNameForm(BaseModel): collection_name: Optional[str] = None