diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 486311902..09913ae01 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -39,6 +39,13 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse +from open_webui.socket.main import ( + app as socket_app, + periodic_usage_pool_cleanup, + get_event_call, + get_event_emitter, +) + from open_webui.routers import ( audio, images, @@ -63,35 +70,19 @@ from open_webui.routers import ( users, utils, ) -from open_webui.retrieval.utils import get_sources_from_files from open_webui.routers.retrieval import ( get_embedding_function, - update_embedding_model, - update_reranking_model, + get_ef, + get_rf, ) +from open_webui.retrieval.utils import get_sources_from_files -from open_webui.socket.main import ( - app as socket_app, - periodic_usage_pool_cleanup, - get_event_call, - get_event_emitter, -) - from open_webui.internal.db import Session - -from open_webui.routers.webui import ( - app as webui_app, - generate_function_chat_completion, - get_all_models as get_open_webui_models, -) - - from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.models.users import UserModel, Users -from open_webui.utils.plugin import load_function_module_by_id from open_webui.constants import TASKS @@ -279,7 +270,7 @@ from open_webui.env import ( OFFLINE_MODE, ) - +from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.misc import ( add_or_update_system_message, get_last_user_message, @@ -528,8 +519,8 @@ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS app.state.EMBEDDING_FUNCTION = None -app.state.sentence_transformer_ef = None -app.state.sentence_transformer_rf = None +app.state.ef = None +app.state.rf = None app.state.YOUTUBE_LOADER_TRANSLATION = None @@ -537,29 +528,34 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, + app.state.ef, ( - app.state.config.OPENAI_API_BASE_URL + app.state.config.RAG_OPENAI_API_BASE_URL if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL + else app.state.config.RAG_OLLAMA_BASE_URL ), ( - app.state.config.OPENAI_API_KEY + app.state.config.RAG_OPENAI_API_KEY if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY + else app.state.config.RAG_OLLAMA_API_KEY ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) -update_embedding_model( - app.state.config.RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, -) +try: + app.state.ef = get_ef( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + ) -update_reranking_model( - app.state.config.RAG_RERANKING_MODEL, - RAG_RERANKING_MODEL_AUTO_UPDATE, -) + app.state.rf = get_rf( + app.state.config.RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + ) +except Exception as e: + log.error(f"Error updating models: {e}") + pass ######################################## @@ -990,11 +986,11 @@ async def chat_completion_files_handler( sources = get_sources_from_files( files=files, queries=queries, - embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, - k=retrieval_app.state.config.TOP_K, - reranking_function=retrieval_app.state.sentence_transformer_rf, - r=retrieval_app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=app.state.config.TOP_K, + reranking_function=app.state.rf, + r=app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=app.state.config.ENABLE_RAG_HYBRID_SEARCH, ) log.debug(f"rag_contexts:sources: {sources}") diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 5cd7209a8..c40208ac1 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -97,62 +97,58 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ########################################## -def update_embedding_model( - request: Request, +def get_ef( + engine: str, embedding_model: str, auto_update: bool = False, ): - if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "": + ef = None + if embedding_model and engine == "": from sentence_transformers import SentenceTransformer try: - request.app.state.sentence_transformer_ef = SentenceTransformer( + ef = SentenceTransformer( get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) except Exception as e: log.debug(f"Error loading SentenceTransformer: {e}") - request.app.state.sentence_transformer_ef = None - else: - request.app.state.sentence_transformer_ef = None + + return ef -def update_reranking_model( - request: Request, +def get_rf( reranking_model: str, auto_update: bool = False, ): + rf = None if reranking_model: if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): try: from open_webui.retrieval.models.colbert import ColBERT - request.app.state.sentence_transformer_rf = ColBERT( + rf = ColBERT( get_model_path(reranking_model, auto_update), env="docker" if DOCKER else None, ) + except Exception as e: log.error(f"ColBERT: {e}") - request.app.state.sentence_transformer_rf = None - request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + raise Exception(ERROR_MESSAGES.DEFAULT(e)) else: import sentence_transformers try: - request.app.state.sentence_transformer_rf = ( - sentence_transformers.CrossEncoder( - get_model_path(reranking_model, auto_update), - device=DEVICE_TYPE, - trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - ) + rf = sentence_transformers.CrossEncoder( + get_model_path(reranking_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, ) except: log.error("CrossEncoder error") - request.app.state.sentence_transformer_rf = None - request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False - else: - request.app.state.sentence_transformer_rf = None + raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error")) + return rf ########################################## @@ -261,12 +257,15 @@ async def update_embedding_config( form_data.embedding_batch_size ) - update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL) + request.app.state.ef = get_ef( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + ) request.app.state.EMBEDDING_FUNCTION = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.sentence_transformer_ef, + request.app.state.ef, ( request.app.state.config.OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" @@ -316,7 +315,14 @@ async def update_reranking_config( try: request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True) + try: + request.app.state.rf = get_rf( + request.app.state.config.RAG_RERANKING_MODEL, + True, + ) + except Exception as e: + log.error(f"Error loading reranking model: {e}") + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False return { "status": True, @@ -739,7 +745,7 @@ def save_docs_to_vector_db( embedding_function = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.sentence_transformer_ef, + request.app.state.ef, ( request.app.state.config.OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" @@ -1286,7 +1292,7 @@ def query_doc_handler( query=form_data.query, embedding_function=request.app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.sentence_transformer_rf, + reranking_function=request.app.state.rf, r=( form_data.r if form_data.r @@ -1328,7 +1334,7 @@ def query_collection_handler( queries=[form_data.query], embedding_function=request.app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.sentence_transformer_rf, + reranking_function=request.app.state.rf, r=( form_data.r if form_data.r