From 2eb15ea1fc1b33477a4aad130ccd03e5d4c4f9b6 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 23 Jun 2024 19:28:33 -0700 Subject: [PATCH] feat: SAFE_MODE --- backend/apps/webui/models/functions.py | 13 +++++++++++++ backend/config.py | 6 ++++++ backend/main.py | 7 ++++++- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 1a055f327..966bd0231 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -221,6 +221,19 @@ class FunctionsTable: except: return None + def deactivate_all_functions(self) -> Optional[bool]: + try: + query = Function.update( + **{"is_active": False}, + updated_at=int(time.time()), + ) + + query.execute() + + return True + except: + return None + def delete_function_by_id(self, id: str) -> bool: try: query = Function.delete().where((Function.id == id)) diff --git a/backend/config.py b/backend/config.py index 842cea1ba..2b78cc252 100644 --- a/backend/config.py +++ b/backend/config.py @@ -167,6 +167,12 @@ for version in soup.find_all("h2"): CHANGELOG = changelog_json +#################################### +# SAFE_MODE +#################################### + +SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" + #################################### # WEBUI_BUILD_HASH #################################### diff --git a/backend/main.py b/backend/main.py index 991eb5839..4a889f1b3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -55,7 +55,6 @@ from apps.webui.models.functions import Functions from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id - from utils.utils import ( get_admin_user, get_verified_user, @@ -102,10 +101,16 @@ from config import ( SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + SAFE_MODE, AppConfig, ) from constants import ERROR_MESSAGES +if SAFE_MODE: + print("SAFE MODE ENABLED") + Functions.deactivate_all_functions() + + logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"])