feat: SAFE_MODE

This commit is contained in:
Timothy J. Baek 2024-06-23 19:28:33 -07:00
parent ab700a16be
commit 2eb15ea1fc
3 changed files with 25 additions and 1 deletions

View File

@ -221,6 +221,19 @@ class FunctionsTable:
except: except:
return None return None
def deactivate_all_functions(self) -> Optional[bool]:
try:
query = Function.update(
**{"is_active": False},
updated_at=int(time.time()),
)
query.execute()
return True
except:
return None
def delete_function_by_id(self, id: str) -> bool: def delete_function_by_id(self, id: str) -> bool:
try: try:
query = Function.delete().where((Function.id == id)) query = Function.delete().where((Function.id == id))

View File

@ -167,6 +167,12 @@ for version in soup.find_all("h2"):
CHANGELOG = changelog_json CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
#################################### ####################################
# WEBUI_BUILD_HASH # WEBUI_BUILD_HASH
#################################### ####################################

View File

@ -55,7 +55,6 @@ from apps.webui.models.functions import Functions
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from utils.utils import ( from utils.utils import (
get_admin_user, get_admin_user,
get_verified_user, get_verified_user,
@ -102,10 +101,16 @@ from config import (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
SAFE_MODE,
AppConfig, AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])