diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index b133f3c0f..59557349e 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -451,6 +451,51 @@ AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = ( os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true" ) + +#################################### +# SENTENCE TRANSFORMERS +#################################### + + +SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "") +if SENTENCE_TRANSFORMERS_BACKEND == "": + SENTENCE_TRANSFORMERS_BACKEND = "torch" + + +SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get( + "SENTENCE_TRANSFORMERS_MODEL_KWARGS", "" +) +if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "": + SENTENCE_TRANSFORMERS_MODEL_KWARGS = None +else: + try: + SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads( + SENTENCE_TRANSFORMERS_MODEL_KWARGS + ) + except Exception: + SENTENCE_TRANSFORMERS_MODEL_KWARGS = None + + +SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get( + "SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", "" +) +if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "": + SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch" + + +SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get( + "SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", "" +) +if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "": + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None +else: + try: + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads( + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS + ) + except Exception: + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None + #################################### # OFFLINE_MODE #################################### @@ -460,6 +505,7 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" if OFFLINE_MODE: os.environ["HF_HUB_OFFLINE"] = "1" + #################################### # AUDIT LOGGING #################################### @@ -481,6 +527,7 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders" AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS] AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS] + #################################### # OPENTELEMETRY #################################### diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 76d7fc76c..a582bd9ba 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -91,7 +91,12 @@ from open_webui.env import ( SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER, + SENTENCE_TRANSFORMERS_BACKEND, + SENTENCE_TRANSFORMERS_MODEL_KWARGS, + SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, ) + from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) @@ -118,6 +123,8 @@ def get_ef( get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + backend=SENTENCE_TRANSFORMERS_BACKEND, + model_kwargs=SENTENCE_TRANSFORMERS_MODEL_KWARGS, ) except Exception as e: log.debug(f"Error loading SentenceTransformer: {e}") @@ -151,6 +158,8 @@ def get_rf( get_model_path(reranking_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, + model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, ) except Exception as e: log.error(f"CrossEncoder: {e}")