enh: sentence transformers env vars

Co-Authored-By: DrZoidberg09 <96449693+drzoidberg09@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek 2025-04-24 01:55:18 +09:00
parent 3ec6652f99
commit 732d7aee70
2 changed files with 56 additions and 0 deletions

View File

@ -451,6 +451,51 @@ AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = (
os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true"
)
####################################
# SENTENCE TRANSFORMERS
####################################
SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "")
if SENTENCE_TRANSFORMERS_BACKEND == "":
SENTENCE_TRANSFORMERS_BACKEND = "torch"
SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL_KWARGS", ""
)
if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "":
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
else:
try:
SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads(
SENTENCE_TRANSFORMERS_MODEL_KWARGS
)
except Exception:
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get(
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", ""
)
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "":
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch"
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get(
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", ""
)
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "":
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
else:
try:
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads(
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS
)
except Exception:
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
####################################
# OFFLINE_MODE
####################################
@ -460,6 +505,7 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
####################################
# AUDIT LOGGING
####################################
@ -481,6 +527,7 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders"
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
####################################
# OPENTELEMETRY
####################################

View File

@ -91,7 +91,12 @@ from open_webui.env import (
SRC_LOG_LEVELS,
DEVICE_TYPE,
DOCKER,
SENTENCE_TRANSFORMERS_BACKEND,
SENTENCE_TRANSFORMERS_MODEL_KWARGS,
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
)
from open_webui.constants import ERROR_MESSAGES
log = logging.getLogger(__name__)
@ -118,6 +123,8 @@ def get_ef(
get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
backend=SENTENCE_TRANSFORMERS_BACKEND,
model_kwargs=SENTENCE_TRANSFORMERS_MODEL_KWARGS,
)
except Exception as e:
log.debug(f"Error loading SentenceTransformer: {e}")
@ -151,6 +158,8 @@ def get_rf(
get_model_path(reranking_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
)
except Exception as e:
log.error(f"CrossEncoder: {e}")