From 058eb765687a542e4ce542d5fc8aea921ef85035 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Fri, 10 May 2024 13:36:10 +0800 Subject: [PATCH] feat: save UI config changes to config.json --- backend/apps/audio/main.py | 31 ++- backend/apps/images/main.py | 105 ++++---- backend/apps/ollama/main.py | 56 +++-- backend/apps/openai/main.py | 47 ++-- backend/apps/rag/main.py | 237 ++++++++++-------- backend/apps/web/main.py | 8 +- backend/apps/web/routers/auths.py | 38 +-- backend/apps/web/routers/configs.py | 9 +- backend/apps/web/routers/users.py | 8 +- backend/config.py | 366 ++++++++++++++++++++-------- backend/main.py | 42 ++-- 11 files changed, 611 insertions(+), 336 deletions(-) diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 87732d7bc..c3dc6a2c4 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -45,6 +45,8 @@ from config import ( AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_MODEL, AUDIO_OPENAI_API_VOICE, + config_get, + config_set, ) log = logging.getLogger(__name__) @@ -83,10 +85,10 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, + "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), + "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), + "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), + "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), } @@ -97,17 +99,22 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - app.state.OPENAI_API_BASE_URL = form_data.url - app.state.OPENAI_API_KEY = form_data.key - app.state.OPENAI_API_MODEL = form_data.model - app.state.OPENAI_API_VOICE = form_data.speaker + config_set(app.state.OPENAI_API_BASE_URL, form_data.url) + config_set(app.state.OPENAI_API_KEY, form_data.key) + config_set(app.state.OPENAI_API_MODEL, form_data.model) + config_set(app.state.OPENAI_API_VOICE, form_data.speaker) + + app.state.OPENAI_API_BASE_URL.save() + app.state.OPENAI_API_KEY.save() + app.state.OPENAI_API_MODEL.save() + app.state.OPENAI_API_VOICE.save() return { "status": True, - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, + "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), + "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), + "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), + "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), } diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index f45cf0d12..8ebfb0446 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -42,6 +42,8 @@ from config import ( IMAGE_GENERATION_MODEL, IMAGE_SIZE, IMAGE_STEPS, + config_get, + config_set, ) @@ -79,7 +81,10 @@ app.state.IMAGE_STEPS = IMAGE_STEPS @app.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): - return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} + return { + "engine": config_get(app.state.ENGINE), + "enabled": config_get(app.state.ENABLED), + } class ConfigUpdateForm(BaseModel): @@ -89,9 +94,12 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.ENGINE = form_data.engine - app.state.ENABLED = form_data.enabled - return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} + config_set(app.state.ENGINE, form_data.engine) + config_set(app.state.ENABLED, form_data.enabled) + return { + "engine": config_get(app.state.ENGINE), + "enabled": config_get(app.state.ENABLED), + } class EngineUrlUpdateForm(BaseModel): @@ -102,8 +110,8 @@ class EngineUrlUpdateForm(BaseModel): @app.get("/url") async def get_engine_url(user=Depends(get_admin_user)): return { - "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, - "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, + "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), + "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), } @@ -113,29 +121,29 @@ async def update_engine_url( ): if form_data.AUTOMATIC1111_BASE_URL == None: - app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL + config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL)) else: url = form_data.AUTOMATIC1111_BASE_URL.strip("/") try: r = requests.head(url) - app.state.AUTOMATIC1111_BASE_URL = url + config_set(app.state.AUTOMATIC1111_BASE_URL, url) except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) if form_data.COMFYUI_BASE_URL == None: - app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL + config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL) else: url = form_data.COMFYUI_BASE_URL.strip("/") try: r = requests.head(url) - app.state.COMFYUI_BASE_URL = url + config_set(app.state.COMFYUI_BASE_URL, url) except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { - "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, - "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, + "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), + "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), "status": True, } @@ -148,8 +156,8 @@ class OpenAIConfigUpdateForm(BaseModel): @app.get("/openai/config") async def get_openai_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), + "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), } @@ -160,13 +168,13 @@ async def update_openai_config( if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - app.state.OPENAI_API_BASE_URL = form_data.url - app.state.OPENAI_API_KEY = form_data.key + config_set(app.state.OPENAI_API_BASE_URL, form_data.url) + config_set(app.state.OPENAI_API_KEY, form_data.key) return { "status": True, - "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), + "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), } @@ -176,7 +184,7 @@ class ImageSizeUpdateForm(BaseModel): @app.get("/size") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_SIZE": app.state.IMAGE_SIZE} + return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)} @app.post("/size/update") @@ -185,9 +193,9 @@ async def update_image_size( ): pattern = r"^\d+x\d+$" # Regular expression pattern if re.match(pattern, form_data.size): - app.state.IMAGE_SIZE = form_data.size + config_set(app.state.IMAGE_SIZE, form_data.size) return { - "IMAGE_SIZE": app.state.IMAGE_SIZE, + "IMAGE_SIZE": config_get(app.state.IMAGE_SIZE), "status": True, } else: @@ -203,7 +211,7 @@ class ImageStepsUpdateForm(BaseModel): @app.get("/steps") async def get_image_size(user=Depends(get_admin_user)): - return {"IMAGE_STEPS": app.state.IMAGE_STEPS} + return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)} @app.post("/steps/update") @@ -211,9 +219,9 @@ async def update_image_size( form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) ): if form_data.steps >= 0: - app.state.IMAGE_STEPS = form_data.steps + config_set(app.state.IMAGE_STEPS, form_data.steps) return { - "IMAGE_STEPS": app.state.IMAGE_STEPS, + "IMAGE_STEPS": config_get(app.state.IMAGE_STEPS), "status": True, } else: @@ -263,15 +271,25 @@ def get_models(user=Depends(get_current_user)): async def get_default_model(user=Depends(get_admin_user)): try: if app.state.ENGINE == "openai": - return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} + return { + "model": ( + config_get(app.state.MODEL) + if config_get(app.state.MODEL) + else "dall-e-2" + ) + } elif app.state.ENGINE == "comfyui": - return {"model": app.state.MODEL if app.state.MODEL else ""} + return { + "model": ( + config_get(app.state.MODEL) if config_get(app.state.MODEL) else "" + ) + } else: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() return {"model": options["sd_model_checkpoint"]} except Exception as e: - app.state.ENABLED = False + config_set(app.state.ENABLED, False) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -280,12 +298,9 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - if app.state.ENGINE == "openai": - app.state.MODEL = model - return app.state.MODEL - if app.state.ENGINE == "comfyui": - app.state.MODEL = model - return app.state.MODEL + if app.state.ENGINE in ["openai", "comfyui"]: + config_set(app.state.MODEL, model) + return config_get(app.state.MODEL) else: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() @@ -382,7 +397,7 @@ def generate_image( user=Depends(get_current_user), ): - width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x"))) r = None try: @@ -396,7 +411,11 @@ def generate_image( "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", "prompt": form_data.prompt, "n": form_data.n, - "size": form_data.size if form_data.size else app.state.IMAGE_SIZE, + "size": ( + form_data.size + if form_data.size + else config_get(app.state.IMAGE_SIZE) + ), "response_format": "b64_json", } @@ -430,19 +449,19 @@ def generate_image( "n": form_data.n, } - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + if config_get(app.state.IMAGE_STEPS) is not None: + data["steps"] = config_get(app.state.IMAGE_STEPS) - if form_data.negative_prompt != None: + if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt data = ImageGenerationPayload(**data) res = comfyui_generate_image( - app.state.MODEL, + config_get(app.state.MODEL), data, user.id, - app.state.COMFYUI_BASE_URL, + config_get(app.state.COMFYUI_BASE_URL), ) log.debug(f"res: {res}") @@ -469,10 +488,10 @@ def generate_image( "height": height, } - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + if config_get(app.state.IMAGE_STEPS) is not None: + data["steps"] = config_get(app.state.IMAGE_STEPS) - if form_data.negative_prompt != None: + if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt r = requests.post( diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 042d0336d..7dfadbb0c 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -46,6 +46,8 @@ from config import ( ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, + config_set, + config_get, ) from utils.misc import calculate_sha256 @@ -96,7 +98,7 @@ async def get_status(): @app.get("/urls") async def get_ollama_api_urls(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} + return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} class UrlUpdateForm(BaseModel): @@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel): @app.post("/urls/update") async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.OLLAMA_BASE_URLS = form_data.urls + config_set(app.state.OLLAMA_BASE_URLS, form_data.urls) log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") - return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} + return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} @app.get("/cancel/{request_id}") @@ -153,7 +155,9 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] + tasks = [ + fetch_url(f"{url}/api/tags") for url in config_get(app.state.OLLAMA_BASE_URLS) + ] responses = await asyncio.gather(*tasks) models = { @@ -179,14 +183,15 @@ async def get_ollama_tags( if user.role == "user": models["models"] = list( filter( - lambda model: model["name"] in app.state.MODEL_FILTER_LIST, + lambda model: model["name"] + in config_get(app.state.MODEL_FILTER_LIST), models["models"], ) ) return models return models else: - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -216,7 +221,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None): if url_idx == None: # returns lowest version - tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] + tasks = [ + fetch_url(f"{url}/api/version") + for url in config_get(app.state.OLLAMA_BASE_URLS) + ] responses = await asyncio.gather(*tasks) responses = list(filter(lambda x: x is not None, responses)) @@ -235,7 +243,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, ) else: - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/version") r.raise_for_status() @@ -267,7 +275,7 @@ class ModelNameForm(BaseModel): async def pull_model( form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) ): - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -355,7 +363,7 @@ async def push_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.debug(f"url: {url}") r = None @@ -417,7 +425,7 @@ async def create_model( form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) ): log.debug(f"form_data: {form_data}") - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -490,7 +498,7 @@ async def copy_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -537,7 +545,7 @@ async def delete_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -577,7 +585,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -634,7 +642,7 @@ async def generate_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -684,7 +692,7 @@ def generate_ollama_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") try: @@ -753,7 +761,7 @@ async def generate_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -856,7 +864,7 @@ async def generate_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -965,7 +973,7 @@ async def generate_openai_chat_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] log.info(f"url: {url}") r = None @@ -1064,7 +1072,7 @@ async def get_openai_models( } else: - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -1198,7 +1206,7 @@ async def download_model( if url_idx == None: url_idx = 0 - url = app.state.OLLAMA_BASE_URLS[url_idx] + url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] file_name = parse_huggingface_url(form_data.url) @@ -1217,7 +1225,7 @@ async def download_model( def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): if url_idx == None: url_idx = 0 - ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] + ollama_url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] file_path = f"{UPLOAD_DIR}/{file.filename}" @@ -1282,7 +1290,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): # if url_idx == None: # url_idx = 0 -# url = app.state.OLLAMA_BASE_URLS[url_idx] +# url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] # file_location = os.path.join(UPLOAD_DIR, file.filename) # total_size = file.size @@ -1319,7 +1327,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): async def deprecated_proxy( path: str, request: Request, user=Depends(get_verified_user) ): - url = app.state.OLLAMA_BASE_URLS[0] + url = config_get(app.state.OLLAMA_BASE_URLS)[0] target_url = f"{url}/{path}" body = await request.body() diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index b5d1e68d6..36fed104c 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -26,6 +26,8 @@ from config import ( CACHE_DIR, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, + config_set, + config_get, ) from typing import List, Optional @@ -75,32 +77,34 @@ class KeysUpdateForm(BaseModel): @app.get("/urls") async def get_openai_urls(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} + return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} @app.post("/urls/update") async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): await get_all_models() - app.state.OPENAI_API_BASE_URLS = form_data.urls - return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} + config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls) + return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} @app.get("/keys") async def get_openai_keys(user=Depends(get_admin_user)): - return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} + return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} @app.post("/keys/update") async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): - app.state.OPENAI_API_KEYS = form_data.keys - return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} + config_set(app.state.OPENAI_API_KEYS, form_data.keys) + return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} @app.post("/audio/speech") async def speech(request: Request, user=Depends(get_verified_user)): idx = None try: - idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + idx = config_get(app.state.OPENAI_API_BASE_URLS).index( + "https://api.openai.com/v1" + ) body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -114,13 +118,15 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}" + headers["Authorization"] = ( + f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}" + ) headers["Content-Type"] = "application/json" r = None try: r = requests.post( - url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech", + url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech", data=body, headers=headers, stream=True, @@ -180,7 +186,8 @@ def merge_models_lists(model_lists): [ {**model, "urlIdx": idx} for model in models - if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] + if "api.openai.com" + not in config_get(app.state.OPENAI_API_BASE_URLS)[idx] or "gpt" in model["id"] ] ) @@ -191,12 +198,15 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": + if ( + len(config_get(app.state.OPENAI_API_KEYS)) == 1 + and config_get(app.state.OPENAI_API_KEYS)[0] == "" + ): models = {"data": []} else: tasks = [ - fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) - for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) + fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx]) + for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS)) ] responses = await asyncio.gather(*tasks) @@ -228,18 +238,19 @@ async def get_all_models(): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): if url_idx == None: models = await get_all_models() - if app.state.ENABLE_MODEL_FILTER: + if config_get(app.state.ENABLE_MODEL_FILTER): if user.role == "user": models["data"] = list( filter( - lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + lambda model: model["id"] + in config_get(app.state.MODEL_FILTER_LIST), models["data"], ) ) return models return models else: - url = app.state.OPENAI_API_BASE_URLS[url_idx] + url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx] r = None @@ -303,8 +314,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): except json.JSONDecodeError as e: log.error("Error loading request body into a dictionary:", e) - url = app.state.OPENAI_API_BASE_URLS[idx] - key = app.state.OPENAI_API_KEYS[idx] + url = config_get(app.state.OPENAI_API_BASE_URLS)[idx] + key = config_get(app.state.OPENAI_API_KEYS)[idx] target_url = f"{url}/{path}" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 2e2a8e209..f05447a66 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -93,6 +93,8 @@ from config import ( RAG_TEMPLATE, ENABLE_RAG_LOCAL_WEB_FETCH, YOUTUBE_LOADER_LANGUAGE, + config_set, + config_get, ) from constants import ERROR_MESSAGES @@ -133,7 +135,7 @@ def update_embedding_model( embedding_model: str, update_model: bool = False, ): - if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": + if embedding_model and config_get(app.state.RAG_EMBEDDING_ENGINE) == "": app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( get_model_path(embedding_model, update_model), device=DEVICE_TYPE, @@ -158,22 +160,22 @@ def update_reranking_model( update_embedding_model( - app.state.RAG_EMBEDDING_MODEL, + config_get(app.state.RAG_EMBEDDING_MODEL), RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) update_reranking_model( - app.state.RAG_RERANKING_MODEL, + config_get(app.state.RAG_RERANKING_MODEL), RAG_RERANKING_MODEL_AUTO_UPDATE, ) app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, + config_get(app.state.RAG_EMBEDDING_ENGINE), + config_get(app.state.RAG_EMBEDDING_MODEL), app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, + config_get(app.state.OPENAI_API_KEY), + config_get(app.state.OPENAI_API_BASE_URL), ) origins = ["*"] @@ -200,12 +202,12 @@ class UrlForm(CollectionNameForm): async def get_status(): return { "status": True, - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, - "template": app.state.RAG_TEMPLATE, - "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, - "reranking_model": app.state.RAG_RERANKING_MODEL, + "chunk_size": config_get(app.state.CHUNK_SIZE), + "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), + "template": config_get(app.state.RAG_TEMPLATE), + "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), + "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), + "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), } @@ -213,18 +215,21 @@ async def get_status(): async def get_embedding_config(user=Depends(get_admin_user)): return { "status": True, - "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), + "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), "openai_config": { - "url": app.state.OPENAI_API_BASE_URL, - "key": app.state.OPENAI_API_KEY, + "url": config_get(app.state.OPENAI_API_BASE_URL), + "key": config_get(app.state.OPENAI_API_KEY), }, } @app.get("/reranking") async def get_reraanking_config(user=Depends(get_admin_user)): - return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL} + return { + "status": True, + "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), + } class OpenAIConfigForm(BaseModel): @@ -246,31 +251,31 @@ async def update_embedding_config( f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: - app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + config_set(app.state.RAG_EMBEDDING_ENGINE, form_data.embedding_engine) + config_set(app.state.RAG_EMBEDDING_MODEL, form_data.embedding_model) - if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if config_get(app.state.RAG_EMBEDDING_ENGINE) in ["ollama", "openai"]: if form_data.openai_config != None: - app.state.OPENAI_API_BASE_URL = form_data.openai_config.url - app.state.OPENAI_API_KEY = form_data.openai_config.key + config_set(app.state.OPENAI_API_BASE_URL, form_data.openai_config.url) + config_set(app.state.OPENAI_API_KEY, form_data.openai_config.key) - update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) + update_embedding_model(config_get(app.state.RAG_EMBEDDING_MODEL), True) app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, + config_get(app.state.RAG_EMBEDDING_ENGINE), + config_get(app.state.RAG_EMBEDDING_MODEL), app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, + config_get(app.state.OPENAI_API_KEY), + config_get(app.state.OPENAI_API_BASE_URL), ) return { "status": True, - "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE), + "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL), "openai_config": { - "url": app.state.OPENAI_API_BASE_URL, - "key": app.state.OPENAI_API_KEY, + "url": config_get(app.state.OPENAI_API_BASE_URL), + "key": config_get(app.state.OPENAI_API_KEY), }, } except Exception as e: @@ -293,13 +298,13 @@ async def update_reranking_config( f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: - app.state.RAG_RERANKING_MODEL = form_data.reranking_model + config_set(app.state.RAG_RERANKING_MODEL, form_data.reranking_model) - update_reranking_model(app.state.RAG_RERANKING_MODEL, True) + update_reranking_model(config_get(app.state.RAG_RERANKING_MODEL), True) return { "status": True, - "reranking_model": app.state.RAG_RERANKING_MODEL, + "reranking_model": config_get(app.state.RAG_RERANKING_MODEL), } except Exception as e: log.exception(f"Problem updating reranking model: {e}") @@ -313,14 +318,16 @@ async def update_reranking_config( async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, - "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES), "chunk": { - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "chunk_size": config_get(app.state.CHUNK_SIZE), + "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), }, - "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": config_get( + app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + ), "youtube": { - "language": app.state.YOUTUBE_LOADER_LANGUAGE, + "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE), "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -345,50 +352,69 @@ class ConfigUpdateForm(BaseModel): @app.post("/config/update") async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.PDF_EXTRACT_IMAGES = ( - form_data.pdf_extract_images - if form_data.pdf_extract_images != None - else app.state.PDF_EXTRACT_IMAGES + config_set( + app.state.PDF_EXTRACT_IMAGES, + ( + form_data.pdf_extract_images + if form_data.pdf_extract_images is not None + else config_get(app.state.PDF_EXTRACT_IMAGES) + ), ) - app.state.CHUNK_SIZE = ( - form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE + config_set( + app.state.CHUNK_SIZE, + ( + form_data.chunk.chunk_size + if form_data.chunk is not None + else config_get(app.state.CHUNK_SIZE) + ), ) - app.state.CHUNK_OVERLAP = ( - form_data.chunk.chunk_overlap - if form_data.chunk != None - else app.state.CHUNK_OVERLAP + config_set( + app.state.CHUNK_OVERLAP, + ( + form_data.chunk.chunk_overlap + if form_data.chunk is not None + else config_get(app.state.CHUNK_OVERLAP) + ), ) - app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - form_data.web_loader_ssl_verification - if form_data.web_loader_ssl_verification != None - else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + config_set( + app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + ( + form_data.web_loader_ssl_verification + if form_data.web_loader_ssl_verification != None + else config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION) + ), ) - app.state.YOUTUBE_LOADER_LANGUAGE = ( - form_data.youtube.language - if form_data.youtube != None - else app.state.YOUTUBE_LOADER_LANGUAGE + config_set( + app.state.YOUTUBE_LOADER_LANGUAGE, + ( + form_data.youtube.language + if form_data.youtube is not None + else config_get(app.state.YOUTUBE_LOADER_LANGUAGE) + ), ) app.state.YOUTUBE_LOADER_TRANSLATION = ( form_data.youtube.translation - if form_data.youtube != None + if form_data.youtube is not None else app.state.YOUTUBE_LOADER_TRANSLATION ) return { "status": True, - "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES), "chunk": { - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "chunk_size": config_get(app.state.CHUNK_SIZE), + "chunk_overlap": config_get(app.state.CHUNK_OVERLAP), }, - "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": config_get( + app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + ), "youtube": { - "language": app.state.YOUTUBE_LOADER_LANGUAGE, + "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE), "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, } @@ -398,7 +424,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ async def get_rag_template(user=Depends(get_current_user)): return { "status": True, - "template": app.state.RAG_TEMPLATE, + "template": config_get(app.state.RAG_TEMPLATE), } @@ -406,10 +432,10 @@ async def get_rag_template(user=Depends(get_current_user)): async def get_query_settings(user=Depends(get_admin_user)): return { "status": True, - "template": app.state.RAG_TEMPLATE, - "k": app.state.TOP_K, - "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, + "template": config_get(app.state.RAG_TEMPLATE), + "k": config_get(app.state.TOP_K), + "r": config_get(app.state.RELEVANCE_THRESHOLD), + "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH), } @@ -424,16 +450,22 @@ class QuerySettingsForm(BaseModel): async def update_query_settings( form_data: QuerySettingsForm, user=Depends(get_admin_user) ): - app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE - app.state.TOP_K = form_data.k if form_data.k else 4 - app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False + config_set( + app.state.RAG_TEMPLATE, + form_data.template if form_data.template else RAG_TEMPLATE, + ) + config_set(app.state.TOP_K, form_data.k if form_data.k else 4) + config_set(app.state.RELEVANCE_THRESHOLD, form_data.r if form_data.r else 0.0) + config_set( + app.state.ENABLE_RAG_HYBRID_SEARCH, + form_data.hybrid if form_data.hybrid else False, + ) return { "status": True, - "template": app.state.RAG_TEMPLATE, - "k": app.state.TOP_K, - "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, + "template": config_get(app.state.RAG_TEMPLATE), + "k": config_get(app.state.TOP_K), + "r": config_get(app.state.RELEVANCE_THRESHOLD), + "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH), } @@ -451,21 +483,25 @@ def query_doc_handler( user=Depends(get_current_user), ): try: - if app.state.ENABLE_RAG_HYBRID_SEARCH: + if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH): return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else config_get(app.state.TOP_K), reranking_function=app.state.sentence_transformer_rf, - r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, + r=( + form_data.r + if form_data.r + else config_get(app.state.RELEVANCE_THRESHOLD) + ), ) else: return query_doc( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else config_get(app.state.TOP_K), ) except Exception as e: log.exception(e) @@ -489,21 +525,25 @@ def query_collection_handler( user=Depends(get_current_user), ): try: - if app.state.ENABLE_RAG_HYBRID_SEARCH: + if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH): return query_collection_with_hybrid_search( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else config_get(app.state.TOP_K), reranking_function=app.state.sentence_transformer_rf, - r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, + r=( + form_data.r + if form_data.r + else config_get(app.state.RELEVANCE_THRESHOLD) + ), ) else: return query_collection( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.TOP_K, + k=form_data.k if form_data.k else config_get(app.state.TOP_K), ) except Exception as e: @@ -520,8 +560,8 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): loader = YoutubeLoader.from_youtube_url( form_data.url, add_video_info=True, - language=app.state.YOUTUBE_LOADER_LANGUAGE, - translation=app.state.YOUTUBE_LOADER_TRANSLATION, + language=config_get(app.state.YOUTUBE_LOADER_LANGUAGE), + translation=config_get(app.state.YOUTUBE_LOADER_TRANSLATION), ) data = loader.load() @@ -548,7 +588,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: loader = get_web_loader( - form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + form_data.url, + verify_ssl=config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION), ) data = loader.load() @@ -604,8 +645,8 @@ def resolve_hostname(hostname): def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.CHUNK_SIZE, - chunk_overlap=app.state.CHUNK_OVERLAP, + chunk_size=config_get(app.state.CHUNK_SIZE), + chunk_overlap=config_get(app.state.CHUNK_OVERLAP), add_start_index=True, ) @@ -622,8 +663,8 @@ def store_text_in_vector_db( text, metadata, collection_name, overwrite: bool = False ) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.CHUNK_SIZE, - chunk_overlap=app.state.CHUNK_OVERLAP, + chunk_size=config_get(app.state.CHUNK_SIZE), + chunk_overlap=config_get(app.state.CHUNK_OVERLAP), add_start_index=True, ) docs = text_splitter.create_documents([text], metadatas=[metadata]) @@ -646,11 +687,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b collection = CHROMA_CLIENT.create_collection(name=collection_name) embedding_func = get_embedding_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, + config_get(app.state.RAG_EMBEDDING_ENGINE), + config_get(app.state.RAG_EMBEDDING_MODEL), app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, + config_get(app.state.OPENAI_API_KEY), + config_get(app.state.OPENAI_API_BASE_URL), ) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) @@ -724,7 +765,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ] if file_ext == "pdf": - loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) + loader = PyPDFLoader( + file_path, extract_images=config_get(app.state.PDF_EXTRACT_IMAGES) + ) elif file_ext == "csv": loader = CSVLoader(file_path) elif file_ext == "rst": diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 66cdfb3d4..2bed33543 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -21,6 +21,8 @@ from config import ( USER_PERMISSIONS, WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + JWT_EXPIRES_IN, + config_get, ) app = FastAPI() @@ -28,7 +30,7 @@ app = FastAPI() origins = ["*"] app.state.ENABLE_SIGNUP = ENABLE_SIGNUP -app.state.JWT_EXPIRES_IN = "-1" +app.state.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.DEFAULT_MODELS = DEFAULT_MODELS app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS @@ -61,6 +63,6 @@ async def get_status(): return { "status": True, "auth": WEBUI_AUTH, - "default_models": app.state.DEFAULT_MODELS, - "default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS, + "default_models": config_get(app.state.DEFAULT_MODELS), + "default_prompt_suggestions": config_get(app.state.DEFAULT_PROMPT_SUGGESTIONS), } diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index 9fa962dda..0bc4967f9 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -33,7 +33,7 @@ from utils.utils import ( from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER +from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, config_get, config_set router = APIRouter() @@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), + expires_delta=parse_duration(config_get(request.app.state.JWT_EXPIRES_IN)), ) return { @@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm): @router.post("/signup", response_model=SigninResponse) async def signup(request: Request, form_data: SignupForm): - if not request.app.state.ENABLE_SIGNUP and WEBUI_AUTH: + if not config_get(request.app.state.ENABLE_SIGNUP) and WEBUI_AUTH: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) @@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm): role = ( "admin" if Users.get_num_users() == 0 - else request.app.state.DEFAULT_USER_ROLE + else config_get(request.app.state.DEFAULT_USER_ROLE) ) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( @@ -194,13 +194,15 @@ async def signup(request: Request, form_data: SignupForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), + expires_delta=parse_duration( + config_get(request.app.state.JWT_EXPIRES_IN) + ), ) # response.set_cookie(key='token', value=token, httponly=True) - if request.app.state.WEBHOOK_URL: + if config_get(request.app.state.WEBHOOK_URL): post_webhook( - request.app.state.WEBHOOK_URL, + config_get(request.app.state.WEBHOOK_URL), WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { "action": "signup", @@ -276,13 +278,15 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): @router.get("/signup/enabled", response_model=bool) async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): - return request.app.state.ENABLE_SIGNUP + return config_get(request.app.state.ENABLE_SIGNUP) @router.get("/signup/enabled/toggle", response_model=bool) async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): - request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP - return request.app.state.ENABLE_SIGNUP + config_set( + request.app.state.ENABLE_SIGNUP, not config_get(request.app.state.ENABLE_SIGNUP) + ) + return config_get(request.app.state.ENABLE_SIGNUP) ############################ @@ -292,7 +296,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): @router.get("/signup/user/role") async def get_default_user_role(request: Request, user=Depends(get_admin_user)): - return request.app.state.DEFAULT_USER_ROLE + return config_get(request.app.state.DEFAULT_USER_ROLE) class UpdateRoleForm(BaseModel): @@ -304,8 +308,8 @@ async def update_default_user_role( request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user) ): if form_data.role in ["pending", "user", "admin"]: - request.app.state.DEFAULT_USER_ROLE = form_data.role - return request.app.state.DEFAULT_USER_ROLE + config_set(request.app.state.DEFAULT_USER_ROLE, form_data.role) + return config_get(request.app.state.DEFAULT_USER_ROLE) ############################ @@ -315,7 +319,7 @@ async def update_default_user_role( @router.get("/token/expires") async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): - return request.app.state.JWT_EXPIRES_IN + return config_get(request.app.state.JWT_EXPIRES_IN) class UpdateJWTExpiresDurationForm(BaseModel): @@ -332,10 +336,10 @@ async def update_token_expires_duration( # Check if the input string matches the pattern if re.match(pattern, form_data.duration): - request.app.state.JWT_EXPIRES_IN = form_data.duration - return request.app.state.JWT_EXPIRES_IN + config_set(request.app.state.JWT_EXPIRES_IN, form_data.duration) + return config_get(request.app.state.JWT_EXPIRES_IN) else: - return request.app.state.JWT_EXPIRES_IN + return config_get(request.app.state.JWT_EXPIRES_IN) ############################ diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py index 0bad55a6a..d726cd2dc 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/web/routers/configs.py @@ -9,6 +9,7 @@ import time import uuid from apps.web.models.users import Users +from config import config_set, config_get from utils.utils import ( get_password_hash, @@ -44,8 +45,8 @@ class SetDefaultSuggestionsForm(BaseModel): async def set_global_default_models( request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) ): - request.app.state.DEFAULT_MODELS = form_data.models - return request.app.state.DEFAULT_MODELS + config_set(request.app.state.DEFAULT_MODELS, form_data.models) + return config_get(request.app.state.DEFAULT_MODELS) @router.post("/default/suggestions", response_model=List[PromptSuggestion]) @@ -55,5 +56,5 @@ async def set_global_default_suggestions( user=Depends(get_admin_user), ): data = form_data.model_dump() - request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] - return request.app.state.DEFAULT_PROMPT_SUGGESTIONS + config_set(request.app.state.DEFAULT_PROMPT_SUGGESTIONS, data["suggestions"]) + return config_get(request.app.state.DEFAULT_PROMPT_SUGGESTIONS) diff --git a/backend/apps/web/routers/users.py b/backend/apps/web/routers/users.py index 59f6c21b7..302432540 100644 --- a/backend/apps/web/routers/users.py +++ b/backend/apps/web/routers/users.py @@ -15,7 +15,7 @@ from apps.web.models.auths import Auths from utils.utils import get_current_user, get_password_hash, get_admin_user from constants import ERROR_MESSAGES -from config import SRC_LOG_LEVELS +from config import SRC_LOG_LEVELS, config_set, config_get log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) @router.get("/permissions/user") async def get_user_permissions(request: Request, user=Depends(get_admin_user)): - return request.app.state.USER_PERMISSIONS + return config_get(request.app.state.USER_PERMISSIONS) @router.post("/permissions/user") async def update_user_permissions( request: Request, form_data: dict, user=Depends(get_admin_user) ): - request.app.state.USER_PERMISSIONS = form_data - return request.app.state.USER_PERMISSIONS + config_set(request.app.state.USER_PERMISSIONS, form_data) + return config_get(request.app.state.USER_PERMISSIONS) ############################ diff --git a/backend/config.py b/backend/config.py index 5c6247a9f..028e6caf0 100644 --- a/backend/config.py +++ b/backend/config.py @@ -5,6 +5,7 @@ import chromadb from chromadb import Settings from base64 import b64encode from bs4 import BeautifulSoup +from typing import TypeVar, Generic, Union from pathlib import Path import json @@ -17,7 +18,6 @@ import shutil from secrets import token_bytes from constants import ERROR_MESSAGES - #################################### # Load .env file #################################### @@ -29,7 +29,6 @@ try: except ImportError: print("dotenv not installed, skipping...") - #################################### # LOGGING #################################### @@ -71,7 +70,6 @@ for source in log_sources: log.setLevel(SRC_LOG_LEVELS["CONFIG"]) - WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") if WEBUI_NAME != "Open WebUI": WEBUI_NAME += " (Open WebUI)" @@ -80,7 +78,6 @@ WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000") WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" - #################################### # ENV (dev,test,prod) #################################### @@ -151,26 +148,14 @@ for version in soup.find_all("h2"): changelog_json[version_number] = version_data - CHANGELOG = changelog_json - #################################### # WEBUI_VERSION #################################### WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100") -#################################### -# WEBUI_AUTH (Required for security) -#################################### - -WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" -WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( - "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None -) - - #################################### # DATA/FRONTEND BUILD DIR #################################### @@ -184,6 +169,93 @@ try: except: CONFIG_DATA = {} + +#################################### +# Config helpers +#################################### + + +def save_config(): + try: + with open(f"{DATA_DIR}/config.json", "w") as f: + json.dump(CONFIG_DATA, f, indent="\t") + except Exception as e: + log.exception(e) + + +def get_config_value(config_path: str): + path_parts = config_path.split(".") + cur_config = CONFIG_DATA + for key in path_parts: + if key in cur_config: + cur_config = cur_config[key] + else: + return None + return cur_config + + +T = TypeVar("T") + + +class WrappedConfig(Generic[T]): + def __init__(self, env_name: str, config_path: str, env_value: T): + self.env_name = env_name + self.config_path = config_path + self.env_value = env_value + self.config_value = get_config_value(config_path) + if self.config_value is not None: + log.info(f"'{env_name}' loaded from config.json") + self.value = self.config_value + else: + self.value = env_value + + def __str__(self): + return str(self.value) + + def save(self): + # Don't save if the value is the same as the env value and the config value + if self.env_value == self.value: + if self.config_value == self.value: + return + log.info(f"Saving '{self.env_name}' to config.json") + path_parts = self.config_path.split(".") + config = CONFIG_DATA + for key in path_parts[:-1]: + if key not in config: + config[key] = {} + config = config[key] + config[path_parts[-1]] = self.value + save_config() + self.config_value = self.value + + +def config_set(config: Union[WrappedConfig[T], T], value: T, save_config=True): + if isinstance(config, WrappedConfig): + config.value = value + if save_config: + config.save() + else: + config = value + + +def config_get(config: Union[WrappedConfig[T], T]) -> T: + if isinstance(config, WrappedConfig): + return config.value + return config + + +#################################### +# WEBUI_AUTH (Required for security) +#################################### + +WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" +WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( + "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None +) +JWT_EXPIRES_IN = WrappedConfig( + "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") +) + #################################### # Static DIR #################################### @@ -225,7 +297,6 @@ if CUSTOM_NAME: log.exception(e) pass - #################################### # File Upload DIR #################################### @@ -233,7 +304,6 @@ if CUSTOM_NAME: UPLOAD_DIR = f"{DATA_DIR}/uploads" Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) - #################################### # Cache DIR #################################### @@ -241,7 +311,6 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) CACHE_DIR = f"{DATA_DIR}/cache" Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) - #################################### # Docs DIR #################################### @@ -282,7 +351,6 @@ if not os.path.exists(LITELLM_CONFIG_PATH): create_config_file(LITELLM_CONFIG_PATH) log.info("Config file created successfully.") - #################################### # OLLAMA_BASE_URL #################################### @@ -313,12 +381,13 @@ if ENV == "prod": elif K8S_FLAG: OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434" - OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] - +OLLAMA_BASE_URLS = WrappedConfig( + "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS +) #################################### # OPENAI_API @@ -327,7 +396,6 @@ OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") - if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" @@ -335,7 +403,7 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] - +OPENAI_API_KEYS = WrappedConfig("OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS) OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") OPENAI_API_BASE_URLS = ( @@ -346,37 +414,42 @@ OPENAI_API_BASE_URLS = [ url.strip() if url != "" else "https://api.openai.com/v1" for url in OPENAI_API_BASE_URLS.split(";") ] +OPENAI_API_BASE_URLS = WrappedConfig( + "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS +) OPENAI_API_KEY = "" try: - OPENAI_API_KEY = OPENAI_API_KEYS[ - OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + OPENAI_API_KEY = OPENAI_API_KEYS.value[ + OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") ] except: pass OPENAI_API_BASE_URL = "https://api.openai.com/v1" - #################################### # WEBUI #################################### -ENABLE_SIGNUP = ( - False - if WEBUI_AUTH == False - else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" +ENABLE_SIGNUP = WrappedConfig( + "ENABLE_SIGNUP", + "ui.enable_signup", + ( + False + if not WEBUI_AUTH + else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" + ), +) +DEFAULT_MODELS = WrappedConfig( + "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) ) -DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None) - -DEFAULT_PROMPT_SUGGESTIONS = ( - CONFIG_DATA["ui"]["prompt_suggestions"] - if "ui" in CONFIG_DATA - and "prompt_suggestions" in CONFIG_DATA["ui"] - and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list - else [ +DEFAULT_PROMPT_SUGGESTIONS = WrappedConfig( + "DEFAULT_PROMPT_SUGGESTIONS", + "ui.prompt_suggestions", + [ { "title": ["Help me study", "vocabulary for a college entrance exam"], "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", @@ -404,23 +477,42 @@ DEFAULT_PROMPT_SUGGESTIONS = ( "title": ["Overcome procrastination", "give me tips"], "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", }, - ] + ], ) - -DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") - -USER_PERMISSIONS_CHAT_DELETION = ( - os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" +DEFAULT_USER_ROLE = WrappedConfig( + "DEFAULT_USER_ROLE", + "ui.default_user_role", + os.getenv("DEFAULT_USER_ROLE", "pending"), ) -USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} +USER_PERMISSIONS_CHAT_DELETION = WrappedConfig( + "USER_PERMISSIONS_CHAT_DELETION", + "ui.user_permissions.chat.deletion", + os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true", +) -ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true" +USER_PERMISSIONS = WrappedConfig( + "USER_PERMISSIONS", + "ui.user_permissions", + {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}, +) + +ENABLE_MODEL_FILTER = WrappedConfig( + "ENABLE_MODEL_FILTER", + "model_filter.enable", + os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", +) MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") -MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] +MODEL_FILTER_LIST = WrappedConfig( + "MODEL_FILTER_LIST", + "model_filter.list", + [model.strip() for model in MODEL_FILTER_LIST.split(";")], +) -WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") +WEBHOOK_URL = WrappedConfig( + "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") +) ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" @@ -458,26 +550,45 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) -RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) -RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) - -ENABLE_RAG_HYBRID_SEARCH = ( - os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" +RAG_TOP_K = WrappedConfig( + "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5")) +) +RAG_RELEVANCE_THRESHOLD = WrappedConfig( + "RAG_RELEVANCE_THRESHOLD", + "rag.relevance_threshold", + float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), ) - -ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true" +ENABLE_RAG_HYBRID_SEARCH = WrappedConfig( + "ENABLE_RAG_HYBRID_SEARCH", + "rag.enable_hybrid_search", + os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", ) -RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") - -PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true" - -RAG_EMBEDDING_MODEL = os.environ.get( - "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" +ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = WrappedConfig( + "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", + "rag.enable_web_loader_ssl_verification", + os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true", ) -log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), + +RAG_EMBEDDING_ENGINE = WrappedConfig( + "RAG_EMBEDDING_ENGINE", + "rag.embedding_engine", + os.environ.get("RAG_EMBEDDING_ENGINE", ""), +) + +PDF_EXTRACT_IMAGES = WrappedConfig( + "PDF_EXTRACT_IMAGES", + "rag.pdf_extract_images", + os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", +) + +RAG_EMBEDDING_MODEL = WrappedConfig( + "RAG_EMBEDDING_MODEL", + "rag.embedding_model", + os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), +) +log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"), RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" @@ -487,9 +598,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) -RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "") -if not RAG_RERANKING_MODEL == "": - log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), +RAG_RERANKING_MODEL = WrappedConfig( + "RAG_RERANKING_MODEL", + "rag.reranking_model", + os.environ.get("RAG_RERANKING_MODEL", ""), +) +if RAG_RERANKING_MODEL.value != "": + log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"), RAG_RERANKING_MODEL_AUTO_UPDATE = ( os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" @@ -499,7 +614,6 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) - if CHROMA_HTTP_HOST != "": CHROMA_CLIENT = chromadb.HttpClient( host=CHROMA_HTTP_HOST, @@ -518,7 +632,6 @@ else: database=CHROMA_DATABASE, ) - # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") @@ -527,9 +640,14 @@ if USE_CUDA.lower() == "true": else: DEVICE_TYPE = "cpu" - -CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500")) -CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100")) +CHUNK_SIZE = WrappedConfig( + "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500")) +) +CHUNK_OVERLAP = WrappedConfig( + "CHUNK_OVERLAP", + "rag.chunk_overlap", + int(os.environ.get("CHUNK_OVERLAP", "100")), +) DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags. @@ -545,16 +663,32 @@ And answer according to the language of the user's question. Given the context information, answer the query. Query: [query]""" -RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) +RAG_TEMPLATE = WrappedConfig( + "RAG_TEMPLATE", + "rag.template", + os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE), +) -RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) -RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) +RAG_OPENAI_API_BASE_URL = WrappedConfig( + "RAG_OPENAI_API_BASE_URL", + "rag.openai_api_base_url", + os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +RAG_OPENAI_API_KEY = WrappedConfig( + "RAG_OPENAI_API_KEY", + "rag.openai_api_key", + os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), +) ENABLE_RAG_LOCAL_WEB_FETCH = ( os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" ) -YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",") +YOUTUBE_LOADER_LANGUAGE = WrappedConfig( + "YOUTUBE_LOADER_LANGUAGE", + "rag.youtube_loader_language", + os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), +) #################################### # Transcribe @@ -566,39 +700,82 @@ WHISPER_MODEL_AUTO_UPDATE = ( os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" ) - #################################### # Images #################################### -IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "") - -ENABLE_IMAGE_GENERATION = ( - os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true" +IMAGE_GENERATION_ENGINE = WrappedConfig( + "IMAGE_GENERATION_ENGINE", + "image_generation.engine", + os.getenv("IMAGE_GENERATION_ENGINE", ""), ) -AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") -COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") - -IMAGES_OPENAI_API_BASE_URL = os.getenv( - "IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL +ENABLE_IMAGE_GENERATION = WrappedConfig( + "ENABLE_IMAGE_GENERATION", + "image_generation.enable", + os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", +) +AUTOMATIC1111_BASE_URL = WrappedConfig( + "AUTOMATIC1111_BASE_URL", + "image_generation.automatic1111.base_url", + os.getenv("AUTOMATIC1111_BASE_URL", ""), ) -IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY) -IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512") +COMFYUI_BASE_URL = WrappedConfig( + "COMFYUI_BASE_URL", + "image_generation.comfyui.base_url", + os.getenv("COMFYUI_BASE_URL", ""), +) -IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50)) +IMAGES_OPENAI_API_BASE_URL = WrappedConfig( + "IMAGES_OPENAI_API_BASE_URL", + "image_generation.openai.api_base_url", + os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +IMAGES_OPENAI_API_KEY = WrappedConfig( + "IMAGES_OPENAI_API_KEY", + "image_generation.openai.api_key", + os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), +) -IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "") +IMAGE_SIZE = WrappedConfig( + "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") +) + +IMAGE_STEPS = WrappedConfig( + "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) +) + +IMAGE_GENERATION_MODEL = WrappedConfig( + "IMAGE_GENERATION_MODEL", + "image_generation.model", + os.getenv("IMAGE_GENERATION_MODEL", ""), +) #################################### # Audio #################################### -AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) -AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY) -AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1") -AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy") +AUDIO_OPENAI_API_BASE_URL = WrappedConfig( + "AUDIO_OPENAI_API_BASE_URL", + "audio.openai.api_base_url", + os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +AUDIO_OPENAI_API_KEY = WrappedConfig( + "AUDIO_OPENAI_API_KEY", + "audio.openai.api_key", + os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY), +) +AUDIO_OPENAI_API_MODEL = WrappedConfig( + "AUDIO_OPENAI_API_MODEL", + "audio.openai.api_model", + os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"), +) +AUDIO_OPENAI_API_VOICE = WrappedConfig( + "AUDIO_OPENAI_API_VOICE", + "audio.openai.api_voice", + os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"), +) #################################### # LiteLLM @@ -612,7 +789,6 @@ if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: raise ValueError("Invalid port number for LITELLM_PROXY_PORT") LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") - #################################### # Database #################################### diff --git a/backend/main.py b/backend/main.py index 139819f7c..6f94a8dad 100644 --- a/backend/main.py +++ b/backend/main.py @@ -58,6 +58,8 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, + config_get, + config_set, ) from constants import ERROR_MESSAGES @@ -243,9 +245,11 @@ async def get_app_config(): "version": VERSION, "auth": WEBUI_AUTH, "default_locale": default_locale, - "images": images_app.state.ENABLED, - "default_models": webui_app.state.DEFAULT_MODELS, - "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, + "images": config_get(images_app.state.ENABLED), + "default_models": config_get(webui_app.state.DEFAULT_MODELS), + "default_prompt_suggestions": config_get( + webui_app.state.DEFAULT_PROMPT_SUGGESTIONS + ), "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "admin_export_enabled": ENABLE_ADMIN_EXPORT, } @@ -254,8 +258,8 @@ async def get_app_config(): @app.get("/api/config/model/filter") async def get_model_filter_config(user=Depends(get_admin_user)): return { - "enabled": app.state.ENABLE_MODEL_FILTER, - "models": app.state.MODEL_FILTER_LIST, + "enabled": config_get(app.state.ENABLE_MODEL_FILTER), + "models": config_get(app.state.MODEL_FILTER_LIST), } @@ -268,28 +272,28 @@ class ModelFilterConfigForm(BaseModel): async def update_model_filter_config( form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): - app.state.ENABLE_MODEL_FILTER = form_data.enabled - app.state.MODEL_FILTER_LIST = form_data.models + config_set(app.state.ENABLE_MODEL_FILTER, form_data.enabled) + config_set(app.state.MODEL_FILTER_LIST, form_data.models) - ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER - ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + ollama_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) + ollama_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) - openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER - openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + openai_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) + openai_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) - litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER - litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + litellm_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) + litellm_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) return { - "enabled": app.state.ENABLE_MODEL_FILTER, - "models": app.state.MODEL_FILTER_LIST, + "enabled": config_get(app.state.ENABLE_MODEL_FILTER), + "models": config_get(app.state.MODEL_FILTER_LIST), } @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { - "url": app.state.WEBHOOK_URL, + "url": config_get(app.state.WEBHOOK_URL), } @@ -299,12 +303,12 @@ class UrlForm(BaseModel): @app.post("/api/webhook") async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): - app.state.WEBHOOK_URL = form_data.url + config_set(app.state.WEBHOOK_URL, form_data.url) - webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL + webui_app.state.WEBHOOK_URL = config_get(app.state.WEBHOOK_URL) return { - "url": app.state.WEBHOOK_URL, + "url": config_get(app.state.WEBHOOK_URL), }